add sam2 yolo auto annotation
This commit is contained in:
@@ -0,0 +1,687 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SAM2 Video Annotation Pipeline\n",
|
||||
"\n",
|
||||
"This notebook sets up SAM2 (Segment Anything Model 2) for automatic object annotation from video.\n",
|
||||
"\n",
|
||||
"## Features\n",
|
||||
"- Install SAM2 and dependencies on Kaggle\n",
|
||||
"- Download pretrained model weights\n",
|
||||
"- Extract frames from video\n",
|
||||
"- Auto-generate masks for all objects\n",
|
||||
"- Save annotations for YOLO conversion\n",
|
||||
"\n",
|
||||
"**Platform:** Kaggle GPU (P100/T4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup Environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check GPU availability\n",
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install SAM2 and dependencies\n",
|
||||
"!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git\n",
|
||||
"!pip install -q opencv-python-headless supervision tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import cv2\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"from IPython.display import display, Image as IPImage\n",
|
||||
"\n",
|
||||
"print(f\"Python: {sys.version}\")\n",
|
||||
"print(f\"PyTorch: {torch.__version__}\")\n",
|
||||
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Download SAM2 Pretrained Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration\n",
|
||||
"MODEL_SIZE = 'large' # Options: 'tiny', 'small', 'base_plus', 'large'\n",
|
||||
"CHECKPOINT_DIR = './checkpoints'\n",
|
||||
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"\n",
|
||||
"# Model configurations\n",
|
||||
"MODEL_CONFIGS = {\n",
|
||||
" 'tiny': {\n",
|
||||
" 'config': 'sam2_hiera_t.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_tiny.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'\n",
|
||||
" },\n",
|
||||
" 'small': {\n",
|
||||
" 'config': 'sam2_hiera_s.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_small.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'\n",
|
||||
" },\n",
|
||||
" 'base_plus': {\n",
|
||||
" 'config': 'sam2_hiera_b+.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_base_plus.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'\n",
|
||||
" },\n",
|
||||
" 'large': {\n",
|
||||
" 'config': 'sam2_hiera_l.yaml',\n",
|
||||
" 'checkpoint': 'sam2_hiera_large.pt',\n",
|
||||
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(f\"Using model: SAM2 {MODEL_SIZE}\")\n",
|
||||
"print(f\"Device: {DEVICE}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Download checkpoint\n",
|
||||
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
|
||||
"\n",
|
||||
"config = MODEL_CONFIGS[MODEL_SIZE]\n",
|
||||
"checkpoint_path = os.path.join(CHECKPOINT_DIR, config['checkpoint'])\n",
|
||||
"\n",
|
||||
"if not os.path.exists(checkpoint_path):\n",
|
||||
" print(f\"Downloading {config['checkpoint']}...\")\n",
|
||||
" !wget -q -O {checkpoint_path} {config['url']}\n",
|
||||
" print(\"Download complete!\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Checkpoint exists: {checkpoint_path}\")\n",
|
||||
"\n",
|
||||
"print(f\"Checkpoint size: {os.path.getsize(checkpoint_path) / 1024 / 1024:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Load SAM2 Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sam2.build_sam import build_sam2\n",
|
||||
"from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
|
||||
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
|
||||
"\n",
|
||||
"print(\"Loading SAM2 model...\")\n",
|
||||
"\n",
|
||||
"# Build model\n",
|
||||
"sam2_model = build_sam2(\n",
|
||||
" config['config'],\n",
|
||||
" checkpoint_path,\n",
|
||||
" device=DEVICE\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create automatic mask generator\n",
|
||||
"mask_generator = SAM2AutomaticMaskGenerator(\n",
|
||||
" model=sam2_model,\n",
|
||||
" points_per_side=32, # Grid density for point prompts\n",
|
||||
" points_per_batch=64, # Batch size for processing\n",
|
||||
" pred_iou_thresh=0.7, # IoU threshold for predictions\n",
|
||||
" stability_score_thresh=0.92, # Stability threshold\n",
|
||||
" stability_score_offset=1.0,\n",
|
||||
" box_nms_thresh=0.7, # NMS threshold for boxes\n",
|
||||
" crop_n_layers=1, # Crop layers for multi-scale\n",
|
||||
" crop_nms_thresh=0.7,\n",
|
||||
" crop_overlap_ratio=0.34,\n",
|
||||
" crop_n_points_downscale_factor=2,\n",
|
||||
" min_mask_region_area=100 # Minimum mask area in pixels\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create predictor for interactive mode\n",
|
||||
"predictor = SAM2ImagePredictor(sam2_model)\n",
|
||||
"\n",
|
||||
"print(\"Model loaded successfully!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Video Frame Extraction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration - UPDATE THESE PATHS\n",
|
||||
"VIDEO_PATH = '/kaggle/input/your-video-dataset/video.mp4' # Change to your video path\n",
|
||||
"FRAMES_DIR = './frames'\n",
|
||||
"OUTPUT_DIR = './annotations'\n",
|
||||
"\n",
|
||||
"# Frame extraction settings\n",
|
||||
"SAMPLE_FPS = 2 # Extract 2 frames per second (adjust based on video)\n",
|
||||
"MAX_FRAMES = 500 # Maximum frames to extract (None for all)\n",
|
||||
"START_TIME = 0 # Start time in seconds\n",
|
||||
"END_TIME = None # End time in seconds (None for entire video)\n",
|
||||
"RESIZE = None # Resize frames: (width, height) or None\n",
|
||||
"\n",
|
||||
"os.makedirs(FRAMES_DIR, exist_ok=True)\n",
|
||||
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_video_info(video_path):\n",
|
||||
" \"\"\"Get video information.\"\"\"\n",
|
||||
" cap = cv2.VideoCapture(video_path)\n",
|
||||
" if not cap.isOpened():\n",
|
||||
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
||||
" \n",
|
||||
" info = {\n",
|
||||
" 'fps': cap.get(cv2.CAP_PROP_FPS),\n",
|
||||
" 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),\n",
|
||||
" 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n",
|
||||
" 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
|
||||
" }\n",
|
||||
" info['duration'] = info['frame_count'] / info['fps']\n",
|
||||
" cap.release()\n",
|
||||
" return info\n",
|
||||
"\n",
|
||||
"# Display video info\n",
|
||||
"if os.path.exists(VIDEO_PATH):\n",
|
||||
" video_info = get_video_info(VIDEO_PATH)\n",
|
||||
" print(\"Video Information:\")\n",
|
||||
" print(f\" Resolution: {video_info['width']}x{video_info['height']}\")\n",
|
||||
" print(f\" FPS: {video_info['fps']:.2f}\")\n",
|
||||
" print(f\" Duration: {video_info['duration']:.2f}s\")\n",
|
||||
" print(f\" Total frames: {video_info['frame_count']}\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Video not found: {VIDEO_PATH}\")\n",
|
||||
" print(\"Please upload your video or update VIDEO_PATH\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extract_frames(video_path, output_dir, sample_fps=None, max_frames=None, \n",
|
||||
" start_time=0, end_time=None, resize=None):\n",
|
||||
" \"\"\"Extract frames from video.\"\"\"\n",
|
||||
" cap = cv2.VideoCapture(video_path)\n",
|
||||
" if not cap.isOpened():\n",
|
||||
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
|
||||
" \n",
|
||||
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
|
||||
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
|
||||
" \n",
|
||||
" # Calculate frame interval\n",
|
||||
" if sample_fps and sample_fps < fps:\n",
|
||||
" frame_interval = int(fps / sample_fps)\n",
|
||||
" else:\n",
|
||||
" frame_interval = 1\n",
|
||||
" \n",
|
||||
" # Calculate frame range\n",
|
||||
" start_frame = int(start_time * fps)\n",
|
||||
" end_frame = int(end_time * fps) if end_time else frame_count\n",
|
||||
" end_frame = min(end_frame, frame_count)\n",
|
||||
" \n",
|
||||
" cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n",
|
||||
" \n",
|
||||
" saved_paths = []\n",
|
||||
" frame_idx = start_frame\n",
|
||||
" extracted = 0\n",
|
||||
" \n",
|
||||
" total_to_extract = min(\n",
|
||||
" (end_frame - start_frame) // frame_interval,\n",
|
||||
" max_frames or float('inf')\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" pbar = tqdm(total=int(total_to_extract), desc=\"Extracting frames\")\n",
|
||||
" \n",
|
||||
" while frame_idx < end_frame:\n",
|
||||
" if max_frames and extracted >= max_frames:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" ret, frame = cap.read()\n",
|
||||
" if not ret:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if (frame_idx - start_frame) % frame_interval == 0:\n",
|
||||
" if resize:\n",
|
||||
" frame = cv2.resize(frame, resize)\n",
|
||||
" \n",
|
||||
" frame_name = f\"frame_{frame_idx:06d}.jpg\"\n",
|
||||
" frame_path = os.path.join(output_dir, frame_name)\n",
|
||||
" cv2.imwrite(frame_path, frame)\n",
|
||||
" saved_paths.append(frame_path)\n",
|
||||
" extracted += 1\n",
|
||||
" pbar.update(1)\n",
|
||||
" \n",
|
||||
" frame_idx += 1\n",
|
||||
" \n",
|
||||
" pbar.close()\n",
|
||||
" cap.release()\n",
|
||||
" \n",
|
||||
" print(f\"Extracted {len(saved_paths)} frames to {output_dir}\")\n",
|
||||
" return saved_paths\n",
|
||||
"\n",
|
||||
"# Extract frames\n",
|
||||
"if os.path.exists(VIDEO_PATH):\n",
|
||||
" frame_paths = extract_frames(\n",
|
||||
" VIDEO_PATH,\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" sample_fps=SAMPLE_FPS,\n",
|
||||
" max_frames=MAX_FRAMES,\n",
|
||||
" start_time=START_TIME,\n",
|
||||
" end_time=END_TIME,\n",
|
||||
" resize=RESIZE\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preview extracted frames\n",
|
||||
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))[:6]\n",
|
||||
"\n",
|
||||
"if frame_files:\n",
|
||||
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
|
||||
" for ax, frame_file in zip(axes.flat, frame_files):\n",
|
||||
" img = cv2.imread(str(frame_file))\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" ax.imshow(img)\n",
|
||||
" ax.set_title(frame_file.name)\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
"else:\n",
|
||||
" print(\"No frames found. Please extract frames first.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Automatic Mask Generation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Annotation settings\n",
|
||||
"MIN_MASK_AREA = 500 # Minimum mask area in pixels\n",
|
||||
"MAX_MASK_AREA = None # Maximum mask area (None for no limit)\n",
|
||||
"SAVE_MASKS = True # Save individual mask images\n",
|
||||
"SAVE_VIS = True # Save visualization images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def visualize_masks(image, masks, alpha=0.5):\n",
|
||||
" \"\"\"Create visualization with masks overlaid.\"\"\"\n",
|
||||
" vis = image.copy()\n",
|
||||
" \n",
|
||||
" for mask_data in masks:\n",
|
||||
" mask = mask_data['segmentation']\n",
|
||||
" color = np.random.randint(0, 255, 3).tolist()\n",
|
||||
" \n",
|
||||
" overlay = vis.copy()\n",
|
||||
" overlay[mask] = color\n",
|
||||
" vis = cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)\n",
|
||||
" \n",
|
||||
" # Draw bbox\n",
|
||||
" x, y, w, h = mask_data['bbox']\n",
|
||||
" cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)\n",
|
||||
" \n",
|
||||
" return vis\n",
|
||||
"\n",
|
||||
"def annotate_single_image(image_path, min_area=100, max_area=None):\n",
|
||||
" \"\"\"Generate masks for a single image.\"\"\"\n",
|
||||
" image = cv2.imread(image_path)\n",
|
||||
" image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
" \n",
|
||||
" # Generate masks\n",
|
||||
" masks = mask_generator.generate(image_rgb)\n",
|
||||
" \n",
|
||||
" # Filter by area\n",
|
||||
" filtered = []\n",
|
||||
" for m in masks:\n",
|
||||
" area = m['area']\n",
|
||||
" if area >= min_area:\n",
|
||||
" if max_area is None or area <= max_area:\n",
|
||||
" filtered.append(m)\n",
|
||||
" \n",
|
||||
" # Sort by area (largest first)\n",
|
||||
" filtered.sort(key=lambda x: x['area'], reverse=True)\n",
|
||||
" \n",
|
||||
" return filtered, image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test on a single frame\n",
|
||||
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))\n",
|
||||
"\n",
|
||||
"if frame_files:\n",
|
||||
" test_frame = str(frame_files[0])\n",
|
||||
" print(f\"Testing on: {test_frame}\")\n",
|
||||
" \n",
|
||||
" masks, image = annotate_single_image(\n",
|
||||
" test_frame,\n",
|
||||
" min_area=MIN_MASK_AREA,\n",
|
||||
" max_area=MAX_MASK_AREA\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(f\"Found {len(masks)} objects\")\n",
|
||||
" \n",
|
||||
" # Visualize\n",
|
||||
" vis = visualize_masks(image, masks)\n",
|
||||
" vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)\n",
|
||||
" \n",
|
||||
" fig, axes = plt.subplots(1, 2, figsize=(15, 7))\n",
|
||||
" axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
|
||||
" axes[0].set_title('Original')\n",
|
||||
" axes[0].axis('off')\n",
|
||||
" \n",
|
||||
" axes[1].imshow(vis_rgb)\n",
|
||||
" axes[1].set_title(f'SAM2 Annotations ({len(masks)} objects)')\n",
|
||||
" axes[1].axis('off')\n",
|
||||
" \n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()\n",
|
||||
" \n",
|
||||
" # Show bbox details\n",
|
||||
" print(\"\\nDetected objects:\")\n",
|
||||
" for i, m in enumerate(masks[:10]):\n",
|
||||
" print(f\" {i}: bbox={m['bbox']}, area={m['area']}, iou={m['predicted_iou']:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Process all frames\n",
|
||||
"def annotate_all_frames(frames_dir, output_dir, min_area=100, max_area=None,\n",
|
||||
" save_masks=True, save_vis=True):\n",
|
||||
" \"\"\"Annotate all frames and save results.\"\"\"\n",
|
||||
" frames_path = Path(frames_dir)\n",
|
||||
" output_path = Path(output_dir)\n",
|
||||
" output_path.mkdir(parents=True, exist_ok=True)\n",
|
||||
" \n",
|
||||
" if save_masks:\n",
|
||||
" masks_dir = output_path / 'masks'\n",
|
||||
" masks_dir.mkdir(exist_ok=True)\n",
|
||||
" \n",
|
||||
" if save_vis:\n",
|
||||
" vis_dir = output_path / 'visualizations'\n",
|
||||
" vis_dir.mkdir(exist_ok=True)\n",
|
||||
" \n",
|
||||
" frame_files = sorted(frames_path.glob(\"*.jpg\"))\n",
|
||||
" print(f\"Processing {len(frame_files)} frames...\")\n",
|
||||
" \n",
|
||||
" all_annotations = {}\n",
|
||||
" total_objects = 0\n",
|
||||
" \n",
|
||||
" for frame_file in tqdm(frame_files, desc=\"Annotating\"):\n",
|
||||
" masks, image = annotate_single_image(\n",
|
||||
" str(frame_file),\n",
|
||||
" min_area=min_area,\n",
|
||||
" max_area=max_area\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Convert to serializable format\n",
|
||||
" annotations = []\n",
|
||||
" for i, m in enumerate(masks):\n",
|
||||
" ann = {\n",
|
||||
" 'id': i,\n",
|
||||
" 'bbox': [int(x) for x in m['bbox']], # [x, y, w, h]\n",
|
||||
" 'area': int(m['area']),\n",
|
||||
" 'predicted_iou': float(m['predicted_iou']),\n",
|
||||
" 'stability_score': float(m['stability_score'])\n",
|
||||
" }\n",
|
||||
" annotations.append(ann)\n",
|
||||
" \n",
|
||||
" # Save individual mask\n",
|
||||
" if save_masks:\n",
|
||||
" frame_masks_dir = masks_dir / frame_file.stem\n",
|
||||
" frame_masks_dir.mkdir(exist_ok=True)\n",
|
||||
" mask_path = frame_masks_dir / f\"mask_{i:03d}.png\"\n",
|
||||
" cv2.imwrite(\n",
|
||||
" str(mask_path),\n",
|
||||
" m['segmentation'].astype(np.uint8) * 255\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" all_annotations[frame_file.name] = annotations\n",
|
||||
" total_objects += len(annotations)\n",
|
||||
" \n",
|
||||
" # Save visualization\n",
|
||||
" if save_vis:\n",
|
||||
" vis = visualize_masks(image, masks)\n",
|
||||
" cv2.imwrite(str(vis_dir / frame_file.name), vis)\n",
|
||||
" \n",
|
||||
" # Save annotations JSON\n",
|
||||
" annotations_file = output_path / 'annotations.json'\n",
|
||||
" with open(annotations_file, 'w') as f:\n",
|
||||
" json.dump(all_annotations, f, indent=2)\n",
|
||||
" \n",
|
||||
" print(f\"\\nAnnotation complete!\")\n",
|
||||
" print(f\" Frames: {len(frame_files)}\")\n",
|
||||
" print(f\" Total objects: {total_objects}\")\n",
|
||||
" print(f\" Avg objects/frame: {total_objects/len(frame_files):.1f}\")\n",
|
||||
" print(f\" Saved to: {output_path}\")\n",
|
||||
" \n",
|
||||
" return all_annotations\n",
|
||||
"\n",
|
||||
"# Run annotation\n",
|
||||
"annotations = annotate_all_frames(\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" OUTPUT_DIR,\n",
|
||||
" min_area=MIN_MASK_AREA,\n",
|
||||
" max_area=MAX_MASK_AREA,\n",
|
||||
" save_masks=SAVE_MASKS,\n",
|
||||
" save_vis=SAVE_VIS\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Review Annotations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preview some annotated frames\n",
|
||||
"vis_dir = Path(OUTPUT_DIR) / 'visualizations'\n",
|
||||
"vis_files = sorted(vis_dir.glob(\"*.jpg\"))[:6]\n",
|
||||
"\n",
|
||||
"if vis_files:\n",
|
||||
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
|
||||
" for ax, vis_file in zip(axes.flat, vis_files):\n",
|
||||
" img = cv2.imread(str(vis_file))\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" ax.imshow(img)\n",
|
||||
" \n",
|
||||
" # Get object count\n",
|
||||
" frame_name = vis_file.name\n",
|
||||
" if frame_name in annotations:\n",
|
||||
" count = len(annotations[frame_name])\n",
|
||||
" else:\n",
|
||||
" count = 0\n",
|
||||
" ax.set_title(f\"{vis_file.name} ({count} objects)\")\n",
|
||||
" ax.axis('off')\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Statistics\n",
|
||||
"object_counts = [len(ann) for ann in annotations.values()]\n",
|
||||
"\n",
|
||||
"print(\"Annotation Statistics:\")\n",
|
||||
"print(f\" Total frames: {len(annotations)}\")\n",
|
||||
"print(f\" Total objects: {sum(object_counts)}\")\n",
|
||||
"print(f\" Min objects/frame: {min(object_counts)}\")\n",
|
||||
"print(f\" Max objects/frame: {max(object_counts)}\")\n",
|
||||
"print(f\" Avg objects/frame: {np.mean(object_counts):.1f}\")\n",
|
||||
"\n",
|
||||
"# Histogram\n",
|
||||
"plt.figure(figsize=(10, 4))\n",
|
||||
"plt.hist(object_counts, bins=30, edgecolor='black')\n",
|
||||
"plt.xlabel('Objects per frame')\n",
|
||||
"plt.ylabel('Frequency')\n",
|
||||
"plt.title('Distribution of Objects per Frame')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Save Output for Next Notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Summary of output files\n",
|
||||
"print(\"Output files for next notebook:\")\n",
|
||||
"print(f\" Frames: {FRAMES_DIR}/\")\n",
|
||||
"print(f\" Annotations: {OUTPUT_DIR}/annotations.json\")\n",
|
||||
"print(f\" Masks: {OUTPUT_DIR}/masks/\")\n",
|
||||
"print(f\" Visualizations: {OUTPUT_DIR}/visualizations/\")\n",
|
||||
"\n",
|
||||
"# Check sizes\n",
|
||||
"import subprocess\n",
|
||||
"!du -sh {FRAMES_DIR} {OUTPUT_DIR}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Optional: Create archive for download\n",
|
||||
"ARCHIVE_NAME = 'sam2_annotations.zip'\n",
|
||||
"\n",
|
||||
"!zip -r {ARCHIVE_NAME} {FRAMES_DIR} {OUTPUT_DIR}/annotations.json {OUTPUT_DIR}/masks/\n",
|
||||
"print(f\"\\nArchive created: {ARCHIVE_NAME}\")\n",
|
||||
"!ls -lh {ARCHIVE_NAME}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## Next Steps\n",
|
||||
"\n",
|
||||
"1. **Run `02_create_yolo_dataset.ipynb`** to convert annotations to YOLO format\n",
|
||||
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t on the dataset\n",
|
||||
"\n",
|
||||
"### Tips\n",
|
||||
"\n",
|
||||
"- Adjust `MIN_MASK_AREA` to filter small detections\n",
|
||||
"- Use `MAX_MASK_AREA` to filter very large objects (e.g., background)\n",
|
||||
"- Lower `SAMPLE_FPS` for faster processing or if frames are similar\n",
|
||||
"- Review visualizations to verify annotation quality"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,726 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Create YOLO Dataset from SAM2 Annotations\n",
|
||||
"\n",
|
||||
"Convert SAM2 mask annotations to YOLO detection format (bounding boxes).\n",
|
||||
"\n",
|
||||
"## Input\n",
|
||||
"- Frames from `01_sam2_video_annotation.ipynb`\n",
|
||||
"- Annotations JSON file\n",
|
||||
"\n",
|
||||
"## Output\n",
|
||||
"- YOLO format dataset ready for training\n",
|
||||
"- `data.yaml` configuration file\n",
|
||||
"\n",
|
||||
"**Platform:** Kaggle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import cv2\n",
|
||||
"import json\n",
|
||||
"import yaml\n",
|
||||
"import shutil\n",
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"print(\"Setup complete!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration - UPDATE THESE PATHS\n",
|
||||
"FRAMES_DIR = './frames' # From notebook 01\n",
|
||||
"ANNOTATIONS_FILE = './annotations/annotations.json' # From notebook 01\n",
|
||||
"MASKS_DIR = './annotations/masks' # Optional: mask images\n",
|
||||
"\n",
|
||||
"# Output dataset\n",
|
||||
"DATASET_DIR = './yolo_dataset'\n",
|
||||
"\n",
|
||||
"# Dataset settings\n",
|
||||
"CLASS_NAMES = ['object'] # Single class for generic objects\n",
|
||||
"VAL_SPLIT = 0.2 # 20% validation\n",
|
||||
"SEED = 42 # Random seed for reproducibility\n",
|
||||
"\n",
|
||||
"# Filtering\n",
|
||||
"MIN_BBOX_AREA = 100 # Minimum bbox area in pixels\n",
|
||||
"MIN_BBOX_SIZE = 0.01 # Minimum bbox dimension (normalized, 0-1)\n",
|
||||
"MAX_OBJECTS_PER_IMAGE = 100 # Maximum objects per image\n",
|
||||
"MIN_IOU_SCORE = 0.5 # Minimum SAM2 IoU score\n",
|
||||
"\n",
|
||||
"random.seed(SEED)\n",
|
||||
"np.random.seed(SEED)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Load Annotations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load annotations\n",
|
||||
"with open(ANNOTATIONS_FILE, 'r') as f:\n",
|
||||
" annotations = json.load(f)\n",
|
||||
"\n",
|
||||
"print(f\"Loaded annotations for {len(annotations)} frames\")\n",
|
||||
"\n",
|
||||
"# Show sample\n",
|
||||
"sample_frame = list(annotations.keys())[0]\n",
|
||||
"print(f\"\\nSample annotation ({sample_frame}):\")\n",
|
||||
"print(json.dumps(annotations[sample_frame][:2], indent=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Verify frames exist\n",
|
||||
"frames_path = Path(FRAMES_DIR)\n",
|
||||
"frame_files = list(frames_path.glob(\"*.jpg\")) + list(frames_path.glob(\"*.png\"))\n",
|
||||
"\n",
|
||||
"print(f\"Found {len(frame_files)} frame images\")\n",
|
||||
"\n",
|
||||
"# Check matching\n",
|
||||
"annotation_frames = set(annotations.keys())\n",
|
||||
"image_frames = {f.name for f in frame_files}\n",
|
||||
"\n",
|
||||
"matched = annotation_frames & image_frames\n",
|
||||
"print(f\"Matched frames: {len(matched)}\")\n",
|
||||
"\n",
|
||||
"if len(matched) < len(annotation_frames):\n",
|
||||
" missing = annotation_frames - image_frames\n",
|
||||
" print(f\"Warning: {len(missing)} annotated frames missing images\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Create YOLO Dataset Structure"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create directory structure\n",
|
||||
"dataset_path = Path(DATASET_DIR)\n",
|
||||
"\n",
|
||||
"images_train = dataset_path / 'images' / 'train'\n",
|
||||
"images_val = dataset_path / 'images' / 'val'\n",
|
||||
"labels_train = dataset_path / 'labels' / 'train'\n",
|
||||
"labels_val = dataset_path / 'labels' / 'val'\n",
|
||||
"\n",
|
||||
"for dir_path in [images_train, images_val, labels_train, labels_val]:\n",
|
||||
" dir_path.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print(f\"Created: {dir_path}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Convert Annotations to YOLO Format"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def bbox_xywh_to_yolo(bbox, image_width, image_height):\n",
|
||||
" \"\"\"\n",
|
||||
" Convert [x, y, w, h] bbox to YOLO format [x_center, y_center, width, height] normalized.\n",
|
||||
" \"\"\"\n",
|
||||
" x, y, w, h = bbox\n",
|
||||
" \n",
|
||||
" x_center = (x + w / 2) / image_width\n",
|
||||
" y_center = (y + h / 2) / image_height\n",
|
||||
" width = w / image_width\n",
|
||||
" height = h / image_height\n",
|
||||
" \n",
|
||||
" # Clamp to [0, 1]\n",
|
||||
" x_center = max(0, min(1, x_center))\n",
|
||||
" y_center = max(0, min(1, y_center))\n",
|
||||
" width = max(0, min(1, width))\n",
|
||||
" height = max(0, min(1, height))\n",
|
||||
" \n",
|
||||
" return x_center, y_center, width, height\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def filter_annotations(anns, img_width, img_height, \n",
|
||||
" min_area=100, min_size=0.01, \n",
|
||||
" min_iou=0.5, max_objects=100):\n",
|
||||
" \"\"\"\n",
|
||||
" Filter annotations based on criteria.\n",
|
||||
" \"\"\"\n",
|
||||
" filtered = []\n",
|
||||
" \n",
|
||||
" for ann in anns:\n",
|
||||
" bbox = ann.get('bbox', [])\n",
|
||||
" area = ann.get('area', 0)\n",
|
||||
" iou = ann.get('predicted_iou', 1.0)\n",
|
||||
" \n",
|
||||
" # Check area\n",
|
||||
" if area < min_area:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" # Check IoU score\n",
|
||||
" if iou < min_iou:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" # Check bbox dimensions\n",
|
||||
" if len(bbox) == 4:\n",
|
||||
" _, _, w, h = bbox\n",
|
||||
" if w / img_width < min_size or h / img_height < min_size:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" filtered.append(ann)\n",
|
||||
" \n",
|
||||
" # Limit number of objects (keep highest IoU)\n",
|
||||
" if len(filtered) > max_objects:\n",
|
||||
" filtered.sort(key=lambda x: x.get('predicted_iou', 0), reverse=True)\n",
|
||||
" filtered = filtered[:max_objects]\n",
|
||||
" \n",
|
||||
" return filtered"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_frame(frame_name, frame_anns, frames_dir, images_dir, labels_dir, class_id=0):\n",
|
||||
" \"\"\"\n",
|
||||
" Process a single frame: copy image and create YOLO label file.\n",
|
||||
" \"\"\"\n",
|
||||
" frame_path = Path(frames_dir) / frame_name\n",
|
||||
" \n",
|
||||
" if not frame_path.exists():\n",
|
||||
" return 0\n",
|
||||
" \n",
|
||||
" # Read image dimensions\n",
|
||||
" image = cv2.imread(str(frame_path))\n",
|
||||
" if image is None:\n",
|
||||
" return 0\n",
|
||||
" \n",
|
||||
" height, width = image.shape[:2]\n",
|
||||
" \n",
|
||||
" # Filter annotations\n",
|
||||
" filtered_anns = filter_annotations(\n",
|
||||
" frame_anns, width, height,\n",
|
||||
" min_area=MIN_BBOX_AREA,\n",
|
||||
" min_size=MIN_BBOX_SIZE,\n",
|
||||
" min_iou=MIN_IOU_SCORE,\n",
|
||||
" max_objects=MAX_OBJECTS_PER_IMAGE\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Copy image\n",
|
||||
" dest_image = images_dir / frame_name\n",
|
||||
" shutil.copy2(frame_path, dest_image)\n",
|
||||
" \n",
|
||||
" # Create YOLO labels\n",
|
||||
" labels = []\n",
|
||||
" for ann in filtered_anns:\n",
|
||||
" bbox = ann.get('bbox', [])\n",
|
||||
" if len(bbox) != 4:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" x_center, y_center, w, h = bbox_xywh_to_yolo(bbox, width, height)\n",
|
||||
" \n",
|
||||
" # YOLO format: class x_center y_center width height\n",
|
||||
" label_line = f\"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\"\n",
|
||||
" labels.append(label_line)\n",
|
||||
" \n",
|
||||
" # Write label file\n",
|
||||
" label_name = Path(frame_name).stem + '.txt'\n",
|
||||
" label_path = labels_dir / label_name\n",
|
||||
" \n",
|
||||
" with open(label_path, 'w') as f:\n",
|
||||
" f.write('\\n'.join(labels))\n",
|
||||
" \n",
|
||||
" return len(labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Split frames into train/val\n",
|
||||
"frame_names = list(annotations.keys())\n",
|
||||
"random.shuffle(frame_names)\n",
|
||||
"\n",
|
||||
"split_idx = int(len(frame_names) * (1 - VAL_SPLIT))\n",
|
||||
"train_frames = frame_names[:split_idx]\n",
|
||||
"val_frames = frame_names[split_idx:]\n",
|
||||
"\n",
|
||||
"print(f\"Train frames: {len(train_frames)}\")\n",
|
||||
"print(f\"Val frames: {len(val_frames)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Process training frames\n",
|
||||
"train_objects = 0\n",
|
||||
"for frame_name in tqdm(train_frames, desc=\"Processing train\"):\n",
|
||||
" count = process_frame(\n",
|
||||
" frame_name,\n",
|
||||
" annotations.get(frame_name, []),\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" images_train,\n",
|
||||
" labels_train,\n",
|
||||
" class_id=0\n",
|
||||
" )\n",
|
||||
" train_objects += count\n",
|
||||
"\n",
|
||||
"print(f\"\\nTrain: {len(train_frames)} images, {train_objects} objects\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Process validation frames\n",
|
||||
"val_objects = 0\n",
|
||||
"for frame_name in tqdm(val_frames, desc=\"Processing val\"):\n",
|
||||
" count = process_frame(\n",
|
||||
" frame_name,\n",
|
||||
" annotations.get(frame_name, []),\n",
|
||||
" FRAMES_DIR,\n",
|
||||
" images_val,\n",
|
||||
" labels_val,\n",
|
||||
" class_id=0\n",
|
||||
" )\n",
|
||||
" val_objects += count\n",
|
||||
"\n",
|
||||
"print(f\"\\nVal: {len(val_frames)} images, {val_objects} objects\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Create data.yaml"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create YOLO data.yaml configuration\n",
|
||||
"data_config = {\n",
|
||||
" 'path': str(Path(DATASET_DIR).absolute()),\n",
|
||||
" 'train': 'images/train',\n",
|
||||
" 'val': 'images/val',\n",
|
||||
" 'names': {i: name for i, name in enumerate(CLASS_NAMES)},\n",
|
||||
" 'nc': len(CLASS_NAMES)\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"yaml_path = dataset_path / 'data.yaml'\n",
|
||||
"with open(yaml_path, 'w') as f:\n",
|
||||
" yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)\n",
|
||||
"\n",
|
||||
"print(f\"Created: {yaml_path}\")\n",
|
||||
"print(\"\\nContents:\")\n",
|
||||
"with open(yaml_path) as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Validate Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def validate_dataset(dataset_dir):\n",
|
||||
" \"\"\"Validate YOLO dataset structure.\"\"\"\n",
|
||||
" dataset_path = Path(dataset_dir)\n",
|
||||
" results = {'valid': True, 'errors': [], 'warnings': [], 'stats': {}}\n",
|
||||
" \n",
|
||||
" # Check data.yaml\n",
|
||||
" yaml_path = dataset_path / 'data.yaml'\n",
|
||||
" if not yaml_path.exists():\n",
|
||||
" results['errors'].append(\"Missing data.yaml\")\n",
|
||||
" results['valid'] = False\n",
|
||||
" else:\n",
|
||||
" with open(yaml_path) as f:\n",
|
||||
" config = yaml.safe_load(f)\n",
|
||||
" results['stats']['num_classes'] = config.get('nc', 0)\n",
|
||||
" results['stats']['class_names'] = config.get('names', {})\n",
|
||||
" \n",
|
||||
" # Check directories and count files\n",
|
||||
" for split in ['train', 'val']:\n",
|
||||
" images_dir = dataset_path / 'images' / split\n",
|
||||
" labels_dir = dataset_path / 'labels' / split\n",
|
||||
" \n",
|
||||
" if not images_dir.exists():\n",
|
||||
" results['errors'].append(f\"Missing images/{split}\")\n",
|
||||
" results['valid'] = False\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" image_files = list(images_dir.glob(\"*.jpg\")) + list(images_dir.glob(\"*.png\"))\n",
|
||||
" label_files = list(labels_dir.glob(\"*.txt\"))\n",
|
||||
" \n",
|
||||
" results['stats'][f'{split}_images'] = len(image_files)\n",
|
||||
" results['stats'][f'{split}_labels'] = len(label_files)\n",
|
||||
" \n",
|
||||
" # Check for missing labels\n",
|
||||
" image_stems = {f.stem for f in image_files}\n",
|
||||
" label_stems = {f.stem for f in label_files}\n",
|
||||
" missing = image_stems - label_stems\n",
|
||||
" \n",
|
||||
" if missing:\n",
|
||||
" results['warnings'].append(f\"{len(missing)} {split} images missing labels\")\n",
|
||||
" \n",
|
||||
" # Count total objects\n",
|
||||
" total_objects = 0\n",
|
||||
" for label_file in label_files:\n",
|
||||
" with open(label_file) as f:\n",
|
||||
" lines = [l.strip() for l in f if l.strip()]\n",
|
||||
" total_objects += len(lines)\n",
|
||||
" results['stats'][f'{split}_objects'] = total_objects\n",
|
||||
" \n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"# Validate\n",
|
||||
"validation = validate_dataset(DATASET_DIR)\n",
|
||||
"\n",
|
||||
"print(\"Dataset Validation:\")\n",
|
||||
"print(f\" Valid: {validation['valid']}\")\n",
|
||||
"print(f\"\\nStatistics:\")\n",
|
||||
"for key, value in validation['stats'].items():\n",
|
||||
" print(f\" {key}: {value}\")\n",
|
||||
"\n",
|
||||
"if validation['errors']:\n",
|
||||
" print(f\"\\nErrors:\")\n",
|
||||
" for err in validation['errors']:\n",
|
||||
" print(f\" - {err}\")\n",
|
||||
"\n",
|
||||
"if validation['warnings']:\n",
|
||||
" print(f\"\\nWarnings:\")\n",
|
||||
" for warn in validation['warnings']:\n",
|
||||
" print(f\" - {warn}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Visualize Dataset Samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def visualize_yolo_sample(image_path, label_path, class_names):\n",
|
||||
" \"\"\"Visualize YOLO annotation on image.\"\"\"\n",
|
||||
" image = cv2.imread(str(image_path))\n",
|
||||
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
||||
" height, width = image.shape[:2]\n",
|
||||
" \n",
|
||||
" # Read labels\n",
|
||||
" if label_path.exists():\n",
|
||||
" with open(label_path) as f:\n",
|
||||
" labels = [l.strip().split() for l in f if l.strip()]\n",
|
||||
" else:\n",
|
||||
" labels = []\n",
|
||||
" \n",
|
||||
" # Draw bboxes\n",
|
||||
" colors = plt.cm.tab10(np.linspace(0, 1, 10))\n",
|
||||
" \n",
|
||||
" for label in labels:\n",
|
||||
" class_id = int(label[0])\n",
|
||||
" x_center, y_center, w, h = map(float, label[1:5])\n",
|
||||
" \n",
|
||||
" # Convert to pixel coordinates\n",
|
||||
" x1 = int((x_center - w/2) * width)\n",
|
||||
" y1 = int((y_center - h/2) * height)\n",
|
||||
" x2 = int((x_center + w/2) * width)\n",
|
||||
" y2 = int((y_center + h/2) * height)\n",
|
||||
" \n",
|
||||
" color = tuple(int(c * 255) for c in colors[class_id % 10][:3])\n",
|
||||
" cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)\n",
|
||||
" \n",
|
||||
" # Add label\n",
|
||||
" class_name = class_names.get(class_id, str(class_id))\n",
|
||||
" cv2.putText(image, class_name, (x1, y1-5), \n",
|
||||
" cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
|
||||
" \n",
|
||||
" return image, len(labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize train samples\n",
|
||||
"train_images = sorted(images_train.glob(\"*.jpg\"))[:6]\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
|
||||
"class_names_dict = {i: name for i, name in enumerate(CLASS_NAMES)}\n",
|
||||
"\n",
|
||||
"for ax, img_path in zip(axes.flat, train_images):\n",
|
||||
" label_path = labels_train / (img_path.stem + '.txt')\n",
|
||||
" vis, count = visualize_yolo_sample(img_path, label_path, class_names_dict)\n",
|
||||
" \n",
|
||||
" ax.imshow(vis)\n",
|
||||
" ax.set_title(f\"{img_path.name} ({count} objects)\")\n",
|
||||
" ax.axis('off')\n",
|
||||
"\n",
|
||||
"plt.suptitle('Training Samples with YOLO Annotations', fontsize=14)\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Object count distribution\n",
|
||||
"train_label_files = list(labels_train.glob(\"*.txt\"))\n",
|
||||
"val_label_files = list(labels_val.glob(\"*.txt\"))\n",
|
||||
"\n",
|
||||
"def count_objects_in_labels(label_files):\n",
|
||||
" counts = []\n",
|
||||
" for lf in label_files:\n",
|
||||
" with open(lf) as f:\n",
|
||||
" lines = [l.strip() for l in f if l.strip()]\n",
|
||||
" counts.append(len(lines))\n",
|
||||
" return counts\n",
|
||||
"\n",
|
||||
"train_counts = count_objects_in_labels(train_label_files)\n",
|
||||
"val_counts = count_objects_in_labels(val_label_files)\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
|
||||
"\n",
|
||||
"axes[0].hist(train_counts, bins=30, edgecolor='black', alpha=0.7, label='Train')\n",
|
||||
"axes[0].hist(val_counts, bins=30, edgecolor='black', alpha=0.7, label='Val')\n",
|
||||
"axes[0].set_xlabel('Objects per image')\n",
|
||||
"axes[0].set_ylabel('Frequency')\n",
|
||||
"axes[0].set_title('Objects per Image Distribution')\n",
|
||||
"axes[0].legend()\n",
|
||||
"\n",
|
||||
"# Bbox size distribution\n",
|
||||
"bbox_sizes = []\n",
|
||||
"for lf in train_label_files:\n",
|
||||
" with open(lf) as f:\n",
|
||||
" for line in f:\n",
|
||||
" parts = line.strip().split()\n",
|
||||
" if len(parts) >= 5:\n",
|
||||
" w, h = float(parts[3]), float(parts[4])\n",
|
||||
" bbox_sizes.append(w * h)\n",
|
||||
"\n",
|
||||
"axes[1].hist(bbox_sizes, bins=50, edgecolor='black', alpha=0.7)\n",
|
||||
"axes[1].set_xlabel('Bbox area (normalized)')\n",
|
||||
"axes[1].set_ylabel('Frequency')\n",
|
||||
"axes[1].set_title('Bounding Box Size Distribution')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"print(f\"\\nBbox size stats:\")\n",
|
||||
"print(f\" Min: {min(bbox_sizes):.4f}\")\n",
|
||||
"print(f\" Max: {max(bbox_sizes):.4f}\")\n",
|
||||
"print(f\" Mean: {np.mean(bbox_sizes):.4f}\")\n",
|
||||
"print(f\" Median: {np.median(bbox_sizes):.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Export for Kaggle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create zip archive for Kaggle dataset\n",
|
||||
"import zipfile\n",
|
||||
"\n",
|
||||
"EXPORT_ZIP = 'yolo_dataset.zip'\n",
|
||||
"\n",
|
||||
"print(f\"Creating {EXPORT_ZIP}...\")\n",
|
||||
"\n",
|
||||
"with zipfile.ZipFile(EXPORT_ZIP, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
|
||||
" for root, dirs, files in os.walk(DATASET_DIR):\n",
|
||||
" for file in files:\n",
|
||||
" file_path = os.path.join(root, file)\n",
|
||||
" arcname = os.path.relpath(file_path, os.path.dirname(DATASET_DIR))\n",
|
||||
" zipf.write(file_path, arcname)\n",
|
||||
"\n",
|
||||
"zip_size = os.path.getsize(EXPORT_ZIP) / 1024 / 1024\n",
|
||||
"print(f\"\\nExport complete!\")\n",
|
||||
"print(f\" File: {EXPORT_ZIP}\")\n",
|
||||
"print(f\" Size: {zip_size:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Alternative: Create Kaggle dataset directly (if kaggle CLI available)\n",
|
||||
"# Uncomment to use\n",
|
||||
"\n",
|
||||
"# KAGGLE_USERNAME = 'your-username'\n",
|
||||
"# DATASET_NAME = 'sam2-yolo-custom'\n",
|
||||
"\n",
|
||||
"# # Create dataset metadata\n",
|
||||
"# metadata = {\n",
|
||||
"# 'title': 'SAM2 Auto-Annotated YOLO Dataset',\n",
|
||||
"# 'id': f'{KAGGLE_USERNAME}/{DATASET_NAME}',\n",
|
||||
"# 'licenses': [{'name': 'CC0-1.0'}]\n",
|
||||
"# }\n",
|
||||
"\n",
|
||||
"# metadata_path = dataset_path / 'dataset-metadata.json'\n",
|
||||
"# with open(metadata_path, 'w') as f:\n",
|
||||
"# json.dump(metadata, f, indent=2)\n",
|
||||
"\n",
|
||||
"# # Upload to Kaggle\n",
|
||||
"# !kaggle datasets create -p {DATASET_DIR}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 9. Dataset Summary"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Final summary\n",
|
||||
"print(\"=\" * 50)\n",
|
||||
"print(\"YOLO DATASET SUMMARY\")\n",
|
||||
"print(\"=\" * 50)\n",
|
||||
"print(f\"\\nDataset location: {Path(DATASET_DIR).absolute()}\")\n",
|
||||
"print(f\"\\nClasses ({len(CLASS_NAMES)}):\")\n",
|
||||
"for i, name in enumerate(CLASS_NAMES):\n",
|
||||
" print(f\" {i}: {name}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nSplit:\")\n",
|
||||
"print(f\" Train: {validation['stats']['train_images']} images, {validation['stats']['train_objects']} objects\")\n",
|
||||
"print(f\" Val: {validation['stats']['val_images']} images, {validation['stats']['val_objects']} objects\")\n",
|
||||
"print(f\" Total: {validation['stats']['train_images'] + validation['stats']['val_images']} images\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nFiles:\")\n",
|
||||
"print(f\" data.yaml: {yaml_path}\")\n",
|
||||
"print(f\" Export: {EXPORT_ZIP}\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 50)\n",
|
||||
"print(\"Ready for YOLOv9t training!\")\n",
|
||||
"print(\"=\" * 50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## Next Steps\n",
|
||||
"\n",
|
||||
"1. **Upload dataset to Kaggle** (if not already done)\n",
|
||||
" - Go to kaggle.com/datasets/new\n",
|
||||
" - Upload `yolo_dataset.zip`\n",
|
||||
" \n",
|
||||
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t\n",
|
||||
"\n",
|
||||
"### Dataset Structure\n",
|
||||
"```\n",
|
||||
"yolo_dataset/\n",
|
||||
"├── data.yaml\n",
|
||||
"├── images/\n",
|
||||
"│ ├── train/\n",
|
||||
"│ └── val/\n",
|
||||
"└── labels/\n",
|
||||
" ├── train/\n",
|
||||
" └── val/\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,729 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train YOLOv9t on Custom Dataset\n",
|
||||
"\n",
|
||||
"Train YOLOv9t (tiny) model on YOLO format dataset created from SAM2 annotations.\n",
|
||||
"\n",
|
||||
"## Input\n",
|
||||
"- YOLO format dataset from `02_create_yolo_dataset.ipynb`\n",
|
||||
"\n",
|
||||
"## Output\n",
|
||||
"- Trained YOLOv9t model weights\n",
|
||||
"- Training metrics and visualizations\n",
|
||||
"\n",
|
||||
"**Platform:** Kaggle GPU (P100/T4) - Enable GPU in notebook settings!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup Environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check GPU\n",
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install YOLOv9 (ultralytics fork with v9 support)\n",
|
||||
"!pip install -q ultralytics\n",
|
||||
"\n",
|
||||
"# Alternative: Install official YOLOv9 repo\n",
|
||||
"# !git clone https://github.com/WongKinYiu/yolov9.git\n",
|
||||
"# %cd yolov9\n",
|
||||
"# !pip install -q -r requirements.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import yaml\n",
|
||||
"import shutil\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from pathlib import Path\n",
|
||||
"from datetime import datetime\n",
|
||||
"from IPython.display import Image, display\n",
|
||||
"\n",
|
||||
"print(f\"Python: {sys.version}\")\n",
|
||||
"print(f\"PyTorch: {torch.__version__}\")\n",
|
||||
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
||||
" print(f\"CUDA version: {torch.version.cuda}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ultralytics import YOLO\n",
|
||||
"import ultralytics\n",
|
||||
"print(f\"Ultralytics version: {ultralytics.__version__}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Dataset configuration - UPDATE THIS PATH\n",
|
||||
"# Option 1: From previous notebook (local)\n",
|
||||
"DATASET_PATH = './yolo_dataset'\n",
|
||||
"\n",
|
||||
"# Option 2: From Kaggle dataset (uncomment and update)\n",
|
||||
"# DATASET_PATH = '/kaggle/input/your-dataset-name/yolo_dataset'\n",
|
||||
"\n",
|
||||
"# Training configuration\n",
|
||||
"CONFIG = {\n",
|
||||
" # Model\n",
|
||||
" 'model': 'yolov9t.pt', # Pretrained YOLOv9t (tiny)\n",
|
||||
" \n",
|
||||
" # Training parameters\n",
|
||||
" 'epochs': 100, # Number of epochs\n",
|
||||
" 'batch': 16, # Batch size (adjust based on GPU memory)\n",
|
||||
" 'imgsz': 640, # Image size\n",
|
||||
" 'patience': 20, # Early stopping patience\n",
|
||||
" \n",
|
||||
" # Optimizer\n",
|
||||
" 'optimizer': 'AdamW', # Optimizer: SGD, Adam, AdamW\n",
|
||||
" 'lr0': 0.001, # Initial learning rate\n",
|
||||
" 'lrf': 0.01, # Final learning rate factor\n",
|
||||
" 'momentum': 0.937, # SGD momentum\n",
|
||||
" 'weight_decay': 0.0005, # Weight decay\n",
|
||||
" \n",
|
||||
" # Augmentation\n",
|
||||
" 'hsv_h': 0.015, # HSV-Hue augmentation\n",
|
||||
" 'hsv_s': 0.7, # HSV-Saturation\n",
|
||||
" 'hsv_v': 0.4, # HSV-Value\n",
|
||||
" 'degrees': 0.0, # Rotation\n",
|
||||
" 'translate': 0.1, # Translation\n",
|
||||
" 'scale': 0.5, # Scale\n",
|
||||
" 'shear': 0.0, # Shear\n",
|
||||
" 'flipud': 0.0, # Flip up-down\n",
|
||||
" 'fliplr': 0.5, # Flip left-right\n",
|
||||
" 'mosaic': 1.0, # Mosaic augmentation\n",
|
||||
" 'mixup': 0.0, # Mixup augmentation\n",
|
||||
" \n",
|
||||
" # Other\n",
|
||||
" 'workers': 4, # DataLoader workers\n",
|
||||
" 'device': 0, # GPU device (0 for first GPU)\n",
|
||||
" 'project': 'runs/train', # Output directory\n",
|
||||
" 'name': 'yolov9t_custom', # Experiment name\n",
|
||||
" 'exist_ok': True, # Overwrite existing\n",
|
||||
" 'pretrained': True, # Use pretrained weights\n",
|
||||
" 'verbose': True, # Verbose output\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(\"Configuration loaded!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Prepare Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check dataset exists\n",
|
||||
"dataset_path = Path(DATASET_PATH)\n",
|
||||
"\n",
|
||||
"if not dataset_path.exists():\n",
|
||||
" print(f\"Dataset not found: {dataset_path}\")\n",
|
||||
" print(\"Please update DATASET_PATH or upload your dataset.\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Dataset found: {dataset_path}\")\n",
|
||||
" \n",
|
||||
" # List contents\n",
|
||||
" print(\"\\nContents:\")\n",
|
||||
" for item in dataset_path.iterdir():\n",
|
||||
" if item.is_dir():\n",
|
||||
" count = len(list(item.rglob('*')))\n",
|
||||
" print(f\" {item.name}/ ({count} files)\")\n",
|
||||
" else:\n",
|
||||
" print(f\" {item.name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load and display data.yaml\n",
|
||||
"data_yaml = dataset_path / 'data.yaml'\n",
|
||||
"\n",
|
||||
"if data_yaml.exists():\n",
|
||||
" with open(data_yaml) as f:\n",
|
||||
" data_config = yaml.safe_load(f)\n",
|
||||
" \n",
|
||||
" print(\"data.yaml contents:\")\n",
|
||||
" print(yaml.dump(data_config, default_flow_style=False))\n",
|
||||
" \n",
|
||||
" # Update path to absolute if needed\n",
|
||||
" if not Path(data_config.get('path', '')).is_absolute():\n",
|
||||
" data_config['path'] = str(dataset_path.absolute())\n",
|
||||
" \n",
|
||||
" # Save updated config\n",
|
||||
" with open(data_yaml, 'w') as f:\n",
|
||||
" yaml.dump(data_config, f, default_flow_style=False)\n",
|
||||
" print(\"\\nUpdated path to absolute.\")\n",
|
||||
"else:\n",
|
||||
" print(f\"data.yaml not found: {data_yaml}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Count images and labels\n",
|
||||
"train_images = len(list((dataset_path / 'images' / 'train').glob('*')))\n",
|
||||
"val_images = len(list((dataset_path / 'images' / 'val').glob('*')))\n",
|
||||
"train_labels = len(list((dataset_path / 'labels' / 'train').glob('*.txt')))\n",
|
||||
"val_labels = len(list((dataset_path / 'labels' / 'val').glob('*.txt')))\n",
|
||||
"\n",
|
||||
"print(\"Dataset Statistics:\")\n",
|
||||
"print(f\" Train images: {train_images}\")\n",
|
||||
"print(f\" Train labels: {train_labels}\")\n",
|
||||
"print(f\" Val images: {val_images}\")\n",
|
||||
"print(f\" Val labels: {val_labels}\")\n",
|
||||
"print(f\" Classes: {data_config.get('nc', 'unknown')}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Load YOLOv9t Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load pretrained YOLOv9t model\n",
|
||||
"model = YOLO(CONFIG['model'])\n",
|
||||
"\n",
|
||||
"print(f\"Model: {CONFIG['model']}\")\n",
|
||||
"print(f\"Task: {model.task}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Model information\n",
|
||||
"model.info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Train Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start training\n",
|
||||
"print(\"Starting training...\")\n",
|
||||
"print(f\" Dataset: {data_yaml}\")\n",
|
||||
"print(f\" Epochs: {CONFIG['epochs']}\")\n",
|
||||
"print(f\" Batch size: {CONFIG['batch']}\")\n",
|
||||
"print(f\" Image size: {CONFIG['imgsz']}\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"results = model.train(\n",
|
||||
" data=str(data_yaml),\n",
|
||||
" epochs=CONFIG['epochs'],\n",
|
||||
" batch=CONFIG['batch'],\n",
|
||||
" imgsz=CONFIG['imgsz'],\n",
|
||||
" patience=CONFIG['patience'],\n",
|
||||
" optimizer=CONFIG['optimizer'],\n",
|
||||
" lr0=CONFIG['lr0'],\n",
|
||||
" lrf=CONFIG['lrf'],\n",
|
||||
" momentum=CONFIG['momentum'],\n",
|
||||
" weight_decay=CONFIG['weight_decay'],\n",
|
||||
" hsv_h=CONFIG['hsv_h'],\n",
|
||||
" hsv_s=CONFIG['hsv_s'],\n",
|
||||
" hsv_v=CONFIG['hsv_v'],\n",
|
||||
" degrees=CONFIG['degrees'],\n",
|
||||
" translate=CONFIG['translate'],\n",
|
||||
" scale=CONFIG['scale'],\n",
|
||||
" shear=CONFIG['shear'],\n",
|
||||
" flipud=CONFIG['flipud'],\n",
|
||||
" fliplr=CONFIG['fliplr'],\n",
|
||||
" mosaic=CONFIG['mosaic'],\n",
|
||||
" mixup=CONFIG['mixup'],\n",
|
||||
" workers=CONFIG['workers'],\n",
|
||||
" device=CONFIG['device'],\n",
|
||||
" project=CONFIG['project'],\n",
|
||||
" name=CONFIG['name'],\n",
|
||||
" exist_ok=CONFIG['exist_ok'],\n",
|
||||
" pretrained=CONFIG['pretrained'],\n",
|
||||
" verbose=CONFIG['verbose'],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\nTraining complete!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Training Results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find training output directory\n",
|
||||
"train_dir = Path(CONFIG['project']) / CONFIG['name']\n",
|
||||
"\n",
|
||||
"print(f\"Training output: {train_dir}\")\n",
|
||||
"print(\"\\nContents:\")\n",
|
||||
"for item in sorted(train_dir.iterdir()):\n",
|
||||
" print(f\" {item.name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display training curves\n",
|
||||
"results_png = train_dir / 'results.png'\n",
|
||||
"if results_png.exists():\n",
|
||||
" display(Image(filename=str(results_png), width=1000))\n",
|
||||
"else:\n",
|
||||
" print(\"results.png not found\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display confusion matrix\n",
|
||||
"confusion_matrix = train_dir / 'confusion_matrix.png'\n",
|
||||
"if confusion_matrix.exists():\n",
|
||||
" display(Image(filename=str(confusion_matrix), width=600))\n",
|
||||
"else:\n",
|
||||
" print(\"confusion_matrix.png not found\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display F1 curve\n",
|
||||
"f1_curve = train_dir / 'F1_curve.png'\n",
|
||||
"if f1_curve.exists():\n",
|
||||
" display(Image(filename=str(f1_curve), width=600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display PR curve\n",
|
||||
"pr_curve = train_dir / 'PR_curve.png'\n",
|
||||
"if pr_curve.exists():\n",
|
||||
" display(Image(filename=str(pr_curve), width=600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display sample predictions\n",
|
||||
"val_batch = train_dir / 'val_batch0_pred.jpg'\n",
|
||||
"if val_batch.exists():\n",
|
||||
" print(\"Validation batch predictions:\")\n",
|
||||
" display(Image(filename=str(val_batch), width=1000))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Evaluate Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load best weights\n",
|
||||
"best_weights = train_dir / 'weights' / 'best.pt'\n",
|
||||
"last_weights = train_dir / 'weights' / 'last.pt'\n",
|
||||
"\n",
|
||||
"print(f\"Best weights: {best_weights}\")\n",
|
||||
"print(f\" Size: {best_weights.stat().st_size / 1024 / 1024:.1f} MB\")\n",
|
||||
"\n",
|
||||
"# Load best model\n",
|
||||
"best_model = YOLO(str(best_weights))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Evaluate on validation set\n",
|
||||
"print(\"Evaluating on validation set...\")\n",
|
||||
"metrics = best_model.val(data=str(data_yaml))\n",
|
||||
"\n",
|
||||
"print(\"\\nValidation Metrics:\")\n",
|
||||
"print(f\" mAP50: {metrics.box.map50:.4f}\")\n",
|
||||
"print(f\" mAP50-95: {metrics.box.map:.4f}\")\n",
|
||||
"print(f\" Precision: {metrics.box.mp:.4f}\")\n",
|
||||
"print(f\" Recall: {metrics.box.mr:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Test Inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test on validation images\n",
|
||||
"val_images_dir = dataset_path / 'images' / 'val'\n",
|
||||
"test_images = list(val_images_dir.glob('*.jpg'))[:4]\n",
|
||||
"\n",
|
||||
"if test_images:\n",
|
||||
" print(f\"Testing on {len(test_images)} images...\")\n",
|
||||
" \n",
|
||||
" results = best_model.predict(\n",
|
||||
" source=test_images,\n",
|
||||
" conf=0.25,\n",
|
||||
" save=True,\n",
|
||||
" project='runs/predict',\n",
|
||||
" name='test_inference'\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Display results\n",
|
||||
" predict_dir = Path('runs/predict/test_inference')\n",
|
||||
" for img_path in sorted(predict_dir.glob('*.jpg'))[:4]:\n",
|
||||
" print(f\"\\n{img_path.name}\")\n",
|
||||
" display(Image(filename=str(img_path), width=600))\n",
|
||||
"else:\n",
|
||||
" print(\"No validation images found\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Inference speed test\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"if test_images:\n",
|
||||
" test_img = str(test_images[0])\n",
|
||||
" \n",
|
||||
" # Warmup\n",
|
||||
" for _ in range(3):\n",
|
||||
" _ = best_model.predict(test_img, verbose=False)\n",
|
||||
" \n",
|
||||
" # Benchmark\n",
|
||||
" times = []\n",
|
||||
" for _ in range(10):\n",
|
||||
" start = time.time()\n",
|
||||
" _ = best_model.predict(test_img, verbose=False)\n",
|
||||
" times.append(time.time() - start)\n",
|
||||
" \n",
|
||||
" avg_time = np.mean(times) * 1000\n",
|
||||
" fps = 1000 / avg_time\n",
|
||||
" \n",
|
||||
" print(f\"Inference speed:\")\n",
|
||||
" print(f\" Average: {avg_time:.1f} ms\")\n",
|
||||
" print(f\" FPS: {fps:.1f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 9. Export Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Export to ONNX\n",
|
||||
"print(\"Exporting to ONNX...\")\n",
|
||||
"onnx_path = best_model.export(format='onnx', simplify=True)\n",
|
||||
"print(f\"ONNX exported: {onnx_path}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Export to TorchScript\n",
|
||||
"print(\"Exporting to TorchScript...\")\n",
|
||||
"torchscript_path = best_model.export(format='torchscript')\n",
|
||||
"print(f\"TorchScript exported: {torchscript_path}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Optional: Export to other formats\n",
|
||||
"# TensorRT (requires TensorRT installation)\n",
|
||||
"# engine_path = best_model.export(format='engine')\n",
|
||||
"\n",
|
||||
"# OpenVINO\n",
|
||||
"# openvino_path = best_model.export(format='openvino')\n",
|
||||
"\n",
|
||||
"# CoreML (macOS)\n",
|
||||
"# coreml_path = best_model.export(format='coreml')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 10. Save and Download"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create output archive\n",
|
||||
"import zipfile\n",
|
||||
"\n",
|
||||
"OUTPUT_ZIP = 'yolov9t_trained.zip'\n",
|
||||
"\n",
|
||||
"print(f\"Creating {OUTPUT_ZIP}...\")\n",
|
||||
"\n",
|
||||
"with zipfile.ZipFile(OUTPUT_ZIP, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
|
||||
" # Add weights\n",
|
||||
" zipf.write(best_weights, 'weights/best.pt')\n",
|
||||
" zipf.write(last_weights, 'weights/last.pt')\n",
|
||||
" \n",
|
||||
" # Add ONNX if exists\n",
|
||||
" onnx_file = best_weights.with_suffix('.onnx')\n",
|
||||
" if onnx_file.exists():\n",
|
||||
" zipf.write(onnx_file, 'weights/best.onnx')\n",
|
||||
" \n",
|
||||
" # Add results\n",
|
||||
" for result_file in train_dir.glob('*.png'):\n",
|
||||
" zipf.write(result_file, f'results/{result_file.name}')\n",
|
||||
" \n",
|
||||
" for result_file in train_dir.glob('*.csv'):\n",
|
||||
" zipf.write(result_file, f'results/{result_file.name}')\n",
|
||||
" \n",
|
||||
" # Add args\n",
|
||||
" args_file = train_dir / 'args.yaml'\n",
|
||||
" if args_file.exists():\n",
|
||||
" zipf.write(args_file, 'args.yaml')\n",
|
||||
"\n",
|
||||
"zip_size = os.path.getsize(OUTPUT_ZIP) / 1024 / 1024\n",
|
||||
"print(f\"\\nExport complete!\")\n",
|
||||
"print(f\" File: {OUTPUT_ZIP}\")\n",
|
||||
"print(f\" Size: {zip_size:.1f} MB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# List all output files\n",
|
||||
"print(\"\\nAll output files:\")\n",
|
||||
"print(f\"\\nTraining directory: {train_dir}\")\n",
|
||||
"!ls -la {train_dir}/weights/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Final summary\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"TRAINING SUMMARY\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(f\"\\nModel: YOLOv9t\")\n",
|
||||
"print(f\"Dataset: {dataset_path}\")\n",
|
||||
"print(f\" Train images: {train_images}\")\n",
|
||||
"print(f\" Val images: {val_images}\")\n",
|
||||
"print(f\" Classes: {data_config.get('nc', 'unknown')}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nTraining:\")\n",
|
||||
"print(f\" Epochs: {CONFIG['epochs']}\")\n",
|
||||
"print(f\" Batch size: {CONFIG['batch']}\")\n",
|
||||
"print(f\" Image size: {CONFIG['imgsz']}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nResults:\")\n",
|
||||
"print(f\" mAP50: {metrics.box.map50:.4f}\")\n",
|
||||
"print(f\" mAP50-95: {metrics.box.map:.4f}\")\n",
|
||||
"print(f\" Precision: {metrics.box.mp:.4f}\")\n",
|
||||
"print(f\" Recall: {metrics.box.mr:.4f}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nOutput files:\")\n",
|
||||
"print(f\" Best weights: {best_weights}\")\n",
|
||||
"print(f\" Export archive: {OUTPUT_ZIP}\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 60)\n",
|
||||
"print(\"Training complete! Download weights for deployment.\")\n",
|
||||
"print(\"=\" * 60)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## Usage Example\n",
|
||||
"\n",
|
||||
"After training, use the model for inference:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from ultralytics import YOLO\n",
|
||||
"\n",
|
||||
"# Load trained model\n",
|
||||
"model = YOLO('best.pt')\n",
|
||||
"\n",
|
||||
"# Inference on image\n",
|
||||
"results = model.predict('image.jpg', conf=0.25)\n",
|
||||
"\n",
|
||||
"# Inference on video\n",
|
||||
"results = model.predict('video.mp4', conf=0.25, save=True)\n",
|
||||
"\n",
|
||||
"# Access detections\n",
|
||||
"for result in results:\n",
|
||||
" boxes = result.boxes\n",
|
||||
" for box in boxes:\n",
|
||||
" x1, y1, x2, y2 = box.xyxy[0]\n",
|
||||
" confidence = box.conf[0]\n",
|
||||
" class_id = box.cls[0]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"## Tips\n",
|
||||
"\n",
|
||||
"- **Low mAP?** Try:\n",
|
||||
" - More training epochs\n",
|
||||
" - Data augmentation adjustments\n",
|
||||
" - Lower learning rate\n",
|
||||
" - More training data\n",
|
||||
"\n",
|
||||
"- **Overfitting?** Try:\n",
|
||||
" - More augmentation\n",
|
||||
" - Dropout/regularization\n",
|
||||
" - Early stopping (patience)\n",
|
||||
"\n",
|
||||
"- **Slow training?** Try:\n",
|
||||
" - Larger batch size (if GPU memory allows)\n",
|
||||
" - Mixed precision (amp=True)\n",
|
||||
" - Smaller image size"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user