#!/usr/bin/env python3
"""
Conversion du modèle YOLOv8 vers TensorFlow Lite pour déploiement mobile
"""

import argparse
import tensorflow as tf
import onnx
from onnx_tf.backend import prepare
import numpy as np
from pathlib import Path
import shutil

class ModelConverter:
    def __init__(self, model_path, output_dir='converted_models'):
        self.model_path = Path(model_path)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
    def yolo_to_onnx(self):
        """Convertit YOLO vers ONNX"""
        print("🔄 Conversion YOLO -> ONNX...")
        
        from ultralytics import YOLO
        
        model = YOLO(str(self.model_path))
        
        # Export ONNX
        onnx_path = model.export(
            format='onnx',
            imgsz=640,
            dynamic=False,
            simplify=True,
            opset=12
        )
        
        print(f"✅ ONNX sauvegardé: {onnx_path}")
        return Path(onnx_path)
        
    def onnx_to_tflite(self, onnx_path, quantize=True):
        """Convertit ONNX vers TensorFlow Lite"""
        print("🔄 Conversion ONNX -> TFLite...")
        
        # Charger le modèle ONNX
        onnx_model = onnx.load(str(onnx_path))
        
        # Convertir en TensorFlow
        print("   Conversion en TensorFlow...")
        tf_rep = prepare(onnx_model)
        
        # Sauvegarder le modèle TensorFlow
        tf_model_dir = self.output_dir / 'tf_model'
        tf_rep.export_graph(str(tf_model_dir))
        
        # Convertir en TFLite
        print("   Conversion en TFLite...")
        converter = tf.lite.TFLiteConverter.from_saved_model(str(tf_model_dir))
        
        # Optimisations
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        
        if quantize:
            print("   Application de la quantification INT8...")
            
            # Quantification complète INT8
            def representative_dataset():
                """Génère un dataset représentatif pour la quantification"""
                for _ in range(100):
                    # Images aléatoires représentatives
                    data = np.random.rand(1, 640, 640, 3).astype(np.float32)
                    yield [data]
            
            converter.representative_dataset = representative_dataset
            converter.target_spec.supported_ops = [
                tf.lite.OpsSet.TFLITE_BUILTINS_INT8
            ]
            converter.inference_input_type = tf.uint8
            converter.inference_output_type = tf.uint8
        
        # Convertir
        tflite_model = converter.convert()
        
        # Sauvegarder
        suffix = '_quant' if quantize else ''
        tflite_path = self.output_dir / f'queen_detector{suffix}.tflite'
        
        with open(tflite_path, 'wb') as f:
            f.write(tflite_model)
        
        # Afficher la taille
        size_mb = tflite_path.stat().st_size / (1024 * 1024)
        print(f"✅ TFLite sauvegardé: {tflite_path} ({size_mb:.2f} MB)")
        
        # Nettoyer
        shutil.rmtree(tf_model_dir)
        
        return tflite_path
        
    def create_metadata(self, tflite_path):
        """Crée le fichier de métadonnées pour le modèle"""
        metadata = {
            "name": "Queen Bee Detector",
            "description": "Détecteur de reines d'abeilles basé sur YOLOv8",
            "version": "1.0.0",
            "author": "Team AURA - Jerome",
            "input": {
                "shape": [1, 640, 640, 3],
                "type": "uint8",
                "mean": 127.5,
                "std": 127.5
            },
            "output": {
                "num_classes": 2,
                "classes": ["queen", "worker"],
                "confidence_threshold": 0.5,
                "iou_threshold": 0.45
            }
        }
        
        import json
        metadata_path = tflite_path.with_suffix('.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
            
        print(f"✅ Métadonnées créées: {metadata_path}")
        return metadata_path
        
    def test_tflite_model(self, tflite_path, test_image=None):
        """Teste le modèle TFLite"""
        print("\n🧪 Test du modèle TFLite...")
        
        # Charger l'interpréteur
        interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
        interpreter.allocate_tensors()
        
        # Détails du modèle
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        
        print("\n📋 Détails du modèle:")
        print(f"   Input shape: {input_details[0]['shape']}")
        print(f"   Input type: {input_details[0]['dtype']}")
        print(f"   Output shape: {output_details[0]['shape']}")
        print(f"   Output type: {output_details[0]['dtype']}")
        
        # Test avec image aléatoire si pas d'image fournie
        if test_image is None:
            print("\n   Test avec image aléatoire...")
            test_data = np.random.randint(0, 255, size=(1, 640, 640, 3), dtype=np.uint8)
        else:
            import cv2
            img = cv2.imread(test_image)
            img = cv2.resize(img, (640, 640))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            test_data = np.expand_dims(img, axis=0).astype(np.uint8)
        
        # Inférence
        interpreter.set_tensor(input_details[0]['index'], test_data)
        
        import time
        start = time.time()
        interpreter.invoke()
        inference_time = (time.time() - start) * 1000
        
        output_data = interpreter.get_tensor(output_details[0]['index'])
        
        print(f"   ⏱️  Temps d'inférence: {inference_time:.2f} ms")
        print(f"   ✅ Test réussi!")
        
        return True
        
    def convert_all(self, quantize=True, test=True):
        """Pipeline complet de conversion"""
        print("🚀 Début de la conversion complète")
        print("=" * 60)
        
        # 1. YOLO -> ONNX
        onnx_path = self.yolo_to_onnx()
        
        # 2. ONNX -> TFLite (quantifié)
        if quantize:
            tflite_path_quant = self.onnx_to_tflite(onnx_path, quantize=True)
            self.create_metadata(tflite_path_quant)
            
            if test:
                self.test_tflite_model(tflite_path_quant)
        
        # 3. ONNX -> TFLite (non quantifié, pour comparaison)
        tflite_path = self.onnx_to_tflite(onnx_path, quantize=False)
        self.create_metadata(tflite_path)
        
        if test:
            self.test_tflite_model(tflite_path)
        
        print("\n" + "=" * 60)
        print("✅ Conversion terminée avec succès!")
        print(f"📁 Modèles sauvegardés dans: {self.output_dir}")
        
        # Afficher les tailles
        print("\n📊 Tailles des modèles:")
        for model_file in self.output_dir.glob('*.tflite'):
            size_mb = model_file.stat().st_size / (1024 * 1024)
            print(f"   {model_file.name}: {size_mb:.2f} MB")

def main():
    parser = argparse.ArgumentParser(
        description='Conversion du modèle YOLO vers TensorFlow Lite'
    )
    
    parser.add_argument('--model', type=str, required=True,
                       help='Chemin vers le modèle YOLO (.pt)')
    parser.add_argument('--output', type=str, default='converted_models',
                       help='Dossier de sortie')
    parser.add_argument('--no-quantize', action='store_true',
                       help='Désactiver la quantification')
    parser.add_argument('--no-test', action='store_true',
                       help='Ne pas tester le modèle après conversion')
    parser.add_argument('--test-image', type=str, default=None,
                       help='Image de test (optionnel)')
    
    args = parser.parse_args()
    
    converter = ModelConverter(args.model, args.output)
    converter.convert_all(
        quantize=not args.no_quantize,
        test=not args.no_test
    )

if __name__ == '__main__':
    main()
