#!/usr/bin/env python3
"""
Outil d'annotation pour créer le dataset de détection de reines d'abeilles
Usage: python annotation_tool.py /chemin/vers/images
"""

import cv2
import json
import os
import sys
from pathlib import Path

class BeeQueenAnnotator:
    def __init__(self, image_dir):
        self.image_dir = Path(image_dir)
        self.images = list(self.image_dir.glob('*.jpg')) + list(self.image_dir.glob('*.png'))
        self.current_index = 0
        self.annotations = {}
        self.annotation_file = self.image_dir / 'annotations.json'
        
        # Charger les annotations existantes
        if self.annotation_file.exists():
            with open(self.annotation_file, 'r') as f:
                self.annotations = json.load(f)
        
        self.drawing = False
        self.start_point = None
        self.current_box = None
        
        print(f"Trouvé {len(self.images)} images à annoter")
        print("\nControles:")
        print("  - Cliquer et glisser pour dessiner un rectangle autour de la reine")
        print("  - 'n' : Image suivante")
        print("  - 'p' : Image précédente")
        print("  - 'd' : Supprimer la dernière annotation")
        print("  - 's' : Sauvegarder")
        print("  - 'q' : Quitter (sauvegarde automatique)")
        print("  - ESPACE : Marquer comme 'pas de reine visible'")
    
    def mouse_callback(self, event, x, y, flags, param):
        """Gestion des événements souris pour dessiner les rectangles"""
        if event == cv2.EVENT_LBUTTONDOWN:
            self.drawing = True
            self.start_point = (x, y)
        
        elif event == cv2.EVENT_MOUSEMOVE:
            if self.drawing:
                self.current_box = (self.start_point[0], self.start_point[1], x, y)
        
        elif event == cv2.EVENT_LBUTTONUP:
            self.drawing = False
            if self.start_point:
                x1, y1 = self.start_point
                x2, y2 = x, y
                
                # Normaliser les coordonnées
                x1, x2 = min(x1, x2), max(x1, x2)
                y1, y2 = min(y1, y2), max(y1, y2)
                
                # Éviter les rectangles trop petits
                if abs(x2 - x1) > 20 and abs(y2 - y1) > 20:
                    self.add_annotation(x1, y1, x2, y2)
                
                self.current_box = None
    
    def add_annotation(self, x1, y1, x2, y2):
        """Ajouter une annotation pour l'image courante"""
        image_name = self.images[self.current_index].name
        
        if image_name not in self.annotations:
            self.annotations[image_name] = {
                'boxes': [],
                'image_width': self.current_image.shape[1],
                'image_height': self.current_image.shape[0]
            }
        
        self.annotations[image_name]['boxes'].append({
            'x1': int(x1),
            'y1': int(y1),
            'x2': int(x2),
            'y2': int(y2),
            'class': 'queen_bee'
        })
        
        print(f"Annotation ajoutée: ({x1}, {y1}) -> ({x2}, {y2})")
    
    def draw_annotations(self, image):
        """Dessiner les annotations existantes sur l'image"""
        image_name = self.images[self.current_index].name
        display = image.copy()
        
        if image_name in self.annotations:
            for box in self.annotations[image_name]['boxes']:
                cv2.rectangle(display, 
                            (box['x1'], box['y1']), 
                            (box['x2'], box['y2']), 
                            (0, 255, 0), 2)
                cv2.putText(display, 'Queen', 
                          (box['x1'], box['y1'] - 10),
                          cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        # Dessiner le rectangle en cours
        if self.current_box:
            x1, y1, x2, y2 = self.current_box
            cv2.rectangle(display, (x1, y1), (x2, y2), (255, 0, 0), 2)
        
        return display
    
    def save_annotations(self):
        """Sauvegarder les annotations en JSON"""
        with open(self.annotation_file, 'w') as f:
            json.dump(self.annotations, f, indent=2)
        print(f"\nAnnotations sauvegardées: {self.annotation_file}")
    
    def export_to_yolo(self, output_dir):
        """Exporter au format YOLO pour l'entraînement"""
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)
        
        for image_name, data in self.annotations.items():
            # Créer le fichier .txt correspondant
            label_file = output_dir / f"{Path(image_name).stem}.txt"
            
            with open(label_file, 'w') as f:
                img_w = data['image_width']
                img_h = data['image_height']
                
                for box in data['boxes']:
                    # Convertir en format YOLO (normalized center x, y, width, height)
                    x_center = ((box['x1'] + box['x2']) / 2) / img_w
                    y_center = ((box['y1'] + box['y2']) / 2) / img_h
                    width = abs(box['x2'] - box['x1']) / img_w
                    height = abs(box['y2'] - box['y1']) / img_h
                    
                    # Classe 0 = queen_bee
                    f.write(f"0 {x_center} {y_center} {width} {height}\n")
        
        print(f"Format YOLO exporté vers: {output_dir}")
    
    def run(self):
        """Boucle principale d'annotation"""
        if len(self.images) == 0:
            print("Aucune image trouvée!")
            return
        
        cv2.namedWindow('Annotation Tool')
        cv2.setMouseCallback('Annotation Tool', self.mouse_callback)
        
        while True:
            # Charger l'image courante
            image_path = self.images[self.current_index]
            self.current_image = cv2.imread(str(image_path))
            
            if self.current_image is None:
                print(f"Erreur de chargement: {image_path}")
                self.current_index = (self.current_index + 1) % len(self.images)
                continue
            
            # Afficher
            display = self.draw_annotations(self.current_image)
            
            # Info
            image_name = image_path.name
            num_boxes = len(self.annotations.get(image_name, {}).get('boxes', []))
            info_text = f"Image {self.current_index + 1}/{len(self.images)} | {image_name} | Annotations: {num_boxes}"
            
            cv2.putText(display, info_text, (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
            
            cv2.imshow('Annotation Tool', display)
            
            # Gestion clavier
            key = cv2.waitKey(1) & 0xFF
            
            if key == ord('q'):
                self.save_annotations()
                break
            
            elif key == ord('n'):  # Image suivante
                self.current_index = (self.current_index + 1) % len(self.images)
            
            elif key == ord('p'):  # Image précédente
                self.current_index = (self.current_index - 1) % len(self.images)
            
            elif key == ord('d'):  # Supprimer dernière annotation
                image_name = self.images[self.current_index].name
                if image_name in self.annotations and self.annotations[image_name]['boxes']:
                    self.annotations[image_name]['boxes'].pop()
                    print("Dernière annotation supprimée")
            
            elif key == ord('s'):  # Sauvegarder
                self.save_annotations()
            
            elif key == ord(' '):  # Marquer "pas de reine"
                image_name = self.images[self.current_index].name
                self.annotations[image_name] = {
                    'boxes': [],
                    'no_queen': True,
                    'image_width': self.current_image.shape[1],
                    'image_height': self.current_image.shape[0]
                }
                print("Marqué comme 'pas de reine'")
                self.current_index = (self.current_index + 1) % len(self.images)
        
        cv2.destroyAllWindows()


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python annotation_tool.py /chemin/vers/images")
        sys.exit(1)
    
    image_dir = sys.argv[1]
    
    if not os.path.isdir(image_dir):
        print(f"Erreur: {image_dir} n'est pas un répertoire valide")
        sys.exit(1)
    
    annotator = BeeQueenAnnotator(image_dir)
    annotator.run()
    
    # Exporter au format YOLO
    export_choice = input("\nExporter au format YOLO ? (o/n): ")
    if export_choice.lower() == 'o':
        output_dir = input("Dossier de sortie (défaut: ./yolo_labels): ") or "./yolo_labels"
        annotator.export_to_yolo(output_dir)
