"""
Queen Bee Annotation Tool
Graphical interface for annotating queen bees in images
"""

import os
import json
import tkinter as tk
from tkinter import filedialog, messagebox, Canvas
from PIL import Image, ImageTk
import glob


class AnnotationTool:
    def __init__(self, root):
        self.root = root
        self.root.title("🐝 Queen Bee Annotator")
        self.root.geometry("1400x900")
        
        # State
        self.image_dir = None
        self.output_dir = None
        self.images = []
        self.current_index = 0
        self.current_image = None
        self.current_photo = None
        self.annotations = []
        self.current_annotation = None
        self.drawing = False
        self.start_x = None
        self.start_y = None
        self.scale = 1.0
        
        # Current class (0=queen, 1=worker)
        self.current_class = 0
        
        self.setup_ui()
        
    def setup_ui(self):
        """Setup user interface"""
        
        # Top menu
        menu_frame = tk.Frame(self.root, bg='#2c3e50', height=60)
        menu_frame.pack(side=tk.TOP, fill=tk.X)
        
        # Buttons
        tk.Button(
            menu_frame, 
            text="📁 Ouvrir Dossier", 
            command=self.load_directory,
            bg='#3498db',
            fg='white',
            font=('Arial', 12, 'bold'),
            padx=20,
            pady=10
        ).pack(side=tk.LEFT, padx=10, pady=10)
        
        tk.Button(
            menu_frame,
            text="💾 Sauvegarder",
            command=self.save_annotation,
            bg='#2ecc71',
            fg='white',
            font=('Arial', 12, 'bold'),
            padx=20,
            pady=10
        ).pack(side=tk.LEFT, padx=5, pady=10)
        
        tk.Button(
            menu_frame,
            text="⏭️ Suivant",
            command=self.next_image,
            bg='#e74c3c',
            fg='white',
            font=('Arial', 12, 'bold'),
            padx=20,
            pady=10
        ).pack(side=tk.LEFT, padx=5, pady=10)
        
        tk.Button(
            menu_frame,
            text="⏮️ Précédent",
            command=self.previous_image,
            bg='#e74c3c',
            fg='white',
            font=('Arial', 12, 'bold'),
            padx=20,
            pady=10
        ).pack(side=tk.LEFT, padx=5, pady=10)
        
        # Class selector
        class_frame = tk.Frame(menu_frame, bg='#2c3e50')
        class_frame.pack(side=tk.LEFT, padx=20)
        
        tk.Label(
            class_frame,
            text="Classe:",
            bg='#2c3e50',
            fg='white',
            font=('Arial', 10)
        ).pack(side=tk.LEFT, padx=5)
        
        self.class_var = tk.IntVar(value=0)
        tk.Radiobutton(
            class_frame,
            text="👑 Reine",
            variable=self.class_var,
            value=0,
            bg='#2c3e50',
            fg='#f39c12',
            selectcolor='#34495e',
            font=('Arial', 10, 'bold')
        ).pack(side=tk.LEFT)
        
        tk.Radiobutton(
            class_frame,
            text="🐝 Ouvrière",
            variable=self.class_var,
            value=1,
            bg='#2c3e50',
            fg='white',
            selectcolor='#34495e',
            font=('Arial', 10)
        ).pack(side=tk.LEFT)
        
        # Info label
        self.info_label = tk.Label(
            menu_frame,
            text="Aucune image chargée",
            bg='#2c3e50',
            fg='white',
            font=('Arial', 10)
        )
        self.info_label.pack(side=tk.RIGHT, padx=20)
        
        # Main container
        main_frame = tk.Frame(self.root)
        main_frame.pack(fill=tk.BOTH, expand=True)
        
        # Canvas for image
        self.canvas = Canvas(main_frame, bg='#34495e', cursor='cross')
        self.canvas.pack(fill=tk.BOTH, expand=True)
        
        # Bind events
        self.canvas.bind('<Button-1>', self.on_mouse_down)
        self.canvas.bind('<B1-Motion>', self.on_mouse_drag)
        self.canvas.bind('<ButtonRelease-1>', self.on_mouse_up)
        self.root.bind('<Delete>', self.delete_last_annotation)
        self.root.bind('<Left>', lambda e: self.previous_image())
        self.root.bind('<Right>', lambda e: self.next_image())
        self.root.bind('<s>', lambda e: self.save_annotation())
        
        # Instructions
        instructions = tk.Label(
            self.root,
            text="Instructions: Cliquez-glissez pour annoter | Suppr: effacer dernier | ←/→: naviguer | S: sauvegarder",
            bg='#ecf0f1',
            font=('Arial', 9),
            pady=5
        )
        instructions.pack(side=tk.BOTTOM, fill=tk.X)
        
    def load_directory(self):
        """Load image directory"""
        directory = filedialog.askdirectory(title="Sélectionner le dossier d'images")
        
        if not directory:
            return
            
        self.image_dir = directory
        self.output_dir = os.path.join(directory, 'labels')
        os.makedirs(self.output_dir, exist_ok=True)
        
        # Load images
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
        self.images = []
        for ext in extensions:
            self.images.extend(glob.glob(os.path.join(directory, ext)))
        
        self.images.sort()
        
        if not self.images:
            messagebox.showwarning("Aucune image", "Aucune image trouvée dans ce dossier")
            return
        
        self.current_index = 0
        self.load_image()
        
    def load_image(self):
        """Load current image"""
        if not self.images:
            return
            
        image_path = self.images[self.current_index]
        self.current_image = Image.open(image_path)
        
        # Resize to fit canvas
        canvas_width = self.canvas.winfo_width()
        canvas_height = self.canvas.winfo_height()
        
        if canvas_width > 1 and canvas_height > 1:
            img_width, img_height = self.current_image.size
            
            scale_w = canvas_width / img_width
            scale_h = canvas_height / img_height
            self.scale = min(scale_w, scale_h, 1.0)
            
            new_width = int(img_width * self.scale)
            new_height = int(img_height * self.scale)
            
            resized = self.current_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            self.current_photo = ImageTk.PhotoImage(resized)
            
            self.canvas.delete('all')
            self.canvas.create_image(0, 0, anchor=tk.NW, image=self.current_photo)
        
        # Load existing annotations
        self.load_existing_annotations()
        
        # Update info
        filename = os.path.basename(image_path)
        self.info_label.config(
            text=f"{self.current_index + 1}/{len(self.images)} - {filename}"
        )
        
    def load_existing_annotations(self):
        """Load existing annotations for current image"""
        if not self.images:
            return
            
        image_path = self.images[self.current_index]
        label_path = os.path.join(
            self.output_dir,
            os.path.splitext(os.path.basename(image_path))[0] + '.txt'
        )
        
        self.annotations = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 5:
                        cls, x, y, w, h = map(float, parts)
                        self.annotations.append({
                            'class': int(cls),
                            'x': x,
                            'y': y,
                            'w': w,
                            'h': h
                        })
        
        self.draw_annotations()
        
    def draw_annotations(self):
        """Draw all annotations on canvas"""
        if not self.current_image:
            return
            
        img_width, img_height = self.current_image.size
        
        for ann in self.annotations:
            # Convert from YOLO format to pixel coordinates
            x = ann['x'] * img_width * self.scale
            y = ann['y'] * img_height * self.scale
            w = ann['w'] * img_width * self.scale
            h = ann['h'] * img_height * self.scale
            
            x1 = x - w/2
            y1 = y - h/2
            x2 = x + w/2
            y2 = y + h/2
            
            color = '#f39c12' if ann['class'] == 0 else '#3498db'
            
            self.canvas.create_rectangle(
                x1, y1, x2, y2,
                outline=color,
                width=3,
                tags='annotation'
            )
        
    def on_mouse_down(self, event):
        """Mouse down event"""
        self.drawing = True
        self.start_x = event.x
        self.start_y = event.y
        
    def on_mouse_drag(self, event):
        """Mouse drag event"""
        if not self.drawing:
            return
            
        # Remove previous rectangle
        self.canvas.delete('current')
        
        # Draw new rectangle
        color = '#f39c12' if self.class_var.get() == 0 else '#3498db'
        self.canvas.create_rectangle(
            self.start_x, self.start_y,
            event.x, event.y,
            outline=color,
            width=3,
            tags='current'
        )
        
    def on_mouse_up(self, event):
        """Mouse up event"""
        if not self.drawing or not self.current_image:
            return
            
        self.drawing = False
        
        # Calculate annotation
        img_width, img_height = self.current_image.size
        
        x1 = min(self.start_x, event.x) / self.scale
        y1 = min(self.start_y, event.y) / self.scale
        x2 = max(self.start_x, event.x) / self.scale
        y2 = max(self.start_y, event.y) / self.scale
        
        # Convert to YOLO format (normalized center x, y, width, height)
        center_x = ((x1 + x2) / 2) / img_width
        center_y = ((y1 + y2) / 2) / img_height
        width = (x2 - x1) / img_width
        height = (y2 - y1) / img_height
        
        # Add annotation
        self.annotations.append({
            'class': self.class_var.get(),
            'x': center_x,
            'y': center_y,
            'w': width,
            'h': height
        })
        
        # Redraw
        self.canvas.delete('current')
        self.draw_annotations()
        
    def delete_last_annotation(self, event=None):
        """Delete last annotation"""
        if self.annotations:
            self.annotations.pop()
            self.canvas.delete('annotation')
            self.draw_annotations()
        
    def save_annotation(self):
        """Save annotations to file"""
        if not self.images or not self.current_image:
            return
            
        image_path = self.images[self.current_index]
        label_path = os.path.join(
            self.output_dir,
            os.path.splitext(os.path.basename(image_path))[0] + '.txt'
        )
        
        with open(label_path, 'w') as f:
            for ann in self.annotations:
                f.write(f"{ann['class']} {ann['x']:.6f} {ann['y']:.6f} {ann['w']:.6f} {ann['h']:.6f}\n")
        
        messagebox.showinfo("Sauvegardé", f"Annotations sauvegardées: {len(self.annotations)}")
        
    def next_image(self):
        """Load next image"""
        if not self.images:
            return
            
        self.save_annotation()
        self.current_index = (self.current_index + 1) % len(self.images)
        self.load_image()
        
    def previous_image(self):
        """Load previous image"""
        if not self.images:
            return
            
        self.save_annotation()
        self.current_index = (self.current_index - 1) % len(self.images)
        self.load_image()


def main():
    """Main entry point"""
    root = tk.Tk()
    app = AnnotationTool(root)
    root.mainloop()


if __name__ == '__main__':
    main()
