"""
Dataset Validator
Check dataset integrity and quality
"""

import os
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


class DatasetValidator:
    def __init__(self, dataset_dir):
        self.dataset_dir = Path(dataset_dir)
        self.issues = []
        
    def validate(self):
        """Run all validation checks"""
        print("🔍 Validating dataset...")
        print(f"Directory: {self.dataset_dir}")
        print("="*60)
        
        self.check_structure()
        self.check_image_label_pairs()
        self.check_image_quality()
        self.check_annotation_quality()
        self.generate_statistics()
        
        print("\n" + "="*60)
        if not self.issues:
            print("✅ Dataset validation passed!")
        else:
            print(f"⚠️  Found {len(self.issues)} issues:")
            for issue in self.issues[:10]:  # Show first 10
                print(f"  - {issue}")
            if len(self.issues) > 10:
                print(f"  ... and {len(self.issues) - 10} more")
        
        return len(self.issues) == 0
    
    def check_structure(self):
        """Check directory structure"""
        print("\n1. Checking directory structure...")
        
        required_dirs = [
            'train/images', 'train/labels',
            'val/images', 'val/labels'
        ]
        
        for dir_path in required_dirs:
            full_path = self.dataset_dir / dir_path
            if not full_path.exists():
                self.issues.append(f"Missing directory: {dir_path}")
                print(f"  ❌ Missing: {dir_path}")
            else:
                count = len(list(full_path.glob('*')))
                print(f"  ✓ {dir_path}: {count} files")
    
    def check_image_label_pairs(self):
        """Check that each image has a corresponding label"""
        print("\n2. Checking image-label pairs...")
        
        for split in ['train', 'val', 'test']:
            img_dir = self.dataset_dir / split / 'images'
            lbl_dir = self.dataset_dir / split / 'labels'
            
            if not img_dir.exists():
                continue
                
            images = set(f.stem for f in img_dir.glob('*') if f.suffix.lower() in ['.jpg', '.jpeg', '.png'])
            labels = set(f.stem for f in lbl_dir.glob('*.txt'))
            
            missing_labels = images - labels
            orphan_labels = labels - images
            
            if missing_labels:
                self.issues.append(f"{split}: {len(missing_labels)} images without labels")
                print(f"  ⚠️  {split}: {len(missing_labels)} images missing labels")
            
            if orphan_labels:
                self.issues.append(f"{split}: {len(orphan_labels)} labels without images")
                print(f"  ⚠️  {split}: {len(orphan_labels)} orphan labels")
            
            if not missing_labels and not orphan_labels:
                print(f"  ✓ {split}: All pairs match ({len(images)} images)")
    
    def check_image_quality(self):
        """Check image dimensions and quality"""
        print("\n3. Checking image quality...")
        
        min_size = 416  # Minimum recommended size
        sizes = []
        
        for split in ['train', 'val', 'test']:
            img_dir = self.dataset_dir / split / 'images'
            if not img_dir.exists():
                continue
                
            for img_path in img_dir.glob('*'):
                if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png']:
                    continue
                    
                try:
                    img = Image.open(img_path)
                    w, h = img.size
                    sizes.append((w, h))
                    
                    if w < min_size or h < min_size:
                        self.issues.append(f"Image too small: {img_path.name} ({w}x{h})")
                        
                except Exception as e:
                    self.issues.append(f"Cannot open image: {img_path.name} - {e}")
        
        if sizes:
            sizes = np.array(sizes)
            print(f"  ✓ Image sizes: {sizes.shape[0]} images")
            print(f"    - Mean: {sizes.mean(axis=0).astype(int)}")
            print(f"    - Min: {sizes.min(axis=0)}")
            print(f"    - Max: {sizes.max(axis=0)}")
    
    def check_annotation_quality(self):
        """Check annotation format and validity"""
        print("\n4. Checking annotation quality...")
        
        total_annotations = 0
        invalid_annotations = 0
        
        for split in ['train', 'val', 'test']:
            lbl_dir = self.dataset_dir / split / 'labels'
            if not lbl_dir.exists():
                continue
                
            for lbl_path in lbl_dir.glob('*.txt'):
                with open(lbl_path, 'r') as f:
                    for line_num, line in enumerate(f, 1):
                        parts = line.strip().split()
                        
                        if len(parts) != 5:
                            self.issues.append(f"{lbl_path.name}:{line_num} - Invalid format")
                            invalid_annotations += 1
                            continue
                        
                        try:
                            cls, x, y, w, h = map(float, parts)
                            
                            # Check bounds
                            if not (0 <= x <= 1 and 0 <= y <= 1 and 0 <= w <= 1 and 0 <= h <= 1):
                                self.issues.append(f"{lbl_path.name}:{line_num} - Out of bounds")
                                invalid_annotations += 1
                            
                            # Check class
                            if cls not in [0, 1]:
                                self.issues.append(f"{lbl_path.name}:{line_num} - Invalid class")
                                invalid_annotations += 1
                            
                            total_annotations += 1
                            
                        except ValueError:
                            self.issues.append(f"{lbl_path.name}:{line_num} - Invalid numbers")
                            invalid_annotations += 1
        
        print(f"  ✓ Total annotations: {total_annotations}")
        if invalid_annotations > 0:
            print(f"  ⚠️  Invalid annotations: {invalid_annotations}")
        else:
            print(f"  ✓ All annotations valid")
    
    def generate_statistics(self):
        """Generate dataset statistics"""
        print("\n5. Dataset Statistics:")
        
        stats = {
            'train': {'images': 0, 'queens': 0, 'workers': 0},
            'val': {'images': 0, 'queens': 0, 'workers': 0},
            'test': {'images': 0, 'queens': 0, 'workers': 0}
        }
        
        for split in ['train', 'val', 'test']:
            img_dir = self.dataset_dir / split / 'images'
            lbl_dir = self.dataset_dir / split / 'labels'
            
            if not img_dir.exists():
                continue
            
            stats[split]['images'] = len(list(img_dir.glob('*')))
            
            if lbl_dir.exists():
                for lbl_path in lbl_dir.glob('*.txt'):
                    with open(lbl_path, 'r') as f:
                        for line in f:
                            parts = line.strip().split()
                            if len(parts) == 5:
                                cls = int(float(parts[0]))
                                if cls == 0:
                                    stats[split]['queens'] += 1
                                else:
                                    stats[split]['workers'] += 1
        
        total_images = sum(s['images'] for s in stats.values())
        total_queens = sum(s['queens'] for s in stats.values())
        total_workers = sum(s['workers'] for s in stats.values())
        
        print(f"\n  SPLIT     IMAGES    QUEENS    WORKERS")
        print(f"  -----     ------    ------    -------")
        for split in ['train', 'val', 'test']:
            s = stats[split]
            print(f"  {split:5s}     {s['images']:6d}    {s['queens']:6d}    {s['workers']:7d}")
        print(f"  -----     ------    ------    -------")
        print(f"  TOTAL     {total_images:6d}    {total_queens:6d}    {total_workers:7d}")
        
        # Recommendations
        print("\n6. Recommendations:")
        
        if total_images < 500:
            print("  ⚠️  Dataset too small. Recommended: 500+ images")
        elif total_images < 1000:
            print("  ⚡ Decent size. Better: 1000+ images")
        else:
            print("  ✅ Good dataset size!")
        
        if total_queens < total_images * 0.5:
            print("  ⚠️  Low queen count. Ensure most images have queens annotated")
        
        if stats['train']['images'] < 300:
            print("  ⚠️  Training set too small. Minimum: 300 images")


def main():
    """Main entry point"""
    import argparse
    
    parser = argparse.ArgumentParser(description='Validate queen bee dataset')
    parser.add_argument('--dataset', type=str, default='../dataset',
                       help='Path to dataset directory')
    
    args = parser.parse_args()
    
    validator = DatasetValidator(args.dataset)
    is_valid = validator.validate()
    
    exit(0 if is_valid else 1)


if __name__ == '__main__':
    main()
