"""
Data Augmentation Script
Increase dataset size using image transformations
"""

import os
import cv2
import numpy as np
import albumentations as A
from pathlib import Path
from tqdm import tqdm
import shutil


class DataAugmenter:
    def __init__(self, input_dir, output_dir, augmentations_per_image=5):
        """
        Initialize data augmenter
        
        Args:
            input_dir: Directory with images and labels
            output_dir: Output directory for augmented data
            augmentations_per_image: Number of augmented versions per image
        """
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.augmentations_per_image = augmentations_per_image
        
        # Create output directories
        (self.output_dir / 'images').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'labels').mkdir(parents=True, exist_ok=True)
        
        # Define augmentation pipeline
        self.transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.Blur(blur_limit=3, p=0.1),
            A.CLAHE(clip_limit=2.0, p=0.2),
            A.RandomRotate90(p=0.2),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
        ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
        
    def augment_image(self, image_path, label_path):
        """
        Augment a single image
        
        Args:
            image_path: Path to image
            label_path: Path to corresponding label file
        """
        # Read image
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Read bounding boxes
        bboxes = []
        class_labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) == 5:
                        cls, x, y, w, h = map(float, parts)
                        bboxes.append([x, y, w, h])
                        class_labels.append(int(cls))
        
        # Generate augmentations
        base_name = Path(image_path).stem
        
        for i in range(self.augmentations_per_image):
            try:
                # Apply transformation
                transformed = self.transform(
                    image=image,
                    bboxes=bboxes,
                    class_labels=class_labels
                )
                
                aug_image = transformed['image']
                aug_bboxes = transformed['bboxes']
                aug_labels = transformed['class_labels']
                
                # Save augmented image
                output_image_path = self.output_dir / 'images' / f"{base_name}_aug_{i}.jpg"
                aug_image_bgr = cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)
                cv2.imwrite(str(output_image_path), aug_image_bgr)
                
                # Save augmented labels
                output_label_path = self.output_dir / 'labels' / f"{base_name}_aug_{i}.txt"
                with open(output_label_path, 'w') as f:
                    for bbox, label in zip(aug_bboxes, aug_labels):
                        x, y, w, h = bbox
                        f.write(f"{label} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
                        
            except Exception as e:
                print(f"Error augmenting {image_path}: {e}")
                continue
    
    def augment_dataset(self):
        """Augment entire dataset"""
        print("🐝 Starting data augmentation...")
        print(f"Input: {self.input_dir}")
        print(f"Output: {self.output_dir}")
        print(f"Augmentations per image: {self.augmentations_per_image}")
        
        # Find all images
        image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
        images = []
        for ext in image_extensions:
            images.extend(list(self.input_dir.glob(f'images/*{ext}')))
        
        print(f"\nFound {len(images)} images")
        
        # First, copy original images and labels
        print("\nCopying original files...")
        for img_path in tqdm(images):
            # Copy image
            shutil.copy(
                img_path,
                self.output_dir / 'images' / img_path.name
            )
            
            # Copy label
            label_path = self.input_dir / 'labels' / f"{img_path.stem}.txt"
            if label_path.exists():
                shutil.copy(
                    label_path,
                    self.output_dir / 'labels' / f"{img_path.stem}.txt"
                )
        
        # Generate augmentations
        print("\nGenerating augmentations...")
        for img_path in tqdm(images):
            label_path = self.input_dir / 'labels' / f"{img_path.stem}.txt"
            self.augment_image(img_path, label_path)
        
        # Count results
        final_images = len(list((self.output_dir / 'images').glob('*')))
        final_labels = len(list((self.output_dir / 'labels').glob('*.txt')))
        
        print("\n" + "="*50)
        print("AUGMENTATION COMPLETE")
        print("="*50)
        print(f"Original images: {len(images)}")
        print(f"Final images: {final_images}")
        print(f"Final labels: {final_labels}")
        print(f"Increase factor: {final_images / len(images):.1f}x")
        print("="*50)


def main():
    """Main entry point"""
    import argparse
    
    parser = argparse.ArgumentParser(description='Augment queen bee dataset')
    parser.add_argument('--input', type=str, required=True,
                       help='Input directory with images/ and labels/ subdirectories')
    parser.add_argument('--output', type=str, required=True,
                       help='Output directory for augmented data')
    parser.add_argument('--augmentations', type=int, default=5,
                       help='Number of augmentations per image (default: 5)')
    
    args = parser.parse_args()
    
    augmenter = DataAugmenter(
        input_dir=args.input,
        output_dir=args.output,
        augmentations_per_image=args.augmentations
    )
    
    augmenter.augment_dataset()
    
    print("\n✅ Done! Dataset ready for training.")


if __name__ == '__main__':
    main()
