add sam2 yolo auto annotation
This commit is contained in:
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
SAM2 annotation utilities for automatic mask generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class SAM2Annotator:
|
||||
"""Automatic annotation using SAM2 model."""
|
||||
|
||||
# Available SAM2 model variants
|
||||
MODEL_CONFIGS = {
|
||||
'tiny': {
|
||||
'config': 'sam2_hiera_t.yaml',
|
||||
'checkpoint': 'sam2_hiera_tiny.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'
|
||||
},
|
||||
'small': {
|
||||
'config': 'sam2_hiera_s.yaml',
|
||||
'checkpoint': 'sam2_hiera_small.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'
|
||||
},
|
||||
'base_plus': {
|
||||
'config': 'sam2_hiera_b+.yaml',
|
||||
'checkpoint': 'sam2_hiera_base_plus.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'
|
||||
},
|
||||
'large': {
|
||||
'config': 'sam2_hiera_l.yaml',
|
||||
'checkpoint': 'sam2_hiera_large.pt',
|
||||
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_size: str = 'large',
|
||||
checkpoint_dir: str = './checkpoints',
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize SAM2 annotator.
|
||||
|
||||
Args:
|
||||
model_size: Model variant ('tiny', 'small', 'base_plus', 'large')
|
||||
checkpoint_dir: Directory to store model checkpoints
|
||||
device: Device to run model on ('cuda', 'cpu', or None for auto)
|
||||
"""
|
||||
self.model_size = model_size
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
self.model = None
|
||||
self.predictor = None
|
||||
self.mask_generator = None
|
||||
|
||||
def download_checkpoint(self) -> str:
|
||||
"""Download SAM2 checkpoint if not exists."""
|
||||
config = self.MODEL_CONFIGS[self.model_size]
|
||||
checkpoint_path = self.checkpoint_dir / config['checkpoint']
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"Downloading SAM2 {self.model_size} checkpoint...")
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(config['url'], str(checkpoint_path))
|
||||
print(f"Downloaded to {checkpoint_path}")
|
||||
else:
|
||||
print(f"Checkpoint exists: {checkpoint_path}")
|
||||
|
||||
return str(checkpoint_path)
|
||||
|
||||
def load_model(self):
|
||||
"""Load SAM2 model for inference."""
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||
|
||||
checkpoint_path = self.download_checkpoint()
|
||||
config = self.MODEL_CONFIGS[self.model_size]
|
||||
|
||||
print(f"Loading SAM2 {self.model_size} model on {self.device}...")
|
||||
|
||||
self.model = build_sam2(
|
||||
config['config'],
|
||||
checkpoint_path,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Create predictor for interactive annotation
|
||||
self.predictor = SAM2ImagePredictor(self.model)
|
||||
|
||||
# Create automatic mask generator
|
||||
self.mask_generator = SAM2AutomaticMaskGenerator(
|
||||
model=self.model,
|
||||
points_per_side=32,
|
||||
points_per_batch=64,
|
||||
pred_iou_thresh=0.7,
|
||||
stability_score_thresh=0.92,
|
||||
stability_score_offset=1.0,
|
||||
box_nms_thresh=0.7,
|
||||
crop_n_layers=1,
|
||||
crop_nms_thresh=0.7,
|
||||
crop_overlap_ratio=0.34,
|
||||
crop_n_points_downscale_factor=2,
|
||||
min_mask_region_area=100
|
||||
)
|
||||
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate_masks_auto(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
min_area: int = 100,
|
||||
max_area: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Automatically generate masks for all objects in image.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format from OpenCV)
|
||||
min_area: Minimum mask area in pixels
|
||||
max_area: Maximum mask area in pixels
|
||||
|
||||
Returns:
|
||||
List of mask dictionaries with keys:
|
||||
- segmentation: Binary mask
|
||||
- bbox: Bounding box [x, y, w, h]
|
||||
- area: Mask area
|
||||
- predicted_iou: Confidence score
|
||||
"""
|
||||
if self.mask_generator is None:
|
||||
self.load_model()
|
||||
|
||||
# Convert BGR to RGB
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Generate masks
|
||||
masks = self.mask_generator.generate(image_rgb)
|
||||
|
||||
# Filter by area
|
||||
filtered_masks = []
|
||||
for mask in masks:
|
||||
area = mask['area']
|
||||
if area >= min_area:
|
||||
if max_area is None or area <= max_area:
|
||||
filtered_masks.append(mask)
|
||||
|
||||
# Sort by area (largest first)
|
||||
filtered_masks.sort(key=lambda x: x['area'], reverse=True)
|
||||
|
||||
return filtered_masks
|
||||
|
||||
def generate_masks_with_prompts(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
points: Optional[List[Tuple[int, int]]] = None,
|
||||
point_labels: Optional[List[int]] = None,
|
||||
boxes: Optional[List[List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Generate masks using point or box prompts.
|
||||
|
||||
Args:
|
||||
image: Input image (BGR format)
|
||||
points: List of (x, y) point coordinates
|
||||
point_labels: Labels for points (1=foreground, 0=background)
|
||||
boxes: List of boxes [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
Tuple of (masks, scores, logits)
|
||||
"""
|
||||
if self.predictor is None:
|
||||
self.load_model()
|
||||
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.predictor.set_image(image_rgb)
|
||||
|
||||
# Prepare inputs
|
||||
point_coords = np.array(points) if points else None
|
||||
point_labels_arr = np.array(point_labels) if point_labels else None
|
||||
box_arr = np.array(boxes) if boxes else None
|
||||
|
||||
masks, scores, logits = self.predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels_arr,
|
||||
box=box_arr,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
return masks, scores, logits
|
||||
|
||||
def annotate_frames(
|
||||
self,
|
||||
frames_dir: str,
|
||||
output_dir: str,
|
||||
min_area: int = 100,
|
||||
max_area: Optional[int] = None,
|
||||
save_visualizations: bool = True
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Annotate all frames in a directory.
|
||||
|
||||
Args:
|
||||
frames_dir: Directory containing frame images
|
||||
output_dir: Directory to save annotations
|
||||
min_area: Minimum object area
|
||||
max_area: Maximum object area
|
||||
save_visualizations: Save annotated visualization images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame names to annotation lists
|
||||
"""
|
||||
frames_path = Path(frames_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_visualizations:
|
||||
vis_path = output_path / 'visualizations'
|
||||
vis_path.mkdir(exist_ok=True)
|
||||
|
||||
# Find all frame images
|
||||
frame_files = sorted(
|
||||
list(frames_path.glob("*.jpg")) +
|
||||
list(frames_path.glob("*.png"))
|
||||
)
|
||||
|
||||
if not frame_files:
|
||||
raise ValueError(f"No frames found in {frames_dir}")
|
||||
|
||||
print(f"Found {len(frame_files)} frames to annotate")
|
||||
|
||||
all_annotations = {}
|
||||
|
||||
for frame_file in tqdm(frame_files, desc="Annotating frames"):
|
||||
image = cv2.imread(str(frame_file))
|
||||
if image is None:
|
||||
continue
|
||||
|
||||
# Generate masks
|
||||
masks = self.generate_masks_auto(image, min_area, max_area)
|
||||
|
||||
# Convert to serializable format
|
||||
annotations = []
|
||||
for i, mask_data in enumerate(masks):
|
||||
ann = {
|
||||
'id': i,
|
||||
'bbox': mask_data['bbox'], # [x, y, w, h]
|
||||
'area': int(mask_data['area']),
|
||||
'predicted_iou': float(mask_data['predicted_iou']),
|
||||
'stability_score': float(mask_data['stability_score'])
|
||||
}
|
||||
annotations.append(ann)
|
||||
|
||||
# Save mask as separate file if needed
|
||||
mask_filename = f"{frame_file.stem}_mask_{i:03d}.png"
|
||||
mask_path = output_path / 'masks' / frame_file.stem
|
||||
mask_path.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(
|
||||
str(mask_path / mask_filename),
|
||||
mask_data['segmentation'].astype(np.uint8) * 255
|
||||
)
|
||||
|
||||
all_annotations[frame_file.name] = annotations
|
||||
|
||||
# Save visualization
|
||||
if save_visualizations:
|
||||
vis_image = self._visualize_masks(image, masks)
|
||||
cv2.imwrite(str(vis_path / frame_file.name), vis_image)
|
||||
|
||||
# Save annotations JSON
|
||||
annotations_file = output_path / 'annotations.json'
|
||||
with open(annotations_file, 'w') as f:
|
||||
json.dump(all_annotations, f, indent=2)
|
||||
|
||||
print(f"Annotations saved to {annotations_file}")
|
||||
return all_annotations
|
||||
|
||||
def _visualize_masks(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
masks: List[Dict],
|
||||
alpha: float = 0.5
|
||||
) -> np.ndarray:
|
||||
"""Create visualization of masks overlaid on image."""
|
||||
vis_image = image.copy()
|
||||
|
||||
for mask_data in masks:
|
||||
mask = mask_data['segmentation']
|
||||
color = np.random.randint(0, 255, 3).tolist()
|
||||
|
||||
# Create colored overlay
|
||||
overlay = vis_image.copy()
|
||||
overlay[mask] = color
|
||||
vis_image = cv2.addWeighted(vis_image, 1 - alpha, overlay, alpha, 0)
|
||||
|
||||
# Draw bounding box
|
||||
x, y, w, h = mask_data['bbox']
|
||||
cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
|
||||
|
||||
return vis_image
|
||||
|
||||
|
||||
def load_sam2_video_predictor(
|
||||
model_size: str = 'large',
|
||||
checkpoint_dir: str = './checkpoints',
|
||||
device: str = 'cuda'
|
||||
):
|
||||
"""
|
||||
Load SAM2 video predictor for tracking objects across frames.
|
||||
|
||||
Args:
|
||||
model_size: Model size variant
|
||||
checkpoint_dir: Checkpoint directory
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
SAM2VideoPredictor instance
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
annotator = SAM2Annotator(model_size, checkpoint_dir, device)
|
||||
checkpoint_path = annotator.download_checkpoint()
|
||||
config = SAM2Annotator.MODEL_CONFIGS[model_size]
|
||||
|
||||
predictor = build_sam2_video_predictor(
|
||||
config['config'],
|
||||
checkpoint_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
return predictor
|
||||
Reference in New Issue
Block a user