688 lines
22 KiB
Plaintext
688 lines
22 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SAM2 Video Annotation Pipeline\n",
|
|
"\n",
|
|
"This notebook sets up SAM2 (Segment Anything Model 2) for automatic object annotation from video.\n",
|
|
"\n",
|
|
"## Features\n",
|
|
"- Install SAM2 and dependencies on Kaggle\n",
|
|
"- Download pretrained model weights\n",
|
|
"- Extract frames from video\n",
|
|
"- Auto-generate masks for all objects\n",
|
|
"- Save annotations for YOLO conversion\n",
|
|
"\n",
|
|
"**Platform:** Kaggle GPU (P100/T4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Setup Environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check GPU availability\n",
|
|
"!nvidia-smi"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Install SAM2 and dependencies\n",
|
|
"!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git\n",
|
|
"!pip install -q opencv-python-headless supervision tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import cv2\n",
|
|
"import json\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"from IPython.display import display, Image as IPImage\n",
|
|
"\n",
|
|
"print(f\"Python: {sys.version}\")\n",
|
|
"print(f\"PyTorch: {torch.__version__}\")\n",
|
|
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Download SAM2 Pretrained Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Configuration\n",
|
|
"MODEL_SIZE = 'large' # Options: 'tiny', 'small', 'base_plus', 'large'\n",
|
|
"CHECKPOINT_DIR = './checkpoints'\n",
|
|
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
|
"\n",
|
|
"# Model configurations\n",
|
|
"MODEL_CONFIGS = {\n",
|
|
" 'tiny': {\n",
|
|
" 'config': 'sam2_hiera_t.yaml',\n",
|
|
" 'checkpoint': 'sam2_hiera_tiny.pt',\n",
|
|
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'\n",
|
|
" },\n",
|
|
" 'small': {\n",
|
|
" 'config': 'sam2_hiera_s.yaml',\n",
|
|
" 'checkpoint': 'sam2_hiera_small.pt',\n",
|
|
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'\n",
|
|
" },\n",
|
|
" 'base_plus': {\n",
|
|
" 'config': 'sam2_hiera_b+.yaml',\n",
|
|
" 'checkpoint': 'sam2_hiera_base_plus.pt',\n",
|
|
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'\n",
|
|
" },\n",
|
|
" 'large': {\n",
|
|
" 'config': 'sam2_hiera_l.yaml',\n",
|
|
" 'checkpoint': 'sam2_hiera_large.pt',\n",
|
|
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'\n",
|
|
" }\n",
|
|
"}\n",
|
|
"\n",
|
|
"print(f\"Using model: SAM2 {MODEL_SIZE}\")\n",
|
|
"print(f\"Device: {DEVICE}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Download checkpoint\n",
|
|
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
|
|
"\n",
|
|
"config = MODEL_CONFIGS[MODEL_SIZE]\n",
|
|
"checkpoint_path = os.path.join(CHECKPOINT_DIR, config['checkpoint'])\n",
|
|
"\n",
|
|
"if not os.path.exists(checkpoint_path):\n",
|
|
" print(f\"Downloading {config['checkpoint']}...\")\n",
|
|
" !wget -q -O {checkpoint_path} {config['url']}\n",
|
|
" print(\"Download complete!\")\n",
|
|
"else:\n",
|
|
" print(f\"Checkpoint exists: {checkpoint_path}\")\n",
|
|
"\n",
|
|
"print(f\"Checkpoint size: {os.path.getsize(checkpoint_path) / 1024 / 1024:.1f} MB\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Load SAM2 Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sam2.build_sam import build_sam2\n",
|
|
"from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
|
|
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
|
|
"\n",
|
|
"print(\"Loading SAM2 model...\")\n",
|
|
"\n",
|
|
"# Build model\n",
|
|
"sam2_model = build_sam2(\n",
|
|
" config['config'],\n",
|
|
" checkpoint_path,\n",
|
|
" device=DEVICE\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create automatic mask generator\n",
|
|
"mask_generator = SAM2AutomaticMaskGenerator(\n",
|
|
" model=sam2_model,\n",
|
|
" points_per_side=32, # Grid density for point prompts\n",
|
|
" points_per_batch=64, # Batch size for processing\n",
|
|
" pred_iou_thresh=0.7, # IoU threshold for predictions\n",
|
|
" stability_score_thresh=0.92, # Stability threshold\n",
|
|
" stability_score_offset=1.0,\n",
|
|
" box_nms_thresh=0.7, # NMS threshold for boxes\n",
|
|
" crop_n_layers=1, # Crop layers for multi-scale\n",
|
|
" crop_nms_thresh=0.7,\n",
|
|
" crop_overlap_ratio=0.34,\n",
|
|
" crop_n_points_downscale_factor=2,\n",
|
|
" min_mask_region_area=100 # Minimum mask area in pixels\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create predictor for interactive mode\n",
|
|
"predictor = SAM2ImagePredictor(sam2_model)\n",
|
|
"\n",
|
|
"print(\"Model loaded successfully!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4. Video Frame Extraction"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Configuration - UPDATE THESE PATHS\n",
|
|
"VIDEO_PATH = '/kaggle/input/your-video-dataset/video.mp4' # Change to your video path\n",
|
|
"FRAMES_DIR = './frames'\n",
|
|
"OUTPUT_DIR = './annotations'\n",
|
|
"\n",
|
|
"# Frame extraction settings\n",
|
|
"SAMPLE_FPS = 2 # Extract 2 frames per second (adjust based on video)\n",
|
|
"MAX_FRAMES = 500 # Maximum frames to extract (None for all)\n",
|
|
"START_TIME = 0 # Start time in seconds\n",
|
|
"END_TIME = None # End time in seconds (None for entire video)\n",
|
|
"RESIZE = None # Resize frames: (width, height) or None\n",
|
|
"\n",
|
|
"os.makedirs(FRAMES_DIR, exist_ok=True)\n",
|
|
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_video_info(video_path):\n",
|
|
" \"\"\"Get video information.\"\"\"\n",
|
|
" cap = cv2.VideoCapture(video_path)\n",
|
|
" if not cap.isOpened():\n",
|
|
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
|
" \n",
|
|
" info = {\n",
|
|
" 'fps': cap.get(cv2.CAP_PROP_FPS),\n",
|
|
" 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),\n",
|
|
" 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n",
|
|
" 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
|
|
" }\n",
|
|
" info['duration'] = info['frame_count'] / info['fps']\n",
|
|
" cap.release()\n",
|
|
" return info\n",
|
|
"\n",
|
|
"# Display video info\n",
|
|
"if os.path.exists(VIDEO_PATH):\n",
|
|
" video_info = get_video_info(VIDEO_PATH)\n",
|
|
" print(\"Video Information:\")\n",
|
|
" print(f\" Resolution: {video_info['width']}x{video_info['height']}\")\n",
|
|
" print(f\" FPS: {video_info['fps']:.2f}\")\n",
|
|
" print(f\" Duration: {video_info['duration']:.2f}s\")\n",
|
|
" print(f\" Total frames: {video_info['frame_count']}\")\n",
|
|
"else:\n",
|
|
" print(f\"Video not found: {VIDEO_PATH}\")\n",
|
|
" print(\"Please upload your video or update VIDEO_PATH\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def extract_frames(video_path, output_dir, sample_fps=None, max_frames=None, \n",
|
|
" start_time=0, end_time=None, resize=None):\n",
|
|
" \"\"\"Extract frames from video.\"\"\"\n",
|
|
" cap = cv2.VideoCapture(video_path)\n",
|
|
" if not cap.isOpened():\n",
|
|
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
|
" \n",
|
|
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
|
|
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
|
|
" \n",
|
|
" # Calculate frame interval\n",
|
|
" if sample_fps and sample_fps < fps:\n",
|
|
" frame_interval = int(fps / sample_fps)\n",
|
|
" else:\n",
|
|
" frame_interval = 1\n",
|
|
" \n",
|
|
" # Calculate frame range\n",
|
|
" start_frame = int(start_time * fps)\n",
|
|
" end_frame = int(end_time * fps) if end_time else frame_count\n",
|
|
" end_frame = min(end_frame, frame_count)\n",
|
|
" \n",
|
|
" cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n",
|
|
" \n",
|
|
" saved_paths = []\n",
|
|
" frame_idx = start_frame\n",
|
|
" extracted = 0\n",
|
|
" \n",
|
|
" total_to_extract = min(\n",
|
|
" (end_frame - start_frame) // frame_interval,\n",
|
|
" max_frames or float('inf')\n",
|
|
" )\n",
|
|
" \n",
|
|
" pbar = tqdm(total=int(total_to_extract), desc=\"Extracting frames\")\n",
|
|
" \n",
|
|
" while frame_idx < end_frame:\n",
|
|
" if max_frames and extracted >= max_frames:\n",
|
|
" break\n",
|
|
" \n",
|
|
" ret, frame = cap.read()\n",
|
|
" if not ret:\n",
|
|
" break\n",
|
|
" \n",
|
|
" if (frame_idx - start_frame) % frame_interval == 0:\n",
|
|
" if resize:\n",
|
|
" frame = cv2.resize(frame, resize)\n",
|
|
" \n",
|
|
" frame_name = f\"frame_{frame_idx:06d}.jpg\"\n",
|
|
" frame_path = os.path.join(output_dir, frame_name)\n",
|
|
" cv2.imwrite(frame_path, frame)\n",
|
|
" saved_paths.append(frame_path)\n",
|
|
" extracted += 1\n",
|
|
" pbar.update(1)\n",
|
|
" \n",
|
|
" frame_idx += 1\n",
|
|
" \n",
|
|
" pbar.close()\n",
|
|
" cap.release()\n",
|
|
" \n",
|
|
" print(f\"Extracted {len(saved_paths)} frames to {output_dir}\")\n",
|
|
" return saved_paths\n",
|
|
"\n",
|
|
"# Extract frames\n",
|
|
"if os.path.exists(VIDEO_PATH):\n",
|
|
" frame_paths = extract_frames(\n",
|
|
" VIDEO_PATH,\n",
|
|
" FRAMES_DIR,\n",
|
|
" sample_fps=SAMPLE_FPS,\n",
|
|
" max_frames=MAX_FRAMES,\n",
|
|
" start_time=START_TIME,\n",
|
|
" end_time=END_TIME,\n",
|
|
" resize=RESIZE\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Preview extracted frames\n",
|
|
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))[:6]\n",
|
|
"\n",
|
|
"if frame_files:\n",
|
|
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
|
|
" for ax, frame_file in zip(axes.flat, frame_files):\n",
|
|
" img = cv2.imread(str(frame_file))\n",
|
|
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
|
" ax.imshow(img)\n",
|
|
" ax.set_title(frame_file.name)\n",
|
|
" ax.axis('off')\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"else:\n",
|
|
" print(\"No frames found. Please extract frames first.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 5. Automatic Mask Generation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Annotation settings\n",
|
|
"MIN_MASK_AREA = 500 # Minimum mask area in pixels\n",
|
|
"MAX_MASK_AREA = None # Maximum mask area (None for no limit)\n",
|
|
"SAVE_MASKS = True # Save individual mask images\n",
|
|
"SAVE_VIS = True # Save visualization images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def visualize_masks(image, masks, alpha=0.5):\n",
|
|
" \"\"\"Create visualization with masks overlaid.\"\"\"\n",
|
|
" vis = image.copy()\n",
|
|
" \n",
|
|
" for mask_data in masks:\n",
|
|
" mask = mask_data['segmentation']\n",
|
|
" color = np.random.randint(0, 255, 3).tolist()\n",
|
|
" \n",
|
|
" overlay = vis.copy()\n",
|
|
" overlay[mask] = color\n",
|
|
" vis = cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)\n",
|
|
" \n",
|
|
" # Draw bbox\n",
|
|
" x, y, w, h = mask_data['bbox']\n",
|
|
" cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)\n",
|
|
" \n",
|
|
" return vis\n",
|
|
"\n",
|
|
"def annotate_single_image(image_path, min_area=100, max_area=None):\n",
|
|
" \"\"\"Generate masks for a single image.\"\"\"\n",
|
|
" image = cv2.imread(image_path)\n",
|
|
" image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
|
" \n",
|
|
" # Generate masks\n",
|
|
" masks = mask_generator.generate(image_rgb)\n",
|
|
" \n",
|
|
" # Filter by area\n",
|
|
" filtered = []\n",
|
|
" for m in masks:\n",
|
|
" area = m['area']\n",
|
|
" if area >= min_area:\n",
|
|
" if max_area is None or area <= max_area:\n",
|
|
" filtered.append(m)\n",
|
|
" \n",
|
|
" # Sort by area (largest first)\n",
|
|
" filtered.sort(key=lambda x: x['area'], reverse=True)\n",
|
|
" \n",
|
|
" return filtered, image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test on a single frame\n",
|
|
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))\n",
|
|
"\n",
|
|
"if frame_files:\n",
|
|
" test_frame = str(frame_files[0])\n",
|
|
" print(f\"Testing on: {test_frame}\")\n",
|
|
" \n",
|
|
" masks, image = annotate_single_image(\n",
|
|
" test_frame,\n",
|
|
" min_area=MIN_MASK_AREA,\n",
|
|
" max_area=MAX_MASK_AREA\n",
|
|
" )\n",
|
|
" \n",
|
|
" print(f\"Found {len(masks)} objects\")\n",
|
|
" \n",
|
|
" # Visualize\n",
|
|
" vis = visualize_masks(image, masks)\n",
|
|
" vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)\n",
|
|
" \n",
|
|
" fig, axes = plt.subplots(1, 2, figsize=(15, 7))\n",
|
|
" axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
|
|
" axes[0].set_title('Original')\n",
|
|
" axes[0].axis('off')\n",
|
|
" \n",
|
|
" axes[1].imshow(vis_rgb)\n",
|
|
" axes[1].set_title(f'SAM2 Annotations ({len(masks)} objects)')\n",
|
|
" axes[1].axis('off')\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
" # Show bbox details\n",
|
|
" print(\"\\nDetected objects:\")\n",
|
|
" for i, m in enumerate(masks[:10]):\n",
|
|
" print(f\" {i}: bbox={m['bbox']}, area={m['area']}, iou={m['predicted_iou']:.3f}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Process all frames\n",
|
|
"def annotate_all_frames(frames_dir, output_dir, min_area=100, max_area=None,\n",
|
|
" save_masks=True, save_vis=True):\n",
|
|
" \"\"\"Annotate all frames and save results.\"\"\"\n",
|
|
" frames_path = Path(frames_dir)\n",
|
|
" output_path = Path(output_dir)\n",
|
|
" output_path.mkdir(parents=True, exist_ok=True)\n",
|
|
" \n",
|
|
" if save_masks:\n",
|
|
" masks_dir = output_path / 'masks'\n",
|
|
" masks_dir.mkdir(exist_ok=True)\n",
|
|
" \n",
|
|
" if save_vis:\n",
|
|
" vis_dir = output_path / 'visualizations'\n",
|
|
" vis_dir.mkdir(exist_ok=True)\n",
|
|
" \n",
|
|
" frame_files = sorted(frames_path.glob(\"*.jpg\"))\n",
|
|
" print(f\"Processing {len(frame_files)} frames...\")\n",
|
|
" \n",
|
|
" all_annotations = {}\n",
|
|
" total_objects = 0\n",
|
|
" \n",
|
|
" for frame_file in tqdm(frame_files, desc=\"Annotating\"):\n",
|
|
" masks, image = annotate_single_image(\n",
|
|
" str(frame_file),\n",
|
|
" min_area=min_area,\n",
|
|
" max_area=max_area\n",
|
|
" )\n",
|
|
" \n",
|
|
" # Convert to serializable format\n",
|
|
" annotations = []\n",
|
|
" for i, m in enumerate(masks):\n",
|
|
" ann = {\n",
|
|
" 'id': i,\n",
|
|
" 'bbox': [int(x) for x in m['bbox']], # [x, y, w, h]\n",
|
|
" 'area': int(m['area']),\n",
|
|
" 'predicted_iou': float(m['predicted_iou']),\n",
|
|
" 'stability_score': float(m['stability_score'])\n",
|
|
" }\n",
|
|
" annotations.append(ann)\n",
|
|
" \n",
|
|
" # Save individual mask\n",
|
|
" if save_masks:\n",
|
|
" frame_masks_dir = masks_dir / frame_file.stem\n",
|
|
" frame_masks_dir.mkdir(exist_ok=True)\n",
|
|
" mask_path = frame_masks_dir / f\"mask_{i:03d}.png\"\n",
|
|
" cv2.imwrite(\n",
|
|
" str(mask_path),\n",
|
|
" m['segmentation'].astype(np.uint8) * 255\n",
|
|
" )\n",
|
|
" \n",
|
|
" all_annotations[frame_file.name] = annotations\n",
|
|
" total_objects += len(annotations)\n",
|
|
" \n",
|
|
" # Save visualization\n",
|
|
" if save_vis:\n",
|
|
" vis = visualize_masks(image, masks)\n",
|
|
" cv2.imwrite(str(vis_dir / frame_file.name), vis)\n",
|
|
" \n",
|
|
" # Save annotations JSON\n",
|
|
" annotations_file = output_path / 'annotations.json'\n",
|
|
" with open(annotations_file, 'w') as f:\n",
|
|
" json.dump(all_annotations, f, indent=2)\n",
|
|
" \n",
|
|
" print(f\"\\nAnnotation complete!\")\n",
|
|
" print(f\" Frames: {len(frame_files)}\")\n",
|
|
" print(f\" Total objects: {total_objects}\")\n",
|
|
" print(f\" Avg objects/frame: {total_objects/len(frame_files):.1f}\")\n",
|
|
" print(f\" Saved to: {output_path}\")\n",
|
|
" \n",
|
|
" return all_annotations\n",
|
|
"\n",
|
|
"# Run annotation\n",
|
|
"annotations = annotate_all_frames(\n",
|
|
" FRAMES_DIR,\n",
|
|
" OUTPUT_DIR,\n",
|
|
" min_area=MIN_MASK_AREA,\n",
|
|
" max_area=MAX_MASK_AREA,\n",
|
|
" save_masks=SAVE_MASKS,\n",
|
|
" save_vis=SAVE_VIS\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 6. Review Annotations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Preview some annotated frames\n",
|
|
"vis_dir = Path(OUTPUT_DIR) / 'visualizations'\n",
|
|
"vis_files = sorted(vis_dir.glob(\"*.jpg\"))[:6]\n",
|
|
"\n",
|
|
"if vis_files:\n",
|
|
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
|
|
" for ax, vis_file in zip(axes.flat, vis_files):\n",
|
|
" img = cv2.imread(str(vis_file))\n",
|
|
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
|
" ax.imshow(img)\n",
|
|
" \n",
|
|
" # Get object count\n",
|
|
" frame_name = vis_file.name\n",
|
|
" if frame_name in annotations:\n",
|
|
" count = len(annotations[frame_name])\n",
|
|
" else:\n",
|
|
" count = 0\n",
|
|
" ax.set_title(f\"{vis_file.name} ({count} objects)\")\n",
|
|
" ax.axis('off')\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Statistics\n",
|
|
"object_counts = [len(ann) for ann in annotations.values()]\n",
|
|
"\n",
|
|
"print(\"Annotation Statistics:\")\n",
|
|
"print(f\" Total frames: {len(annotations)}\")\n",
|
|
"print(f\" Total objects: {sum(object_counts)}\")\n",
|
|
"print(f\" Min objects/frame: {min(object_counts)}\")\n",
|
|
"print(f\" Max objects/frame: {max(object_counts)}\")\n",
|
|
"print(f\" Avg objects/frame: {np.mean(object_counts):.1f}\")\n",
|
|
"\n",
|
|
"# Histogram\n",
|
|
"plt.figure(figsize=(10, 4))\n",
|
|
"plt.hist(object_counts, bins=30, edgecolor='black')\n",
|
|
"plt.xlabel('Objects per frame')\n",
|
|
"plt.ylabel('Frequency')\n",
|
|
"plt.title('Distribution of Objects per Frame')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 7. Save Output for Next Notebook"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Summary of output files\n",
|
|
"print(\"Output files for next notebook:\")\n",
|
|
"print(f\" Frames: {FRAMES_DIR}/\")\n",
|
|
"print(f\" Annotations: {OUTPUT_DIR}/annotations.json\")\n",
|
|
"print(f\" Masks: {OUTPUT_DIR}/masks/\")\n",
|
|
"print(f\" Visualizations: {OUTPUT_DIR}/visualizations/\")\n",
|
|
"\n",
|
|
"# Check sizes\n",
|
|
"import subprocess\n",
|
|
"!du -sh {FRAMES_DIR} {OUTPUT_DIR}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Optional: Create archive for download\n",
|
|
"ARCHIVE_NAME = 'sam2_annotations.zip'\n",
|
|
"\n",
|
|
"!zip -r {ARCHIVE_NAME} {FRAMES_DIR} {OUTPUT_DIR}/annotations.json {OUTPUT_DIR}/masks/\n",
|
|
"print(f\"\\nArchive created: {ARCHIVE_NAME}\")\n",
|
|
"!ls -lh {ARCHIVE_NAME}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"\n",
|
|
"## Next Steps\n",
|
|
"\n",
|
|
"1. **Run `02_create_yolo_dataset.ipynb`** to convert annotations to YOLO format\n",
|
|
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t on the dataset\n",
|
|
"\n",
|
|
"### Tips\n",
|
|
"\n",
|
|
"- Adjust `MIN_MASK_AREA` to filter small detections\n",
|
|
"- Use `MAX_MASK_AREA` to filter very large objects (e.g., background)\n",
|
|
"- Lower `SAMPLE_FPS` for faster processing or if frames are similar\n",
|
|
"- Review visualizations to verify annotation quality"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.10.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|