#!/usr/bin/env python3
"""
Script d'entraînement du modèle YOLOv8 pour la détection de reines d'abeilles
"""

from ultralytics import YOLO
import argparse
from pathlib import Path
import yaml
import torch

def train_model(data_yaml, epochs=100, batch_size=16, img_size=640, 
                model_size='n', device='cuda', project='runs/train'):
    """
    Entraîne le modèle YOLOv8
    
    Args:
        data_yaml: Chemin vers le fichier de configuration du dataset
        epochs: Nombre d'époques
        batch_size: Taille du batch
        img_size: Taille des images
        model_size: Taille du modèle (n, s, m, l, x)
        device: Device (cuda, cpu, mps)
        project: Dossier de sortie
    """
    
    print("🐝 Queen Bee Detector - Training")
    print("=" * 60)
    
    # Vérifier CUDA
    if device == 'cuda' and not torch.cuda.is_available():
        print("⚠️  CUDA non disponible, utilisation du CPU")
        device = 'cpu'
    
    print(f"📊 Configuration:")
    print(f"   Modèle: YOLOv8{model_size}")
    print(f"   Époques: {epochs}")
    print(f"   Batch size: {batch_size}")
    print(f"   Image size: {img_size}")
    print(f"   Device: {device}")
    print("=" * 60)
    
    # Charger le modèle pré-entraîné
    model = YOLO(f'yolov8{model_size}.pt')
    
    # Configuration de l'entraînement
    results = model.train(
        data=data_yaml,
        epochs=epochs,
        batch=batch_size,
        imgsz=img_size,
        device=device,
        project=project,
        name='queen_detector',
        
        # Optimisations
        optimizer='AdamW',
        lr0=0.001,  # Learning rate initial
        lrf=0.01,   # Learning rate final (lr0 * lrf)
        momentum=0.937,
        weight_decay=0.0005,
        warmup_epochs=3.0,
        warmup_momentum=0.8,
        warmup_bias_lr=0.1,
        
        # Augmentation
        hsv_h=0.015,  # Hue augmentation
        hsv_s=0.7,    # Saturation
        hsv_v=0.4,    # Value
        degrees=10.0,  # Rotation
        translate=0.1,
        scale=0.5,
        shear=0.0,
        perspective=0.0,
        flipud=0.0,
        fliplr=0.5,
        mosaic=1.0,
        mixup=0.0,
        
        # Callbacks et monitoring
        save=True,
        save_period=10,  # Sauvegarder tous les 10 epochs
        plots=True,
        verbose=True,
        
        # Early stopping
        patience=50,
        
        # Validation
        val=True,
        split='val',
        
        # Performance
        workers=8,
        cache=True,  # Cache images en RAM pour plus de vitesse
    )
    
    print("\n✅ Entraînement terminé!")
    print(f"📁 Résultats sauvegardés dans: {results.save_dir}")
    
    # Afficher les métriques finales
    print("\n📊 Métriques finales:")
    print(f"   mAP50: {results.results_dict.get('metrics/mAP50(B)', 'N/A')}")
    print(f"   mAP50-95: {results.results_dict.get('metrics/mAP50-95(B)', 'N/A')}")
    
    return results

def evaluate_model(model_path, data_yaml, img_size=640, device='cuda'):
    """Évalue le modèle sur le dataset de test"""
    print("\n🔍 Évaluation du modèle...")
    
    model = YOLO(model_path)
    
    results = model.val(
        data=data_yaml,
        split='test',
        imgsz=img_size,
        device=device,
        plots=True,
        save_json=True,
        save_hybrid=True
    )
    
    print("\n📊 Résultats sur le test set:")
    print(f"   mAP50: {results.box.map50:.4f}")
    print(f"   mAP50-95: {results.box.map:.4f}")
    print(f"   Précision: {results.box.mp:.4f}")
    print(f"   Rappel: {results.box.mr:.4f}")
    
    return results

def export_model(model_path, formats=['onnx', 'torchscript']):
    """Exporte le modèle dans différents formats"""
    print(f"\n📤 Export du modèle...")
    
    model = YOLO(model_path)
    
    for fmt in formats:
        print(f"   Exportation en {fmt}...")
        try:
            model.export(format=fmt, imgsz=640, half=False)
            print(f"   ✅ {fmt} exporté")
        except Exception as e:
            print(f"   ❌ Erreur {fmt}: {e}")

def main():
    parser = argparse.ArgumentParser(
        description='Entraînement du détecteur de reines d\'abeilles'
    )
    
    # Arguments principaux
    parser.add_argument('--data', type=str, required=True,
                       help='Chemin vers le fichier data.yaml')
    parser.add_argument('--epochs', type=int, default=100,
                       help='Nombre d\'époques (défaut: 100)')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='Taille du batch (défaut: 16)')
    parser.add_argument('--img-size', type=int, default=640,
                       help='Taille des images (défaut: 640)')
    parser.add_argument('--model-size', type=str, default='n',
                       choices=['n', 's', 'm', 'l', 'x'],
                       help='Taille du modèle (n=nano, s=small, m=medium, l=large, x=xlarge)')
    parser.add_argument('--device', type=str, default='cuda',
                       choices=['cuda', 'cpu', 'mps'],
                       help='Device d\'entraînement')
    parser.add_argument('--project', type=str, default='runs/train',
                       help='Dossier de sortie')
    
    # Options supplémentaires
    parser.add_argument('--evaluate', action='store_true',
                       help='Évaluer le modèle après entraînement')
    parser.add_argument('--export', action='store_true',
                       help='Exporter le modèle après entraînement')
    parser.add_argument('--resume', type=str, default=None,
                       help='Reprendre depuis un checkpoint')
    
    args = parser.parse_args()
    
    # Entraînement
    if args.resume:
        print(f"🔄 Reprise de l'entraînement depuis: {args.resume}")
        model = YOLO(args.resume)
        results = model.train(resume=True)
    else:
        results = train_model(
            data_yaml=args.data,
            epochs=args.epochs,
            batch_size=args.batch_size,
            img_size=args.img_size,
            model_size=args.model_size,
            device=args.device,
            project=args.project
        )
    
    # Chemin du meilleur modèle
    best_model = Path(results.save_dir) / 'weights' / 'best.pt'
    
    # Évaluation
    if args.evaluate:
        evaluate_model(str(best_model), args.data, args.img_size, args.device)
    
    # Export
    if args.export:
        export_model(str(best_model), formats=['onnx', 'torchscript'])
    
    print(f"\n🎉 Processus terminé!")
    print(f"📁 Meilleur modèle: {best_model}")

if __name__ == '__main__':
    main()
