add sam2 yolo auto annotation

This commit is contained in:
2026-02-04 15:29:36 +07:00
parent 7e56948ece
commit 5a951d8812
2061 changed files with 316473 additions and 0 deletions
+6
View File
@@ -0,0 +1,6 @@
# SAM2-YOLO Pipeline Utilities
from .video_utils import VideoProcessor
from .sam2_utils import SAM2Annotator
from .yolo_utils import YOLODatasetBuilder
__all__ = ['VideoProcessor', 'SAM2Annotator', 'YOLODatasetBuilder']
+343
View File
@@ -0,0 +1,343 @@
"""
SAM2 annotation utilities for automatic mask generation.
"""
import os
import cv2
import json
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
class SAM2Annotator:
"""Automatic annotation using SAM2 model."""
# Available SAM2 model variants
MODEL_CONFIGS = {
'tiny': {
'config': 'sam2_hiera_t.yaml',
'checkpoint': 'sam2_hiera_tiny.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'
},
'small': {
'config': 'sam2_hiera_s.yaml',
'checkpoint': 'sam2_hiera_small.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'
},
'base_plus': {
'config': 'sam2_hiera_b+.yaml',
'checkpoint': 'sam2_hiera_base_plus.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'
},
'large': {
'config': 'sam2_hiera_l.yaml',
'checkpoint': 'sam2_hiera_large.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
}
}
def __init__(
self,
model_size: str = 'large',
checkpoint_dir: str = './checkpoints',
device: Optional[str] = None
):
"""
Initialize SAM2 annotator.
Args:
model_size: Model variant ('tiny', 'small', 'base_plus', 'large')
checkpoint_dir: Directory to store model checkpoints
device: Device to run model on ('cuda', 'cpu', or None for auto)
"""
self.model_size = model_size
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.model = None
self.predictor = None
self.mask_generator = None
def download_checkpoint(self) -> str:
"""Download SAM2 checkpoint if not exists."""
config = self.MODEL_CONFIGS[self.model_size]
checkpoint_path = self.checkpoint_dir / config['checkpoint']
if not checkpoint_path.exists():
print(f"Downloading SAM2 {self.model_size} checkpoint...")
import urllib.request
urllib.request.urlretrieve(config['url'], str(checkpoint_path))
print(f"Downloaded to {checkpoint_path}")
else:
print(f"Checkpoint exists: {checkpoint_path}")
return str(checkpoint_path)
def load_model(self):
"""Load SAM2 model for inference."""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
checkpoint_path = self.download_checkpoint()
config = self.MODEL_CONFIGS[self.model_size]
print(f"Loading SAM2 {self.model_size} model on {self.device}...")
self.model = build_sam2(
config['config'],
checkpoint_path,
device=self.device
)
# Create predictor for interactive annotation
self.predictor = SAM2ImagePredictor(self.model)
# Create automatic mask generator
self.mask_generator = SAM2AutomaticMaskGenerator(
model=self.model,
points_per_side=32,
points_per_batch=64,
pred_iou_thresh=0.7,
stability_score_thresh=0.92,
stability_score_offset=1.0,
box_nms_thresh=0.7,
crop_n_layers=1,
crop_nms_thresh=0.7,
crop_overlap_ratio=0.34,
crop_n_points_downscale_factor=2,
min_mask_region_area=100
)
print("Model loaded successfully!")
def generate_masks_auto(
self,
image: np.ndarray,
min_area: int = 100,
max_area: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Automatically generate masks for all objects in image.
Args:
image: Input image (BGR format from OpenCV)
min_area: Minimum mask area in pixels
max_area: Maximum mask area in pixels
Returns:
List of mask dictionaries with keys:
- segmentation: Binary mask
- bbox: Bounding box [x, y, w, h]
- area: Mask area
- predicted_iou: Confidence score
"""
if self.mask_generator is None:
self.load_model()
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Generate masks
masks = self.mask_generator.generate(image_rgb)
# Filter by area
filtered_masks = []
for mask in masks:
area = mask['area']
if area >= min_area:
if max_area is None or area <= max_area:
filtered_masks.append(mask)
# Sort by area (largest first)
filtered_masks.sort(key=lambda x: x['area'], reverse=True)
return filtered_masks
def generate_masks_with_prompts(
self,
image: np.ndarray,
points: Optional[List[Tuple[int, int]]] = None,
point_labels: Optional[List[int]] = None,
boxes: Optional[List[List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate masks using point or box prompts.
Args:
image: Input image (BGR format)
points: List of (x, y) point coordinates
point_labels: Labels for points (1=foreground, 0=background)
boxes: List of boxes [x1, y1, x2, y2]
Returns:
Tuple of (masks, scores, logits)
"""
if self.predictor is None:
self.load_model()
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
# Prepare inputs
point_coords = np.array(points) if points else None
point_labels_arr = np.array(point_labels) if point_labels else None
box_arr = np.array(boxes) if boxes else None
masks, scores, logits = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels_arr,
box=box_arr,
multimask_output=True
)
return masks, scores, logits
def annotate_frames(
self,
frames_dir: str,
output_dir: str,
min_area: int = 100,
max_area: Optional[int] = None,
save_visualizations: bool = True
) -> Dict[str, List[Dict]]:
"""
Annotate all frames in a directory.
Args:
frames_dir: Directory containing frame images
output_dir: Directory to save annotations
min_area: Minimum object area
max_area: Maximum object area
save_visualizations: Save annotated visualization images
Returns:
Dictionary mapping frame names to annotation lists
"""
frames_path = Path(frames_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
if save_visualizations:
vis_path = output_path / 'visualizations'
vis_path.mkdir(exist_ok=True)
# Find all frame images
frame_files = sorted(
list(frames_path.glob("*.jpg")) +
list(frames_path.glob("*.png"))
)
if not frame_files:
raise ValueError(f"No frames found in {frames_dir}")
print(f"Found {len(frame_files)} frames to annotate")
all_annotations = {}
for frame_file in tqdm(frame_files, desc="Annotating frames"):
image = cv2.imread(str(frame_file))
if image is None:
continue
# Generate masks
masks = self.generate_masks_auto(image, min_area, max_area)
# Convert to serializable format
annotations = []
for i, mask_data in enumerate(masks):
ann = {
'id': i,
'bbox': mask_data['bbox'], # [x, y, w, h]
'area': int(mask_data['area']),
'predicted_iou': float(mask_data['predicted_iou']),
'stability_score': float(mask_data['stability_score'])
}
annotations.append(ann)
# Save mask as separate file if needed
mask_filename = f"{frame_file.stem}_mask_{i:03d}.png"
mask_path = output_path / 'masks' / frame_file.stem
mask_path.mkdir(parents=True, exist_ok=True)
cv2.imwrite(
str(mask_path / mask_filename),
mask_data['segmentation'].astype(np.uint8) * 255
)
all_annotations[frame_file.name] = annotations
# Save visualization
if save_visualizations:
vis_image = self._visualize_masks(image, masks)
cv2.imwrite(str(vis_path / frame_file.name), vis_image)
# Save annotations JSON
annotations_file = output_path / 'annotations.json'
with open(annotations_file, 'w') as f:
json.dump(all_annotations, f, indent=2)
print(f"Annotations saved to {annotations_file}")
return all_annotations
def _visualize_masks(
self,
image: np.ndarray,
masks: List[Dict],
alpha: float = 0.5
) -> np.ndarray:
"""Create visualization of masks overlaid on image."""
vis_image = image.copy()
for mask_data in masks:
mask = mask_data['segmentation']
color = np.random.randint(0, 255, 3).tolist()
# Create colored overlay
overlay = vis_image.copy()
overlay[mask] = color
vis_image = cv2.addWeighted(vis_image, 1 - alpha, overlay, alpha, 0)
# Draw bounding box
x, y, w, h = mask_data['bbox']
cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
return vis_image
def load_sam2_video_predictor(
model_size: str = 'large',
checkpoint_dir: str = './checkpoints',
device: str = 'cuda'
):
"""
Load SAM2 video predictor for tracking objects across frames.
Args:
model_size: Model size variant
checkpoint_dir: Checkpoint directory
device: Device to use
Returns:
SAM2VideoPredictor instance
"""
from sam2.build_sam import build_sam2_video_predictor
annotator = SAM2Annotator(model_size, checkpoint_dir, device)
checkpoint_path = annotator.download_checkpoint()
config = SAM2Annotator.MODEL_CONFIGS[model_size]
predictor = build_sam2_video_predictor(
config['config'],
checkpoint_path,
device=device
)
return predictor
+202
View File
@@ -0,0 +1,202 @@
"""
Video processing utilities for frame extraction.
"""
import os
import cv2
import numpy as np
from pathlib import Path
from typing import Generator, Tuple, Optional, List
from tqdm import tqdm
class VideoProcessor:
"""Extract frames from video files for annotation."""
def __init__(self, video_path: str):
"""
Initialize video processor.
Args:
video_path: Path to the video file
"""
self.video_path = Path(video_path)
if not self.video_path.exists():
raise FileNotFoundError(f"Video not found: {video_path}")
self.cap = cv2.VideoCapture(str(self.video_path))
if not self.cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.duration = self.frame_count / self.fps if self.fps > 0 else 0
def __del__(self):
if hasattr(self, 'cap') and self.cap is not None:
self.cap.release()
def get_info(self) -> dict:
"""Get video information."""
return {
'path': str(self.video_path),
'fps': self.fps,
'frame_count': self.frame_count,
'width': self.width,
'height': self.height,
'duration_seconds': self.duration
}
def extract_frames(
self,
output_dir: str,
sample_fps: Optional[float] = None,
max_frames: Optional[int] = None,
start_time: float = 0.0,
end_time: Optional[float] = None,
resize: Optional[Tuple[int, int]] = None
) -> List[str]:
"""
Extract frames from video and save to directory.
Args:
output_dir: Directory to save extracted frames
sample_fps: Target FPS for sampling (None = use all frames)
max_frames: Maximum number of frames to extract
start_time: Start time in seconds
end_time: End time in seconds (None = until end)
resize: Resize frames to (width, height)
Returns:
List of saved frame paths
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Calculate frame interval for sampling
if sample_fps and sample_fps < self.fps:
frame_interval = int(self.fps / sample_fps)
else:
frame_interval = 1
# Calculate frame range
start_frame = int(start_time * self.fps)
end_frame = int(end_time * self.fps) if end_time else self.frame_count
end_frame = min(end_frame, self.frame_count)
# Reset video position
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
saved_paths = []
frame_idx = start_frame
extracted_count = 0
pbar = tqdm(total=min((end_frame - start_frame) // frame_interval, max_frames or float('inf')),
desc="Extracting frames")
while frame_idx < end_frame:
if max_frames and extracted_count >= max_frames:
break
ret, frame = self.cap.read()
if not ret:
break
if (frame_idx - start_frame) % frame_interval == 0:
if resize:
frame = cv2.resize(frame, resize)
# Save frame with zero-padded index
frame_name = f"frame_{frame_idx:06d}.jpg"
frame_path = output_path / frame_name
cv2.imwrite(str(frame_path), frame)
saved_paths.append(str(frame_path))
extracted_count += 1
pbar.update(1)
frame_idx += 1
pbar.close()
print(f"Extracted {len(saved_paths)} frames to {output_dir}")
return saved_paths
def iterate_frames(
self,
sample_fps: Optional[float] = None,
start_time: float = 0.0,
end_time: Optional[float] = None
) -> Generator[Tuple[int, np.ndarray], None, None]:
"""
Iterate through video frames as a generator.
Args:
sample_fps: Target FPS for sampling
start_time: Start time in seconds
end_time: End time in seconds
Yields:
Tuple of (frame_index, frame_array)
"""
if sample_fps and sample_fps < self.fps:
frame_interval = int(self.fps / sample_fps)
else:
frame_interval = 1
start_frame = int(start_time * self.fps)
end_frame = int(end_time * self.fps) if end_time else self.frame_count
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frame_idx = start_frame
while frame_idx < end_frame:
ret, frame = self.cap.read()
if not ret:
break
if (frame_idx - start_frame) % frame_interval == 0:
yield frame_idx, frame
frame_idx += 1
def frames_to_video(
frames_dir: str,
output_path: str,
fps: float = 30.0,
codec: str = 'mp4v'
) -> str:
"""
Convert frames directory back to video.
Args:
frames_dir: Directory containing frame images
output_path: Output video path
fps: Frames per second
codec: Video codec
Returns:
Path to created video
"""
frames_path = Path(frames_dir)
frame_files = sorted(frames_path.glob("*.jpg")) + sorted(frames_path.glob("*.png"))
if not frame_files:
raise ValueError(f"No frames found in {frames_dir}")
# Read first frame to get dimensions
first_frame = cv2.imread(str(frame_files[0]))
height, width = first_frame.shape[:2]
# Create video writer
fourcc = cv2.VideoWriter_fourcc(*codec)
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame_file in tqdm(frame_files, desc="Creating video"):
frame = cv2.imread(str(frame_file))
out.write(frame)
out.release()
print(f"Video saved to {output_path}")
return output_path
+478
View File
@@ -0,0 +1,478 @@
"""
YOLO dataset utilities for converting SAM2 annotations to YOLO format.
"""
import os
import cv2
import json
import shutil
import yaml
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
import random
class YOLODatasetBuilder:
"""Build YOLO format dataset from SAM2 annotations."""
def __init__(
self,
output_dir: str,
class_names: Optional[List[str]] = None
):
"""
Initialize YOLO dataset builder.
Args:
output_dir: Root directory for YOLO dataset
class_names: List of class names (default: ['object'])
"""
self.output_dir = Path(output_dir)
self.class_names = class_names or ['object']
# Create directory structure
self.images_train = self.output_dir / 'images' / 'train'
self.images_val = self.output_dir / 'images' / 'val'
self.labels_train = self.output_dir / 'labels' / 'train'
self.labels_val = self.output_dir / 'labels' / 'val'
for dir_path in [self.images_train, self.images_val,
self.labels_train, self.labels_val]:
dir_path.mkdir(parents=True, exist_ok=True)
def mask_to_bbox(
self,
mask: np.ndarray,
image_width: int,
image_height: int
) -> Tuple[float, float, float, float]:
"""
Convert binary mask to normalized YOLO bounding box.
Args:
mask: Binary mask array
image_width: Image width
image_height: Image height
Returns:
Tuple of (x_center, y_center, width, height) normalized to [0, 1]
"""
# Find mask coordinates
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if not rows.any() or not cols.any():
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
# Calculate YOLO format (normalized center x, center y, width, height)
x_center = ((x_min + x_max) / 2) / image_width
y_center = ((y_min + y_max) / 2) / image_height
width = (x_max - x_min) / image_width
height = (y_max - y_min) / image_height
return (x_center, y_center, width, height)
def bbox_xywh_to_yolo(
self,
bbox: List[int],
image_width: int,
image_height: int
) -> Tuple[float, float, float, float]:
"""
Convert [x, y, w, h] bbox to normalized YOLO format.
Args:
bbox: Bounding box [x, y, width, height]
image_width: Image width
image_height: Image height
Returns:
Tuple of (x_center, y_center, width, height) normalized
"""
x, y, w, h = bbox
x_center = (x + w / 2) / image_width
y_center = (y + h / 2) / image_height
width = w / image_width
height = h / image_height
return (x_center, y_center, width, height)
def convert_sam2_annotations(
self,
frames_dir: str,
annotations_file: str,
masks_dir: Optional[str] = None,
class_id: int = 0,
min_area: int = 100,
min_bbox_size: float = 0.01,
val_split: float = 0.2,
seed: int = 42
) -> Dict[str, int]:
"""
Convert SAM2 annotations to YOLO format dataset.
Args:
frames_dir: Directory containing frame images
annotations_file: Path to SAM2 annotations JSON
masks_dir: Directory containing mask images (optional)
class_id: Class ID for all objects
min_area: Minimum bbox area in pixels
min_bbox_size: Minimum bbox dimension (normalized)
val_split: Validation set ratio
seed: Random seed for split
Returns:
Statistics dictionary
"""
frames_path = Path(frames_dir)
# Load annotations
with open(annotations_file, 'r') as f:
annotations = json.load(f)
# Get all frame files
frame_files = list(annotations.keys())
# Shuffle and split
random.seed(seed)
random.shuffle(frame_files)
split_idx = int(len(frame_files) * (1 - val_split))
train_files = frame_files[:split_idx]
val_files = frame_files[split_idx:]
stats = {
'total_frames': len(frame_files),
'train_frames': len(train_files),
'val_frames': len(val_files),
'total_objects': 0,
'train_objects': 0,
'val_objects': 0
}
# Process training files
for frame_name in tqdm(train_files, desc="Processing train"):
count = self._process_frame(
frames_path / frame_name,
annotations.get(frame_name, []),
self.images_train,
self.labels_train,
class_id,
min_area,
min_bbox_size
)
stats['train_objects'] += count
# Process validation files
for frame_name in tqdm(val_files, desc="Processing val"):
count = self._process_frame(
frames_path / frame_name,
annotations.get(frame_name, []),
self.images_val,
self.labels_val,
class_id,
min_area,
min_bbox_size
)
stats['val_objects'] += count
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
# Create data.yaml
self._create_data_yaml()
print(f"\nDataset created at {self.output_dir}")
print(f" Train: {stats['train_frames']} images, {stats['train_objects']} objects")
print(f" Val: {stats['val_frames']} images, {stats['val_objects']} objects")
return stats
def _process_frame(
self,
image_path: Path,
frame_annotations: List[Dict],
images_dir: Path,
labels_dir: Path,
class_id: int,
min_area: int,
min_bbox_size: float
) -> int:
"""Process a single frame and create YOLO label file."""
if not image_path.exists():
return 0
# Read image to get dimensions
image = cv2.imread(str(image_path))
if image is None:
return 0
height, width = image.shape[:2]
# Copy image to dataset
dest_image = images_dir / image_path.name
shutil.copy2(image_path, dest_image)
# Create label file
label_name = image_path.stem + '.txt'
label_path = labels_dir / label_name
labels = []
for ann in frame_annotations:
bbox = ann.get('bbox', [])
area = ann.get('area', 0)
if area < min_area:
continue
if len(bbox) == 4:
x_center, y_center, w, h = self.bbox_xywh_to_yolo(
bbox, width, height
)
# Filter small boxes
if w < min_bbox_size or h < min_bbox_size:
continue
# YOLO format: class x_center y_center width height
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
labels.append(label_line)
# Write label file
with open(label_path, 'w') as f:
f.write('\n'.join(labels))
return len(labels)
def _create_data_yaml(self):
"""Create YOLO data.yaml configuration file."""
data_config = {
'path': str(self.output_dir.absolute()),
'train': 'images/train',
'val': 'images/val',
'names': {i: name for i, name in enumerate(self.class_names)},
'nc': len(self.class_names)
}
yaml_path = self.output_dir / 'data.yaml'
with open(yaml_path, 'w') as f:
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
print(f"Created {yaml_path}")
def create_from_masks_directory(
self,
frames_dir: str,
masks_dir: str,
class_id: int = 0,
val_split: float = 0.2,
seed: int = 42
) -> Dict[str, int]:
"""
Create YOLO dataset directly from mask images.
Args:
frames_dir: Directory containing original frames
masks_dir: Directory containing mask images (one per object)
class_id: Class ID for all objects
val_split: Validation split ratio
seed: Random seed
Returns:
Statistics dictionary
"""
frames_path = Path(frames_dir)
masks_path = Path(masks_dir)
# Find all frame images
frame_files = sorted(
list(frames_path.glob("*.jpg")) +
list(frames_path.glob("*.png"))
)
random.seed(seed)
random.shuffle(frame_files)
split_idx = int(len(frame_files) * (1 - val_split))
train_files = frame_files[:split_idx]
val_files = frame_files[split_idx:]
stats = {
'total_frames': len(frame_files),
'train_frames': len(train_files),
'val_frames': len(val_files),
'total_objects': 0,
'train_objects': 0,
'val_objects': 0
}
# Process frames
for frame_list, images_dir, labels_dir, key in [
(train_files, self.images_train, self.labels_train, 'train_objects'),
(val_files, self.images_val, self.labels_val, 'val_objects')
]:
for frame_file in tqdm(frame_list, desc=f"Processing {key.split('_')[0]}"):
# Find corresponding masks
frame_masks_dir = masks_path / frame_file.stem
if frame_masks_dir.exists():
mask_files = list(frame_masks_dir.glob("*.png"))
else:
mask_files = []
count = self._process_frame_with_masks(
frame_file,
mask_files,
images_dir,
labels_dir,
class_id
)
stats[key] += count
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
self._create_data_yaml()
return stats
def _process_frame_with_masks(
self,
image_path: Path,
mask_files: List[Path],
images_dir: Path,
labels_dir: Path,
class_id: int
) -> int:
"""Process frame with mask files."""
image = cv2.imread(str(image_path))
if image is None:
return 0
height, width = image.shape[:2]
# Copy image
shutil.copy2(image_path, images_dir / image_path.name)
labels = []
for mask_file in mask_files:
mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
if mask is None:
continue
mask = mask > 127 # Convert to binary
bbox = self.mask_to_bbox(mask, width, height)
if bbox is not None:
x_center, y_center, w, h = bbox
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
labels.append(label_line)
# Write labels
label_path = labels_dir / (image_path.stem + '.txt')
with open(label_path, 'w') as f:
f.write('\n'.join(labels))
return len(labels)
def export_for_kaggle(
dataset_dir: str,
output_zip: str,
include_checkpoints: bool = False
) -> str:
"""
Export YOLO dataset as zip file for Kaggle upload.
Args:
dataset_dir: Path to YOLO dataset directory
output_zip: Output zip file path
include_checkpoints: Include model checkpoints if present
Returns:
Path to created zip file
"""
import zipfile
dataset_path = Path(dataset_dir)
with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(dataset_path):
# Skip checkpoints unless requested
if not include_checkpoints and 'checkpoints' in root:
continue
for file in files:
file_path = Path(root) / file
arcname = file_path.relative_to(dataset_path.parent)
zipf.write(file_path, arcname)
print(f"Dataset exported to {output_zip}")
return output_zip
def validate_yolo_dataset(dataset_dir: str) -> Dict[str, Any]:
"""
Validate YOLO dataset structure and contents.
Args:
dataset_dir: Path to YOLO dataset
Returns:
Validation results dictionary
"""
dataset_path = Path(dataset_dir)
results = {
'valid': True,
'errors': [],
'warnings': [],
'stats': {}
}
# Check data.yaml
yaml_path = dataset_path / 'data.yaml'
if not yaml_path.exists():
results['errors'].append("Missing data.yaml")
results['valid'] = False
else:
with open(yaml_path) as f:
config = yaml.safe_load(f)
results['stats']['num_classes'] = config.get('nc', 0)
results['stats']['class_names'] = config.get('names', {})
# Check directories
for split in ['train', 'val']:
images_dir = dataset_path / 'images' / split
labels_dir = dataset_path / 'labels' / split
if not images_dir.exists():
results['errors'].append(f"Missing images/{split} directory")
results['valid'] = False
continue
if not labels_dir.exists():
results['errors'].append(f"Missing labels/{split} directory")
results['valid'] = False
continue
# Count files
image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
label_files = list(labels_dir.glob("*.txt"))
results['stats'][f'{split}_images'] = len(image_files)
results['stats'][f'{split}_labels'] = len(label_files)
# Check matching
image_stems = {f.stem for f in image_files}
label_stems = {f.stem for f in label_files}
missing_labels = image_stems - label_stems
if missing_labels:
results['warnings'].append(
f"{len(missing_labels)} images in {split} missing labels"
)
return results