344 lines
11 KiB
Python
344 lines
11 KiB
Python
"""
|
|
SAM2 annotation utilities for automatic mask generation.
|
|
"""
|
|
|
|
import os
|
|
import cv2
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from tqdm import tqdm
|
|
|
|
|
|
class SAM2Annotator:
|
|
"""Automatic annotation using SAM2 model."""
|
|
|
|
# Available SAM2 model variants
|
|
MODEL_CONFIGS = {
|
|
'tiny': {
|
|
'config': 'sam2_hiera_t.yaml',
|
|
'checkpoint': 'sam2_hiera_tiny.pt',
|
|
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'
|
|
},
|
|
'small': {
|
|
'config': 'sam2_hiera_s.yaml',
|
|
'checkpoint': 'sam2_hiera_small.pt',
|
|
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'
|
|
},
|
|
'base_plus': {
|
|
'config': 'sam2_hiera_b+.yaml',
|
|
'checkpoint': 'sam2_hiera_base_plus.pt',
|
|
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'
|
|
},
|
|
'large': {
|
|
'config': 'sam2_hiera_l.yaml',
|
|
'checkpoint': 'sam2_hiera_large.pt',
|
|
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
|
|
}
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
model_size: str = 'large',
|
|
checkpoint_dir: str = './checkpoints',
|
|
device: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize SAM2 annotator.
|
|
|
|
Args:
|
|
model_size: Model variant ('tiny', 'small', 'base_plus', 'large')
|
|
checkpoint_dir: Directory to store model checkpoints
|
|
device: Device to run model on ('cuda', 'cpu', or None for auto)
|
|
"""
|
|
self.model_size = model_size
|
|
self.checkpoint_dir = Path(checkpoint_dir)
|
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if device is None:
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
else:
|
|
self.device = device
|
|
|
|
self.model = None
|
|
self.predictor = None
|
|
self.mask_generator = None
|
|
|
|
def download_checkpoint(self) -> str:
|
|
"""Download SAM2 checkpoint if not exists."""
|
|
config = self.MODEL_CONFIGS[self.model_size]
|
|
checkpoint_path = self.checkpoint_dir / config['checkpoint']
|
|
|
|
if not checkpoint_path.exists():
|
|
print(f"Downloading SAM2 {self.model_size} checkpoint...")
|
|
import urllib.request
|
|
urllib.request.urlretrieve(config['url'], str(checkpoint_path))
|
|
print(f"Downloaded to {checkpoint_path}")
|
|
else:
|
|
print(f"Checkpoint exists: {checkpoint_path}")
|
|
|
|
return str(checkpoint_path)
|
|
|
|
def load_model(self):
|
|
"""Load SAM2 model for inference."""
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
checkpoint_path = self.download_checkpoint()
|
|
config = self.MODEL_CONFIGS[self.model_size]
|
|
|
|
print(f"Loading SAM2 {self.model_size} model on {self.device}...")
|
|
|
|
self.model = build_sam2(
|
|
config['config'],
|
|
checkpoint_path,
|
|
device=self.device
|
|
)
|
|
|
|
# Create predictor for interactive annotation
|
|
self.predictor = SAM2ImagePredictor(self.model)
|
|
|
|
# Create automatic mask generator
|
|
self.mask_generator = SAM2AutomaticMaskGenerator(
|
|
model=self.model,
|
|
points_per_side=32,
|
|
points_per_batch=64,
|
|
pred_iou_thresh=0.7,
|
|
stability_score_thresh=0.92,
|
|
stability_score_offset=1.0,
|
|
box_nms_thresh=0.7,
|
|
crop_n_layers=1,
|
|
crop_nms_thresh=0.7,
|
|
crop_overlap_ratio=0.34,
|
|
crop_n_points_downscale_factor=2,
|
|
min_mask_region_area=100
|
|
)
|
|
|
|
print("Model loaded successfully!")
|
|
|
|
def generate_masks_auto(
|
|
self,
|
|
image: np.ndarray,
|
|
min_area: int = 100,
|
|
max_area: Optional[int] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Automatically generate masks for all objects in image.
|
|
|
|
Args:
|
|
image: Input image (BGR format from OpenCV)
|
|
min_area: Minimum mask area in pixels
|
|
max_area: Maximum mask area in pixels
|
|
|
|
Returns:
|
|
List of mask dictionaries with keys:
|
|
- segmentation: Binary mask
|
|
- bbox: Bounding box [x, y, w, h]
|
|
- area: Mask area
|
|
- predicted_iou: Confidence score
|
|
"""
|
|
if self.mask_generator is None:
|
|
self.load_model()
|
|
|
|
# Convert BGR to RGB
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
# Generate masks
|
|
masks = self.mask_generator.generate(image_rgb)
|
|
|
|
# Filter by area
|
|
filtered_masks = []
|
|
for mask in masks:
|
|
area = mask['area']
|
|
if area >= min_area:
|
|
if max_area is None or area <= max_area:
|
|
filtered_masks.append(mask)
|
|
|
|
# Sort by area (largest first)
|
|
filtered_masks.sort(key=lambda x: x['area'], reverse=True)
|
|
|
|
return filtered_masks
|
|
|
|
def generate_masks_with_prompts(
|
|
self,
|
|
image: np.ndarray,
|
|
points: Optional[List[Tuple[int, int]]] = None,
|
|
point_labels: Optional[List[int]] = None,
|
|
boxes: Optional[List[List[int]]] = None
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""
|
|
Generate masks using point or box prompts.
|
|
|
|
Args:
|
|
image: Input image (BGR format)
|
|
points: List of (x, y) point coordinates
|
|
point_labels: Labels for points (1=foreground, 0=background)
|
|
boxes: List of boxes [x1, y1, x2, y2]
|
|
|
|
Returns:
|
|
Tuple of (masks, scores, logits)
|
|
"""
|
|
if self.predictor is None:
|
|
self.load_model()
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
self.predictor.set_image(image_rgb)
|
|
|
|
# Prepare inputs
|
|
point_coords = np.array(points) if points else None
|
|
point_labels_arr = np.array(point_labels) if point_labels else None
|
|
box_arr = np.array(boxes) if boxes else None
|
|
|
|
masks, scores, logits = self.predictor.predict(
|
|
point_coords=point_coords,
|
|
point_labels=point_labels_arr,
|
|
box=box_arr,
|
|
multimask_output=True
|
|
)
|
|
|
|
return masks, scores, logits
|
|
|
|
def annotate_frames(
|
|
self,
|
|
frames_dir: str,
|
|
output_dir: str,
|
|
min_area: int = 100,
|
|
max_area: Optional[int] = None,
|
|
save_visualizations: bool = True
|
|
) -> Dict[str, List[Dict]]:
|
|
"""
|
|
Annotate all frames in a directory.
|
|
|
|
Args:
|
|
frames_dir: Directory containing frame images
|
|
output_dir: Directory to save annotations
|
|
min_area: Minimum object area
|
|
max_area: Maximum object area
|
|
save_visualizations: Save annotated visualization images
|
|
|
|
Returns:
|
|
Dictionary mapping frame names to annotation lists
|
|
"""
|
|
frames_path = Path(frames_dir)
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
if save_visualizations:
|
|
vis_path = output_path / 'visualizations'
|
|
vis_path.mkdir(exist_ok=True)
|
|
|
|
# Find all frame images
|
|
frame_files = sorted(
|
|
list(frames_path.glob("*.jpg")) +
|
|
list(frames_path.glob("*.png"))
|
|
)
|
|
|
|
if not frame_files:
|
|
raise ValueError(f"No frames found in {frames_dir}")
|
|
|
|
print(f"Found {len(frame_files)} frames to annotate")
|
|
|
|
all_annotations = {}
|
|
|
|
for frame_file in tqdm(frame_files, desc="Annotating frames"):
|
|
image = cv2.imread(str(frame_file))
|
|
if image is None:
|
|
continue
|
|
|
|
# Generate masks
|
|
masks = self.generate_masks_auto(image, min_area, max_area)
|
|
|
|
# Convert to serializable format
|
|
annotations = []
|
|
for i, mask_data in enumerate(masks):
|
|
ann = {
|
|
'id': i,
|
|
'bbox': mask_data['bbox'], # [x, y, w, h]
|
|
'area': int(mask_data['area']),
|
|
'predicted_iou': float(mask_data['predicted_iou']),
|
|
'stability_score': float(mask_data['stability_score'])
|
|
}
|
|
annotations.append(ann)
|
|
|
|
# Save mask as separate file if needed
|
|
mask_filename = f"{frame_file.stem}_mask_{i:03d}.png"
|
|
mask_path = output_path / 'masks' / frame_file.stem
|
|
mask_path.mkdir(parents=True, exist_ok=True)
|
|
cv2.imwrite(
|
|
str(mask_path / mask_filename),
|
|
mask_data['segmentation'].astype(np.uint8) * 255
|
|
)
|
|
|
|
all_annotations[frame_file.name] = annotations
|
|
|
|
# Save visualization
|
|
if save_visualizations:
|
|
vis_image = self._visualize_masks(image, masks)
|
|
cv2.imwrite(str(vis_path / frame_file.name), vis_image)
|
|
|
|
# Save annotations JSON
|
|
annotations_file = output_path / 'annotations.json'
|
|
with open(annotations_file, 'w') as f:
|
|
json.dump(all_annotations, f, indent=2)
|
|
|
|
print(f"Annotations saved to {annotations_file}")
|
|
return all_annotations
|
|
|
|
def _visualize_masks(
|
|
self,
|
|
image: np.ndarray,
|
|
masks: List[Dict],
|
|
alpha: float = 0.5
|
|
) -> np.ndarray:
|
|
"""Create visualization of masks overlaid on image."""
|
|
vis_image = image.copy()
|
|
|
|
for mask_data in masks:
|
|
mask = mask_data['segmentation']
|
|
color = np.random.randint(0, 255, 3).tolist()
|
|
|
|
# Create colored overlay
|
|
overlay = vis_image.copy()
|
|
overlay[mask] = color
|
|
vis_image = cv2.addWeighted(vis_image, 1 - alpha, overlay, alpha, 0)
|
|
|
|
# Draw bounding box
|
|
x, y, w, h = mask_data['bbox']
|
|
cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
|
|
|
|
return vis_image
|
|
|
|
|
|
def load_sam2_video_predictor(
|
|
model_size: str = 'large',
|
|
checkpoint_dir: str = './checkpoints',
|
|
device: str = 'cuda'
|
|
):
|
|
"""
|
|
Load SAM2 video predictor for tracking objects across frames.
|
|
|
|
Args:
|
|
model_size: Model size variant
|
|
checkpoint_dir: Checkpoint directory
|
|
device: Device to use
|
|
|
|
Returns:
|
|
SAM2VideoPredictor instance
|
|
"""
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
|
|
annotator = SAM2Annotator(model_size, checkpoint_dir, device)
|
|
checkpoint_path = annotator.download_checkpoint()
|
|
config = SAM2Annotator.MODEL_CONFIGS[model_size]
|
|
|
|
predictor = build_sam2_video_predictor(
|
|
config['config'],
|
|
checkpoint_path,
|
|
device=device
|
|
)
|
|
|
|
return predictor
|