#!/usr/bin/env python3 """ 8-Class Pose Classifier Training ================================ Train a classifier for animal pose relative to camera. Classes: front, front-left, front-right, left, right, back-left, back-right, back Usage: python train_pose_classifier.py --data_dir ./pose_labels --epochs 30 python train_pose_classifier.py --train_csv train.csv --val_csv val.csv --epochs 30 """ import argparse import os from pathlib import Path import numpy as np from PIL import Image from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from torchvision import transforms import pandas as pd # ============================================================ # Configuration # ============================================================ POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back'] CLASS_TO_IDX = {c: i for i, c in enumerate(POSE_CLASSES)} IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()} NUM_CLASSES = len(POSE_CLASSES) # Horizontal flip swaps these pairs FLIP_PAIRS = { 'front-left': 'front-right', 'front-right': 'front-left', 'left': 'right', 'right': 'left', 'back-left': 'back-right', 'back-right': 'back-left', 'front': 'front', 'back': 'back', } # DINOv2 model sizes DINO_MODELS = { 'small': ('dinov2_vits14', 384), 'base': ('dinov2_vitb14', 768), 'large': ('dinov2_vitl14', 1024), } # ============================================================ # Dataset # ============================================================ class PoseDataset(Dataset): """Dataset that supports both folder structure and CSV""" def __init__(self, data_source, transform=None, augment_flip=True): """ Args: data_source: Either a directory path (folder structure) or CSV path transform: Image transforms augment_flip: Whether to apply horizontal flip with label swap """ self.transform = transform self.augment_flip = augment_flip self.samples = [] data_path = Path(data_source) if data_path.is_dir(): # Load from folder structure for cls in POSE_CLASSES: cls_dir = data_path / cls if cls_dir.exists(): for img_path in cls_dir.glob('*'): if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']: self.samples.append((str(img_path), cls)) else: # Load from CSV df = pd.read_csv(data_path) img_col = 'image_path' if 'image_path' in df.columns else df.columns[0] label_col = 'pose' if 'pose' in df.columns else df.columns[1] for _, row in df.iterrows(): if row[label_col] in POSE_CLASSES: self.samples.append((row[img_col], row[label_col])) print(f"Loaded {len(self.samples)} samples") self._print_distribution() def _print_distribution(self): from collections import Counter counts = Counter(s[1] for s in self.samples) print("Class distribution:") for cls in POSE_CLASSES: print(f" {cls}: {counts.get(cls, 0)}") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] image = Image.open(img_path).convert('RGB') # Horizontal flip augmentation with label swap do_flip = self.augment_flip and torch.rand(1) < 0.5 if do_flip: image = transforms.functional.hflip(image) label = FLIP_PAIRS[label] if self.transform: image = self.transform(image) return image, CLASS_TO_IDX[label] def get_sample_weights(self): """Weights for balanced sampling""" from collections import Counter counts = Counter(s[1] for s in self.samples) weights = [1.0 / counts[s[1]] for s in self.samples] return torch.DoubleTensor(weights) # ============================================================ # Model # ============================================================ class PoseClassifier(nn.Module): """DINOv2 + MLP head for 8-class pose classification""" def __init__(self, model_size='small', dropout=0.3): super().__init__() model_name, feat_dim = DINO_MODELS[model_size] # Load frozen DINOv2 backbone self.backbone = torch.hub.load('facebookresearch/dinov2', model_name) for param in self.backbone.parameters(): param.requires_grad = False self.backbone.eval() # Trainable MLP head self.head = nn.Sequential( nn.LayerNorm(feat_dim), nn.Linear(feat_dim, 256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, NUM_CLASSES) ) def forward(self, x): with torch.no_grad(): features = self.backbone(x) return self.head(features) def predict_proba(self, x): logits = self.forward(x) return F.softmax(logits, dim=-1) # ============================================================ # Training # ============================================================ def get_transforms(train=True): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) if train: return transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), transforms.RandomRotation(15), transforms.ToTensor(), normalize, ]) else: return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None): model.train() model.backbone.eval() # Keep backbone frozen total_loss = 0 correct = 0 total = 0 pbar = tqdm(dataloader, desc='Training') for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() if scaler: with torch.cuda.amp.autocast(): outputs = model(images) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'}) return total_loss / len(dataloader), correct / total @torch.no_grad() def evaluate(model, dataloader, criterion, device): model.eval() total_loss = 0 correct = 0 total = 0 all_preds, all_labels = [], [] for images, labels in tqdm(dataloader, desc='Evaluating'): images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) return total_loss / len(dataloader), correct / total, all_preds, all_labels def print_confusion_matrix(preds, labels): """Print confusion matrix""" from collections import defaultdict matrix = defaultdict(lambda: defaultdict(int)) for p, l in zip(preds, labels): matrix[IDX_TO_CLASS[l]][IDX_TO_CLASS[p]] += 1 print("\nConfusion Matrix (rows=true, cols=pred):") # Header header = f"{'':>12}" + "".join(f"{c[:6]:>8}" for c in POSE_CLASSES) print(header) for true_class in POSE_CLASSES: row = f"{true_class:>12}" for pred_class in POSE_CLASSES: count = matrix[true_class][pred_class] row += f"{count:>8}" print(row) # Per-class accuracy print("\nPer-class accuracy:") for cls in POSE_CLASSES: correct = matrix[cls][cls] total = sum(matrix[cls].values()) acc = correct / total * 100 if total > 0 else 0 print(f" {cls:>12}: {acc:5.1f}% ({correct}/{total})") def export_onnx(model, output_path, device='cpu'): """Export to ONNX""" model.eval() model.to(device) dummy = torch.randn(1, 3, 224, 224).to(device) torch.onnx.export( model, dummy, output_path, export_params=True, opset_version=14, input_names=['image'], output_names=['logits'], dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}} ) print(f"Exported to {output_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, help='Directory with class folders') parser.add_argument('--train_csv', type=str, help='Training CSV') parser.add_argument('--val_csv', type=str, help='Validation CSV') parser.add_argument('--model_size', type=str, default='small', choices=['small', 'base', 'large']) parser.add_argument('--epochs', type=int, default=30) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--output_dir', type=str, default='./checkpoints') parser.add_argument('--export_onnx', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") os.makedirs(args.output_dir, exist_ok=True) # Load data train_transform = get_transforms(train=True) val_transform = get_transforms(train=False) if args.train_csv: train_dataset = PoseDataset(args.train_csv, train_transform, augment_flip=True) val_dataset = PoseDataset(args.val_csv, val_transform, augment_flip=False) if args.val_csv else None elif args.data_dir: full_dataset = PoseDataset(args.data_dir, train_transform, augment_flip=True) # Split 80/20 n_val = int(0.2 * len(full_dataset)) n_train = len(full_dataset) - n_val train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val]) # Wrap val with no augmentation val_dataset.dataset.augment_flip = False val_dataset.dataset.transform = val_transform else: print("Provide --data_dir or --train_csv") return # Weighted sampler for class balance if hasattr(train_dataset, 'get_sample_weights'): weights = train_dataset.get_sample_weights() sampler = WeightedRandomSampler(weights, len(weights)) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4) else: train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) if val_dataset else None # Model print(f"\nLoading DINOv2-{args.model_size}...") model = PoseClassifier(model_size=args.model_size).to(device) trainable = sum(p.numel() for p in model.head.parameters()) print(f"Trainable parameters: {trainable:,}") # Training criterion = nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.head.parameters(), lr=args.lr, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None best_acc = 0 for epoch in range(args.epochs): print(f"\nEpoch {epoch+1}/{args.epochs}") train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler) if val_loader: val_loss, val_acc, preds, labels = evaluate(model, val_loader, criterion, device) print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%") print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.1f}%") if val_acc > best_acc: best_acc = val_acc torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'head_state_dict': model.head.state_dict(), 'val_acc': val_acc, 'classes': POSE_CLASSES, }, f'{args.output_dir}/best_pose_model.pth') print(f" → Saved (acc: {val_acc*100:.1f}%)") else: print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%") scheduler.step() # Final evaluation if val_loader: print("\n" + "="*60) print("Final Evaluation") print("="*60) ckpt = torch.load(f'{args.output_dir}/best_pose_model.pth') model.load_state_dict(ckpt['model_state_dict']) _, acc, preds, labels = evaluate(model, val_loader, criterion, device) print(f"Best Accuracy: {acc*100:.1f}%") print_confusion_matrix(preds, labels) # Export if args.export_onnx: export_onnx(model, f'{args.output_dir}/pose_classifier.onnx') print("\nDone!") if __name__ == '__main__': main()