#!/usr/bin/env python3
"""
Script de test et démonstration du modèle de détection de reine
Peut traiter des images, vidéos ou flux webcam

Usage:
    python test_model.py --model best.pt --source image.jpg
    python test_model.py --model best.pt --source video.mp4
    python test_model.py --model best.pt --source 0  # webcam
"""

import argparse
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import time

class QueenDetector:
    def __init__(self, model_path, conf_threshold=0.5):
        """
        Initialiser le détecteur
        
        Args:
            model_path: Chemin vers le modèle YOLO
            conf_threshold: Seuil de confiance minimum
        """
        self.model = YOLO(model_path)
        self.conf_threshold = conf_threshold
        self.colors = {
            'queen_bee': (0, 255, 0),  # Vert pour la reine
            'worker_bee': (255, 165, 0)  # Orange pour ouvrières
        }
        
        print(f"Modèle chargé: {model_path}")
        print(f"Seuil de confiance: {conf_threshold}")
    
    def detect_image(self, image_path, save_path=None):
        """Détecter sur une image"""
        img = cv2.imread(str(image_path))
        
        if img is None:
            print(f"Erreur: Impossible de charger {image_path}")
            return None
        
        # Détection
        results = self.model.predict(
            img,
            conf=self.conf_threshold,
            verbose=False
        )
        
        # Dessiner les résultats
        annotated = self.draw_detections(img, results[0])
        
        # Sauvegarder si demandé
        if save_path:
            cv2.imwrite(str(save_path), annotated)
            print(f"Résultat sauvegardé: {save_path}")
        
        return annotated, results[0]
    
    def detect_video(self, video_source, output_path=None, show=True):
        """
        Détecter sur une vidéo ou webcam
        
        Args:
            video_source: Chemin vidéo ou index webcam (0, 1, etc.)
            output_path: Chemin pour sauvegarder la vidéo annotée
            show: Afficher en temps réel
        """
        # Ouvrir la source
        if isinstance(video_source, int) or video_source.isdigit():
            cap = cv2.VideoCapture(int(video_source))
            is_webcam = True
        else:
            cap = cv2.VideoCapture(str(video_source))
            is_webcam = False
        
        if not cap.isOpened():
            print(f"Erreur: Impossible d'ouvrir {video_source}")
            return
        
        # Propriétés vidéo
        fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Writer pour sauvegarder
        writer = None
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
        
        print(f"\nTraitement de: {video_source}")
        print(f"Résolution: {width}x{height} @ {fps} FPS")
        print("\nAppuyez sur 'q' pour quitter")
        
        frame_count = 0
        queens_detected = 0
        start_time = time.time()
        
        # Tracker pour suivi multi-frame
        tracked_queens = {}
        next_id = 0
        
        while True:
            ret, frame = cap.read()
            
            if not ret:
                if is_webcam:
                    print("Erreur de lecture webcam")
                    break
                else:
                    print("Fin de la vidéo")
                    break
            
            frame_count += 1
            
            # Détection
            results = self.model.track(
                frame,
                conf=self.conf_threshold,
                persist=True,  # Active le suivi
                verbose=False
            )
            
            # Dessiner les détections avec suivi
            annotated = self.draw_detections_with_tracking(frame, results[0])
            
            # Stats
            num_queens = len(results[0].boxes) if results[0].boxes is not None else 0
            if num_queens > 0:
                queens_detected += 1
            
            # Afficher les stats sur la frame
            elapsed = time.time() - start_time
            current_fps = frame_count / elapsed if elapsed > 0 else 0
            
            stats_text = [
                f"Frame: {frame_count}",
                f"FPS: {current_fps:.1f}",
                f"Reines detectees: {num_queens}",
                f"Detection rate: {queens_detected/frame_count*100:.1f}%"
            ]
            
            y_offset = 30
            for text in stats_text:
                cv2.putText(annotated, text, (10, y_offset),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                cv2.putText(annotated, text, (10, y_offset),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)
                y_offset += 25
            
            # Sauvegarder
            if writer:
                writer.write(annotated)
            
            # Afficher
            if show:
                cv2.imshow('Queen Bee Detection', annotated)
                
                key = cv2.waitKey(1) & 0xFF
                if key == ord('q'):
                    break
                elif key == ord('s'):  # Screenshot
                    screenshot_path = f"screenshot_{frame_count}.jpg"
                    cv2.imwrite(screenshot_path, annotated)
                    print(f"Screenshot sauvegardé: {screenshot_path}")
        
        # Nettoyage
        cap.release()
        if writer:
            writer.release()
        if show:
            cv2.destroyAllWindows()
        
        # Stats finales
        print(f"\n{'='*50}")
        print(f"Traitement terminé:")
        print(f"  Frames traitées: {frame_count}")
        print(f"  Frames avec reine: {queens_detected}")
        print(f"  Taux de détection: {queens_detected/frame_count*100:.1f}%")
        print(f"  FPS moyen: {frame_count/(time.time()-start_time):.1f}")
        if output_path:
            print(f"  Vidéo sauvegardée: {output_path}")
    
    def draw_detections(self, img, results):
        """Dessiner les détections sur l'image"""
        annotated = img.copy()
        
        if results.boxes is None or len(results.boxes) == 0:
            return annotated
        
        for box in results.boxes:
            # Coordonnées
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            conf = float(box.conf[0])
            cls = int(box.cls[0])
            
            # Couleur
            color = self.colors.get('queen_bee', (0, 255, 0))
            
            # Rectangle
            cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 3)
            
            # Label avec confiance
            label = f"Queen {conf:.2f}"
            
            # Background du label
            (label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
            cv2.rectangle(annotated, (x1, y1 - label_h - 10), (x1 + label_w, y1), color, -1)
            
            # Texte
            cv2.putText(annotated, label, (x1, y1 - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2)
            
            # Croix au centre
            center_x = (x1 + x2) // 2
            center_y = (y1 + y2) // 2
            cross_size = 20
            cv2.line(annotated, (center_x - cross_size, center_y),
                    (center_x + cross_size, center_y), color, 2)
            cv2.line(annotated, (center_x, center_y - cross_size),
                    (center_x, center_y + cross_size), color, 2)
            
            # Icône reine
            cv2.putText(annotated, "♛", (x1 + 5, y2 - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 215, 0), 2)
        
        return annotated
    
    def draw_detections_with_tracking(self, img, results):
        """Dessiner avec ID de suivi"""
        annotated = img.copy()
        
        if results.boxes is None or len(results.boxes) == 0:
            return annotated
        
        for box in results.boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            conf = float(box.conf[0])
            
            # ID de tracking si disponible
            track_id = int(box.id[0]) if box.id is not None else -1
            
            # Couleur basée sur l'ID (consistante)
            if track_id >= 0:
                np.random.seed(track_id)
                color = tuple(map(int, np.random.randint(100, 255, 3)))
            else:
                color = (0, 255, 0)
            
            # Rectangle
            cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 3)
            
            # Label avec ID
            label = f"Queen #{track_id} {conf:.2f}" if track_id >= 0 else f"Queen {conf:.2f}"
            
            (label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
            cv2.rectangle(annotated, (x1, y1 - label_h - 10), (x1 + label_w, y1), color, -1)
            cv2.putText(annotated, label, (x1, y1 - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
            
            # Réticule centré
            center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
            cv2.drawMarker(annotated, (center_x, center_y), color,
                          cv2.MARKER_CROSS, 30, 2)
        
        return annotated
    
    def batch_process(self, input_dir, output_dir, pattern='*.jpg'):
        """Traiter un dossier d'images"""
        input_path = Path(input_dir)
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        images = list(input_path.glob(pattern))
        print(f"Traitement de {len(images)} images...")
        
        for i, img_path in enumerate(images, 1):
            print(f"[{i}/{len(images)}] {img_path.name}")
            
            save_path = output_path / f"detected_{img_path.name}"
            self.detect_image(img_path, save_path)
        
        print(f"\nTerminé! Résultats dans: {output_path}")


def main():
    parser = argparse.ArgumentParser(description='Test du modèle de détection de reine')
    parser.add_argument('--model', type=str, required=True, help='Chemin vers le modèle (.pt)')
    parser.add_argument('--source', type=str, required=True,
                       help='Source: image, vidéo, webcam (0), ou dossier')
    parser.add_argument('--conf', type=float, default=0.5, help='Seuil de confiance')
    parser.add_argument('--output', type=str, help='Chemin de sortie')
    parser.add_argument('--no-show', action='store_true', help='Ne pas afficher (pour batch)')
    
    args = parser.parse_args()
    
    detector = QueenDetector(args.model, args.conf)
    
    source_path = Path(args.source)
    
    # Déterminer le type de source
    if source_path.is_file():
        # Image ou vidéo
        if source_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
            print("Mode: Image")
            detector.detect_image(source_path, args.output)
        elif source_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv']:
            print("Mode: Vidéo")
            detector.detect_video(source_path, args.output, show=not args.no_show)
    
    elif source_path.is_dir():
        # Batch
        print("Mode: Batch")
        output = args.output or source_path / 'detected'
        detector.batch_process(source_path, output)
    
    else:
        # Webcam
        print("Mode: Webcam")
        detector.detect_video(args.source, args.output, show=not args.no_show)


if __name__ == '__main__':
    main()
