| |
| """ |
| 8-Class Pose Classifier Training |
| ================================ |
| Train a classifier for animal pose relative to camera. |
| |
| Classes: |
| front, front-left, front-right, left, right, back-left, back-right, back |
| |
| Usage: |
| python train_pose_classifier.py --data_dir ./pose_labels --epochs 30 |
| python train_pose_classifier.py --train_csv train.csv --val_csv val.csv --epochs 30 |
| """ |
|
|
| import argparse |
| import os |
| from pathlib import Path |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
| from torchvision import transforms |
| import pandas as pd |
|
|
| |
| |
| |
|
|
| POSE_CLASSES = ['front', 'front-left', 'front-right', 'left', 'right', 'back-left', 'back-right', 'back'] |
| CLASS_TO_IDX = {c: i for i, c in enumerate(POSE_CLASSES)} |
| IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()} |
| NUM_CLASSES = len(POSE_CLASSES) |
|
|
| |
| FLIP_PAIRS = { |
| 'front-left': 'front-right', |
| 'front-right': 'front-left', |
| 'left': 'right', |
| 'right': 'left', |
| 'back-left': 'back-right', |
| 'back-right': 'back-left', |
| 'front': 'front', |
| 'back': 'back', |
| } |
|
|
| |
| DINO_MODELS = { |
| 'small': ('dinov2_vits14', 384), |
| 'base': ('dinov2_vitb14', 768), |
| 'large': ('dinov2_vitl14', 1024), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class PoseDataset(Dataset): |
| """Dataset that supports both folder structure and CSV""" |
| |
| def __init__(self, data_source, transform=None, augment_flip=True): |
| """ |
| Args: |
| data_source: Either a directory path (folder structure) or CSV path |
| transform: Image transforms |
| augment_flip: Whether to apply horizontal flip with label swap |
| """ |
| self.transform = transform |
| self.augment_flip = augment_flip |
| self.samples = [] |
| |
| data_path = Path(data_source) |
| |
| if data_path.is_dir(): |
| |
| for cls in POSE_CLASSES: |
| cls_dir = data_path / cls |
| if cls_dir.exists(): |
| for img_path in cls_dir.glob('*'): |
| if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']: |
| self.samples.append((str(img_path), cls)) |
| else: |
| |
| df = pd.read_csv(data_path) |
| img_col = 'image_path' if 'image_path' in df.columns else df.columns[0] |
| label_col = 'pose' if 'pose' in df.columns else df.columns[1] |
| |
| for _, row in df.iterrows(): |
| if row[label_col] in POSE_CLASSES: |
| self.samples.append((row[img_col], row[label_col])) |
| |
| print(f"Loaded {len(self.samples)} samples") |
| self._print_distribution() |
| |
| def _print_distribution(self): |
| from collections import Counter |
| counts = Counter(s[1] for s in self.samples) |
| print("Class distribution:") |
| for cls in POSE_CLASSES: |
| print(f" {cls}: {counts.get(cls, 0)}") |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| img_path, label = self.samples[idx] |
| image = Image.open(img_path).convert('RGB') |
| |
| |
| do_flip = self.augment_flip and torch.rand(1) < 0.5 |
| if do_flip: |
| image = transforms.functional.hflip(image) |
| label = FLIP_PAIRS[label] |
| |
| if self.transform: |
| image = self.transform(image) |
| |
| return image, CLASS_TO_IDX[label] |
| |
| def get_sample_weights(self): |
| """Weights for balanced sampling""" |
| from collections import Counter |
| counts = Counter(s[1] for s in self.samples) |
| weights = [1.0 / counts[s[1]] for s in self.samples] |
| return torch.DoubleTensor(weights) |
|
|
|
|
| |
| |
| |
|
|
| class PoseClassifier(nn.Module): |
| """DINOv2 + MLP head for 8-class pose classification""" |
| |
| def __init__(self, model_size='small', dropout=0.3): |
| super().__init__() |
| |
| model_name, feat_dim = DINO_MODELS[model_size] |
| |
| |
| self.backbone = torch.hub.load('facebookresearch/dinov2', model_name) |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
| self.backbone.eval() |
| |
| |
| self.head = nn.Sequential( |
| nn.LayerNorm(feat_dim), |
| nn.Linear(feat_dim, 256), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(256, 128), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(128, NUM_CLASSES) |
| ) |
| |
| def forward(self, x): |
| with torch.no_grad(): |
| features = self.backbone(x) |
| return self.head(features) |
| |
| def predict_proba(self, x): |
| logits = self.forward(x) |
| return F.softmax(logits, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def get_transforms(train=True): |
| normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| |
| if train: |
| return transforms.Compose([ |
| transforms.Resize(256), |
| transforms.RandomCrop(224), |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), |
| transforms.RandomRotation(15), |
| transforms.ToTensor(), |
| normalize, |
| ]) |
| else: |
| return transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| normalize, |
| ]) |
|
|
|
|
| def train_epoch(model, dataloader, optimizer, criterion, device, scaler=None): |
| model.train() |
| model.backbone.eval() |
| |
| total_loss = 0 |
| correct = 0 |
| total = 0 |
| |
| pbar = tqdm(dataloader, desc='Training') |
| for images, labels in pbar: |
| images, labels = images.to(device), labels.to(device) |
| |
| optimizer.zero_grad() |
| |
| if scaler: |
| with torch.cuda.amp.autocast(): |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| _, predicted = outputs.max(1) |
| total += labels.size(0) |
| correct += predicted.eq(labels).sum().item() |
| |
| pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'}) |
| |
| return total_loss / len(dataloader), correct / total |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, dataloader, criterion, device): |
| model.eval() |
| |
| total_loss = 0 |
| correct = 0 |
| total = 0 |
| all_preds, all_labels = [], [] |
| |
| for images, labels in tqdm(dataloader, desc='Evaluating'): |
| images, labels = images.to(device), labels.to(device) |
| |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| |
| total_loss += loss.item() |
| _, predicted = outputs.max(1) |
| total += labels.size(0) |
| correct += predicted.eq(labels).sum().item() |
| |
| all_preds.extend(predicted.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| |
| return total_loss / len(dataloader), correct / total, all_preds, all_labels |
|
|
|
|
| def print_confusion_matrix(preds, labels): |
| """Print confusion matrix""" |
| from collections import defaultdict |
| |
| matrix = defaultdict(lambda: defaultdict(int)) |
| for p, l in zip(preds, labels): |
| matrix[IDX_TO_CLASS[l]][IDX_TO_CLASS[p]] += 1 |
| |
| print("\nConfusion Matrix (rows=true, cols=pred):") |
| |
| |
| header = f"{'':>12}" + "".join(f"{c[:6]:>8}" for c in POSE_CLASSES) |
| print(header) |
| |
| for true_class in POSE_CLASSES: |
| row = f"{true_class:>12}" |
| for pred_class in POSE_CLASSES: |
| count = matrix[true_class][pred_class] |
| row += f"{count:>8}" |
| print(row) |
| |
| |
| print("\nPer-class accuracy:") |
| for cls in POSE_CLASSES: |
| correct = matrix[cls][cls] |
| total = sum(matrix[cls].values()) |
| acc = correct / total * 100 if total > 0 else 0 |
| print(f" {cls:>12}: {acc:5.1f}% ({correct}/{total})") |
|
|
|
|
| def export_onnx(model, output_path, device='cpu'): |
| """Export to ONNX""" |
| model.eval() |
| model.to(device) |
| |
| dummy = torch.randn(1, 3, 224, 224).to(device) |
| |
| torch.onnx.export( |
| model, dummy, output_path, |
| export_params=True, |
| opset_version=14, |
| input_names=['image'], |
| output_names=['logits'], |
| dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}} |
| ) |
| print(f"Exported to {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--data_dir', type=str, help='Directory with class folders') |
| parser.add_argument('--train_csv', type=str, help='Training CSV') |
| parser.add_argument('--val_csv', type=str, help='Validation CSV') |
| parser.add_argument('--model_size', type=str, default='small', choices=['small', 'base', 'large']) |
| parser.add_argument('--epochs', type=int, default=30) |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--output_dir', type=str, default='./checkpoints') |
| parser.add_argument('--export_onnx', action='store_true') |
| args = parser.parse_args() |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| train_transform = get_transforms(train=True) |
| val_transform = get_transforms(train=False) |
| |
| if args.train_csv: |
| train_dataset = PoseDataset(args.train_csv, train_transform, augment_flip=True) |
| val_dataset = PoseDataset(args.val_csv, val_transform, augment_flip=False) if args.val_csv else None |
| elif args.data_dir: |
| full_dataset = PoseDataset(args.data_dir, train_transform, augment_flip=True) |
| |
| n_val = int(0.2 * len(full_dataset)) |
| n_train = len(full_dataset) - n_val |
| train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val]) |
| |
| val_dataset.dataset.augment_flip = False |
| val_dataset.dataset.transform = val_transform |
| else: |
| print("Provide --data_dir or --train_csv") |
| return |
| |
| |
| if hasattr(train_dataset, 'get_sample_weights'): |
| weights = train_dataset.get_sample_weights() |
| sampler = WeightedRandomSampler(weights, len(weights)) |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4) |
| else: |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) |
| |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) if val_dataset else None |
| |
| |
| print(f"\nLoading DINOv2-{args.model_size}...") |
| model = PoseClassifier(model_size=args.model_size).to(device) |
| |
| trainable = sum(p.numel() for p in model.head.parameters()) |
| print(f"Trainable parameters: {trainable:,}") |
| |
| |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
| optimizer = torch.optim.AdamW(model.head.parameters(), lr=args.lr, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) |
| scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None |
| |
| best_acc = 0 |
| |
| for epoch in range(args.epochs): |
| print(f"\nEpoch {epoch+1}/{args.epochs}") |
| |
| train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler) |
| |
| if val_loader: |
| val_loss, val_acc, preds, labels = evaluate(model, val_loader, criterion, device) |
| print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%") |
| print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.1f}%") |
| |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'head_state_dict': model.head.state_dict(), |
| 'val_acc': val_acc, |
| 'classes': POSE_CLASSES, |
| }, f'{args.output_dir}/best_pose_model.pth') |
| print(f" → Saved (acc: {val_acc*100:.1f}%)") |
| else: |
| print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.1f}%") |
| |
| scheduler.step() |
| |
| |
| if val_loader: |
| print("\n" + "="*60) |
| print("Final Evaluation") |
| print("="*60) |
| |
| ckpt = torch.load(f'{args.output_dir}/best_pose_model.pth') |
| model.load_state_dict(ckpt['model_state_dict']) |
| |
| _, acc, preds, labels = evaluate(model, val_loader, criterion, device) |
| print(f"Best Accuracy: {acc*100:.1f}%") |
| print_confusion_matrix(preds, labels) |
| |
| |
| if args.export_onnx: |
| export_onnx(model, f'{args.output_dir}/pose_classifier.onnx') |
| |
| print("\nDone!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|