mmla-dino-pose / train_pose_classifier.py
jennamk14's picture
Add README, training/inference code, and trained DINOv2-small pose-classifier checkpoint (#1)
58b3e34
#!/usr/bin/env python3
"""
8-Class Pose Classifier Training
================================
Train a classifier for animal pose relative to camera.
Classes:
front, front-left, front-right, left, right, back-left, back-right, back
Usage:
python train_pose_classifier.py --data_dir ./pose_labels --epochs 30
python train_pose_classifier.py --train_csv train.csv --val_csv val.csv --epochs 30
"""
import argparse
import os
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import pandas as pd
# ============================================================
# Configuration
# ============================================================
POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back']
CLASS_TO_IDX = {c: i for i, c in enumerate(POSE_CLASSES)}
IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()}
NUM_CLASSES = len(POSE_CLASSES)
# Horizontal flip swaps these pairs
FLIP_PAIRS = {
'front-left': 'front-right',
'front-right': 'front-left',
'left': 'right',
'right': 'left',
'back-left': 'back-right',
'back-right': 'back-left',
'front': 'front',
'back': 'back',
}
# DINOv2 model sizes
DINO_MODELS = {
'small': ('dinov2_vits14', 384),
'base': ('dinov2_vitb14', 768),
'large': ('dinov2_vitl14', 1024),
}
# ============================================================
# Dataset
# ============================================================
class PoseDataset(Dataset):
"""Dataset that supports both folder structure and CSV"""
def __init__(self, data_source, transform=None, augment_flip=True):
"""
Args:
data_source: Either a directory path (folder structure) or CSV path
transform: Image transforms
augment_flip: Whether to apply horizontal flip with label swap
"""
self.transform = transform
self.augment_flip = augment_flip
self.samples = []
data_path = Path(data_source)
if data_path.is_dir():
# Load from folder structure
for cls in POSE_CLASSES:
cls_dir = data_path / cls
if cls_dir.exists():
for img_path in cls_dir.glob('*'):
if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
self.samples.append((str(img_path), cls))
else:
# Load from CSV
df = pd.read_csv(data_path)
img_col = 'image_path' if 'image_path' in df.columns else df.columns[0]
label_col = 'pose' if 'pose' in df.columns else df.columns[1]
for _, row in df.iterrows():
if row[label_col] in POSE_CLASSES:
self.samples.append((row[img_col], row[label_col]))
print(f"Loaded {len(self.samples)} samples")
self._print_distribution()
def _print_distribution(self):
from collections import Counter
counts = Counter(s[1] for s in self.samples)
print("Class distribution:")
for cls in POSE_CLASSES:
print(f" {cls}: {counts.get(cls, 0)}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
# Horizontal flip augmentation with label swap
do_flip = self.augment_flip and torch.rand(1) < 0.5
if do_flip:
image = transforms.functional.hflip(image)
label = FLIP_PAIRS[label]
if self.transform:
image = self.transform(image)
return image, CLASS_TO_IDX[label]
def get_sample_weights(self):
"""Weights for balanced sampling"""
from collections import Counter
counts = Counter(s[1] for s in self.samples)
weights = [1.0 / counts[s[1]] for s in self.samples]
return torch.DoubleTensor(weights)
# ============================================================
# Model
# ============================================================
class PoseClassifier(nn.Module):
"""DINOv2 + MLP head for 8-class pose classification"""
def __init__(self, model_size='small', dropout=0.3):
super().__init__()
model_name, feat_dim = DINO_MODELS[model_size]
# Load frozen DINOv2 backbone
self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
for param in self.backbone.parameters():
param.requires_grad = False
self.backbone.eval()
# Trainable MLP head
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Linear(feat_dim, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, NUM_CLASSES)
)
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
return self.head(features)
def predict_proba(self, x):
logits = self.forward(x)
return F.softmax(logits, dim=-1)
# ============================================================
# Training
# ============================================================
def get_transforms(train=True):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
if train:
return transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
transforms.RandomRotation(15),
transforms.ToTensor(),
normalize,
])
else:
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None):
model.train()
model.backbone.eval() # Keep backbone frozen
total_loss = 0
correct = 0
total = 0
pbar = tqdm(dataloader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
if scaler:
with torch.cuda.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'})
return total_loss / len(dataloader), correct / total
@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
all_preds, all_labels = [], []
for images, labels in tqdm(dataloader, desc='Evaluating'):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return total_loss / len(dataloader), correct / total, all_preds, all_labels
def print_confusion_matrix(preds, labels):
"""Print confusion matrix"""
from collections import defaultdict
matrix = defaultdict(lambda: defaultdict(int))
for p, l in zip(preds, labels):
matrix[IDX_TO_CLASS[l]][IDX_TO_CLASS[p]] += 1
print("\nConfusion Matrix (rows=true, cols=pred):")
# Header
header = f"{'':>12}" + "".join(f"{c[:6]:>8}" for c in POSE_CLASSES)
print(header)
for true_class in POSE_CLASSES:
row = f"{true_class:>12}"
for pred_class in POSE_CLASSES:
count = matrix[true_class][pred_class]
row += f"{count:>8}"
print(row)
# Per-class accuracy
print("\nPer-class accuracy:")
for cls in POSE_CLASSES:
correct = matrix[cls][cls]
total = sum(matrix[cls].values())
acc = correct / total * 100 if total > 0 else 0
print(f" {cls:>12}: {acc:5.1f}% ({correct}/{total})")
def export_onnx(model, output_path, device='cpu'):
"""Export to ONNX"""
model.eval()
model.to(device)
dummy = torch.randn(1, 3, 224, 224).to(device)
torch.onnx.export(
model, dummy, output_path,
export_params=True,
opset_version=14,
input_names=['image'],
output_names=['logits'],
dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}}
)
print(f"Exported to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, help='Directory with class folders')
parser.add_argument('--train_csv', type=str, help='Training CSV')
parser.add_argument('--val_csv', type=str, help='Validation CSV')
parser.add_argument('--model_size', type=str, default='small', choices=['small', 'base', 'large'])
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--output_dir', type=str, default='./checkpoints')
parser.add_argument('--export_onnx', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
os.makedirs(args.output_dir, exist_ok=True)
# Load data
train_transform = get_transforms(train=True)
val_transform = get_transforms(train=False)
if args.train_csv:
train_dataset = PoseDataset(args.train_csv, train_transform, augment_flip=True)
val_dataset = PoseDataset(args.val_csv, val_transform, augment_flip=False) if args.val_csv else None
elif args.data_dir:
full_dataset = PoseDataset(args.data_dir, train_transform, augment_flip=True)
# Split 80/20
n_val = int(0.2 * len(full_dataset))
n_train = len(full_dataset) - n_val
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val])
# Wrap val with no augmentation
val_dataset.dataset.augment_flip = False
val_dataset.dataset.transform = val_transform
else:
print("Provide --data_dir or --train_csv")
return
# Weighted sampler for class balance
if hasattr(train_dataset, 'get_sample_weights'):
weights = train_dataset.get_sample_weights()
sampler = WeightedRandomSampler(weights, len(weights))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4)
else:
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) if val_dataset else None
# Model
print(f"\nLoading DINOv2-{args.model_size}...")
model = PoseClassifier(model_size=args.model_size).to(device)
trainable = sum(p.numel() for p in model.head.parameters())
print(f"Trainable parameters: {trainable:,}")
# Training
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.head.parameters(), lr=args.lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
best_acc = 0
for epoch in range(args.epochs):
print(f"\nEpoch {epoch+1}/{args.epochs}")
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
if val_loader:
val_loss, val_acc, preds, labels = evaluate(model, val_loader, criterion, device)
print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.1f}%")
if val_acc > best_acc:
best_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'head_state_dict': model.head.state_dict(),
'val_acc': val_acc,
'classes': POSE_CLASSES,
}, f'{args.output_dir}/best_pose_model.pth')
print(f" → Saved (acc: {val_acc*100:.1f}%)")
else:
print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%")
scheduler.step()
# Final evaluation
if val_loader:
print("\n" + "="*60)
print("Final Evaluation")
print("="*60)
ckpt = torch.load(f'{args.output_dir}/best_pose_model.pth')
model.load_state_dict(ckpt['model_state_dict'])
_, acc, preds, labels = evaluate(model, val_loader, criterion, device)
print(f"Best Accuracy: {acc*100:.1f}%")
print_confusion_matrix(preds, labels)
# Export
if args.export_onnx:
export_onnx(model, f'{args.output_dir}/pose_classifier.onnx')
print("\nDone!")
if __name__ == '__main__':
main()