"""
Dataset Splitter
Split dataset into train/validation/test sets
"""

import os
import shutil
import random
from pathlib import Path
from tqdm import tqdm


class DatasetSplitter:
    def __init__(self, source_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
        """
        Initialize dataset splitter
        
        Args:
            source_dir: Directory with images/ and labels/
            output_dir: Output directory for split dataset
            train_ratio: Proportion for training (default: 0.7)
            val_ratio: Proportion for validation (default: 0.2)
            test_ratio: Proportion for testing (default: 0.1)
        """
        self.source_dir = Path(source_dir)
        self.output_dir = Path(output_dir)
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        
        # Validate ratios
        if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.01:
            raise ValueError("Ratios must sum to 1.0")
        
    def split_dataset(self, seed=42):
        """Split the dataset"""
        print("🐝 Splitting dataset...")
        print(f"Source: {self.source_dir}")
        print(f"Output: {self.output_dir}")
        print(f"Ratios - Train: {self.train_ratio}, Val: {self.val_ratio}, Test: {self.test_ratio}")
        
        # Set random seed
        random.seed(seed)
        
        # Create output directories
        for split in ['train', 'val', 'test']:
            (self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
            (self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
        
        # Get all images
        image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
        images = []
        for ext in image_extensions:
            images.extend(list((self.source_dir / 'images').glob(f'*{ext}')))
        
        # Shuffle
        random.shuffle(images)
        
        total = len(images)
        train_count = int(total * self.train_ratio)
        val_count = int(total * self.val_ratio)
        
        # Split indices
        train_images = images[:train_count]
        val_images = images[train_count:train_count + val_count]
        test_images = images[train_count + val_count:]
        
        print(f"\nTotal images: {total}")
        print(f"Train: {len(train_images)}")
        print(f"Val: {len(val_images)}")
        print(f"Test: {len(test_images)}")
        
        # Copy files
        print("\nCopying files...")
        
        splits = {
            'train': train_images,
            'val': val_images,
            'test': test_images
        }
        
        for split_name, split_images in splits.items():
            print(f"\nProcessing {split_name} set...")
            
            for img_path in tqdm(split_images):
                # Copy image
                dest_img = self.output_dir / split_name / 'images' / img_path.name
                shutil.copy(img_path, dest_img)
                
                # Copy label if exists
                label_path = self.source_dir / 'labels' / f"{img_path.stem}.txt"
                if label_path.exists():
                    dest_label = self.output_dir / split_name / 'labels' / f"{img_path.stem}.txt"
                    shutil.copy(label_path, dest_label)
        
        # Generate statistics
        self.print_statistics()
        
        print("\n✅ Dataset split complete!")
        
    def print_statistics(self):
        """Print dataset statistics"""
        print("\n" + "="*50)
        print("DATASET STATISTICS")
        print("="*50)
        
        for split in ['train', 'val', 'test']:
            images = list((self.output_dir / split / 'images').glob('*'))
            labels = list((self.output_dir / split / 'labels').glob('*.txt'))
            
            # Count annotations
            total_annotations = 0
            queen_count = 0
            worker_count = 0
            
            for label_file in labels:
                with open(label_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) == 5:
                            total_annotations += 1
                            cls = int(parts[0])
                            if cls == 0:
                                queen_count += 1
                            else:
                                worker_count += 1
            
            print(f"\n{split.upper()}:")
            print(f"  Images: {len(images)}")
            print(f"  Labels: {len(labels)}")
            print(f"  Total annotations: {total_annotations}")
            print(f"  Queens: {queen_count}")
            print(f"  Workers: {worker_count}")
            if len(images) > 0:
                print(f"  Avg annotations/image: {total_annotations / len(images):.2f}")
        
        print("="*50)


def main():
    """Main entry point"""
    import argparse
    
    parser = argparse.ArgumentParser(description='Split dataset into train/val/test')
    parser.add_argument('--source', type=str, required=True,
                       help='Source directory with images/ and labels/')
    parser.add_argument('--output', type=str, required=True,
                       help='Output directory for split dataset')
    parser.add_argument('--train', type=float, default=0.7,
                       help='Training set ratio (default: 0.7)')
    parser.add_argument('--val', type=float, default=0.2,
                       help='Validation set ratio (default: 0.2)')
    parser.add_argument('--test', type=float, default=0.1,
                       help='Test set ratio (default: 0.1)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed (default: 42)')
    
    args = parser.parse_args()
    
    splitter = DatasetSplitter(
        source_dir=args.source,
        output_dir=args.output,
        train_ratio=args.train,
        val_ratio=args.val,
        test_ratio=args.test
    )
    
    splitter.split_dataset(seed=args.seed)


if __name__ == '__main__':
    main()
