add sam2 yolo auto annotation
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# SAM2-YOLO Pipeline Utilities
|
||||
from .video_utils import VideoProcessor
|
||||
from .sam2_utils import SAM2Annotator
|
||||
from .yolo_utils import YOLODatasetBuilder
|
||||
|
||||
__all__ = ['VideoProcessor', 'SAM2Annotator', 'YOLODatasetBuilder']
|
||||
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
SAM2 annotation utilities for automatic mask generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class SAM2Annotator:
|
||||
"""Automatic annotation using SAM2 model."""
|
||||
|
||||
# Available SAM2 model variants
|
||||
MODEL_CONFIGS = {
|
||||
'tiny': {
|
||||
'config': 'sam2_hiera_t.yaml',
|
||||
'checkpoint': 'sam2_hiera_tiny.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'
|
||||
},
|
||||
'small': {
|
||||
'config': 'sam2_hiera_s.yaml',
|
||||
'checkpoint': 'sam2_hiera_small.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'
|
||||
},
|
||||
'base_plus': {
|
||||
'config': 'sam2_hiera_b+.yaml',
|
||||
'checkpoint': 'sam2_hiera_base_plus.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'
|
||||
},
|
||||
'large': {
|
||||
'config': 'sam2_hiera_l.yaml',
|
||||
'checkpoint': 'sam2_hiera_large.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = 'large',
|
||||
checkpoint_dir: str = './checkpoints',
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize SAM2 annotator.
|
||||
|
||||
Args:
|
||||
model_size: Model variant ('tiny', 'small', 'base_plus', 'large')
|
||||
checkpoint_dir: Directory to store model checkpoints
|
||||
device: Device to run model on ('cuda', 'cpu', or None for auto)
|
||||
"""
|
||||
self.model_size = model_size
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
self.model = None
|
||||
self.predictor = None
|
||||
self.mask_generator = None
|
||||
|
||||
def download_checkpoint(self) -> str:
|
||||
"""Download SAM2 checkpoint if not exists."""
|
||||
config = self.MODEL_CONFIGS[self.model_size]
|
||||
checkpoint_path = self.checkpoint_dir / config['checkpoint']
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Downloading SAM2 {self.model_size} checkpoint...")
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(config['url'], str(checkpoint_path))
|
||||
print(f"Downloaded to {checkpoint_path}")
|
||||
else:
|
||||
print(f"Checkpoint exists: {checkpoint_path}")
|
||||
|
||||
return str(checkpoint_path)
|
||||
|
||||
def load_model(self):
|
||||
"""Load SAM2 model for inference."""
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||
|
||||
checkpoint_path = self.download_checkpoint()
|
||||
config = self.MODEL_CONFIGS[self.model_size]
|
||||
|
||||
print(f"Loading SAM2 {self.model_size} model on {self.device}...")
|
||||
|
||||
self.model = build_sam2(
|
||||
config['config'],
|
||||
checkpoint_path,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Create predictor for interactive annotation
|
||||
self.predictor = SAM2ImagePredictor(self.model)
|
||||
|
||||
# Create automatic mask generator
|
||||
self.mask_generator = SAM2AutomaticMaskGenerator(
|
||||
model=self.model,
|
||||
points_per_side=32,
|
||||
points_per_batch=64,
|
||||
pred_iou_thresh=0.7,
|
||||
stability_score_thresh=0.92,
|
||||
stability_score_offset=1.0,
|
||||
box_nms_thresh=0.7,
|
||||
crop_n_layers=1,
|
||||
crop_nms_thresh=0.7,
|
||||
crop_overlap_ratio=0.34,
|
||||
crop_n_points_downscale_factor=2,
|
||||
min_mask_region_area=100
|
||||
)
|
||||
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate_masks_auto(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
min_area: int = 100,
|
||||
max_area: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Automatically generate masks for all objects in image.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format from OpenCV)
|
||||
min_area: Minimum mask area in pixels
|
||||
max_area: Maximum mask area in pixels
|
||||
|
||||
Returns:
|
||||
List of mask dictionaries with keys:
|
||||
- segmentation: Binary mask
|
||||
- bbox: Bounding box [x, y, w, h]
|
||||
- area: Mask area
|
||||
- predicted_iou: Confidence score
|
||||
"""
|
||||
if self.mask_generator is None:
|
||||
self.load_model()
|
||||
|
||||
# Convert BGR to RGB
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Generate masks
|
||||
masks = self.mask_generator.generate(image_rgb)
|
||||
|
||||
# Filter by area
|
||||
filtered_masks = []
|
||||
for mask in masks:
|
||||
area = mask['area']
|
||||
if area >= min_area:
|
||||
if max_area is None or area <= max_area:
|
||||
filtered_masks.append(mask)
|
||||
|
||||
# Sort by area (largest first)
|
||||
filtered_masks.sort(key=lambda x: x['area'], reverse=True)
|
||||
|
||||
return filtered_masks
|
||||
|
||||
def generate_masks_with_prompts(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
points: Optional[List[Tuple[int, int]]] = None,
|
||||
point_labels: Optional[List[int]] = None,
|
||||
boxes: Optional[List[List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Generate masks using point or box prompts.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format)
|
||||
points: List of (x, y) point coordinates
|
||||
point_labels: Labels for points (1=foreground, 0=background)
|
||||
boxes: List of boxes [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
Tuple of (masks, scores, logits)
|
||||
"""
|
||||
if self.predictor is None:
|
||||
self.load_model()
|
||||
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.predictor.set_image(image_rgb)
|
||||
|
||||
# Prepare inputs
|
||||
point_coords = np.array(points) if points else None
|
||||
point_labels_arr = np.array(point_labels) if point_labels else None
|
||||
box_arr = np.array(boxes) if boxes else None
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels_arr,
|
||||
box=box_arr,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
return masks, scores, logits
|
||||
|
||||
def annotate_frames(
|
||||
self,
|
||||
frames_dir: str,
|
||||
output_dir: str,
|
||||
min_area: int = 100,
|
||||
max_area: Optional[int] = None,
|
||||
save_visualizations: bool = True
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Annotate all frames in a directory.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing frame images
|
||||
output_dir: Directory to save annotations
|
||||
min_area: Minimum object area
|
||||
max_area: Maximum object area
|
||||
save_visualizations: Save annotated visualization images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame names to annotation lists
|
||||
"""
|
||||
frames_path = Path(frames_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_visualizations:
|
||||
vis_path = output_path / 'visualizations'
|
||||
vis_path.mkdir(exist_ok=True)
|
||||
|
||||
# Find all frame images
|
||||
frame_files = sorted(
|
||||
list(frames_path.glob("*.jpg")) +
|
||||
list(frames_path.glob("*.png"))
|
||||
)
|
||||
|
||||
if not frame_files:
|
||||
raise ValueError(f"No frames found in {frames_dir}")
|
||||
|
||||
print(f"Found {len(frame_files)} frames to annotate")
|
||||
|
||||
all_annotations = {}
|
||||
|
||||
for frame_file in tqdm(frame_files, desc="Annotating frames"):
|
||||
image = cv2.imread(str(frame_file))
|
||||
if image is None:
|
||||
continue
|
||||
|
||||
# Generate masks
|
||||
masks = self.generate_masks_auto(image, min_area, max_area)
|
||||
|
||||
# Convert to serializable format
|
||||
annotations = []
|
||||
for i, mask_data in enumerate(masks):
|
||||
ann = {
|
||||
'id': i,
|
||||
'bbox': mask_data['bbox'], # [x, y, w, h]
|
||||
'area': int(mask_data['area']),
|
||||
'predicted_iou': float(mask_data['predicted_iou']),
|
||||
'stability_score': float(mask_data['stability_score'])
|
||||
}
|
||||
annotations.append(ann)
|
||||
|
||||
# Save mask as separate file if needed
|
||||
mask_filename = f"{frame_file.stem}_mask_{i:03d}.png"
|
||||
mask_path = output_path / 'masks' / frame_file.stem
|
||||
mask_path.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(
|
||||
str(mask_path / mask_filename),
|
||||
mask_data['segmentation'].astype(np.uint8) * 255
|
||||
)
|
||||
|
||||
all_annotations[frame_file.name] = annotations
|
||||
|
||||
# Save visualization
|
||||
if save_visualizations:
|
||||
vis_image = self._visualize_masks(image, masks)
|
||||
cv2.imwrite(str(vis_path / frame_file.name), vis_image)
|
||||
|
||||
# Save annotations JSON
|
||||
annotations_file = output_path / 'annotations.json'
|
||||
with open(annotations_file, 'w') as f:
|
||||
json.dump(all_annotations, f, indent=2)
|
||||
|
||||
print(f"Annotations saved to {annotations_file}")
|
||||
return all_annotations
|
||||
|
||||
def _visualize_masks(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
masks: List[Dict],
|
||||
alpha: float = 0.5
|
||||
) -> np.ndarray:
|
||||
"""Create visualization of masks overlaid on image."""
|
||||
vis_image = image.copy()
|
||||
|
||||
for mask_data in masks:
|
||||
mask = mask_data['segmentation']
|
||||
color = np.random.randint(0, 255, 3).tolist()
|
||||
|
||||
# Create colored overlay
|
||||
overlay = vis_image.copy()
|
||||
overlay[mask] = color
|
||||
vis_image = cv2.addWeighted(vis_image, 1 - alpha, overlay, alpha, 0)
|
||||
|
||||
# Draw bounding box
|
||||
x, y, w, h = mask_data['bbox']
|
||||
cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
|
||||
|
||||
return vis_image
|
||||
|
||||
|
||||
def load_sam2_video_predictor(
|
||||
model_size: str = 'large',
|
||||
checkpoint_dir: str = './checkpoints',
|
||||
device: str = 'cuda'
|
||||
):
|
||||
"""
|
||||
Load SAM2 video predictor for tracking objects across frames.
|
||||
|
||||
Args:
|
||||
model_size: Model size variant
|
||||
checkpoint_dir: Checkpoint directory
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
SAM2VideoPredictor instance
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
annotator = SAM2Annotator(model_size, checkpoint_dir, device)
|
||||
checkpoint_path = annotator.download_checkpoint()
|
||||
config = SAM2Annotator.MODEL_CONFIGS[model_size]
|
||||
|
||||
predictor = build_sam2_video_predictor(
|
||||
config['config'],
|
||||
checkpoint_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
return predictor
|
||||
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Video processing utilities for frame extraction.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple, Optional, List
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
"""Extract frames from video files for annotation."""
|
||||
|
||||
def __init__(self, video_path: str):
|
||||
"""
|
||||
Initialize video processor.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
"""
|
||||
self.video_path = Path(video_path)
|
||||
if not self.video_path.exists():
|
||||
raise FileNotFoundError(f"Video not found: {video_path}")
|
||||
|
||||
self.cap = cv2.VideoCapture(str(self.video_path))
|
||||
if not self.cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
|
||||
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
self.duration = self.frame_count / self.fps if self.fps > 0 else 0
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'cap') and self.cap is not None:
|
||||
self.cap.release()
|
||||
|
||||
def get_info(self) -> dict:
|
||||
"""Get video information."""
|
||||
return {
|
||||
'path': str(self.video_path),
|
||||
'fps': self.fps,
|
||||
'frame_count': self.frame_count,
|
||||
'width': self.width,
|
||||
'height': self.height,
|
||||
'duration_seconds': self.duration
|
||||
}
|
||||
|
||||
def extract_frames(
|
||||
self,
|
||||
output_dir: str,
|
||||
sample_fps: Optional[float] = None,
|
||||
max_frames: Optional[int] = None,
|
||||
start_time: float = 0.0,
|
||||
end_time: Optional[float] = None,
|
||||
resize: Optional[Tuple[int, int]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Extract frames from video and save to directory.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save extracted frames
|
||||
sample_fps: Target FPS for sampling (None = use all frames)
|
||||
max_frames: Maximum number of frames to extract
|
||||
start_time: Start time in seconds
|
||||
end_time: End time in seconds (None = until end)
|
||||
resize: Resize frames to (width, height)
|
||||
|
||||
Returns:
|
||||
List of saved frame paths
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Calculate frame interval for sampling
|
||||
if sample_fps and sample_fps < self.fps:
|
||||
frame_interval = int(self.fps / sample_fps)
|
||||
else:
|
||||
frame_interval = 1
|
||||
|
||||
# Calculate frame range
|
||||
start_frame = int(start_time * self.fps)
|
||||
end_frame = int(end_time * self.fps) if end_time else self.frame_count
|
||||
end_frame = min(end_frame, self.frame_count)
|
||||
|
||||
# Reset video position
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
saved_paths = []
|
||||
frame_idx = start_frame
|
||||
extracted_count = 0
|
||||
|
||||
pbar = tqdm(total=min((end_frame - start_frame) // frame_interval, max_frames or float('inf')),
|
||||
desc="Extracting frames")
|
||||
|
||||
while frame_idx < end_frame:
|
||||
if max_frames and extracted_count >= max_frames:
|
||||
break
|
||||
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if (frame_idx - start_frame) % frame_interval == 0:
|
||||
if resize:
|
||||
frame = cv2.resize(frame, resize)
|
||||
|
||||
# Save frame with zero-padded index
|
||||
frame_name = f"frame_{frame_idx:06d}.jpg"
|
||||
frame_path = output_path / frame_name
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
saved_paths.append(str(frame_path))
|
||||
extracted_count += 1
|
||||
pbar.update(1)
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
pbar.close()
|
||||
print(f"Extracted {len(saved_paths)} frames to {output_dir}")
|
||||
return saved_paths
|
||||
|
||||
def iterate_frames(
|
||||
self,
|
||||
sample_fps: Optional[float] = None,
|
||||
start_time: float = 0.0,
|
||||
end_time: Optional[float] = None
|
||||
) -> Generator[Tuple[int, np.ndarray], None, None]:
|
||||
"""
|
||||
Iterate through video frames as a generator.
|
||||
|
||||
Args:
|
||||
sample_fps: Target FPS for sampling
|
||||
start_time: Start time in seconds
|
||||
end_time: End time in seconds
|
||||
|
||||
Yields:
|
||||
Tuple of (frame_index, frame_array)
|
||||
"""
|
||||
if sample_fps and sample_fps < self.fps:
|
||||
frame_interval = int(self.fps / sample_fps)
|
||||
else:
|
||||
frame_interval = 1
|
||||
|
||||
start_frame = int(start_time * self.fps)
|
||||
end_frame = int(end_time * self.fps) if end_time else self.frame_count
|
||||
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
frame_idx = start_frame
|
||||
while frame_idx < end_frame:
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if (frame_idx - start_frame) % frame_interval == 0:
|
||||
yield frame_idx, frame
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
|
||||
def frames_to_video(
|
||||
frames_dir: str,
|
||||
output_path: str,
|
||||
fps: float = 30.0,
|
||||
codec: str = 'mp4v'
|
||||
) -> str:
|
||||
"""
|
||||
Convert frames directory back to video.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing frame images
|
||||
output_path: Output video path
|
||||
fps: Frames per second
|
||||
codec: Video codec
|
||||
|
||||
Returns:
|
||||
Path to created video
|
||||
"""
|
||||
frames_path = Path(frames_dir)
|
||||
frame_files = sorted(frames_path.glob("*.jpg")) + sorted(frames_path.glob("*.png"))
|
||||
|
||||
if not frame_files:
|
||||
raise ValueError(f"No frames found in {frames_dir}")
|
||||
|
||||
# Read first frame to get dimensions
|
||||
first_frame = cv2.imread(str(frame_files[0]))
|
||||
height, width = first_frame.shape[:2]
|
||||
|
||||
# Create video writer
|
||||
fourcc = cv2.VideoWriter_fourcc(*codec)
|
||||
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||||
|
||||
for frame_file in tqdm(frame_files, desc="Creating video"):
|
||||
frame = cv2.imread(str(frame_file))
|
||||
out.write(frame)
|
||||
|
||||
out.release()
|
||||
print(f"Video saved to {output_path}")
|
||||
return output_path
|
||||
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
YOLO dataset utilities for converting SAM2 annotations to YOLO format.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import shutil
|
||||
import yaml
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
|
||||
class YOLODatasetBuilder:
|
||||
"""Build YOLO format dataset from SAM2 annotations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
class_names: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Initialize YOLO dataset builder.
|
||||
|
||||
Args:
|
||||
output_dir: Root directory for YOLO dataset
|
||||
class_names: List of class names (default: ['object'])
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.class_names = class_names or ['object']
|
||||
|
||||
# Create directory structure
|
||||
self.images_train = self.output_dir / 'images' / 'train'
|
||||
self.images_val = self.output_dir / 'images' / 'val'
|
||||
self.labels_train = self.output_dir / 'labels' / 'train'
|
||||
self.labels_val = self.output_dir / 'labels' / 'val'
|
||||
|
||||
for dir_path in [self.images_train, self.images_val,
|
||||
self.labels_train, self.labels_val]:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def mask_to_bbox(
|
||||
self,
|
||||
mask: np.ndarray,
|
||||
image_width: int,
|
||||
image_height: int
|
||||
) -> Tuple[float, float, float, float]:
|
||||
"""
|
||||
Convert binary mask to normalized YOLO bounding box.
|
||||
|
||||
Args:
|
||||
mask: Binary mask array
|
||||
image_width: Image width
|
||||
image_height: Image height
|
||||
|
||||
Returns:
|
||||
Tuple of (x_center, y_center, width, height) normalized to [0, 1]
|
||||
"""
|
||||
# Find mask coordinates
|
||||
rows = np.any(mask, axis=1)
|
||||
cols = np.any(mask, axis=0)
|
||||
|
||||
if not rows.any() or not cols.any():
|
||||
return None
|
||||
|
||||
y_min, y_max = np.where(rows)[0][[0, -1]]
|
||||
x_min, x_max = np.where(cols)[0][[0, -1]]
|
||||
|
||||
# Calculate YOLO format (normalized center x, center y, width, height)
|
||||
x_center = ((x_min + x_max) / 2) / image_width
|
||||
y_center = ((y_min + y_max) / 2) / image_height
|
||||
width = (x_max - x_min) / image_width
|
||||
height = (y_max - y_min) / image_height
|
||||
|
||||
return (x_center, y_center, width, height)
|
||||
|
||||
def bbox_xywh_to_yolo(
|
||||
self,
|
||||
bbox: List[int],
|
||||
image_width: int,
|
||||
image_height: int
|
||||
) -> Tuple[float, float, float, float]:
|
||||
"""
|
||||
Convert [x, y, w, h] bbox to normalized YOLO format.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box [x, y, width, height]
|
||||
image_width: Image width
|
||||
image_height: Image height
|
||||
|
||||
Returns:
|
||||
Tuple of (x_center, y_center, width, height) normalized
|
||||
"""
|
||||
x, y, w, h = bbox
|
||||
|
||||
x_center = (x + w / 2) / image_width
|
||||
y_center = (y + h / 2) / image_height
|
||||
width = w / image_width
|
||||
height = h / image_height
|
||||
|
||||
return (x_center, y_center, width, height)
|
||||
|
||||
def convert_sam2_annotations(
|
||||
self,
|
||||
frames_dir: str,
|
||||
annotations_file: str,
|
||||
masks_dir: Optional[str] = None,
|
||||
class_id: int = 0,
|
||||
min_area: int = 100,
|
||||
min_bbox_size: float = 0.01,
|
||||
val_split: float = 0.2,
|
||||
seed: int = 42
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Convert SAM2 annotations to YOLO format dataset.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing frame images
|
||||
annotations_file: Path to SAM2 annotations JSON
|
||||
masks_dir: Directory containing mask images (optional)
|
||||
class_id: Class ID for all objects
|
||||
min_area: Minimum bbox area in pixels
|
||||
min_bbox_size: Minimum bbox dimension (normalized)
|
||||
val_split: Validation set ratio
|
||||
seed: Random seed for split
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
frames_path = Path(frames_dir)
|
||||
|
||||
# Load annotations
|
||||
with open(annotations_file, 'r') as f:
|
||||
annotations = json.load(f)
|
||||
|
||||
# Get all frame files
|
||||
frame_files = list(annotations.keys())
|
||||
|
||||
# Shuffle and split
|
||||
random.seed(seed)
|
||||
random.shuffle(frame_files)
|
||||
|
||||
split_idx = int(len(frame_files) * (1 - val_split))
|
||||
train_files = frame_files[:split_idx]
|
||||
val_files = frame_files[split_idx:]
|
||||
|
||||
stats = {
|
||||
'total_frames': len(frame_files),
|
||||
'train_frames': len(train_files),
|
||||
'val_frames': len(val_files),
|
||||
'total_objects': 0,
|
||||
'train_objects': 0,
|
||||
'val_objects': 0
|
||||
}
|
||||
|
||||
# Process training files
|
||||
for frame_name in tqdm(train_files, desc="Processing train"):
|
||||
count = self._process_frame(
|
||||
frames_path / frame_name,
|
||||
annotations.get(frame_name, []),
|
||||
self.images_train,
|
||||
self.labels_train,
|
||||
class_id,
|
||||
min_area,
|
||||
min_bbox_size
|
||||
)
|
||||
stats['train_objects'] += count
|
||||
|
||||
# Process validation files
|
||||
for frame_name in tqdm(val_files, desc="Processing val"):
|
||||
count = self._process_frame(
|
||||
frames_path / frame_name,
|
||||
annotations.get(frame_name, []),
|
||||
self.images_val,
|
||||
self.labels_val,
|
||||
class_id,
|
||||
min_area,
|
||||
min_bbox_size
|
||||
)
|
||||
stats['val_objects'] += count
|
||||
|
||||
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
|
||||
|
||||
# Create data.yaml
|
||||
self._create_data_yaml()
|
||||
|
||||
print(f"\nDataset created at {self.output_dir}")
|
||||
print(f" Train: {stats['train_frames']} images, {stats['train_objects']} objects")
|
||||
print(f" Val: {stats['val_frames']} images, {stats['val_objects']} objects")
|
||||
|
||||
return stats
|
||||
|
||||
def _process_frame(
|
||||
self,
|
||||
image_path: Path,
|
||||
frame_annotations: List[Dict],
|
||||
images_dir: Path,
|
||||
labels_dir: Path,
|
||||
class_id: int,
|
||||
min_area: int,
|
||||
min_bbox_size: float
|
||||
) -> int:
|
||||
"""Process a single frame and create YOLO label file."""
|
||||
if not image_path.exists():
|
||||
return 0
|
||||
|
||||
# Read image to get dimensions
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
return 0
|
||||
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Copy image to dataset
|
||||
dest_image = images_dir / image_path.name
|
||||
shutil.copy2(image_path, dest_image)
|
||||
|
||||
# Create label file
|
||||
label_name = image_path.stem + '.txt'
|
||||
label_path = labels_dir / label_name
|
||||
|
||||
labels = []
|
||||
for ann in frame_annotations:
|
||||
bbox = ann.get('bbox', [])
|
||||
area = ann.get('area', 0)
|
||||
|
||||
if area < min_area:
|
||||
continue
|
||||
|
||||
if len(bbox) == 4:
|
||||
x_center, y_center, w, h = self.bbox_xywh_to_yolo(
|
||||
bbox, width, height
|
||||
)
|
||||
|
||||
# Filter small boxes
|
||||
if w < min_bbox_size or h < min_bbox_size:
|
||||
continue
|
||||
|
||||
# YOLO format: class x_center y_center width height
|
||||
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
||||
labels.append(label_line)
|
||||
|
||||
# Write label file
|
||||
with open(label_path, 'w') as f:
|
||||
f.write('\n'.join(labels))
|
||||
|
||||
return len(labels)
|
||||
|
||||
def _create_data_yaml(self):
|
||||
"""Create YOLO data.yaml configuration file."""
|
||||
data_config = {
|
||||
'path': str(self.output_dir.absolute()),
|
||||
'train': 'images/train',
|
||||
'val': 'images/val',
|
||||
'names': {i: name for i, name in enumerate(self.class_names)},
|
||||
'nc': len(self.class_names)
|
||||
}
|
||||
|
||||
yaml_path = self.output_dir / 'data.yaml'
|
||||
with open(yaml_path, 'w') as f:
|
||||
yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"Created {yaml_path}")
|
||||
|
||||
def create_from_masks_directory(
|
||||
self,
|
||||
frames_dir: str,
|
||||
masks_dir: str,
|
||||
class_id: int = 0,
|
||||
val_split: float = 0.2,
|
||||
seed: int = 42
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Create YOLO dataset directly from mask images.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing original frames
|
||||
masks_dir: Directory containing mask images (one per object)
|
||||
class_id: Class ID for all objects
|
||||
val_split: Validation split ratio
|
||||
seed: Random seed
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
frames_path = Path(frames_dir)
|
||||
masks_path = Path(masks_dir)
|
||||
|
||||
# Find all frame images
|
||||
frame_files = sorted(
|
||||
list(frames_path.glob("*.jpg")) +
|
||||
list(frames_path.glob("*.png"))
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
random.shuffle(frame_files)
|
||||
|
||||
split_idx = int(len(frame_files) * (1 - val_split))
|
||||
train_files = frame_files[:split_idx]
|
||||
val_files = frame_files[split_idx:]
|
||||
|
||||
stats = {
|
||||
'total_frames': len(frame_files),
|
||||
'train_frames': len(train_files),
|
||||
'val_frames': len(val_files),
|
||||
'total_objects': 0,
|
||||
'train_objects': 0,
|
||||
'val_objects': 0
|
||||
}
|
||||
|
||||
# Process frames
|
||||
for frame_list, images_dir, labels_dir, key in [
|
||||
(train_files, self.images_train, self.labels_train, 'train_objects'),
|
||||
(val_files, self.images_val, self.labels_val, 'val_objects')
|
||||
]:
|
||||
for frame_file in tqdm(frame_list, desc=f"Processing {key.split('_')[0]}"):
|
||||
# Find corresponding masks
|
||||
frame_masks_dir = masks_path / frame_file.stem
|
||||
if frame_masks_dir.exists():
|
||||
mask_files = list(frame_masks_dir.glob("*.png"))
|
||||
else:
|
||||
mask_files = []
|
||||
|
||||
count = self._process_frame_with_masks(
|
||||
frame_file,
|
||||
mask_files,
|
||||
images_dir,
|
||||
labels_dir,
|
||||
class_id
|
||||
)
|
||||
stats[key] += count
|
||||
|
||||
stats['total_objects'] = stats['train_objects'] + stats['val_objects']
|
||||
self._create_data_yaml()
|
||||
|
||||
return stats
|
||||
|
||||
def _process_frame_with_masks(
|
||||
self,
|
||||
image_path: Path,
|
||||
mask_files: List[Path],
|
||||
images_dir: Path,
|
||||
labels_dir: Path,
|
||||
class_id: int
|
||||
) -> int:
|
||||
"""Process frame with mask files."""
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
return 0
|
||||
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Copy image
|
||||
shutil.copy2(image_path, images_dir / image_path.name)
|
||||
|
||||
labels = []
|
||||
for mask_file in mask_files:
|
||||
mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
continue
|
||||
|
||||
mask = mask > 127 # Convert to binary
|
||||
bbox = self.mask_to_bbox(mask, width, height)
|
||||
|
||||
if bbox is not None:
|
||||
x_center, y_center, w, h = bbox
|
||||
label_line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
||||
labels.append(label_line)
|
||||
|
||||
# Write labels
|
||||
label_path = labels_dir / (image_path.stem + '.txt')
|
||||
with open(label_path, 'w') as f:
|
||||
f.write('\n'.join(labels))
|
||||
|
||||
return len(labels)
|
||||
|
||||
|
||||
def export_for_kaggle(
|
||||
dataset_dir: str,
|
||||
output_zip: str,
|
||||
include_checkpoints: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Export YOLO dataset as zip file for Kaggle upload.
|
||||
|
||||
Args:
|
||||
dataset_dir: Path to YOLO dataset directory
|
||||
output_zip: Output zip file path
|
||||
include_checkpoints: Include model checkpoints if present
|
||||
|
||||
Returns:
|
||||
Path to created zip file
|
||||
"""
|
||||
import zipfile
|
||||
|
||||
dataset_path = Path(dataset_dir)
|
||||
|
||||
with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(dataset_path):
|
||||
# Skip checkpoints unless requested
|
||||
if not include_checkpoints and 'checkpoints' in root:
|
||||
continue
|
||||
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
arcname = file_path.relative_to(dataset_path.parent)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
print(f"Dataset exported to {output_zip}")
|
||||
return output_zip
|
||||
|
||||
|
||||
def validate_yolo_dataset(dataset_dir: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate YOLO dataset structure and contents.
|
||||
|
||||
Args:
|
||||
dataset_dir: Path to YOLO dataset
|
||||
|
||||
Returns:
|
||||
Validation results dictionary
|
||||
"""
|
||||
dataset_path = Path(dataset_dir)
|
||||
|
||||
results = {
|
||||
'valid': True,
|
||||
'errors': [],
|
||||
'warnings': [],
|
||||
'stats': {}
|
||||
}
|
||||
|
||||
# Check data.yaml
|
||||
yaml_path = dataset_path / 'data.yaml'
|
||||
if not yaml_path.exists():
|
||||
results['errors'].append("Missing data.yaml")
|
||||
results['valid'] = False
|
||||
else:
|
||||
with open(yaml_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
results['stats']['num_classes'] = config.get('nc', 0)
|
||||
results['stats']['class_names'] = config.get('names', {})
|
||||
|
||||
# Check directories
|
||||
for split in ['train', 'val']:
|
||||
images_dir = dataset_path / 'images' / split
|
||||
labels_dir = dataset_path / 'labels' / split
|
||||
|
||||
if not images_dir.exists():
|
||||
results['errors'].append(f"Missing images/{split} directory")
|
||||
results['valid'] = False
|
||||
continue
|
||||
|
||||
if not labels_dir.exists():
|
||||
results['errors'].append(f"Missing labels/{split} directory")
|
||||
results['valid'] = False
|
||||
continue
|
||||
|
||||
# Count files
|
||||
image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
|
||||
label_files = list(labels_dir.glob("*.txt"))
|
||||
|
||||
results['stats'][f'{split}_images'] = len(image_files)
|
||||
results['stats'][f'{split}_labels'] = len(label_files)
|
||||
|
||||
# Check matching
|
||||
image_stems = {f.stem for f in image_files}
|
||||
label_stems = {f.stem for f in label_files}
|
||||
|
||||
missing_labels = image_stems - label_stems
|
||||
if missing_labels:
|
||||
results['warnings'].append(
|
||||
f"{len(missing_labels)} images in {split} missing labels"
|
||||
)
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user