#!/usr/bin/env python3
"""
Script de vérification de la qualité du dataset
"""

import json
import cv2
import numpy as np
from pathlib import Path
import argparse
from collections import defaultdict
import matplotlib.pyplot as plt

class DatasetValidator:
    def __init__(self, annotations_file, images_dir):
        self.annotations_file = Path(annotations_file)
        self.images_dir = Path(images_dir)
        self.errors = []
        self.warnings = []
        self.stats = defaultdict(int)
        
    def load_annotations(self):
        """Charge les annotations COCO"""
        with open(self.annotations_file, 'r') as f:
            return json.load(f)
            
    def validate(self):
        """Valide le dataset"""
        print("🔍 Validation du dataset...")
        print("=" * 60)
        
        data = self.load_annotations()
        
        # Statistiques de base
        self.stats['num_images'] = len(data['images'])
        self.stats['num_annotations'] = len(data['annotations'])
        self.stats['num_classes'] = len(data['categories'])
        
        print(f"📊 Statistiques de base:")
        print(f"   Images: {self.stats['num_images']}")
        print(f"   Annotations: {self.stats['num_annotations']}")
        print(f"   Classes: {self.stats['num_classes']}")
        
        # Valider les images
        self.validate_images(data)
        
        # Valider les annotations
        self.validate_annotations(data)
        
        # Analyser la distribution
        self.analyze_distribution(data)
        
        # Rapport final
        self.print_report()
        
    def validate_images(self, data):
        """Valide les images"""
        print("\n📸 Validation des images...")
        
        for img_info in data['images']:
            img_path = self.images_dir / img_info['file_name']
            
            # Vérifier existence
            if not img_path.exists():
                self.errors.append(f"Image manquante: {img_info['file_name']}")
                continue
                
            # Vérifier dimensions
            img = cv2.imread(str(img_path))
            if img is None:
                self.errors.append(f"Image corrompue: {img_info['file_name']}")
                continue
                
            h, w = img.shape[:2]
            
            if h != img_info['height'] or w != img_info['width']:
                self.warnings.append(
                    f"Dimensions incorrectes: {img_info['file_name']} "
                    f"(déclaré: {img_info['width']}x{img_info['height']}, "
                    f"réel: {w}x{h})"
                )
            
            # Vérifier résolution minimale
            if w < 640 or h < 480:
                self.warnings.append(
                    f"Résolution faible: {img_info['file_name']} ({w}x{h})"
                )
            
            self.stats['total_pixels'] += w * h
            
    def validate_annotations(self, data):
        """Valide les annotations"""
        print("🏷️  Validation des annotations...")
        
        # Grouper par image
        annotations_by_image = defaultdict(list)
        for anno in data['annotations']:
            annotations_by_image[anno['image_id']].append(anno)
        
        # Statistiques par classe
        class_counts = defaultdict(int)
        bbox_sizes = []
        
        for img_info in data['images']:
            img_id = img_info['id']
            annos = annotations_by_image.get(img_id, [])
            
            # Images sans annotations
            if len(annos) == 0:
                self.stats['images_without_annotations'] += 1
            
            for anno in annos:
                class_counts[anno['category_id']] += 1
                
                # Valider bbox
                x, y, w, h = anno['bbox']
                
                # Bbox hors limites
                if x < 0 or y < 0:
                    self.errors.append(
                        f"Bbox hors limites (négatif): image {img_id}"
                    )
                
                if x + w > img_info['width'] or y + h > img_info['height']:
                    self.errors.append(
                        f"Bbox hors limites (dépassement): image {img_id}"
                    )
                
                # Bbox trop petite
                if w < 10 or h < 10:
                    self.warnings.append(
                        f"Bbox très petite ({w}x{h}): image {img_id}"
                    )
                
                # Bbox trop grande (probablement erreur)
                img_area = img_info['width'] * img_info['height']
                bbox_area = w * h
                if bbox_area / img_area > 0.5:
                    self.warnings.append(
                        f"Bbox très grande (>50% image): image {img_id}"
                    )
                
                bbox_sizes.append((w, h))
        
        # Statistiques
        self.stats['class_counts'] = dict(class_counts)
        self.stats['bbox_sizes'] = bbox_sizes
        
        # Moyenne d'annotations par image
        if self.stats['num_images'] > 0:
            self.stats['avg_annotations_per_image'] = \
                self.stats['num_annotations'] / self.stats['num_images']
                
    def analyze_distribution(self, data):
        """Analyse la distribution du dataset"""
        print("\n📈 Analyse de la distribution...")
        
        # Distribution par classe
        print("\n   Distribution par classe:")
        for class_id, count in self.stats['class_counts'].items():
            class_name = next(
                (c['name'] for c in data['categories'] if c['id'] == class_id),
                f"Unknown ({class_id})"
            )
            percentage = (count / self.stats['num_annotations']) * 100
            print(f"      {class_name}: {count} ({percentage:.1f}%)")
        
        # Taille moyenne des bbox
        if self.stats['bbox_sizes']:
            widths = [w for w, h in self.stats['bbox_sizes']]
            heights = [h for w, h in self.stats['bbox_sizes']]
            
            print(f"\n   Taille des bbox:")
            print(f"      Largeur moyenne: {np.mean(widths):.1f}px")
            print(f"      Hauteur moyenne: {np.mean(heights):.1f}px")
            print(f"      Largeur médiane: {np.median(widths):.1f}px")
            print(f"      Hauteur médiane: {np.median(heights):.1f}px")
        
    def print_report(self):
        """Affiche le rapport final"""
        print("\n" + "=" * 60)
        print("📋 RAPPORT FINAL")
        print("=" * 60)
        
        # Résumé
        print(f"\n✅ Images validées: {self.stats['num_images']}")
        print(f"✅ Annotations validées: {self.stats['num_annotations']}")
        print(f"📊 Moyenne annotations/image: {self.stats.get('avg_annotations_per_image', 0):.2f}")
        
        if self.stats.get('images_without_annotations', 0) > 0:
            pct = (self.stats['images_without_annotations'] / self.stats['num_images']) * 100
            print(f"⚠️  Images sans annotation: {self.stats['images_without_annotations']} ({pct:.1f}%)")
        
        # Erreurs
        if self.errors:
            print(f"\n❌ ERREURS ({len(self.errors)}):")
            for error in self.errors[:10]:  # Limiter à 10
                print(f"   • {error}")
            if len(self.errors) > 10:
                print(f"   ... et {len(self.errors) - 10} autres")
        else:
            print("\n✅ Aucune erreur détectée")
        
        # Avertissements
        if self.warnings:
            print(f"\n⚠️  AVERTISSEMENTS ({len(self.warnings)}):")
            for warning in self.warnings[:10]:
                print(f"   • {warning}")
            if len(self.warnings) > 10:
                print(f"   ... et {len(self.warnings) - 10} autres")
        else:
            print("\n✅ Aucun avertissement")
        
        # Recommandations
        print("\n💡 RECOMMANDATIONS:")
        
        if self.stats['num_images'] < 500:
            print("   ⚠️  Dataset petit (<500 images), collectez plus de données")
        elif self.stats['num_images'] < 1000:
            print("   ℹ️  Dataset correct, 1000+ images recommandé pour production")
        else:
            print("   ✅ Taille du dataset excellente")
        
        if self.stats.get('images_without_annotations', 0) / self.stats['num_images'] > 0.1:
            print("   ⚠️  Trop d'images sans annotation (>10%), vérifier la qualité")
        
        # Score global
        score = 100
        score -= len(self.errors) * 5
        score -= len(self.warnings) * 1
        score = max(0, score)
        
        print(f"\n🎯 Score de qualité: {score}/100")
        
        if score >= 90:
            print("   ✅ Excellent dataset, prêt pour l'entraînement")
        elif score >= 70:
            print("   ⚠️  Dataset acceptable, corrections recommandées")
        else:
            print("   ❌ Dataset problématique, corrections nécessaires")
            
    def visualize_distribution(self):
        """Crée des visualisations de la distribution"""
        if not self.stats.get('bbox_sizes'):
            return
            
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Distribution des largeurs
        widths = [w for w, h in self.stats['bbox_sizes']]
        axes[0, 0].hist(widths, bins=30, edgecolor='black')
        axes[0, 0].set_title('Distribution des largeurs de bbox')
        axes[0, 0].set_xlabel('Largeur (px)')
        axes[0, 0].set_ylabel('Fréquence')
        
        # Distribution des hauteurs
        heights = [h for w, h in self.stats['bbox_sizes']]
        axes[0, 1].hist(heights, bins=30, edgecolor='black')
        axes[0, 1].set_title('Distribution des hauteurs de bbox')
        axes[0, 1].set_xlabel('Hauteur (px)')
        axes[0, 1].set_ylabel('Fréquence')
        
        # Scatter largeur vs hauteur
        axes[1, 0].scatter(widths, heights, alpha=0.5)
        axes[1, 0].set_title('Largeur vs Hauteur')
        axes[1, 0].set_xlabel('Largeur (px)')
        axes[1, 0].set_ylabel('Hauteur (px)')
        
        # Distribution par classe
        if self.stats.get('class_counts'):
            classes = list(self.stats['class_counts'].keys())
            counts = list(self.stats['class_counts'].values())
            axes[1, 1].bar(classes, counts)
            axes[1, 1].set_title('Distribution par classe')
            axes[1, 1].set_xlabel('Classe')
            axes[1, 1].set_ylabel('Nombre d\'annotations')
        
        plt.tight_layout()
        plt.savefig('dataset_distribution.png', dpi=150)
        print("\n📊 Graphiques sauvegardés: dataset_distribution.png")

def main():
    parser = argparse.ArgumentParser(
        description='Validation de la qualité du dataset'
    )
    parser.add_argument('--annotations', type=str, required=True,
                       help='Fichier annotations.json')
    parser.add_argument('--images', type=str, required=True,
                       help='Dossier des images')
    parser.add_argument('--visualize', action='store_true',
                       help='Créer des visualisations')
    
    args = parser.parse_args()
    
    validator = DatasetValidator(args.annotations, args.images)
    validator.validate()
    
    if args.visualize:
        validator.visualize_distribution()

if __name__ == '__main__':
    main()
