#!/usr/bin/env python3
"""
Script de préparation du dataset avec augmentation de données
"""

import cv2
import numpy as np
import json
import albumentations as A
from pathlib import Path
import argparse
from tqdm import tqdm
import shutil

class DatasetPreparator:
    def __init__(self, input_dir, output_dir, augment=True):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.augment = augment
        
        # Créer la structure de sortie
        (self.output_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'test' / 'images').mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'test' / 'labels').mkdir(parents=True, exist_ok=True)
        
        # Pipeline d'augmentation
        self.augmentation_pipeline = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.2,
                contrast_limit=0.2,
                p=0.5
            ),
            A.RandomRotate90(p=0.5),
            A.Blur(blur_limit=3, p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.HueSaturationValue(
                hue_shift_limit=10,
                sat_shift_limit=20,
                val_shift_limit=10,
                p=0.3
            ),
            A.RandomScale(scale_limit=0.1, p=0.3),
        ], bbox_params=A.BboxParams(
            format='yolo',
            label_fields=['class_labels']
        ))
        
    def load_annotations(self):
        """Charge les annotations YOLO"""
        anno_file = self.input_dir / 'annotations.json'
        if not anno_file.exists():
            raise FileNotFoundError(f"Fichier d'annotations introuvable: {anno_file}")
            
        with open(anno_file, 'r') as f:
            data = json.load(f)
            
        return data
        
    def yolo_to_bbox(self, x_center, y_center, width, height, img_w, img_h):
        """Convertit YOLO (normalisé) vers bbox absolue"""
        x1 = int((x_center - width/2) * img_w)
        y1 = int((y_center - height/2) * img_h)
        x2 = int((x_center + width/2) * img_w)
        y2 = int((y_center + height/2) * img_h)
        return x1, y1, x2, y2
        
    def augment_image(self, img, bboxes, class_labels, num_augmentations=3):
        """Applique l'augmentation de données"""
        augmented_images = []
        
        for _ in range(num_augmentations):
            try:
                transformed = self.augmentation_pipeline(
                    image=img,
                    bboxes=bboxes,
                    class_labels=class_labels
                )
                augmented_images.append({
                    'image': transformed['image'],
                    'bboxes': transformed['bboxes'],
                    'class_labels': transformed['class_labels']
                })
            except Exception as e:
                print(f"⚠️  Erreur augmentation: {e}")
                continue
                
        return augmented_images
        
    def split_dataset(self, images, train_ratio=0.7, val_ratio=0.15):
        """Divise le dataset en train/val/test"""
        n = len(images)
        np.random.shuffle(images)
        
        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))
        
        return {
            'train': images[:train_end],
            'val': images[train_end:val_end],
            'test': images[val_end:]
        }
        
    def save_yolo_annotation(self, filepath, bboxes, class_labels):
        """Sauvegarde au format YOLO"""
        with open(filepath, 'w') as f:
            for bbox, class_id in zip(bboxes, class_labels):
                x_center, y_center, width, height = bbox
                f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
                
    def prepare_dataset(self):
        """Prépare le dataset complet"""
        print("🔍 Chargement des annotations...")
        annotations = self.load_annotations()
        
        # Associer images et annotations
        image_data = []
        for img_info in annotations['images']:
            img_name = img_info['file_name']
            img_path = self.input_dir / 'images' / img_name
            
            if not img_path.exists():
                print(f"⚠️  Image introuvable: {img_path}")
                continue
                
            # Récupérer annotations pour cette image
            img_annotations = [
                anno for anno in annotations['annotations']
                if anno['image_id'] == img_info['id']
            ]
            
            if img_annotations:
                image_data.append({
                    'path': img_path,
                    'info': img_info,
                    'annotations': img_annotations
                })
                
        print(f"📊 {len(image_data)} images avec annotations")
        
        # Diviser en train/val/test
        splits = self.split_dataset(image_data)
        
        for split_name, split_data in splits.items():
            print(f"\n📁 Traitement {split_name}: {len(split_data)} images")
            
            for img_item in tqdm(split_data):
                img = cv2.imread(str(img_item['path']))
                h, w = img.shape[:2]
                
                # Convertir annotations au format YOLO
                bboxes = []
                class_labels = []
                
                for anno in img_item['annotations']:
                    x, y, bbox_w, bbox_h = anno['bbox']
                    x_center = (x + bbox_w/2) / w
                    y_center = (y + bbox_h/2) / h
                    norm_w = bbox_w / w
                    norm_h = bbox_h / h
                    
                    bboxes.append([x_center, y_center, norm_w, norm_h])
                    class_labels.append(anno['category_id'])
                
                # Image originale
                img_name = img_item['path'].name
                cv2.imwrite(
                    str(self.output_dir / split_name / 'images' / img_name),
                    img
                )
                self.save_yolo_annotation(
                    self.output_dir / split_name / 'labels' / f"{img_item['path'].stem}.txt",
                    bboxes,
                    class_labels
                )
                
                # Augmentation (seulement pour train)
                if self.augment and split_name == 'train':
                    augmented = self.augment_image(img, bboxes, class_labels, num_augmentations=3)
                    
                    for idx, aug_data in enumerate(augmented):
                        aug_name = f"{img_item['path'].stem}_aug{idx}{img_item['path'].suffix}"
                        
                        cv2.imwrite(
                            str(self.output_dir / split_name / 'images' / aug_name),
                            aug_data['image']
                        )
                        self.save_yolo_annotation(
                            self.output_dir / split_name / 'labels' / f"{img_item['path'].stem}_aug{idx}.txt",
                            aug_data['bboxes'],
                            aug_data['class_labels']
                        )
        
        # Créer le fichier de config YOLO
        self.create_yaml_config()
        
        print("\n✅ Dataset préparé avec succès!")
        print(f"📊 Statistiques:")
        for split in ['train', 'val', 'test']:
            n_images = len(list((self.output_dir / split / 'images').glob('*')))
            print(f"   {split}: {n_images} images")
            
    def create_yaml_config(self):
        """Crée le fichier de configuration YAML pour YOLO"""
        config = f"""# Queen Bee Detection Dataset
path: {self.output_dir.absolute()}
train: train/images
val: val/images
test: test/images

# Classes
nc: 2
names: ['queen', 'worker']
"""
        with open(self.output_dir / 'data.yaml', 'w') as f:
            f.write(config)
            
        print(f"📄 Configuration YAML créée: {self.output_dir / 'data.yaml'}")

def main():
    parser = argparse.ArgumentParser(description='Préparation du dataset')
    parser.add_argument('--input', type=str, required=True,
                       help='Dossier des annotations')
    parser.add_argument('--output', type=str, default='prepared_dataset',
                       help='Dossier de sortie')
    parser.add_argument('--no-augment', action='store_true',
                       help='Désactiver l\'augmentation de données')
    
    args = parser.parse_args()
    
    preparator = DatasetPreparator(
        args.input,
        args.output,
        augment=not args.no_augment
    )
    preparator.prepare_dataset()

if __name__ == '__main__':
    main()
