Files
2026-02-04 15:29:36 +07:00

479 lines
15 KiB
Python

"""
YOLO dataset utilities for converting SAM2 annotations to YOLO format.
"""
import os
import cv2
import json
import shutil
import yaml
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
import random
class YOLODatasetBuilder:
"""Build YOLO format dataset from SAM2 annotations."""
def __init__(
self,
output_dir: str,
class_names: Optional[List[str]] = None
):
"""
Initialize YOLO dataset builder.
Args:
output_dir: Root directory for YOLO dataset
class_names: List of class names (default: ['object'])
"""
self.output_dir = Path(output_dir)
self.class_names = class_names or ['object']
# Create directory structure
self.images_train = self.output_dir / 'images' / 'train'
self.images_val = self.output_dir / 'images' / 'val'
self.labels_train = self.output_dir / 'labels' / 'train'
self.labels_val = self.output_dir / 'labels' / 'val'
for dir_path in [self.images_train, self.images_val,
self.labels_train, self.labels_val]:
dir_path.mkdir(parents=True, exist_ok=True)
def mask_to_bbox(
self,
mask: np.ndarray,
image_width: int,
image_height: int
) -> Tuple[float, float, float, float]:
"""
Convert binary mask to normalized YOLO bounding box.
Args:
mask: Binary mask array
image_width: Image width
image_height: Image height
Returns:
Tuple of (x_center, y_center, width, height) normalized to [0, 1]
"""
# Find mask coordinates
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if not rows.any() or not cols.any():
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
# Calculate YOLO format (normalized center x, center y, width, height)
x_center = ((x_min + x_max) / 2) / image_width
y_center = ((y_min + y_max) / 2) / image_height
width = (x_max - x_min) / image_width
height = (y_max - y_min) / image_height
return (x_center, y_center, width, height)
def bbox_xywh_to_yolo(
self,
bbox: List[int],
image_width: int,
image_height: int
) -> Tuple[float, float, float, float]:
"""
Convert [x, y, w, h] bbox to normalized YOLO format.
Args:
bbox: Bounding box [x, y, width, height]
image_width: Image width
image_height: Image height
Returns:
Tuple of (x_center, y_center, width, height) normalized
"""
x, y, w, h = bbox
x_center = (x + w / 2) / image_width
y_center = (y + h / 2) / image_height
width = w / image_width
height = h / image_height
return (x_center, y_center, width, height)
def convert_sam2_annotations(
self,
frames_dir: str,
annotations_file: str,
masks_dir: Optional[str] = None,
class_id: int = 0,
min_area: int = 100,
min_bbox_size: float = 0.01,
val_split: float = 0.2,
seed: int = 42
) -> Dict[str, int]:
"""
Convert SAM2 annotations to YOLO format dataset.
Args:
frames_dir: Directory containing frame images
annotations_file: Path to SAM2 annotations JSON
masks_dir: Directory containing mask images (optional)
class_id: Class ID for all objects
min_area: Minimum bbox area in pixels
min_bbox_size: Minimum bbox dimension (normalized)
val_split: Validation set ratio
seed: Random seed for split
Returns:
Statistics dictionary
"""
frames_path = Path(frames_dir)
# Load annotations
with open(annotations_file, 'r') as f:
annotations = json.load(f)
# Get all frame files
frame_files = list(annotations.keys())
# Shuffle and split
random.seed(seed)
random.shuffle(frame_files)
split_idx = int(len(frame_files) * (1 - val_split))
train_files = frame_files[:split_idx]
val_files = frame_files[split_idx:]
stats = {
'total_frames': len(frame_files),
'train_frames': len(train_files),
'val_frames': len(val_files),
'total_objects': 0,
'train_objects': 0,
'val_objects': 0
}
# Process training files
for frame_name in tqdm(train_files, desc="Processing train"):
count = self._process_frame(
frames_path / frame_name,
annotations.get(frame_name, []),
self.images_train,
self.labels_train,
class_id,
min_area,
min_bbox_size
)
stats['train_objects'] += count
# Process validation files
for frame_name in tqdm(val_files, desc="Processing val"):
count = self._process_frame(
frames_path / frame_name,
annotations.get(frame_name, []),
self.images_val,
self.labels_val,
class_id,
min_area,
min_bbox_size
)
stats['val_objects'] += count
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
# Create data.yaml
self._create_data_yaml()
print(f"\nDataset created at {self.output_dir}")
print(f" Train: {stats['train_frames']} images, {stats['train_objects']} objects")
print(f" Val: {stats['val_frames']} images, {stats['val_objects']} objects")
return stats
def _process_frame(
self,
image_path: Path,
frame_annotations: List[Dict],
images_dir: Path,
labels_dir: Path,
class_id: int,
min_area: int,
min_bbox_size: float
) -> int:
"""Process a single frame and create YOLO label file."""
if not image_path.exists():
return 0
# Read image to get dimensions
image = cv2.imread(str(image_path))
if image is None:
return 0
height, width = image.shape[:2]
# Copy image to dataset
dest_image = images_dir / image_path.name
shutil.copy2(image_path, dest_image)
# Create label file
label_name = image_path.stem + '.txt'
label_path = labels_dir / label_name
labels = []
for ann in frame_annotations:
bbox = ann.get('bbox', [])
area = ann.get('area', 0)
if area < min_area:
continue
if len(bbox) == 4:
x_center, y_center, w, h = self.bbox_xywh_to_yolo(
bbox, width, height
)
# Filter small boxes
if w < min_bbox_size or h < min_bbox_size:
continue
# YOLO format: class x_center y_center width height
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
labels.append(label_line)
# Write label file
with open(label_path, 'w') as f:
f.write('\n'.join(labels))
return len(labels)
def _create_data_yaml(self):
"""Create YOLO data.yaml configuration file."""
data_config = {
'path': str(self.output_dir.absolute()),
'train': 'images/train',
'val': 'images/val',
'names': {i: name for i, name in enumerate(self.class_names)},
'nc': len(self.class_names)
}
yaml_path = self.output_dir / 'data.yaml'
with open(yaml_path, 'w') as f:
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
print(f"Created {yaml_path}")
def create_from_masks_directory(
self,
frames_dir: str,
masks_dir: str,
class_id: int = 0,
val_split: float = 0.2,
seed: int = 42
) -> Dict[str, int]:
"""
Create YOLO dataset directly from mask images.
Args:
frames_dir: Directory containing original frames
masks_dir: Directory containing mask images (one per object)
class_id: Class ID for all objects
val_split: Validation split ratio
seed: Random seed
Returns:
Statistics dictionary
"""
frames_path = Path(frames_dir)
masks_path = Path(masks_dir)
# Find all frame images
frame_files = sorted(
list(frames_path.glob("*.jpg")) +
list(frames_path.glob("*.png"))
)
random.seed(seed)
random.shuffle(frame_files)
split_idx = int(len(frame_files) * (1 - val_split))
train_files = frame_files[:split_idx]
val_files = frame_files[split_idx:]
stats = {
'total_frames': len(frame_files),
'train_frames': len(train_files),
'val_frames': len(val_files),
'total_objects': 0,
'train_objects': 0,
'val_objects': 0
}
# Process frames
for frame_list, images_dir, labels_dir, key in [
(train_files, self.images_train, self.labels_train, 'train_objects'),
(val_files, self.images_val, self.labels_val, 'val_objects')
]:
for frame_file in tqdm(frame_list, desc=f"Processing {key.split('_')[0]}"):
# Find corresponding masks
frame_masks_dir = masks_path / frame_file.stem
if frame_masks_dir.exists():
mask_files = list(frame_masks_dir.glob("*.png"))
else:
mask_files = []
count = self._process_frame_with_masks(
frame_file,
mask_files,
images_dir,
labels_dir,
class_id
)
stats[key] += count
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
self._create_data_yaml()
return stats
def _process_frame_with_masks(
self,
image_path: Path,
mask_files: List[Path],
images_dir: Path,
labels_dir: Path,
class_id: int
) -> int:
"""Process frame with mask files."""
image = cv2.imread(str(image_path))
if image is None:
return 0
height, width = image.shape[:2]
# Copy image
shutil.copy2(image_path, images_dir / image_path.name)
labels = []
for mask_file in mask_files:
mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
if mask is None:
continue
mask = mask > 127 # Convert to binary
bbox = self.mask_to_bbox(mask, width, height)
if bbox is not None:
x_center, y_center, w, h = bbox
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
labels.append(label_line)
# Write labels
label_path = labels_dir / (image_path.stem + '.txt')
with open(label_path, 'w') as f:
f.write('\n'.join(labels))
return len(labels)
def export_for_kaggle(
dataset_dir: str,
output_zip: str,
include_checkpoints: bool = False
) -> str:
"""
Export YOLO dataset as zip file for Kaggle upload.
Args:
dataset_dir: Path to YOLO dataset directory
output_zip: Output zip file path
include_checkpoints: Include model checkpoints if present
Returns:
Path to created zip file
"""
import zipfile
dataset_path = Path(dataset_dir)
with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(dataset_path):
# Skip checkpoints unless requested
if not include_checkpoints and 'checkpoints' in root:
continue
for file in files:
file_path = Path(root) / file
arcname = file_path.relative_to(dataset_path.parent)
zipf.write(file_path, arcname)
print(f"Dataset exported to {output_zip}")
return output_zip
def validate_yolo_dataset(dataset_dir: str) -> Dict[str, Any]:
"""
Validate YOLO dataset structure and contents.
Args:
dataset_dir: Path to YOLO dataset
Returns:
Validation results dictionary
"""
dataset_path = Path(dataset_dir)
results = {
'valid': True,
'errors': [],
'warnings': [],
'stats': {}
}
# Check data.yaml
yaml_path = dataset_path / 'data.yaml'
if not yaml_path.exists():
results['errors'].append("Missing data.yaml")
results['valid'] = False
else:
with open(yaml_path) as f:
config = yaml.safe_load(f)
results['stats']['num_classes'] = config.get('nc', 0)
results['stats']['class_names'] = config.get('names', {})
# Check directories
for split in ['train', 'val']:
images_dir = dataset_path / 'images' / split
labels_dir = dataset_path / 'labels' / split
if not images_dir.exists():
results['errors'].append(f"Missing images/{split} directory")
results['valid'] = False
continue
if not labels_dir.exists():
results['errors'].append(f"Missing labels/{split} directory")
results['valid'] = False
continue
# Count files
image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
label_files = list(labels_dir.glob("*.txt"))
results['stats'][f'{split}_images'] = len(image_files)
results['stats'][f'{split}_labels'] = len(label_files)
# Check matching
image_stems = {f.stem for f in image_files}
label_stems = {f.stem for f in label_files}
missing_labels = image_stems - label_stems
if missing_labels:
results['warnings'].append(
f"{len(missing_labels)} images in {split} missing labels"
)
return results