"""
Model Evaluation Script
Evaluate trained model performance and generate metrics
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from ultralytics import YOLO
from sklearn.metrics import confusion_matrix, classification_report
import cv2


class ModelEvaluator:
    def __init__(self, model_path, data_yaml):
        """
        Initialize evaluator
        
        Args:
            model_path: Path to trained model weights
            data_yaml: Path to dataset YAML
        """
        self.model = YOLO(model_path)
        self.data_yaml = data_yaml
        self.results = None
        
    def evaluate_on_test_set(self, conf_threshold=0.5):
        """Run evaluation on test set"""
        print("Evaluating model on test set...")
        
        results = self.model.val(
            data=self.data_yaml,
            split='test',
            conf=conf_threshold,
            save_json=True,
            plots=True
        )
        
        self.results = results
        
        # Print metrics
        print("\n" + "="*50)
        print("EVALUATION RESULTS")
        print("="*50)
        print(f"Precision: {results.box.p[0]:.4f}")
        print(f"Recall: {results.box.r[0]:.4f}")
        print(f"mAP@50: {results.box.map50:.4f}")
        print(f"mAP@50-95: {results.box.map:.4f}")
        print("="*50)
        
        return results
    
    def inference_speed_test(self, image_path, num_runs=100):
        """Test inference speed"""
        print(f"\nRunning speed test ({num_runs} iterations)...")
        
        img = cv2.imread(image_path)
        
        import time
        times = []
        
        # Warmup
        for _ in range(10):
            _ = self.model.predict(img, verbose=False)
        
        # Actual timing
        for _ in range(num_runs):
            start = time.time()
            _ = self.model.predict(img, verbose=False)
            times.append(time.time() - start)
        
        avg_time = np.mean(times) * 1000  # Convert to ms
        std_time = np.std(times) * 1000
        fps = 1000 / avg_time
        
        print("\n" + "="*50)
        print("INFERENCE SPEED")
        print("="*50)
        print(f"Average time: {avg_time:.2f} ± {std_time:.2f} ms")
        print(f"FPS: {fps:.1f}")
        print("="*50)
        
        return avg_time, fps
    
    def visualize_predictions(self, image_dir, output_dir='evaluation/predictions', num_samples=10):
        """Visualize predictions on sample images"""
        print(f"\nGenerating prediction visualizations...")
        
        os.makedirs(output_dir, exist_ok=True)
        
        images = list(Path(image_dir).glob('*.jpg'))[:num_samples]
        
        for img_path in images:
            results = self.model.predict(
                source=str(img_path),
                conf=0.5,
                save=False
            )
            
            # Plot
            result = results[0]
            img = result.plot()
            
            # Save
            output_path = os.path.join(output_dir, img_path.name)
            cv2.imwrite(output_path, img)
        
        print(f"Saved {len(images)} visualizations to {output_dir}")
    
    def analyze_errors(self, predictions, ground_truth):
        """Analyze false positives and false negatives"""
        print("\nAnalyzing prediction errors...")
        
        # Count error types
        false_positives = 0
        false_negatives = 0
        true_positives = 0
        
        # This is simplified - in production, match predictions to ground truth by IoU
        
        print(f"False Positives: {false_positives}")
        print(f"False Negatives: {false_negatives}")
        print(f"True Positives: {true_positives}")
    
    def generate_report(self, output_path='evaluation/report.json'):
        """Generate comprehensive evaluation report"""
        if self.results is None:
            print("Run evaluate_on_test_set() first")
            return
        
        report = {
            'precision': float(self.results.box.p[0]),
            'recall': float(self.results.box.r[0]),
            'map50': float(self.results.box.map50),
            'map50_95': float(self.results.box.map),
            'total_images_tested': len(self.results.results_dict) if hasattr(self.results, 'results_dict') else 0,
        }
        
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2)
        
        print(f"\nReport saved to {output_path}")
        
        return report


def main():
    """Main evaluation pipeline"""
    
    # Paths
    MODEL_PATH = 'runs/detect/train/weights/best.pt'
    DATA_YAML = 'data/queen_bee.yaml'
    TEST_IMAGE = 'test_images/sample.jpg'
    
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model not found at {MODEL_PATH}")
        print("Please train the model first using train_yolo.py")
        return
    
    print("🐝 Queen Bee Detector - Model Evaluation 🐝\n")
    
    # Initialize evaluator
    evaluator = ModelEvaluator(MODEL_PATH, DATA_YAML)
    
    # Run evaluation
    evaluator.evaluate_on_test_set()
    
    # Speed test (if test image exists)
    if os.path.exists(TEST_IMAGE):
        evaluator.inference_speed_test(TEST_IMAGE)
    
    # Generate visualizations
    # evaluator.visualize_predictions('dataset/images/test')
    
    # Generate report
    evaluator.generate_report()
    
    print("\n✅ Evaluation complete!")


if __name__ == '__main__':
    main()
