""" YOLO dataset utilities for converting SAM2 annotations to YOLO format. """ import os import cv2 import json import shutil import yaml import numpy as np from pathlib import Path from typing import List, Dict, Any, Optional, Tuple from tqdm import tqdm import random class YOLODatasetBuilder: """Build YOLO format dataset from SAM2 annotations.""" def __init__( self, output_dir: str, class_names: Optional[List[str]] = None ): """ Initialize YOLO dataset builder. Args: output_dir: Root directory for YOLO dataset class_names: List of class names (default: ['object']) """ self.output_dir = Path(output_dir) self.class_names = class_names or ['object'] # Create directory structure self.images_train = self.output_dir / 'images' / 'train' self.images_val = self.output_dir / 'images' / 'val' self.labels_train = self.output_dir / 'labels' / 'train' self.labels_val = self.output_dir / 'labels' / 'val' for dir_path in [self.images_train, self.images_val, self.labels_train, self.labels_val]: dir_path.mkdir(parents=True, exist_ok=True) def mask_to_bbox( self, mask: np.ndarray, image_width: int, image_height: int ) -> Tuple[float, float, float, float]: """ Convert binary mask to normalized YOLO bounding box. Args: mask: Binary mask array image_width: Image width image_height: Image height Returns: Tuple of (x_center, y_center, width, height) normalized to [0, 1] """ # Find mask coordinates rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) if not rows.any() or not cols.any(): return None y_min, y_max = np.where(rows)[0][[0, -1]] x_min, x_max = np.where(cols)[0][[0, -1]] # Calculate YOLO format (normalized center x, center y, width, height) x_center = ((x_min + x_max) / 2) / image_width y_center = ((y_min + y_max) / 2) / image_height width = (x_max - x_min) / image_width height = (y_max - y_min) / image_height return (x_center, y_center, width, height) def bbox_xywh_to_yolo( self, bbox: List[int], image_width: int, image_height: int ) -> Tuple[float, float, float, float]: """ Convert [x, y, w, h] bbox to normalized YOLO format. Args: bbox: Bounding box [x, y, width, height] image_width: Image width image_height: Image height Returns: Tuple of (x_center, y_center, width, height) normalized """ x, y, w, h = bbox x_center = (x + w / 2) / image_width y_center = (y + h / 2) / image_height width = w / image_width height = h / image_height return (x_center, y_center, width, height) def convert_sam2_annotations( self, frames_dir: str, annotations_file: str, masks_dir: Optional[str] = None, class_id: int = 0, min_area: int = 100, min_bbox_size: float = 0.01, val_split: float = 0.2, seed: int = 42 ) -> Dict[str, int]: """ Convert SAM2 annotations to YOLO format dataset. Args: frames_dir: Directory containing frame images annotations_file: Path to SAM2 annotations JSON masks_dir: Directory containing mask images (optional) class_id: Class ID for all objects min_area: Minimum bbox area in pixels min_bbox_size: Minimum bbox dimension (normalized) val_split: Validation set ratio seed: Random seed for split Returns: Statistics dictionary """ frames_path = Path(frames_dir) # Load annotations with open(annotations_file, 'r') as f: annotations = json.load(f) # Get all frame files frame_files = list(annotations.keys()) # Shuffle and split random.seed(seed) random.shuffle(frame_files) split_idx = int(len(frame_files) * (1 - val_split)) train_files = frame_files[:split_idx] val_files = frame_files[split_idx:] stats = { 'total_frames': len(frame_files), 'train_frames': len(train_files), 'val_frames': len(val_files), 'total_objects': 0, 'train_objects': 0, 'val_objects': 0 } # Process training files for frame_name in tqdm(train_files, desc="Processing train"): count = self._process_frame( frames_path / frame_name, annotations.get(frame_name, []), self.images_train, self.labels_train, class_id, min_area, min_bbox_size ) stats['train_objects'] += count # Process validation files for frame_name in tqdm(val_files, desc="Processing val"): count = self._process_frame( frames_path / frame_name, annotations.get(frame_name, []), self.images_val, self.labels_val, class_id, min_area, min_bbox_size ) stats['val_objects'] += count stats['total_objects'] = stats['train_objects'] + stats['val_objects'] # Create data.yaml self._create_data_yaml() print(f"\nDataset created at {self.output_dir}") print(f" Train: {stats['train_frames']} images, {stats['train_objects']} objects") print(f" Val: {stats['val_frames']} images, {stats['val_objects']} objects") return stats def _process_frame( self, image_path: Path, frame_annotations: List[Dict], images_dir: Path, labels_dir: Path, class_id: int, min_area: int, min_bbox_size: float ) -> int: """Process a single frame and create YOLO label file.""" if not image_path.exists(): return 0 # Read image to get dimensions image = cv2.imread(str(image_path)) if image is None: return 0 height, width = image.shape[:2] # Copy image to dataset dest_image = images_dir / image_path.name shutil.copy2(image_path, dest_image) # Create label file label_name = image_path.stem + '.txt' label_path = labels_dir / label_name labels = [] for ann in frame_annotations: bbox = ann.get('bbox', []) area = ann.get('area', 0) if area < min_area: continue if len(bbox) == 4: x_center, y_center, w, h = self.bbox_xywh_to_yolo( bbox, width, height ) # Filter small boxes if w < min_bbox_size or h < min_bbox_size: continue # YOLO format: class x_center y_center width height label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}" labels.append(label_line) # Write label file with open(label_path, 'w') as f: f.write('\n'.join(labels)) return len(labels) def _create_data_yaml(self): """Create YOLO data.yaml configuration file.""" data_config = { 'path': str(self.output_dir.absolute()), 'train': 'images/train', 'val': 'images/val', 'names': {i: name for i, name in enumerate(self.class_names)}, 'nc': len(self.class_names) } yaml_path = self.output_dir / 'data.yaml' with open(yaml_path, 'w') as f: yaml.dump(data_config, f, default_flow_style=False, sort_keys=False) print(f"Created {yaml_path}") def create_from_masks_directory( self, frames_dir: str, masks_dir: str, class_id: int = 0, val_split: float = 0.2, seed: int = 42 ) -> Dict[str, int]: """ Create YOLO dataset directly from mask images. Args: frames_dir: Directory containing original frames masks_dir: Directory containing mask images (one per object) class_id: Class ID for all objects val_split: Validation split ratio seed: Random seed Returns: Statistics dictionary """ frames_path = Path(frames_dir) masks_path = Path(masks_dir) # Find all frame images frame_files = sorted( list(frames_path.glob("*.jpg")) + list(frames_path.glob("*.png")) ) random.seed(seed) random.shuffle(frame_files) split_idx = int(len(frame_files) * (1 - val_split)) train_files = frame_files[:split_idx] val_files = frame_files[split_idx:] stats = { 'total_frames': len(frame_files), 'train_frames': len(train_files), 'val_frames': len(val_files), 'total_objects': 0, 'train_objects': 0, 'val_objects': 0 } # Process frames for frame_list, images_dir, labels_dir, key in [ (train_files, self.images_train, self.labels_train, 'train_objects'), (val_files, self.images_val, self.labels_val, 'val_objects') ]: for frame_file in tqdm(frame_list, desc=f"Processing {key.split('_')[0]}"): # Find corresponding masks frame_masks_dir = masks_path / frame_file.stem if frame_masks_dir.exists(): mask_files = list(frame_masks_dir.glob("*.png")) else: mask_files = [] count = self._process_frame_with_masks( frame_file, mask_files, images_dir, labels_dir, class_id ) stats[key] += count stats['total_objects'] = stats['train_objects'] + stats['val_objects'] self._create_data_yaml() return stats def _process_frame_with_masks( self, image_path: Path, mask_files: List[Path], images_dir: Path, labels_dir: Path, class_id: int ) -> int: """Process frame with mask files.""" image = cv2.imread(str(image_path)) if image is None: return 0 height, width = image.shape[:2] # Copy image shutil.copy2(image_path, images_dir / image_path.name) labels = [] for mask_file in mask_files: mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE) if mask is None: continue mask = mask > 127 # Convert to binary bbox = self.mask_to_bbox(mask, width, height) if bbox is not None: x_center, y_center, w, h = bbox label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}" labels.append(label_line) # Write labels label_path = labels_dir / (image_path.stem + '.txt') with open(label_path, 'w') as f: f.write('\n'.join(labels)) return len(labels) def export_for_kaggle( dataset_dir: str, output_zip: str, include_checkpoints: bool = False ) -> str: """ Export YOLO dataset as zip file for Kaggle upload. Args: dataset_dir: Path to YOLO dataset directory output_zip: Output zip file path include_checkpoints: Include model checkpoints if present Returns: Path to created zip file """ import zipfile dataset_path = Path(dataset_dir) with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(dataset_path): # Skip checkpoints unless requested if not include_checkpoints and 'checkpoints' in root: continue for file in files: file_path = Path(root) / file arcname = file_path.relative_to(dataset_path.parent) zipf.write(file_path, arcname) print(f"Dataset exported to {output_zip}") return output_zip def validate_yolo_dataset(dataset_dir: str) -> Dict[str, Any]: """ Validate YOLO dataset structure and contents. Args: dataset_dir: Path to YOLO dataset Returns: Validation results dictionary """ dataset_path = Path(dataset_dir) results = { 'valid': True, 'errors': [], 'warnings': [], 'stats': {} } # Check data.yaml yaml_path = dataset_path / 'data.yaml' if not yaml_path.exists(): results['errors'].append("Missing data.yaml") results['valid'] = False else: with open(yaml_path) as f: config = yaml.safe_load(f) results['stats']['num_classes'] = config.get('nc', 0) results['stats']['class_names'] = config.get('names', {}) # Check directories for split in ['train', 'val']: images_dir = dataset_path / 'images' / split labels_dir = dataset_path / 'labels' / split if not images_dir.exists(): results['errors'].append(f"Missing images/{split} directory") results['valid'] = False continue if not labels_dir.exists(): results['errors'].append(f"Missing labels/{split} directory") results['valid'] = False continue # Count files image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")) label_files = list(labels_dir.glob("*.txt")) results['stats'][f'{split}_images'] = len(image_files) results['stats'][f'{split}_labels'] = len(label_files) # Check matching image_stems = {f.stem for f in image_files} label_stems = {f.stem for f in label_files} missing_labels = image_stems - label_stems if missing_labels: results['warnings'].append( f"{len(missing_labels)} images in {split} missing labels" ) return results