#!/usr/bin/env python3
"""
Script de test automatisé du modèle de détection
Permet de vérifier le fonctionnement et les performances
"""

import cv2
import numpy as np
from pathlib import Path
import argparse
import json
from ultralytics import YOLO
import time
from tqdm import tqdm

class ModelTester:
    def __init__(self, model_path, test_images_dir):
        self.model_path = Path(model_path)
        self.test_images_dir = Path(test_images_dir)
        self.model = None
        self.results = {
            'total_images': 0,
            'total_detections': 0,
            'avg_inference_time': 0,
            'avg_confidence': 0,
            'images_with_detections': 0,
            'per_image_results': []
        }
        
    def load_model(self):
        """Charge le modèle YOLO"""
        print(f"📦 Chargement du modèle: {self.model_path}")
        self.model = YOLO(str(self.model_path))
        print("✅ Modèle chargé")
        
    def test_single_image(self, image_path, visualize=False, save_path=None):
        """Teste sur une seule image"""
        img = cv2.imread(str(image_path))
        if img is None:
            print(f"❌ Impossible de charger: {image_path}")
            return None
            
        # Inférence
        start_time = time.time()
        results = self.model(img, verbose=False)[0]
        inference_time = (time.time() - start_time) * 1000  # ms
        
        # Extraire les détections
        detections = []
        for box in results.boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            conf = float(box.conf[0])
            cls = int(box.cls[0])
            
            detections.append({
                'bbox': [float(x1), float(y1), float(x2), float(y2)],
                'confidence': conf,
                'class': cls,
                'class_name': self.model.names[cls]
            })
        
        result = {
            'image': image_path.name,
            'inference_time_ms': inference_time,
            'num_detections': len(detections),
            'detections': detections
        }
        
        # Visualisation
        if visualize or save_path:
            vis_img = self.visualize_detections(img, detections)
            
            if visualize:
                cv2.imshow('Détections', vis_img)
                cv2.waitKey(0)
                
            if save_path:
                cv2.imwrite(str(save_path), vis_img)
        
        return result
        
    def visualize_detections(self, img, detections):
        """Dessine les détections sur l'image"""
        vis_img = img.copy()
        
        for det in detections:
            x1, y1, x2, y2 = [int(v) for v in det['bbox']]
            conf = det['confidence']
            class_name = det['class_name']
            
            # Couleur selon la classe
            color = (0, 255, 0) if class_name == 'queen' else (255, 0, 0)
            
            # Rectangle
            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
            
            # Label
            label = f"{class_name}: {conf:.2f}"
            (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
            cv2.rectangle(vis_img, (x1, y1-h-10), (x1+w, y1), color, -1)
            cv2.putText(vis_img, label, (x1, y1-5), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        
        return vis_img
        
    def test_directory(self, visualize=False, save_results=True):
        """Teste sur un dossier d'images"""
        print(f"\n🧪 Test sur le dossier: {self.test_images_dir}")
        
        # Trouver toutes les images
        image_files = list(self.test_images_dir.glob('*.jpg')) + \
                     list(self.test_images_dir.glob('*.jpeg')) + \
                     list(self.test_images_dir.glob('*.png'))
        
        if not image_files:
            print("❌ Aucune image trouvée")
            return
        
        print(f"📊 {len(image_files)} images à tester")
        
        # Dossier de sortie pour visualisations
        output_dir = Path('test_results')
        output_dir.mkdir(exist_ok=True)
        
        # Tester chaque image
        inference_times = []
        confidences = []
        
        for img_path in tqdm(image_files, desc="Test en cours"):
            save_path = output_dir / f"result_{img_path.name}" if save_results else None
            
            result = self.test_single_image(img_path, visualize=visualize, save_path=save_path)
            
            if result:
                self.results['per_image_results'].append(result)
                inference_times.append(result['inference_time_ms'])
                
                if result['num_detections'] > 0:
                    self.results['images_with_detections'] += 1
                    self.results['total_detections'] += result['num_detections']
                    
                    for det in result['detections']:
                        confidences.append(det['confidence'])
        
        # Calculer les statistiques
        self.results['total_images'] = len(image_files)
        
        if inference_times:
            self.results['avg_inference_time'] = np.mean(inference_times)
            self.results['min_inference_time'] = np.min(inference_times)
            self.results['max_inference_time'] = np.max(inference_times)
        
        if confidences:
            self.results['avg_confidence'] = np.mean(confidences)
            self.results['min_confidence'] = np.min(confidences)
            self.results['max_confidence'] = np.max(confidences)
        
        # Afficher le rapport
        self.print_report()
        
        # Sauvegarder JSON
        if save_results:
            with open(output_dir / 'test_results.json', 'w') as f:
                json.dump(self.results, f, indent=2)
            print(f"\n💾 Résultats sauvegardés dans: {output_dir}")
        
    def print_report(self):
        """Affiche le rapport de test"""
        print("\n" + "="*60)
        print("📋 RAPPORT DE TEST")
        print("="*60)
        
        print(f"\n📊 Statistiques globales:")
        print(f"   Images testées: {self.results['total_images']}")
        print(f"   Images avec détections: {self.results['images_with_detections']} "
              f"({100*self.results['images_with_detections']/max(1,self.results['total_images']):.1f}%)")
        print(f"   Détections totales: {self.results['total_detections']}")
        
        if self.results['total_detections'] > 0:
            avg_per_image = self.results['total_detections'] / self.results['images_with_detections']
            print(f"   Détections par image (avg): {avg_per_image:.2f}")
        
        print(f"\n⏱️  Performance:")
        print(f"   Temps d'inférence moyen: {self.results.get('avg_inference_time', 0):.2f} ms")
        print(f"   Min: {self.results.get('min_inference_time', 0):.2f} ms")
        print(f"   Max: {self.results.get('max_inference_time', 0):.2f} ms")
        
        if 'avg_inference_time' in self.results:
            fps = 1000 / self.results['avg_inference_time']
            print(f"   FPS théorique: {fps:.1f}")
        
        print(f"\n🎯 Confiance:")
        if 'avg_confidence' in self.results:
            print(f"   Confiance moyenne: {self.results['avg_confidence']:.3f}")
            print(f"   Min: {self.results['min_confidence']:.3f}")
            print(f"   Max: {self.results['max_confidence']:.3f}")
        else:
            print("   Aucune détection")
        
        # Distribution par classe
        class_counts = {}
        for img_result in self.results['per_image_results']:
            for det in img_result['detections']:
                class_name = det['class_name']
                class_counts[class_name] = class_counts.get(class_name, 0) + 1
        
        if class_counts:
            print(f"\n📈 Distribution par classe:")
            for class_name, count in sorted(class_counts.items(), key=lambda x: -x[1]):
                pct = 100 * count / self.results['total_detections']
                print(f"   {class_name}: {count} ({pct:.1f}%)")
        
        # Évaluation
        print(f"\n💡 Évaluation:")
        
        if self.results.get('avg_inference_time', 999) < 100:
            print("   ✅ Performance excellente (<100ms)")
        elif self.results.get('avg_inference_time', 999) < 150:
            print("   ✅ Performance bonne (<150ms)")
        else:
            print("   ⚠️  Performance à améliorer (>150ms)")
        
        detection_rate = self.results['images_with_detections'] / max(1, self.results['total_images'])
        if detection_rate > 0.7:
            print("   ✅ Taux de détection élevé (>70%)")
        elif detection_rate > 0.5:
            print("   ⚠️  Taux de détection moyen (50-70%)")
        else:
            print("   ❌ Taux de détection faible (<50%)")
        
        if self.results.get('avg_confidence', 0) > 0.7:
            print("   ✅ Confiance élevée (>0.7)")
        elif self.results.get('avg_confidence', 0) > 0.5:
            print("   ⚠️  Confiance moyenne (0.5-0.7)")
        else:
            print("   ❌ Confiance faible (<0.5)")

def main():
    parser = argparse.ArgumentParser(
        description='Test automatisé du modèle de détection'
    )
    parser.add_argument('--model', type=str, required=True,
                       help='Chemin vers le modèle (.pt)')
    parser.add_argument('--images', type=str, required=True,
                       help='Dossier contenant les images de test')
    parser.add_argument('--visualize', action='store_true',
                       help='Afficher les résultats image par image')
    parser.add_argument('--no-save', action='store_true',
                       help='Ne pas sauvegarder les résultats')
    
    args = parser.parse_args()
    
    tester = ModelTester(args.model, args.images)
    tester.load_model()
    tester.test_directory(
        visualize=args.visualize,
        save_results=not args.no_save
    )
    
    if args.visualize:
        cv2.destroyAllWindows()

if __name__ == '__main__':
    main()
