add sam2 yolo auto annotation
This commit is contained in:
@@ -0,0 +1,687 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SAM2 Video Annotation Pipeline\n",
|
||||
"\n",
|
||||
"This notebook sets up SAM2 (Segment Anything Model 2) for automatic object annotation from video.\n",
|
||||
"\n",
|
||||
"## Features\n",
|
||||
"- Install SAM2 and dependencies on Kaggle\n",
|
||||
"- Download pretrained model weights\n",
|
||||
"- Extract frames from video\n",
|
||||
"- Auto-generate masks for all objects\n",
|
||||
"- Save annotations for YOLO conversion\n",
|
||||
"\n",
|
||||
"**Platform:** Kaggle GPU (P100/T4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup Environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check GPU availability\n",
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install SAM2 and dependencies\n",
|
||||
"!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git\n",
|
||||
"!pip install -q opencv-python-headless supervision tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import cv2\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"from IPython.display import display, Image as IPImage\n",
|
||||
"\n",
|
||||
"print(f\"Python: {sys.version}\")\n",
|
||||
"print(f\"PyTorch: {torch.__version__}\")\n",
|
||||
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Download SAM2 Pretrained Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration\n",
|
||||
"MODEL_SIZE = 'large' # Options: 'tiny', 'small', 'base_plus', 'large'\n",
|
||||
"CHECKPOINT_DIR = './checkpoints'\n",
|
||||
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"\n",
|
||||
"# Model configurations\n",
|
||||
"MODEL_CONFIGS = {\n",
|
||||
" 'tiny': {\n",
|
||||
" 'config': 'sam2_hiera_t.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_tiny.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'\n",
|
||||
" },\n",
|
||||
" 'small': {\n",
|
||||
" 'config': 'sam2_hiera_s.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_small.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'\n",
|
||||
" },\n",
|
||||
" 'base_plus': {\n",
|
||||
" 'config': 'sam2_hiera_b+.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_base_plus.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'\n",
|
||||
" },\n",
|
||||
" 'large': {\n",
|
||||
" 'config': 'sam2_hiera_l.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_large.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(f\"Using model: SAM2 {MODEL_SIZE}\")\n",
|
||||
"print(f\"Device: {DEVICE}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download checkpoint\n",
|
||||
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
|
||||
"\n",
|
||||
"config = MODEL_CONFIGS[MODEL_SIZE]\n",
|
||||
"checkpoint_path = os.path.join(CHECKPOINT_DIR, config['checkpoint'])\n",
|
||||
"\n",
|
||||
"if not os.path.exists(checkpoint_path):\n",
|
||||
" print(f\"Downloading {config['checkpoint']}...\")\n",
|
||||
" !wget -q -O {checkpoint_path} {config['url']}\n",
|
||||
" print(\"Download complete!\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Checkpoint exists: {checkpoint_path}\")\n",
|
||||
"\n",
|
||||
"print(f\"Checkpoint size: {os.path.getsize(checkpoint_path) / 1024 / 1024:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Load SAM2 Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sam2.build_sam import build_sam2\n",
|
||||
"from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
|
||||
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
|
||||
"\n",
|
||||
"print(\"Loading SAM2 model...\")\n",
|
||||
"\n",
|
||||
"# Build model\n",
|
||||
"sam2_model = build_sam2(\n",
|
||||
" config['config'],\n",
|
||||
" checkpoint_path,\n",
|
||||
" device=DEVICE\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create automatic mask generator\n",
|
||||
"mask_generator = SAM2AutomaticMaskGenerator(\n",
|
||||
" model=sam2_model,\n",
|
||||
" points_per_side=32, # Grid density for point prompts\n",
|
||||
" points_per_batch=64, # Batch size for processing\n",
|
||||
" pred_iou_thresh=0.7, # IoU threshold for predictions\n",
|
||||
" stability_score_thresh=0.92, # Stability threshold\n",
|
||||
" stability_score_offset=1.0,\n",
|
||||
" box_nms_thresh=0.7, # NMS threshold for boxes\n",
|
||||
" crop_n_layers=1, # Crop layers for multi-scale\n",
|
||||
" crop_nms_thresh=0.7,\n",
|
||||
" crop_overlap_ratio=0.34,\n",
|
||||
" crop_n_points_downscale_factor=2,\n",
|
||||
" min_mask_region_area=100 # Minimum mask area in pixels\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create predictor for interactive mode\n",
|
||||
"predictor = SAM2ImagePredictor(sam2_model)\n",
|
||||
"\n",
|
||||
"print(\"Model loaded successfully!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Video Frame Extraction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration - UPDATE THESE PATHS\n",
|
||||
"VIDEO_PATH = '/kaggle/input/your-video-dataset/video.mp4' # Change to your video path\n",
|
||||
"FRAMES_DIR = './frames'\n",
|
||||
"OUTPUT_DIR = './annotations'\n",
|
||||
"\n",
|
||||
"# Frame extraction settings\n",
|
||||
"SAMPLE_FPS = 2 # Extract 2 frames per second (adjust based on video)\n",
|
||||
"MAX_FRAMES = 500 # Maximum frames to extract (None for all)\n",
|
||||
"START_TIME = 0 # Start time in seconds\n",
|
||||
"END_TIME = None # End time in seconds (None for entire video)\n",
|
||||
"RESIZE = None # Resize frames: (width, height) or None\n",
|
||||
"\n",
|
||||
"os.makedirs(FRAMES_DIR, exist_ok=True)\n",
|
||||
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_video_info(video_path):\n",
|
||||
" \"\"\"Get video information.\"\"\"\n",
|
||||
" cap = cv2.VideoCapture(video_path)\n",
|
||||
" if not cap.isOpened():\n",
|
||||
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
||||
" \n",
|
||||
" info = {\n",
|
||||
" 'fps': cap.get(cv2.CAP_PROP_FPS),\n",
|
||||
" 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),\n",
|
||||
" 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n",
|
||||
" 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
|
||||
" }\n",
|
||||
" info['duration'] = info['frame_count'] / info['fps']\n",
|
||||
" cap.release()\n",
|
||||
" return info\n",
|
||||
"\n",
|
||||
"# Display video info\n",
|
||||
"if os.path.exists(VIDEO_PATH):\n",
|
||||
" video_info = get_video_info(VIDEO_PATH)\n",
|
||||
" print(\"Video Information:\")\n",
|
||||
" print(f\" Resolution: {video_info['width']}x{video_info['height']}\")\n",
|
||||
" print(f\" FPS: {video_info['fps']:.2f}\")\n",
|
||||
" print(f\" Duration: {video_info['duration']:.2f}s\")\n",
|
||||
" print(f\" Total frames: {video_info['frame_count']}\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Video not found: {VIDEO_PATH}\")\n",
|
||||
" print(\"Please upload your video or update VIDEO_PATH\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_frames(video_path, output_dir, sample_fps=None, max_frames=None, \n",
|
||||
" start_time=0, end_time=None, resize=None):\n",
|
||||
" \"\"\"Extract frames from video.\"\"\"\n",
|
||||
" cap = cv2.VideoCapture(video_path)\n",
|
||||
" if not cap.isOpened():\n",
|
||||
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
||||
" \n",
|
||||
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
|
||||
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
|
||||
" \n",
|
||||
" # Calculate frame interval\n",
|
||||
" if sample_fps and sample_fps < fps:\n",
|
||||
" frame_interval = int(fps / sample_fps)\n",
|
||||
" else:\n",
|
||||
" frame_interval = 1\n",
|
||||
" \n",
|
||||
" # Calculate frame range\n",
|
||||
" start_frame = int(start_time * fps)\n",
|
||||
" end_frame = int(end_time * fps) if end_time else frame_count\n",
|
||||
" end_frame = min(end_frame, frame_count)\n",
|
||||
" \n",
|
||||
" cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n",
|
||||
" \n",
|
||||
" saved_paths = []\n",
|
||||
" frame_idx = start_frame\n",
|
||||
" extracted = 0\n",
|
||||
" \n",
|
||||
" total_to_extract = min(\n",
|
||||
" (end_frame - start_frame) // frame_interval,\n",
|
||||
" max_frames or float('inf')\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" pbar = tqdm(total=int(total_to_extract), desc=\"Extracting frames\")\n",
|
||||
" \n",
|
||||
" while frame_idx < end_frame:\n",
|
||||
" if max_frames and extracted >= max_frames:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" ret, frame = cap.read()\n",
|
||||
" if not ret:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if (frame_idx - start_frame) % frame_interval == 0:\n",
|
||||
" if resize:\n",
|
||||
" frame = cv2.resize(frame, resize)\n",
|
||||
" \n",
|
||||
" frame_name = f\"frame_{frame_idx:06d}.jpg\"\n",
|
||||
" frame_path = os.path.join(output_dir, frame_name)\n",
|
||||
" cv2.imwrite(frame_path, frame)\n",
|
||||
" saved_paths.append(frame_path)\n",
|
||||
" extracted += 1\n",
|
||||
" pbar.update(1)\n",
|
||||
" \n",
|
||||
" frame_idx += 1\n",
|
||||
" \n",
|
||||
" pbar.close()\n",
|
||||
" cap.release()\n",
|
||||
" \n",
|
||||
" print(f\"Extracted {len(saved_paths)} frames to {output_dir}\")\n",
|
||||
" return saved_paths\n",
|
||||
"\n",
|
||||
"# Extract frames\n",
|
||||
"if os.path.exists(VIDEO_PATH):\n",
|
||||
" frame_paths = extract_frames(\n",
|
||||
" VIDEO_PATH,\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" sample_fps=SAMPLE_FPS,\n",
|
||||
" max_frames=MAX_FRAMES,\n",
|
||||
" start_time=START_TIME,\n",
|
||||
" end_time=END_TIME,\n",
|
||||
" resize=RESIZE\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preview extracted frames\n",
|
||||
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))[:6]\n",
|
||||
"\n",
|
||||
"if frame_files:\n",
|
||||
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
|
||||
" for ax, frame_file in zip(axes.flat, frame_files):\n",
|
||||
" img = cv2.imread(str(frame_file))\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" ax.imshow(img)\n",
|
||||
" ax.set_title(frame_file.name)\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
"else:\n",
|
||||
" print(\"No frames found. Please extract frames first.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Automatic Mask Generation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Annotation settings\n",
|
||||
"MIN_MASK_AREA = 500 # Minimum mask area in pixels\n",
|
||||
"MAX_MASK_AREA = None # Maximum mask area (None for no limit)\n",
|
||||
"SAVE_MASKS = True # Save individual mask images\n",
|
||||
"SAVE_VIS = True # Save visualization images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def visualize_masks(image, masks, alpha=0.5):\n",
|
||||
" \"\"\"Create visualization with masks overlaid.\"\"\"\n",
|
||||
" vis = image.copy()\n",
|
||||
" \n",
|
||||
" for mask_data in masks:\n",
|
||||
" mask = mask_data['segmentation']\n",
|
||||
" color = np.random.randint(0, 255, 3).tolist()\n",
|
||||
" \n",
|
||||
" overlay = vis.copy()\n",
|
||||
" overlay[mask] = color\n",
|
||||
" vis = cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)\n",
|
||||
" \n",
|
||||
" # Draw bbox\n",
|
||||
" x, y, w, h = mask_data['bbox']\n",
|
||||
" cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)\n",
|
||||
" \n",
|
||||
" return vis\n",
|
||||
"\n",
|
||||
"def annotate_single_image(image_path, min_area=100, max_area=None):\n",
|
||||
" \"\"\"Generate masks for a single image.\"\"\"\n",
|
||||
" image = cv2.imread(image_path)\n",
|
||||
" image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
" \n",
|
||||
" # Generate masks\n",
|
||||
" masks = mask_generator.generate(image_rgb)\n",
|
||||
" \n",
|
||||
" # Filter by area\n",
|
||||
" filtered = []\n",
|
||||
" for m in masks:\n",
|
||||
" area = m['area']\n",
|
||||
" if area >= min_area:\n",
|
||||
" if max_area is None or area <= max_area:\n",
|
||||
" filtered.append(m)\n",
|
||||
" \n",
|
||||
" # Sort by area (largest first)\n",
|
||||
" filtered.sort(key=lambda x: x['area'], reverse=True)\n",
|
||||
" \n",
|
||||
" return filtered, image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test on a single frame\n",
|
||||
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))\n",
|
||||
"\n",
|
||||
"if frame_files:\n",
|
||||
" test_frame = str(frame_files[0])\n",
|
||||
" print(f\"Testing on: {test_frame}\")\n",
|
||||
" \n",
|
||||
" masks, image = annotate_single_image(\n",
|
||||
" test_frame,\n",
|
||||
" min_area=MIN_MASK_AREA,\n",
|
||||
" max_area=MAX_MASK_AREA\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(f\"Found {len(masks)} objects\")\n",
|
||||
" \n",
|
||||
" # Visualize\n",
|
||||
" vis = visualize_masks(image, masks)\n",
|
||||
" vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)\n",
|
||||
" \n",
|
||||
" fig, axes = plt.subplots(1, 2, figsize=(15, 7))\n",
|
||||
" axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
|
||||
" axes[0].set_title('Original')\n",
|
||||
" axes[0].axis('off')\n",
|
||||
" \n",
|
||||
" axes[1].imshow(vis_rgb)\n",
|
||||
" axes[1].set_title(f'SAM2 Annotations ({len(masks)} objects)')\n",
|
||||
" axes[1].axis('off')\n",
|
||||
" \n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
" \n",
|
||||
" # Show bbox details\n",
|
||||
" print(\"\\nDetected objects:\")\n",
|
||||
" for i, m in enumerate(masks[:10]):\n",
|
||||
" print(f\" {i}: bbox={m['bbox']}, area={m['area']}, iou={m['predicted_iou']:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Process all frames\n",
|
||||
"def annotate_all_frames(frames_dir, output_dir, min_area=100, max_area=None,\n",
|
||||
" save_masks=True, save_vis=True):\n",
|
||||
" \"\"\"Annotate all frames and save results.\"\"\"\n",
|
||||
" frames_path = Path(frames_dir)\n",
|
||||
" output_path = Path(output_dir)\n",
|
||||
" output_path.mkdir(parents=True, exist_ok=True)\n",
|
||||
" \n",
|
||||
" if save_masks:\n",
|
||||
" masks_dir = output_path / 'masks'\n",
|
||||
" masks_dir.mkdir(exist_ok=True)\n",
|
||||
" \n",
|
||||
" if save_vis:\n",
|
||||
" vis_dir = output_path / 'visualizations'\n",
|
||||
" vis_dir.mkdir(exist_ok=True)\n",
|
||||
" \n",
|
||||
" frame_files = sorted(frames_path.glob(\"*.jpg\"))\n",
|
||||
" print(f\"Processing {len(frame_files)} frames...\")\n",
|
||||
" \n",
|
||||
" all_annotations = {}\n",
|
||||
" total_objects = 0\n",
|
||||
" \n",
|
||||
" for frame_file in tqdm(frame_files, desc=\"Annotating\"):\n",
|
||||
" masks, image = annotate_single_image(\n",
|
||||
" str(frame_file),\n",
|
||||
" min_area=min_area,\n",
|
||||
" max_area=max_area\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Convert to serializable format\n",
|
||||
" annotations = []\n",
|
||||
" for i, m in enumerate(masks):\n",
|
||||
" ann = {\n",
|
||||
" 'id': i,\n",
|
||||
" 'bbox': [int(x) for x in m['bbox']], # [x, y, w, h]\n",
|
||||
" 'area': int(m['area']),\n",
|
||||
" 'predicted_iou': float(m['predicted_iou']),\n",
|
||||
" 'stability_score': float(m['stability_score'])\n",
|
||||
" }\n",
|
||||
" annotations.append(ann)\n",
|
||||
" \n",
|
||||
" # Save individual mask\n",
|
||||
" if save_masks:\n",
|
||||
" frame_masks_dir = masks_dir / frame_file.stem\n",
|
||||
" frame_masks_dir.mkdir(exist_ok=True)\n",
|
||||
" mask_path = frame_masks_dir / f\"mask_{i:03d}.png\"\n",
|
||||
" cv2.imwrite(\n",
|
||||
" str(mask_path),\n",
|
||||
" m['segmentation'].astype(np.uint8) * 255\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" all_annotations[frame_file.name] = annotations\n",
|
||||
" total_objects += len(annotations)\n",
|
||||
" \n",
|
||||
" # Save visualization\n",
|
||||
" if save_vis:\n",
|
||||
" vis = visualize_masks(image, masks)\n",
|
||||
" cv2.imwrite(str(vis_dir / frame_file.name), vis)\n",
|
||||
" \n",
|
||||
" # Save annotations JSON\n",
|
||||
" annotations_file = output_path / 'annotations.json'\n",
|
||||
" with open(annotations_file, 'w') as f:\n",
|
||||
" json.dump(all_annotations, f, indent=2)\n",
|
||||
" \n",
|
||||
" print(f\"\\nAnnotation complete!\")\n",
|
||||
" print(f\" Frames: {len(frame_files)}\")\n",
|
||||
" print(f\" Total objects: {total_objects}\")\n",
|
||||
" print(f\" Avg objects/frame: {total_objects/len(frame_files):.1f}\")\n",
|
||||
" print(f\" Saved to: {output_path}\")\n",
|
||||
" \n",
|
||||
" return all_annotations\n",
|
||||
"\n",
|
||||
"# Run annotation\n",
|
||||
"annotations = annotate_all_frames(\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" OUTPUT_DIR,\n",
|
||||
" min_area=MIN_MASK_AREA,\n",
|
||||
" max_area=MAX_MASK_AREA,\n",
|
||||
" save_masks=SAVE_MASKS,\n",
|
||||
" save_vis=SAVE_VIS\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Review Annotations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preview some annotated frames\n",
|
||||
"vis_dir = Path(OUTPUT_DIR) / 'visualizations'\n",
|
||||
"vis_files = sorted(vis_dir.glob(\"*.jpg\"))[:6]\n",
|
||||
"\n",
|
||||
"if vis_files:\n",
|
||||
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
|
||||
" for ax, vis_file in zip(axes.flat, vis_files):\n",
|
||||
" img = cv2.imread(str(vis_file))\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" ax.imshow(img)\n",
|
||||
" \n",
|
||||
" # Get object count\n",
|
||||
" frame_name = vis_file.name\n",
|
||||
" if frame_name in annotations:\n",
|
||||
" count = len(annotations[frame_name])\n",
|
||||
" else:\n",
|
||||
" count = 0\n",
|
||||
" ax.set_title(f\"{vis_file.name} ({count} objects)\")\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Statistics\n",
|
||||
"object_counts = [len(ann) for ann in annotations.values()]\n",
|
||||
"\n",
|
||||
"print(\"Annotation Statistics:\")\n",
|
||||
"print(f\" Total frames: {len(annotations)}\")\n",
|
||||
"print(f\" Total objects: {sum(object_counts)}\")\n",
|
||||
"print(f\" Min objects/frame: {min(object_counts)}\")\n",
|
||||
"print(f\" Max objects/frame: {max(object_counts)}\")\n",
|
||||
"print(f\" Avg objects/frame: {np.mean(object_counts):.1f}\")\n",
|
||||
"\n",
|
||||
"# Histogram\n",
|
||||
"plt.figure(figsize=(10, 4))\n",
|
||||
"plt.hist(object_counts, bins=30, edgecolor='black')\n",
|
||||
"plt.xlabel('Objects per frame')\n",
|
||||
"plt.ylabel('Frequency')\n",
|
||||
"plt.title('Distribution of Objects per Frame')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Save Output for Next Notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Summary of output files\n",
|
||||
"print(\"Output files for next notebook:\")\n",
|
||||
"print(f\" Frames: {FRAMES_DIR}/\")\n",
|
||||
"print(f\" Annotations: {OUTPUT_DIR}/annotations.json\")\n",
|
||||
"print(f\" Masks: {OUTPUT_DIR}/masks/\")\n",
|
||||
"print(f\" Visualizations: {OUTPUT_DIR}/visualizations/\")\n",
|
||||
"\n",
|
||||
"# Check sizes\n",
|
||||
"import subprocess\n",
|
||||
"!du -sh {FRAMES_DIR} {OUTPUT_DIR}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Optional: Create archive for download\n",
|
||||
"ARCHIVE_NAME = 'sam2_annotations.zip'\n",
|
||||
"\n",
|
||||
"!zip -r {ARCHIVE_NAME} {FRAMES_DIR} {OUTPUT_DIR}/annotations.json {OUTPUT_DIR}/masks/\n",
|
||||
"print(f\"\\nArchive created: {ARCHIVE_NAME}\")\n",
|
||||
"!ls -lh {ARCHIVE_NAME}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## Next Steps\n",
|
||||
"\n",
|
||||
"1. **Run `02_create_yolo_dataset.ipynb`** to convert annotations to YOLO format\n",
|
||||
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t on the dataset\n",
|
||||
"\n",
|
||||
"### Tips\n",
|
||||
"\n",
|
||||
"- Adjust `MIN_MASK_AREA` to filter small detections\n",
|
||||
"- Use `MAX_MASK_AREA` to filter very large objects (e.g., background)\n",
|
||||
"- Lower `SAMPLE_FPS` for faster processing or if frames are similar\n",
|
||||
"- Review visualizations to verify annotation quality"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user