479 lines
15 KiB
Python
479 lines
15 KiB
Python
"""
|
|
YOLO dataset utilities for converting SAM2 annotations to YOLO format.
|
|
"""
|
|
|
|
import os
|
|
import cv2
|
|
import json
|
|
import shutil
|
|
import yaml
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from tqdm import tqdm
|
|
import random
|
|
|
|
|
|
class YOLODatasetBuilder:
|
|
"""Build YOLO format dataset from SAM2 annotations."""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: str,
|
|
class_names: Optional[List[str]] = None
|
|
):
|
|
"""
|
|
Initialize YOLO dataset builder.
|
|
|
|
Args:
|
|
output_dir: Root directory for YOLO dataset
|
|
class_names: List of class names (default: ['object'])
|
|
"""
|
|
self.output_dir = Path(output_dir)
|
|
self.class_names = class_names or ['object']
|
|
|
|
# Create directory structure
|
|
self.images_train = self.output_dir / 'images' / 'train'
|
|
self.images_val = self.output_dir / 'images' / 'val'
|
|
self.labels_train = self.output_dir / 'labels' / 'train'
|
|
self.labels_val = self.output_dir / 'labels' / 'val'
|
|
|
|
for dir_path in [self.images_train, self.images_val,
|
|
self.labels_train, self.labels_val]:
|
|
dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
def mask_to_bbox(
|
|
self,
|
|
mask: np.ndarray,
|
|
image_width: int,
|
|
image_height: int
|
|
) -> Tuple[float, float, float, float]:
|
|
"""
|
|
Convert binary mask to normalized YOLO bounding box.
|
|
|
|
Args:
|
|
mask: Binary mask array
|
|
image_width: Image width
|
|
image_height: Image height
|
|
|
|
Returns:
|
|
Tuple of (x_center, y_center, width, height) normalized to [0, 1]
|
|
"""
|
|
# Find mask coordinates
|
|
rows = np.any(mask, axis=1)
|
|
cols = np.any(mask, axis=0)
|
|
|
|
if not rows.any() or not cols.any():
|
|
return None
|
|
|
|
y_min, y_max = np.where(rows)[0][[0, -1]]
|
|
x_min, x_max = np.where(cols)[0][[0, -1]]
|
|
|
|
# Calculate YOLO format (normalized center x, center y, width, height)
|
|
x_center = ((x_min + x_max) / 2) / image_width
|
|
y_center = ((y_min + y_max) / 2) / image_height
|
|
width = (x_max - x_min) / image_width
|
|
height = (y_max - y_min) / image_height
|
|
|
|
return (x_center, y_center, width, height)
|
|
|
|
def bbox_xywh_to_yolo(
|
|
self,
|
|
bbox: List[int],
|
|
image_width: int,
|
|
image_height: int
|
|
) -> Tuple[float, float, float, float]:
|
|
"""
|
|
Convert [x, y, w, h] bbox to normalized YOLO format.
|
|
|
|
Args:
|
|
bbox: Bounding box [x, y, width, height]
|
|
image_width: Image width
|
|
image_height: Image height
|
|
|
|
Returns:
|
|
Tuple of (x_center, y_center, width, height) normalized
|
|
"""
|
|
x, y, w, h = bbox
|
|
|
|
x_center = (x + w / 2) / image_width
|
|
y_center = (y + h / 2) / image_height
|
|
width = w / image_width
|
|
height = h / image_height
|
|
|
|
return (x_center, y_center, width, height)
|
|
|
|
def convert_sam2_annotations(
|
|
self,
|
|
frames_dir: str,
|
|
annotations_file: str,
|
|
masks_dir: Optional[str] = None,
|
|
class_id: int = 0,
|
|
min_area: int = 100,
|
|
min_bbox_size: float = 0.01,
|
|
val_split: float = 0.2,
|
|
seed: int = 42
|
|
) -> Dict[str, int]:
|
|
"""
|
|
Convert SAM2 annotations to YOLO format dataset.
|
|
|
|
Args:
|
|
frames_dir: Directory containing frame images
|
|
annotations_file: Path to SAM2 annotations JSON
|
|
masks_dir: Directory containing mask images (optional)
|
|
class_id: Class ID for all objects
|
|
min_area: Minimum bbox area in pixels
|
|
min_bbox_size: Minimum bbox dimension (normalized)
|
|
val_split: Validation set ratio
|
|
seed: Random seed for split
|
|
|
|
Returns:
|
|
Statistics dictionary
|
|
"""
|
|
frames_path = Path(frames_dir)
|
|
|
|
# Load annotations
|
|
with open(annotations_file, 'r') as f:
|
|
annotations = json.load(f)
|
|
|
|
# Get all frame files
|
|
frame_files = list(annotations.keys())
|
|
|
|
# Shuffle and split
|
|
random.seed(seed)
|
|
random.shuffle(frame_files)
|
|
|
|
split_idx = int(len(frame_files) * (1 - val_split))
|
|
train_files = frame_files[:split_idx]
|
|
val_files = frame_files[split_idx:]
|
|
|
|
stats = {
|
|
'total_frames': len(frame_files),
|
|
'train_frames': len(train_files),
|
|
'val_frames': len(val_files),
|
|
'total_objects': 0,
|
|
'train_objects': 0,
|
|
'val_objects': 0
|
|
}
|
|
|
|
# Process training files
|
|
for frame_name in tqdm(train_files, desc="Processing train"):
|
|
count = self._process_frame(
|
|
frames_path / frame_name,
|
|
annotations.get(frame_name, []),
|
|
self.images_train,
|
|
self.labels_train,
|
|
class_id,
|
|
min_area,
|
|
min_bbox_size
|
|
)
|
|
stats['train_objects'] += count
|
|
|
|
# Process validation files
|
|
for frame_name in tqdm(val_files, desc="Processing val"):
|
|
count = self._process_frame(
|
|
frames_path / frame_name,
|
|
annotations.get(frame_name, []),
|
|
self.images_val,
|
|
self.labels_val,
|
|
class_id,
|
|
min_area,
|
|
min_bbox_size
|
|
)
|
|
stats['val_objects'] += count
|
|
|
|
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
|
|
|
|
# Create data.yaml
|
|
self._create_data_yaml()
|
|
|
|
print(f"\nDataset created at {self.output_dir}")
|
|
print(f" Train: {stats['train_frames']} images, {stats['train_objects']} objects")
|
|
print(f" Val: {stats['val_frames']} images, {stats['val_objects']} objects")
|
|
|
|
return stats
|
|
|
|
def _process_frame(
|
|
self,
|
|
image_path: Path,
|
|
frame_annotations: List[Dict],
|
|
images_dir: Path,
|
|
labels_dir: Path,
|
|
class_id: int,
|
|
min_area: int,
|
|
min_bbox_size: float
|
|
) -> int:
|
|
"""Process a single frame and create YOLO label file."""
|
|
if not image_path.exists():
|
|
return 0
|
|
|
|
# Read image to get dimensions
|
|
image = cv2.imread(str(image_path))
|
|
if image is None:
|
|
return 0
|
|
|
|
height, width = image.shape[:2]
|
|
|
|
# Copy image to dataset
|
|
dest_image = images_dir / image_path.name
|
|
shutil.copy2(image_path, dest_image)
|
|
|
|
# Create label file
|
|
label_name = image_path.stem + '.txt'
|
|
label_path = labels_dir / label_name
|
|
|
|
labels = []
|
|
for ann in frame_annotations:
|
|
bbox = ann.get('bbox', [])
|
|
area = ann.get('area', 0)
|
|
|
|
if area < min_area:
|
|
continue
|
|
|
|
if len(bbox) == 4:
|
|
x_center, y_center, w, h = self.bbox_xywh_to_yolo(
|
|
bbox, width, height
|
|
)
|
|
|
|
# Filter small boxes
|
|
if w < min_bbox_size or h < min_bbox_size:
|
|
continue
|
|
|
|
# YOLO format: class x_center y_center width height
|
|
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
|
labels.append(label_line)
|
|
|
|
# Write label file
|
|
with open(label_path, 'w') as f:
|
|
f.write('\n'.join(labels))
|
|
|
|
return len(labels)
|
|
|
|
def _create_data_yaml(self):
|
|
"""Create YOLO data.yaml configuration file."""
|
|
data_config = {
|
|
'path': str(self.output_dir.absolute()),
|
|
'train': 'images/train',
|
|
'val': 'images/val',
|
|
'names': {i: name for i, name in enumerate(self.class_names)},
|
|
'nc': len(self.class_names)
|
|
}
|
|
|
|
yaml_path = self.output_dir / 'data.yaml'
|
|
with open(yaml_path, 'w') as f:
|
|
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
|
|
|
|
print(f"Created {yaml_path}")
|
|
|
|
def create_from_masks_directory(
|
|
self,
|
|
frames_dir: str,
|
|
masks_dir: str,
|
|
class_id: int = 0,
|
|
val_split: float = 0.2,
|
|
seed: int = 42
|
|
) -> Dict[str, int]:
|
|
"""
|
|
Create YOLO dataset directly from mask images.
|
|
|
|
Args:
|
|
frames_dir: Directory containing original frames
|
|
masks_dir: Directory containing mask images (one per object)
|
|
class_id: Class ID for all objects
|
|
val_split: Validation split ratio
|
|
seed: Random seed
|
|
|
|
Returns:
|
|
Statistics dictionary
|
|
"""
|
|
frames_path = Path(frames_dir)
|
|
masks_path = Path(masks_dir)
|
|
|
|
# Find all frame images
|
|
frame_files = sorted(
|
|
list(frames_path.glob("*.jpg")) +
|
|
list(frames_path.glob("*.png"))
|
|
)
|
|
|
|
random.seed(seed)
|
|
random.shuffle(frame_files)
|
|
|
|
split_idx = int(len(frame_files) * (1 - val_split))
|
|
train_files = frame_files[:split_idx]
|
|
val_files = frame_files[split_idx:]
|
|
|
|
stats = {
|
|
'total_frames': len(frame_files),
|
|
'train_frames': len(train_files),
|
|
'val_frames': len(val_files),
|
|
'total_objects': 0,
|
|
'train_objects': 0,
|
|
'val_objects': 0
|
|
}
|
|
|
|
# Process frames
|
|
for frame_list, images_dir, labels_dir, key in [
|
|
(train_files, self.images_train, self.labels_train, 'train_objects'),
|
|
(val_files, self.images_val, self.labels_val, 'val_objects')
|
|
]:
|
|
for frame_file in tqdm(frame_list, desc=f"Processing {key.split('_')[0]}"):
|
|
# Find corresponding masks
|
|
frame_masks_dir = masks_path / frame_file.stem
|
|
if frame_masks_dir.exists():
|
|
mask_files = list(frame_masks_dir.glob("*.png"))
|
|
else:
|
|
mask_files = []
|
|
|
|
count = self._process_frame_with_masks(
|
|
frame_file,
|
|
mask_files,
|
|
images_dir,
|
|
labels_dir,
|
|
class_id
|
|
)
|
|
stats[key] += count
|
|
|
|
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
|
|
self._create_data_yaml()
|
|
|
|
return stats
|
|
|
|
def _process_frame_with_masks(
|
|
self,
|
|
image_path: Path,
|
|
mask_files: List[Path],
|
|
images_dir: Path,
|
|
labels_dir: Path,
|
|
class_id: int
|
|
) -> int:
|
|
"""Process frame with mask files."""
|
|
image = cv2.imread(str(image_path))
|
|
if image is None:
|
|
return 0
|
|
|
|
height, width = image.shape[:2]
|
|
|
|
# Copy image
|
|
shutil.copy2(image_path, images_dir / image_path.name)
|
|
|
|
labels = []
|
|
for mask_file in mask_files:
|
|
mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
|
|
if mask is None:
|
|
continue
|
|
|
|
mask = mask > 127 # Convert to binary
|
|
bbox = self.mask_to_bbox(mask, width, height)
|
|
|
|
if bbox is not None:
|
|
x_center, y_center, w, h = bbox
|
|
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
|
labels.append(label_line)
|
|
|
|
# Write labels
|
|
label_path = labels_dir / (image_path.stem + '.txt')
|
|
with open(label_path, 'w') as f:
|
|
f.write('\n'.join(labels))
|
|
|
|
return len(labels)
|
|
|
|
|
|
def export_for_kaggle(
|
|
dataset_dir: str,
|
|
output_zip: str,
|
|
include_checkpoints: bool = False
|
|
) -> str:
|
|
"""
|
|
Export YOLO dataset as zip file for Kaggle upload.
|
|
|
|
Args:
|
|
dataset_dir: Path to YOLO dataset directory
|
|
output_zip: Output zip file path
|
|
include_checkpoints: Include model checkpoints if present
|
|
|
|
Returns:
|
|
Path to created zip file
|
|
"""
|
|
import zipfile
|
|
|
|
dataset_path = Path(dataset_dir)
|
|
|
|
with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
for root, dirs, files in os.walk(dataset_path):
|
|
# Skip checkpoints unless requested
|
|
if not include_checkpoints and 'checkpoints' in root:
|
|
continue
|
|
|
|
for file in files:
|
|
file_path = Path(root) / file
|
|
arcname = file_path.relative_to(dataset_path.parent)
|
|
zipf.write(file_path, arcname)
|
|
|
|
print(f"Dataset exported to {output_zip}")
|
|
return output_zip
|
|
|
|
|
|
def validate_yolo_dataset(dataset_dir: str) -> Dict[str, Any]:
|
|
"""
|
|
Validate YOLO dataset structure and contents.
|
|
|
|
Args:
|
|
dataset_dir: Path to YOLO dataset
|
|
|
|
Returns:
|
|
Validation results dictionary
|
|
"""
|
|
dataset_path = Path(dataset_dir)
|
|
|
|
results = {
|
|
'valid': True,
|
|
'errors': [],
|
|
'warnings': [],
|
|
'stats': {}
|
|
}
|
|
|
|
# Check data.yaml
|
|
yaml_path = dataset_path / 'data.yaml'
|
|
if not yaml_path.exists():
|
|
results['errors'].append("Missing data.yaml")
|
|
results['valid'] = False
|
|
else:
|
|
with open(yaml_path) as f:
|
|
config = yaml.safe_load(f)
|
|
results['stats']['num_classes'] = config.get('nc', 0)
|
|
results['stats']['class_names'] = config.get('names', {})
|
|
|
|
# Check directories
|
|
for split in ['train', 'val']:
|
|
images_dir = dataset_path / 'images' / split
|
|
labels_dir = dataset_path / 'labels' / split
|
|
|
|
if not images_dir.exists():
|
|
results['errors'].append(f"Missing images/{split} directory")
|
|
results['valid'] = False
|
|
continue
|
|
|
|
if not labels_dir.exists():
|
|
results['errors'].append(f"Missing labels/{split} directory")
|
|
results['valid'] = False
|
|
continue
|
|
|
|
# Count files
|
|
image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
|
|
label_files = list(labels_dir.glob("*.txt"))
|
|
|
|
results['stats'][f'{split}_images'] = len(image_files)
|
|
results['stats'][f'{split}_labels'] = len(label_files)
|
|
|
|
# Check matching
|
|
image_stems = {f.stem for f in image_files}
|
|
label_stems = {f.stem for f in label_files}
|
|
|
|
missing_labels = image_stems - label_stems
|
|
if missing_labels:
|
|
results['warnings'].append(
|
|
f"{len(missing_labels)} images in {split} missing labels"
|
|
)
|
|
|
|
return results
|