#!/usr/bin/env python3
"""
Outil d'annotation interactif pour créer le dataset de reines d'abeilles.
Permet d'annoter rapidement des images avec des boîtes englobantes.
"""

import cv2
import os
import json
import numpy as np
from pathlib import Path
import argparse

class BeeAnnotationTool:
    def __init__(self, image_dir, output_dir):
        self.image_dir = Path(image_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Classes
        self.classes = {
            0: 'queen',
            1: 'worker'  # optionnel
        }
        self.current_class = 0
        
        # État de l'annotation
        self.images = list(self.image_dir.glob('*.jpg')) + \
                     list(self.image_dir.glob('*.jpeg')) + \
                     list(self.image_dir.glob('*.png'))
        self.current_idx = 0
        self.annotations = {}
        
        # Dessin
        self.drawing = False
        self.start_point = None
        self.current_boxes = []
        
        # Charger annotations existantes
        self.load_annotations()
        
    def load_annotations(self):
        """Charge les annotations existantes"""
        anno_file = self.output_dir / 'annotations.json'
        if anno_file.exists():
            with open(anno_file, 'r') as f:
                self.annotations = json.load(f)
                
    def save_annotations(self):
        """Sauvegarde les annotations au format COCO et YOLO"""
        # Format COCO JSON
        coco_format = {
            'images': [],
            'annotations': [],
            'categories': [{'id': k, 'name': v} for k, v in self.classes.items()]
        }
        
        annotation_id = 0
        for img_idx, img_path in enumerate(self.images):
            img_name = img_path.name
            if img_name not in self.annotations:
                continue
                
            img = cv2.imread(str(img_path))
            h, w = img.shape[:2]
            
            # Image info
            coco_format['images'].append({
                'id': img_idx,
                'file_name': img_name,
                'width': w,
                'height': h
            })
            
            # Annotations pour cette image
            for box in self.annotations[img_name]:
                x1, y1, x2, y2, class_id = box
                bbox_w = x2 - x1
                bbox_h = y2 - y1
                
                coco_format['annotations'].append({
                    'id': annotation_id,
                    'image_id': img_idx,
                    'category_id': class_id,
                    'bbox': [x1, y1, bbox_w, bbox_h],
                    'area': bbox_w * bbox_h,
                    'iscrowd': 0
                })
                annotation_id += 1
                
                # Format YOLO (fichier txt par image)
                yolo_file = self.output_dir / 'labels' / f"{img_path.stem}.txt"
                yolo_file.parent.mkdir(exist_ok=True)
                
                # YOLO: class_id x_center y_center width height (normalisé)
                x_center = ((x1 + x2) / 2) / w
                y_center = ((y1 + y2) / 2) / h
                norm_w = bbox_w / w
                norm_h = bbox_h / h
                
                with open(yolo_file, 'a') as f:
                    f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {norm_w:.6f} {norm_h:.6f}\n")
        
        # Sauvegarder COCO JSON
        with open(self.output_dir / 'annotations.json', 'w') as f:
            json.dump(coco_format, f, indent=2)
            
        # Fichier de configuration YOLO
        with open(self.output_dir / 'data.yaml', 'w') as f:
            f.write(f"path: {self.output_dir.absolute()}\n")
            f.write(f"train: images\n")
            f.write(f"val: images\n")
            f.write(f"nc: {len(self.classes)}\n")
            f.write(f"names: {list(self.classes.values())}\n")
            
        print(f"✅ Annotations sauvegardées : {len(coco_format['annotations'])} boîtes")
        
    def mouse_callback(self, event, x, y, flags, param):
        """Callback pour les événements souris"""
        if event == cv2.EVENT_LBUTTONDOWN:
            self.drawing = True
            self.start_point = (x, y)
            
        elif event == cv2.EVENT_MOUSEMOVE:
            if self.drawing:
                self.current_end = (x, y)
                
        elif event == cv2.EVENT_LBUTTONUP:
            self.drawing = False
            if self.start_point:
                x1, y1 = self.start_point
                x2, y2 = x, y
                
                # Normaliser les coordonnées
                x1, x2 = min(x1, x2), max(x1, x2)
                y1, y2 = min(y1, y2), max(y1, y2)
                
                # Ajouter la boîte
                self.current_boxes.append([x1, y1, x2, y2, self.current_class])
                
    def draw_boxes(self, img):
        """Dessine toutes les boîtes sur l'image"""
        img_copy = img.copy()
        
        # Boîtes sauvegardées
        for box in self.current_boxes:
            x1, y1, x2, y2, class_id = box
            color = (0, 255, 0) if class_id == 0 else (255, 0, 0)
            cv2.rectangle(img_copy, (x1, y1), (x2, y2), color, 2)
            label = self.classes[class_id]
            cv2.putText(img_copy, label, (x1, y1-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
        
        # Boîte en cours de dessin
        if self.drawing and self.start_point:
            x1, y1 = self.start_point
            x2, y2 = self.current_end if hasattr(self, 'current_end') else self.start_point
            color = (0, 255, 255)
            cv2.rectangle(img_copy, (x1, y1), (x2, y2), color, 2)
            
        return img_copy
        
    def run(self):
        """Lance l'interface d'annotation"""
        if not self.images:
            print("❌ Aucune image trouvée dans", self.image_dir)
            return
            
        cv2.namedWindow('Annotation')
        cv2.setMouseCallback('Annotation', self.mouse_callback)
        
        print("\n🐝 Queen Bee Annotation Tool")
        print("=" * 50)
        print("Contrôles :")
        print("  Clic gauche : Dessiner une boîte")
        print("  Q : Reine (classe 0)")
        print("  W : Ouvrière (classe 1)")
        print("  Z : Supprimer dernière boîte")
        print("  S : Sauvegarder et image suivante")
        print("  A : Image précédente")
        print("  ESC : Quitter")
        print("=" * 50)
        
        while self.current_idx < len(self.images):
            img_path = self.images[self.current_idx]
            img = cv2.imread(str(img_path))
            
            if img is None:
                print(f"⚠️  Impossible de charger {img_path}")
                self.current_idx += 1
                continue
            
            # Charger annotations existantes pour cette image
            img_name = img_path.name
            self.current_boxes = self.annotations.get(img_name, [])
            
            while True:
                display_img = self.draw_boxes(img)
                
                # Afficher infos
                info_text = f"Image {self.current_idx + 1}/{len(self.images)} | "
                info_text += f"Classe: {self.classes[self.current_class]} | "
                info_text += f"Boîtes: {len(self.current_boxes)}"
                
                cv2.putText(display_img, info_text, (10, 30),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
                cv2.putText(display_img, img_name, (10, 60),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
                
                cv2.imshow('Annotation', display_img)
                key = cv2.waitKey(1) & 0xFF
                
                if key == ord('q'):  # Reine
                    self.current_class = 0
                elif key == ord('w'):  # Ouvrière
                    self.current_class = 1
                elif key == ord('z'):  # Supprimer dernière boîte
                    if self.current_boxes:
                        self.current_boxes.pop()
                elif key == ord('s'):  # Sauvegarder et suivant
                    self.annotations[img_name] = self.current_boxes
                    self.save_annotations()
                    self.current_idx += 1
                    break
                elif key == ord('a'):  # Précédent
                    self.annotations[img_name] = self.current_boxes
                    self.save_annotations()
                    self.current_idx = max(0, self.current_idx - 1)
                    break
                elif key == 27:  # ESC
                    self.annotations[img_name] = self.current_boxes
                    self.save_annotations()
                    cv2.destroyAllWindows()
                    return
                    
        cv2.destroyAllWindows()
        print("\n✅ Annotation terminée !")
        self.save_annotations()

def main():
    parser = argparse.ArgumentParser(description='Outil d\'annotation pour reines d\'abeilles')
    parser.add_argument('--images', type=str, required=True,
                       help='Dossier contenant les images à annoter')
    parser.add_argument('--output', type=str, default='annotations',
                       help='Dossier de sortie pour les annotations')
    
    args = parser.parse_args()
    
    tool = BeeAnnotationTool(args.images, args.output)
    tool.run()

if __name__ == '__main__':
    main()
