Files
dataset-yolo-script/sam2-cpu/notebooks/01_sam2_video_annotation.ipynb
2026-02-04 15:29:36 +07:00

688 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SAM2 Video Annotation Pipeline\n",
"\n",
"This notebook sets up SAM2 (Segment Anything Model 2) for automatic object annotation from video.\n",
"\n",
"## Features\n",
"- Install SAM2 and dependencies on Kaggle\n",
"- Download pretrained model weights\n",
"- Extract frames from video\n",
"- Auto-generate masks for all objects\n",
"- Save annotations for YOLO conversion\n",
"\n",
"**Platform:** Kaggle GPU (P100/T4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup Environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check GPU availability\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install SAM2 and dependencies\n",
"!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git\n",
"!pip install -q opencv-python-headless supervision tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import cv2\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from tqdm.notebook import tqdm\n",
"from IPython.display import display, Image as IPImage\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Download SAM2 Pretrained Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration\n",
"MODEL_SIZE = 'large' # Options: 'tiny', 'small', 'base_plus', 'large'\n",
"CHECKPOINT_DIR = './checkpoints'\n",
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"# Model configurations\n",
"MODEL_CONFIGS = {\n",
" 'tiny': {\n",
" 'config': 'sam2_hiera_t.yaml',\n",
" 'checkpoint': 'sam2_hiera_tiny.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'\n",
" },\n",
" 'small': {\n",
" 'config': 'sam2_hiera_s.yaml',\n",
" 'checkpoint': 'sam2_hiera_small.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'\n",
" },\n",
" 'base_plus': {\n",
" 'config': 'sam2_hiera_b+.yaml',\n",
" 'checkpoint': 'sam2_hiera_base_plus.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'\n",
" },\n",
" 'large': {\n",
" 'config': 'sam2_hiera_l.yaml',\n",
" 'checkpoint': 'sam2_hiera_large.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'\n",
" }\n",
"}\n",
"\n",
"print(f\"Using model: SAM2 {MODEL_SIZE}\")\n",
"print(f\"Device: {DEVICE}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Download checkpoint\n",
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
"\n",
"config = MODEL_CONFIGS[MODEL_SIZE]\n",
"checkpoint_path = os.path.join(CHECKPOINT_DIR, config['checkpoint'])\n",
"\n",
"if not os.path.exists(checkpoint_path):\n",
" print(f\"Downloading {config['checkpoint']}...\")\n",
" !wget -q -O {checkpoint_path} {config['url']}\n",
" print(\"Download complete!\")\n",
"else:\n",
" print(f\"Checkpoint exists: {checkpoint_path}\")\n",
"\n",
"print(f\"Checkpoint size: {os.path.getsize(checkpoint_path) / 1024 / 1024:.1f} MB\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load SAM2 Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sam2.build_sam import build_sam2\n",
"from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
"\n",
"print(\"Loading SAM2 model...\")\n",
"\n",
"# Build model\n",
"sam2_model = build_sam2(\n",
" config['config'],\n",
" checkpoint_path,\n",
" device=DEVICE\n",
")\n",
"\n",
"# Create automatic mask generator\n",
"mask_generator = SAM2AutomaticMaskGenerator(\n",
" model=sam2_model,\n",
" points_per_side=32, # Grid density for point prompts\n",
" points_per_batch=64, # Batch size for processing\n",
" pred_iou_thresh=0.7, # IoU threshold for predictions\n",
" stability_score_thresh=0.92, # Stability threshold\n",
" stability_score_offset=1.0,\n",
" box_nms_thresh=0.7, # NMS threshold for boxes\n",
" crop_n_layers=1, # Crop layers for multi-scale\n",
" crop_nms_thresh=0.7,\n",
" crop_overlap_ratio=0.34,\n",
" crop_n_points_downscale_factor=2,\n",
" min_mask_region_area=100 # Minimum mask area in pixels\n",
")\n",
"\n",
"# Create predictor for interactive mode\n",
"predictor = SAM2ImagePredictor(sam2_model)\n",
"\n",
"print(\"Model loaded successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Video Frame Extraction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration - UPDATE THESE PATHS\n",
"VIDEO_PATH = '/kaggle/input/your-video-dataset/video.mp4' # Change to your video path\n",
"FRAMES_DIR = './frames'\n",
"OUTPUT_DIR = './annotations'\n",
"\n",
"# Frame extraction settings\n",
"SAMPLE_FPS = 2 # Extract 2 frames per second (adjust based on video)\n",
"MAX_FRAMES = 500 # Maximum frames to extract (None for all)\n",
"START_TIME = 0 # Start time in seconds\n",
"END_TIME = None # End time in seconds (None for entire video)\n",
"RESIZE = None # Resize frames: (width, height) or None\n",
"\n",
"os.makedirs(FRAMES_DIR, exist_ok=True)\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_video_info(video_path):\n",
" \"\"\"Get video information.\"\"\"\n",
" cap = cv2.VideoCapture(video_path)\n",
" if not cap.isOpened():\n",
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
" \n",
" info = {\n",
" 'fps': cap.get(cv2.CAP_PROP_FPS),\n",
" 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),\n",
" 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n",
" 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
" }\n",
" info['duration'] = info['frame_count'] / info['fps']\n",
" cap.release()\n",
" return info\n",
"\n",
"# Display video info\n",
"if os.path.exists(VIDEO_PATH):\n",
" video_info = get_video_info(VIDEO_PATH)\n",
" print(\"Video Information:\")\n",
" print(f\" Resolution: {video_info['width']}x{video_info['height']}\")\n",
" print(f\" FPS: {video_info['fps']:.2f}\")\n",
" print(f\" Duration: {video_info['duration']:.2f}s\")\n",
" print(f\" Total frames: {video_info['frame_count']}\")\n",
"else:\n",
" print(f\"Video not found: {VIDEO_PATH}\")\n",
" print(\"Please upload your video or update VIDEO_PATH\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def extract_frames(video_path, output_dir, sample_fps=None, max_frames=None, \n",
" start_time=0, end_time=None, resize=None):\n",
" \"\"\"Extract frames from video.\"\"\"\n",
" cap = cv2.VideoCapture(video_path)\n",
" if not cap.isOpened():\n",
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
" \n",
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
" \n",
" # Calculate frame interval\n",
" if sample_fps and sample_fps < fps:\n",
" frame_interval = int(fps / sample_fps)\n",
" else:\n",
" frame_interval = 1\n",
" \n",
" # Calculate frame range\n",
" start_frame = int(start_time * fps)\n",
" end_frame = int(end_time * fps) if end_time else frame_count\n",
" end_frame = min(end_frame, frame_count)\n",
" \n",
" cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n",
" \n",
" saved_paths = []\n",
" frame_idx = start_frame\n",
" extracted = 0\n",
" \n",
" total_to_extract = min(\n",
" (end_frame - start_frame) // frame_interval,\n",
" max_frames or float('inf')\n",
" )\n",
" \n",
" pbar = tqdm(total=int(total_to_extract), desc=\"Extracting frames\")\n",
" \n",
" while frame_idx < end_frame:\n",
" if max_frames and extracted >= max_frames:\n",
" break\n",
" \n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
" \n",
" if (frame_idx - start_frame) % frame_interval == 0:\n",
" if resize:\n",
" frame = cv2.resize(frame, resize)\n",
" \n",
" frame_name = f\"frame_{frame_idx:06d}.jpg\"\n",
" frame_path = os.path.join(output_dir, frame_name)\n",
" cv2.imwrite(frame_path, frame)\n",
" saved_paths.append(frame_path)\n",
" extracted += 1\n",
" pbar.update(1)\n",
" \n",
" frame_idx += 1\n",
" \n",
" pbar.close()\n",
" cap.release()\n",
" \n",
" print(f\"Extracted {len(saved_paths)} frames to {output_dir}\")\n",
" return saved_paths\n",
"\n",
"# Extract frames\n",
"if os.path.exists(VIDEO_PATH):\n",
" frame_paths = extract_frames(\n",
" VIDEO_PATH,\n",
" FRAMES_DIR,\n",
" sample_fps=SAMPLE_FPS,\n",
" max_frames=MAX_FRAMES,\n",
" start_time=START_TIME,\n",
" end_time=END_TIME,\n",
" resize=RESIZE\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preview extracted frames\n",
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))[:6]\n",
"\n",
"if frame_files:\n",
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
" for ax, frame_file in zip(axes.flat, frame_files):\n",
" img = cv2.imread(str(frame_file))\n",
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
" ax.imshow(img)\n",
" ax.set_title(frame_file.name)\n",
" ax.axis('off')\n",
" plt.tight_layout()\n",
" plt.show()\n",
"else:\n",
" print(\"No frames found. Please extract frames first.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Automatic Mask Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Annotation settings\n",
"MIN_MASK_AREA = 500 # Minimum mask area in pixels\n",
"MAX_MASK_AREA = None # Maximum mask area (None for no limit)\n",
"SAVE_MASKS = True # Save individual mask images\n",
"SAVE_VIS = True # Save visualization images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def visualize_masks(image, masks, alpha=0.5):\n",
" \"\"\"Create visualization with masks overlaid.\"\"\"\n",
" vis = image.copy()\n",
" \n",
" for mask_data in masks:\n",
" mask = mask_data['segmentation']\n",
" color = np.random.randint(0, 255, 3).tolist()\n",
" \n",
" overlay = vis.copy()\n",
" overlay[mask] = color\n",
" vis = cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)\n",
" \n",
" # Draw bbox\n",
" x, y, w, h = mask_data['bbox']\n",
" cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)\n",
" \n",
" return vis\n",
"\n",
"def annotate_single_image(image_path, min_area=100, max_area=None):\n",
" \"\"\"Generate masks for a single image.\"\"\"\n",
" image = cv2.imread(image_path)\n",
" image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" \n",
" # Generate masks\n",
" masks = mask_generator.generate(image_rgb)\n",
" \n",
" # Filter by area\n",
" filtered = []\n",
" for m in masks:\n",
" area = m['area']\n",
" if area >= min_area:\n",
" if max_area is None or area <= max_area:\n",
" filtered.append(m)\n",
" \n",
" # Sort by area (largest first)\n",
" filtered.sort(key=lambda x: x['area'], reverse=True)\n",
" \n",
" return filtered, image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test on a single frame\n",
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))\n",
"\n",
"if frame_files:\n",
" test_frame = str(frame_files[0])\n",
" print(f\"Testing on: {test_frame}\")\n",
" \n",
" masks, image = annotate_single_image(\n",
" test_frame,\n",
" min_area=MIN_MASK_AREA,\n",
" max_area=MAX_MASK_AREA\n",
" )\n",
" \n",
" print(f\"Found {len(masks)} objects\")\n",
" \n",
" # Visualize\n",
" vis = visualize_masks(image, masks)\n",
" vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)\n",
" \n",
" fig, axes = plt.subplots(1, 2, figsize=(15, 7))\n",
" axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
" axes[0].set_title('Original')\n",
" axes[0].axis('off')\n",
" \n",
" axes[1].imshow(vis_rgb)\n",
" axes[1].set_title(f'SAM2 Annotations ({len(masks)} objects)')\n",
" axes[1].axis('off')\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
" \n",
" # Show bbox details\n",
" print(\"\\nDetected objects:\")\n",
" for i, m in enumerate(masks[:10]):\n",
" print(f\" {i}: bbox={m['bbox']}, area={m['area']}, iou={m['predicted_iou']:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Process all frames\n",
"def annotate_all_frames(frames_dir, output_dir, min_area=100, max_area=None,\n",
" save_masks=True, save_vis=True):\n",
" \"\"\"Annotate all frames and save results.\"\"\"\n",
" frames_path = Path(frames_dir)\n",
" output_path = Path(output_dir)\n",
" output_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" if save_masks:\n",
" masks_dir = output_path / 'masks'\n",
" masks_dir.mkdir(exist_ok=True)\n",
" \n",
" if save_vis:\n",
" vis_dir = output_path / 'visualizations'\n",
" vis_dir.mkdir(exist_ok=True)\n",
" \n",
" frame_files = sorted(frames_path.glob(\"*.jpg\"))\n",
" print(f\"Processing {len(frame_files)} frames...\")\n",
" \n",
" all_annotations = {}\n",
" total_objects = 0\n",
" \n",
" for frame_file in tqdm(frame_files, desc=\"Annotating\"):\n",
" masks, image = annotate_single_image(\n",
" str(frame_file),\n",
" min_area=min_area,\n",
" max_area=max_area\n",
" )\n",
" \n",
" # Convert to serializable format\n",
" annotations = []\n",
" for i, m in enumerate(masks):\n",
" ann = {\n",
" 'id': i,\n",
" 'bbox': [int(x) for x in m['bbox']], # [x, y, w, h]\n",
" 'area': int(m['area']),\n",
" 'predicted_iou': float(m['predicted_iou']),\n",
" 'stability_score': float(m['stability_score'])\n",
" }\n",
" annotations.append(ann)\n",
" \n",
" # Save individual mask\n",
" if save_masks:\n",
" frame_masks_dir = masks_dir / frame_file.stem\n",
" frame_masks_dir.mkdir(exist_ok=True)\n",
" mask_path = frame_masks_dir / f\"mask_{i:03d}.png\"\n",
" cv2.imwrite(\n",
" str(mask_path),\n",
" m['segmentation'].astype(np.uint8) * 255\n",
" )\n",
" \n",
" all_annotations[frame_file.name] = annotations\n",
" total_objects += len(annotations)\n",
" \n",
" # Save visualization\n",
" if save_vis:\n",
" vis = visualize_masks(image, masks)\n",
" cv2.imwrite(str(vis_dir / frame_file.name), vis)\n",
" \n",
" # Save annotations JSON\n",
" annotations_file = output_path / 'annotations.json'\n",
" with open(annotations_file, 'w') as f:\n",
" json.dump(all_annotations, f, indent=2)\n",
" \n",
" print(f\"\\nAnnotation complete!\")\n",
" print(f\" Frames: {len(frame_files)}\")\n",
" print(f\" Total objects: {total_objects}\")\n",
" print(f\" Avg objects/frame: {total_objects/len(frame_files):.1f}\")\n",
" print(f\" Saved to: {output_path}\")\n",
" \n",
" return all_annotations\n",
"\n",
"# Run annotation\n",
"annotations = annotate_all_frames(\n",
" FRAMES_DIR,\n",
" OUTPUT_DIR,\n",
" min_area=MIN_MASK_AREA,\n",
" max_area=MAX_MASK_AREA,\n",
" save_masks=SAVE_MASKS,\n",
" save_vis=SAVE_VIS\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Review Annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preview some annotated frames\n",
"vis_dir = Path(OUTPUT_DIR) / 'visualizations'\n",
"vis_files = sorted(vis_dir.glob(\"*.jpg\"))[:6]\n",
"\n",
"if vis_files:\n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
" for ax, vis_file in zip(axes.flat, vis_files):\n",
" img = cv2.imread(str(vis_file))\n",
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
" ax.imshow(img)\n",
" \n",
" # Get object count\n",
" frame_name = vis_file.name\n",
" if frame_name in annotations:\n",
" count = len(annotations[frame_name])\n",
" else:\n",
" count = 0\n",
" ax.set_title(f\"{vis_file.name} ({count} objects)\")\n",
" ax.axis('off')\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Statistics\n",
"object_counts = [len(ann) for ann in annotations.values()]\n",
"\n",
"print(\"Annotation Statistics:\")\n",
"print(f\" Total frames: {len(annotations)}\")\n",
"print(f\" Total objects: {sum(object_counts)}\")\n",
"print(f\" Min objects/frame: {min(object_counts)}\")\n",
"print(f\" Max objects/frame: {max(object_counts)}\")\n",
"print(f\" Avg objects/frame: {np.mean(object_counts):.1f}\")\n",
"\n",
"# Histogram\n",
"plt.figure(figsize=(10, 4))\n",
"plt.hist(object_counts, bins=30, edgecolor='black')\n",
"plt.xlabel('Objects per frame')\n",
"plt.ylabel('Frequency')\n",
"plt.title('Distribution of Objects per Frame')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Save Output for Next Notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Summary of output files\n",
"print(\"Output files for next notebook:\")\n",
"print(f\" Frames: {FRAMES_DIR}/\")\n",
"print(f\" Annotations: {OUTPUT_DIR}/annotations.json\")\n",
"print(f\" Masks: {OUTPUT_DIR}/masks/\")\n",
"print(f\" Visualizations: {OUTPUT_DIR}/visualizations/\")\n",
"\n",
"# Check sizes\n",
"import subprocess\n",
"!du -sh {FRAMES_DIR} {OUTPUT_DIR}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: Create archive for download\n",
"ARCHIVE_NAME = 'sam2_annotations.zip'\n",
"\n",
"!zip -r {ARCHIVE_NAME} {FRAMES_DIR} {OUTPUT_DIR}/annotations.json {OUTPUT_DIR}/masks/\n",
"print(f\"\\nArchive created: {ARCHIVE_NAME}\")\n",
"!ls -lh {ARCHIVE_NAME}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"## Next Steps\n",
"\n",
"1. **Run `02_create_yolo_dataset.ipynb`** to convert annotations to YOLO format\n",
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t on the dataset\n",
"\n",
"### Tips\n",
"\n",
"- Adjust `MIN_MASK_AREA` to filter small detections\n",
"- Use `MAX_MASK_AREA` to filter very large objects (e.g., background)\n",
"- Lower `SAMPLE_FPS` for faster processing or if frames are similar\n",
"- Review visualizations to verify annotation quality"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}