Files
2026-02-04 15:29:36 +07:00

344 lines
11 KiB
Python

"""
SAM2 annotation utilities for automatic mask generation.
"""
import os
import cv2
import json
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
class SAM2Annotator:
"""Automatic annotation using SAM2 model."""
# Available SAM2 model variants
MODEL_CONFIGS = {
'tiny': {
'config': 'sam2_hiera_t.yaml',
'checkpoint': 'sam2_hiera_tiny.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'
},
'small': {
'config': 'sam2_hiera_s.yaml',
'checkpoint': 'sam2_hiera_small.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'
},
'base_plus': {
'config': 'sam2_hiera_b+.yaml',
'checkpoint': 'sam2_hiera_base_plus.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'
},
'large': {
'config': 'sam2_hiera_l.yaml',
'checkpoint': 'sam2_hiera_large.pt',
'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
}
}
def __init__(
self,
model_size: str = 'large',
checkpoint_dir: str = './checkpoints',
device: Optional[str] = None
):
"""
Initialize SAM2 annotator.
Args:
model_size: Model variant ('tiny', 'small', 'base_plus', 'large')
checkpoint_dir: Directory to store model checkpoints
device: Device to run model on ('cuda', 'cpu', or None for auto)
"""
self.model_size = model_size
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.model = None
self.predictor = None
self.mask_generator = None
def download_checkpoint(self) -> str:
"""Download SAM2 checkpoint if not exists."""
config = self.MODEL_CONFIGS[self.model_size]
checkpoint_path = self.checkpoint_dir / config['checkpoint']
if not checkpoint_path.exists():
print(f"Downloading SAM2 {self.model_size} checkpoint...")
import urllib.request
urllib.request.urlretrieve(config['url'], str(checkpoint_path))
print(f"Downloaded to {checkpoint_path}")
else:
print(f"Checkpoint exists: {checkpoint_path}")
return str(checkpoint_path)
def load_model(self):
"""Load SAM2 model for inference."""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
checkpoint_path = self.download_checkpoint()
config = self.MODEL_CONFIGS[self.model_size]
print(f"Loading SAM2 {self.model_size} model on {self.device}...")
self.model = build_sam2(
config['config'],
checkpoint_path,
device=self.device
)
# Create predictor for interactive annotation
self.predictor = SAM2ImagePredictor(self.model)
# Create automatic mask generator
self.mask_generator = SAM2AutomaticMaskGenerator(
model=self.model,
points_per_side=32,
points_per_batch=64,
pred_iou_thresh=0.7,
stability_score_thresh=0.92,
stability_score_offset=1.0,
box_nms_thresh=0.7,
crop_n_layers=1,
crop_nms_thresh=0.7,
crop_overlap_ratio=0.34,
crop_n_points_downscale_factor=2,
min_mask_region_area=100
)
print("Model loaded successfully!")
def generate_masks_auto(
self,
image: np.ndarray,
min_area: int = 100,
max_area: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Automatically generate masks for all objects in image.
Args:
image: Input image (BGR format from OpenCV)
min_area: Minimum mask area in pixels
max_area: Maximum mask area in pixels
Returns:
List of mask dictionaries with keys:
- segmentation: Binary mask
- bbox: Bounding box [x, y, w, h]
- area: Mask area
- predicted_iou: Confidence score
"""
if self.mask_generator is None:
self.load_model()
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Generate masks
masks = self.mask_generator.generate(image_rgb)
# Filter by area
filtered_masks = []
for mask in masks:
area = mask['area']
if area >= min_area:
if max_area is None or area <= max_area:
filtered_masks.append(mask)
# Sort by area (largest first)
filtered_masks.sort(key=lambda x: x['area'], reverse=True)
return filtered_masks
def generate_masks_with_prompts(
self,
image: np.ndarray,
points: Optional[List[Tuple[int, int]]] = None,
point_labels: Optional[List[int]] = None,
boxes: Optional[List[List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate masks using point or box prompts.
Args:
image: Input image (BGR format)
points: List of (x, y) point coordinates
point_labels: Labels for points (1=foreground, 0=background)
boxes: List of boxes [x1, y1, x2, y2]
Returns:
Tuple of (masks, scores, logits)
"""
if self.predictor is None:
self.load_model()
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
# Prepare inputs
point_coords = np.array(points) if points else None
point_labels_arr = np.array(point_labels) if point_labels else None
box_arr = np.array(boxes) if boxes else None
masks, scores, logits = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels_arr,
box=box_arr,
multimask_output=True
)
return masks, scores, logits
def annotate_frames(
self,
frames_dir: str,
output_dir: str,
min_area: int = 100,
max_area: Optional[int] = None,
save_visualizations: bool = True
) -> Dict[str, List[Dict]]:
"""
Annotate all frames in a directory.
Args:
frames_dir: Directory containing frame images
output_dir: Directory to save annotations
min_area: Minimum object area
max_area: Maximum object area
save_visualizations: Save annotated visualization images
Returns:
Dictionary mapping frame names to annotation lists
"""
frames_path = Path(frames_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
if save_visualizations:
vis_path = output_path / 'visualizations'
vis_path.mkdir(exist_ok=True)
# Find all frame images
frame_files = sorted(
list(frames_path.glob("*.jpg")) +
list(frames_path.glob("*.png"))
)
if not frame_files:
raise ValueError(f"No frames found in {frames_dir}")
print(f"Found {len(frame_files)} frames to annotate")
all_annotations = {}
for frame_file in tqdm(frame_files, desc="Annotating frames"):
image = cv2.imread(str(frame_file))
if image is None:
continue
# Generate masks
masks = self.generate_masks_auto(image, min_area, max_area)
# Convert to serializable format
annotations = []
for i, mask_data in enumerate(masks):
ann = {
'id': i,
'bbox': mask_data['bbox'], # [x, y, w, h]
'area': int(mask_data['area']),
'predicted_iou': float(mask_data['predicted_iou']),
'stability_score': float(mask_data['stability_score'])
}
annotations.append(ann)
# Save mask as separate file if needed
mask_filename = f"{frame_file.stem}_mask_{i:03d}.png"
mask_path = output_path / 'masks' / frame_file.stem
mask_path.mkdir(parents=True, exist_ok=True)
cv2.imwrite(
str(mask_path / mask_filename),
mask_data['segmentation'].astype(np.uint8) * 255
)
all_annotations[frame_file.name] = annotations
# Save visualization
if save_visualizations:
vis_image = self._visualize_masks(image, masks)
cv2.imwrite(str(vis_path / frame_file.name), vis_image)
# Save annotations JSON
annotations_file = output_path / 'annotations.json'
with open(annotations_file, 'w') as f:
json.dump(all_annotations, f, indent=2)
print(f"Annotations saved to {annotations_file}")
return all_annotations
def _visualize_masks(
self,
image: np.ndarray,
masks: List[Dict],
alpha: float = 0.5
) -> np.ndarray:
"""Create visualization of masks overlaid on image."""
vis_image = image.copy()
for mask_data in masks:
mask = mask_data['segmentation']
color = np.random.randint(0, 255, 3).tolist()
# Create colored overlay
overlay = vis_image.copy()
overlay[mask] = color
vis_image = cv2.addWeighted(vis_image, 1 - alpha, overlay, alpha, 0)
# Draw bounding box
x, y, w, h = mask_data['bbox']
cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
return vis_image
def load_sam2_video_predictor(
model_size: str = 'large',
checkpoint_dir: str = './checkpoints',
device: str = 'cuda'
):
"""
Load SAM2 video predictor for tracking objects across frames.
Args:
model_size: Model size variant
checkpoint_dir: Checkpoint directory
device: Device to use
Returns:
SAM2VideoPredictor instance
"""
from sam2.build_sam import build_sam2_video_predictor
annotator = SAM2Annotator(model_size, checkpoint_dir, device)
checkpoint_path = annotator.download_checkpoint()
config = SAM2Annotator.MODEL_CONFIGS[model_size]
predictor = build_sam2_video_predictor(
config['config'],
checkpoint_path,
device=device
)
return predictor