"""
Inference Demo
Test trained model on images or video
"""

import os
import cv2
import time
import argparse
from pathlib import Path
from ultralytics import YOLO
import numpy as np


class QueenDetectorDemo:
    def __init__(self, model_path, conf_threshold=0.5):
        """
        Initialize demo
        
        Args:
            model_path: Path to trained model
            conf_threshold: Confidence threshold
        """
        print(f"Loading model: {model_path}")
        self.model = YOLO(model_path)
        self.conf_threshold = conf_threshold
        
        # Colors for visualization
        self.colors = {
            0: (0, 255, 0),    # Queen - Green
            1: (255, 0, 0),    # Worker - Blue
        }
        
        self.class_names = {
            0: 'Queen',
            1: 'Worker'
        }
        
    def detect_image(self, image_path, save_path=None, show=True):
        """
        Detect on single image
        
        Args:
            image_path: Path to image
            save_path: Optional path to save result
            show: Whether to display result
        """
        # Read image
        img = cv2.imread(str(image_path))
        if img is None:
            print(f"Error: Cannot read image {image_path}")
            return None
        
        # Run detection
        start_time = time.time()
        results = self.model.predict(
            img,
            conf=self.conf_threshold,
            verbose=False
        )
        inference_time = (time.time() - start_time) * 1000
        
        # Draw results
        result = results[0]
        annotated = img.copy()
        
        detections = []
        for box in result.boxes:
            # Get box coordinates
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            conf = float(box.conf[0])
            cls = int(box.cls[0])
            
            detections.append({
                'class': cls,
                'confidence': conf,
                'bbox': (x1, y1, x2, y2)
            })
            
            # Draw box
            color = self.colors.get(cls, (255, 255, 255))
            cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 3)
            
            # Draw label
            label = f"{self.class_names[cls]}: {conf:.2f}"
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.8
            thickness = 2
            
            (label_w, label_h), _ = cv2.getTextSize(label, font, font_scale, thickness)
            cv2.rectangle(annotated, (x1, y1 - label_h - 10), (x1 + label_w, y1), color, -1)
            cv2.putText(annotated, label, (x1, y1 - 5), font, font_scale, (0, 0, 0), thickness)
            
            # Draw crown emoji for queen
            if cls == 0:
                cv2.putText(annotated, "QUEEN", (x1, y2 + 30), font, 1.0, color, 3)
        
        # Draw stats
        stats_text = [
            f"Detections: {len(detections)}",
            f"Queens: {sum(1 for d in detections if d['class'] == 0)}",
            f"Inference: {inference_time:.1f}ms",
            f"FPS: {1000/inference_time:.1f}"
        ]
        
        y_offset = 30
        for text in stats_text:
            cv2.putText(annotated, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 
                       0.7, (255, 255, 255), 2)
            cv2.putText(annotated, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 
                       0.7, (0, 0, 0), 1)
            y_offset += 30
        
        # Save if requested
        if save_path:
            cv2.imwrite(str(save_path), annotated)
            print(f"Saved result to: {save_path}")
        
        # Show if requested
        if show:
            cv2.imshow('Queen Bee Detection', annotated)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        
        # Print results
        print(f"\nImage: {Path(image_path).name}")
        print(f"Detections: {len(detections)}")
        print(f"Queens found: {sum(1 for d in detections if d['class'] == 0)}")
        print(f"Inference time: {inference_time:.1f}ms")
        
        return detections
    
    def detect_directory(self, dir_path, output_dir=None):
        """
        Detect on all images in directory
        
        Args:
            dir_path: Directory with images
            output_dir: Optional output directory for results
        """
        dir_path = Path(dir_path)
        
        if output_dir:
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
        
        # Find images
        image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
        images = []
        for ext in image_extensions:
            images.extend(list(dir_path.glob(f'*{ext}')))
        
        print(f"\nProcessing {len(images)} images from {dir_path}")
        
        total_queens = 0
        times = []
        
        for img_path in images:
            save_path = None
            if output_dir:
                save_path = output_dir / f"result_{img_path.name}"
            
            start = time.time()
            detections = self.detect_image(img_path, save_path=save_path, show=False)
            times.append((time.time() - start) * 1000)
            
            if detections:
                queens = sum(1 for d in detections if d['class'] == 0)
                total_queens += queens
        
        # Summary
        print("\n" + "="*50)
        print("SUMMARY")
        print("="*50)
        print(f"Images processed: {len(images)}")
        print(f"Total queens found: {total_queens}")
        print(f"Average inference time: {np.mean(times):.1f}ms")
        print(f"Average FPS: {1000/np.mean(times):.1f}")
        print("="*50)
    
    def detect_webcam(self, camera_id=0):
        """
        Run detection on webcam feed
        
        Args:
            camera_id: Camera device ID
        """
        cap = cv2.VideoCapture(camera_id)
        
        if not cap.isOpened():
            print(f"Error: Cannot open camera {camera_id}")
            return
        
        print("\nStarting webcam detection...")
        print("Press 'q' to quit, 's' to save frame")
        
        frame_count = 0
        fps_list = []
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            # Run detection
            start_time = time.time()
            results = self.model.predict(
                frame,
                conf=self.conf_threshold,
                verbose=False
            )
            inference_time = (time.time() - start_time) * 1000
            
            # Draw results
            result = results[0]
            
            queen_count = 0
            for box in result.boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                conf = float(box.conf[0])
                cls = int(box.cls[0])
                
                color = self.colors.get(cls, (255, 255, 255))
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
                
                label = f"{self.class_names[cls]}: {conf:.2f}"
                cv2.putText(frame, label, (x1, y1 - 10), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
                
                if cls == 0:
                    queen_count += 1
            
            # Draw stats
            fps = 1000 / inference_time if inference_time > 0 else 0
            fps_list.append(fps)
            avg_fps = np.mean(fps_list[-30:])  # Average of last 30 frames
            
            cv2.putText(frame, f"FPS: {avg_fps:.1f}", (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            cv2.putText(frame, f"Queens: {queen_count}", (10, 70), 
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            # Show frame
            cv2.imshow('Queen Bee Detection - Webcam', frame)
            
            # Handle keys
            key = cv2.waitKey(1) & 0xFF
            if key == ord('q'):
                break
            elif key == ord('s'):
                filename = f"webcam_frame_{frame_count}.jpg"
                cv2.imwrite(filename, frame)
                print(f"Saved: {filename}")
            
            frame_count += 1
        
        cap.release()
        cv2.destroyAllWindows()


def main():
    """Main entry point"""
    parser = argparse.ArgumentParser(description='Queen Bee Detection Demo')
    parser.add_argument('--model', type=str, default='runs/detect/train/weights/best.pt',
                       help='Path to trained model')
    parser.add_argument('--conf', type=float, default=0.5,
                       help='Confidence threshold (default: 0.5)')
    parser.add_argument('--source', type=str, required=True,
                       help='Source: image path, directory, or "webcam"')
    parser.add_argument('--output', type=str, default=None,
                       help='Output directory for results')
    parser.add_argument('--show', action='store_true',
                       help='Show results (for single image)')
    
    args = parser.parse_args()
    
    # Check model exists
    if not os.path.exists(args.model):
        print(f"Error: Model not found at {args.model}")
        print("Please train the model first using train_yolo.py")
        return
    
    # Initialize detector
    detector = QueenDetectorDemo(args.model, args.conf)
    
    # Run detection based on source type
    if args.source.lower() == 'webcam':
        detector.detect_webcam()
    elif os.path.isfile(args.source):
        detector.detect_image(args.source, save_path=args.output, show=args.show)
    elif os.path.isdir(args.source):
        detector.detect_directory(args.source, output_dir=args.output)
    else:
        print(f"Error: Invalid source: {args.source}")


if __name__ == '__main__':
    main()
