#!/usr/bin/env python3
"""
Script d'entraînement du modèle de détection de reine d'abeille
Utilise YOLOv8 de Ultralytics

Installation requise:
pip install ultralytics opencv-python pillow

Usage:
python train_model.py --data /chemin/vers/dataset --epochs 100
"""

import argparse
from pathlib import Path
import yaml
from ultralytics import YOLO
import torch

class QueenBeeTrainer:
    def __init__(self, data_path, model_size='n', epochs=100, batch_size=16, img_size=640):
        """
        Initialiser l'entraînement
        
        Args:
            data_path: Chemin vers le dataset
            model_size: Taille du modèle ('n', 's', 'm', 'l', 'x')
            epochs: Nombre d'époques
            batch_size: Taille des batches
            img_size: Taille des images
        """
        self.data_path = Path(data_path)
        self.model_size = model_size
        self.epochs = epochs
        self.batch_size = batch_size
        self.img_size = img_size
        
        # Vérifier CUDA
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device utilisé: {self.device}")
        
        # Créer la structure de dossiers
        self.setup_directories()
    
    def setup_directories(self):
        """Créer la structure de répertoires pour l'entraînement"""
        dirs = ['train/images', 'train/labels', 'val/images', 'val/labels', 'test/images', 'test/labels']
        
        for dir_path in dirs:
            (self.data_path / dir_path).mkdir(parents=True, exist_ok=True)
        
        print(f"Structure de répertoires créée dans: {self.data_path}")
    
    def create_yaml_config(self):
        """Créer le fichier de configuration YAML pour YOLO"""
        config = {
            'path': str(self.data_path.absolute()),
            'train': 'train/images',
            'val': 'val/images',
            'test': 'test/images',
            'nc': 1,  # Nombre de classes
            'names': ['queen_bee']  # Noms des classes
        }
        
        yaml_path = self.data_path / 'dataset.yaml'
        with open(yaml_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
        
        print(f"Configuration YAML créée: {yaml_path}")
        return yaml_path
    
    def split_dataset(self, train_ratio=0.8, val_ratio=0.1):
        """
        Diviser le dataset en train/val/test
        
        Args:
            train_ratio: Ratio d'entraînement (0.8 = 80%)
            val_ratio: Ratio de validation (0.1 = 10%)
        """
        import random
        import shutil
        
        # Trouver toutes les images annotées
        images_dir = self.data_path / 'images'
        labels_dir = self.data_path / 'labels'
        
        if not images_dir.exists() or not labels_dir.exists():
            print("Erreur: Dossiers 'images' et 'labels' requis!")
            return False
        
        image_files = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png'))
        
        # Mélanger
        random.shuffle(image_files)
        
        # Calculer les splits
        total = len(image_files)
        train_count = int(total * train_ratio)
        val_count = int(total * val_ratio)
        
        train_images = image_files[:train_count]
        val_images = image_files[train_count:train_count + val_count]
        test_images = image_files[train_count + val_count:]
        
        # Copier les fichiers
        for split_name, image_list in [('train', train_images), ('val', val_images), ('test', test_images)]:
            for img_path in image_list:
                # Image
                shutil.copy(img_path, self.data_path / split_name / 'images' / img_path.name)
                
                # Label correspondant
                label_path = labels_dir / f"{img_path.stem}.txt"
                if label_path.exists():
                    shutil.copy(label_path, self.data_path / split_name / 'labels' / label_path.name)
        
        print(f"\nDataset divisé:")
        print(f"  Train: {len(train_images)} images")
        print(f"  Val: {len(val_images)} images")
        print(f"  Test: {len(test_images)} images")
        
        return True
    
    def train(self):
        """Lancer l'entraînement du modèle"""
        
        # Créer le fichier YAML
        yaml_path = self.create_yaml_config()
        
        # Charger le modèle pré-entraîné
        model_name = f'yolov8{self.model_size}.pt'
        print(f"\nChargement du modèle: {model_name}")
        model = YOLO(model_name)
        
        # Paramètres d'entraînement
        print(f"\nDémarrage de l'entraînement...")
        print(f"  Époques: {self.epochs}")
        print(f"  Batch size: {self.batch_size}")
        print(f"  Image size: {self.img_size}")
        
        # Entraîner
        results = model.train(
            data=str(yaml_path),
            epochs=self.epochs,
            batch=self.batch_size,
            imgsz=self.img_size,
            device=self.device,
            project='queen_bee_detection',
            name='experiment',
            pretrained=True,
            optimizer='AdamW',
            lr0=0.001,
            patience=20,
            save=True,
            save_period=10,
            plots=True,
            verbose=True,
            # Data augmentation
            hsv_h=0.015,
            hsv_s=0.7,
            hsv_v=0.4,
            degrees=10.0,
            translate=0.1,
            scale=0.5,
            shear=0.0,
            perspective=0.0,
            flipud=0.0,
            fliplr=0.5,
            mosaic=1.0,
            mixup=0.0
        )
        
        print("\nEntraînement terminé!")
        print(f"Résultats sauvegardés dans: queen_bee_detection/experiment")
        
        return results
    
    def validate(self, model_path):
        """Valider le modèle entraîné"""
        model = YOLO(model_path)
        
        yaml_path = self.data_path / 'dataset.yaml'
        
        print("\nValidation du modèle...")
        results = model.val(
            data=str(yaml_path),
            split='test'
        )
        
        return results
    
    def export_to_tensorflowjs(self, model_path, output_dir='./tfjs_model'):
        """Exporter le modèle vers TensorFlow.js pour l'app web"""
        model = YOLO(model_path)
        
        print(f"\nExportation vers TensorFlow.js...")
        
        # Export en ONNX puis conversion
        model.export(format='onnx', simplify=True)
        
        print(f"Modèle exporté! Utilisez onnx2tf pour convertir en TF.js:")
        print(f"  pip install onnx-tf tensorflow")
        print(f"  onnx-tf convert -i model.onnx -o {output_dir}")
        
        return True
    
    def export_to_tflite(self, model_path, output_dir='./tflite_model'):
        """Exporter vers TFLite pour Android/iOS"""
        model = YOLO(model_path)
        
        print(f"\nExportation vers TFLite...")
        model.export(format='tflite', imgsz=640, int8=False)
        
        print(f"Modèle TFLite créé!")
        return True


def main():
    parser = argparse.ArgumentParser(description='Entraîner un modèle de détection de reine d\'abeille')
    parser.add_argument('--data', type=str, required=True, help='Chemin vers le dataset')
    parser.add_argument('--model', type=str, default='n', choices=['n', 's', 'm', 'l', 'x'],
                       help='Taille du modèle (n=nano, s=small, m=medium, l=large, x=extra large)')
    parser.add_argument('--epochs', type=int, default=100, help='Nombre d\'époques')
    parser.add_argument('--batch', type=int, default=16, help='Taille du batch')
    parser.add_argument('--img-size', type=int, default=640, help='Taille des images')
    parser.add_argument('--split', action='store_true', help='Diviser le dataset automatiquement')
    parser.add_argument('--validate', type=str, help='Valider un modèle existant')
    parser.add_argument('--export-tfjs', type=str, help='Exporter vers TensorFlow.js')
    parser.add_argument('--export-tflite', type=str, help='Exporter vers TFLite')
    
    args = parser.parse_args()
    
    trainer = QueenBeeTrainer(
        data_path=args.data,
        model_size=args.model,
        epochs=args.epochs,
        batch_size=args.batch,
        img_size=args.img_size
    )
    
    # Split du dataset si demandé
    if args.split:
        success = trainer.split_dataset()
        if not success:
            print("Erreur lors du split du dataset")
            return
    
    # Validation d'un modèle existant
    if args.validate:
        trainer.validate(args.validate)
        return
    
    # Export TensorFlow.js
    if args.export_tfjs:
        trainer.export_to_tensorflowjs(args.export_tfjs)
        return
    
    # Export TFLite
    if args.export_tflite:
        trainer.export_to_tflite(args.export_tflite)
        return
    
    # Entraînement
    results = trainer.train()
    
    print("\n" + "="*50)
    print("PROCHAINES ÉTAPES:")
    print("="*50)
    print("1. Valider le modèle:")
    print("   python train_model.py --data /chemin/dataset --validate queen_bee_detection/experiment/weights/best.pt")
    print("\n2. Exporter pour l'app web:")
    print("   python train_model.py --data /chemin/dataset --export-tfjs queen_bee_detection/experiment/weights/best.pt")
    print("\n3. Copier le modèle exporté dans app/models/")


if __name__ == '__main__':
    main()
