add sam2 yolo auto annotation

This commit is contained in:
2026-02-04 15:29:36 +07:00
parent 7e56948ece
commit 5a951d8812
2061 changed files with 316473 additions and 0 deletions
@@ -0,0 +1,687 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SAM2 Video Annotation Pipeline\n",
"\n",
"This notebook sets up SAM2 (Segment Anything Model 2) for automatic object annotation from video.\n",
"\n",
"## Features\n",
"- Install SAM2 and dependencies on Kaggle\n",
"- Download pretrained model weights\n",
"- Extract frames from video\n",
"- Auto-generate masks for all objects\n",
"- Save annotations for YOLO conversion\n",
"\n",
"**Platform:** Kaggle GPU (P100/T4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup Environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check GPU availability\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install SAM2 and dependencies\n",
"!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git\n",
"!pip install -q opencv-python-headless supervision tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import cv2\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from tqdm.notebook import tqdm\n",
"from IPython.display import display, Image as IPImage\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Download SAM2 Pretrained Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration\n",
"MODEL_SIZE = 'large' # Options: 'tiny', 'small', 'base_plus', 'large'\n",
"CHECKPOINT_DIR = './checkpoints'\n",
"DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"# Model configurations\n",
"MODEL_CONFIGS = {\n",
" 'tiny': {\n",
" 'config': 'sam2_hiera_t.yaml',\n",
" 'checkpoint': 'sam2_hiera_tiny.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt'\n",
" },\n",
" 'small': {\n",
" 'config': 'sam2_hiera_s.yaml',\n",
" 'checkpoint': 'sam2_hiera_small.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt'\n",
" },\n",
" 'base_plus': {\n",
" 'config': 'sam2_hiera_b+.yaml',\n",
" 'checkpoint': 'sam2_hiera_base_plus.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt'\n",
" },\n",
" 'large': {\n",
" 'config': 'sam2_hiera_l.yaml',\n",
" 'checkpoint': 'sam2_hiera_large.pt',\n",
" 'url': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'\n",
" }\n",
"}\n",
"\n",
"print(f\"Using model: SAM2 {MODEL_SIZE}\")\n",
"print(f\"Device: {DEVICE}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Download checkpoint\n",
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
"\n",
"config = MODEL_CONFIGS[MODEL_SIZE]\n",
"checkpoint_path = os.path.join(CHECKPOINT_DIR, config['checkpoint'])\n",
"\n",
"if not os.path.exists(checkpoint_path):\n",
" print(f\"Downloading {config['checkpoint']}...\")\n",
" !wget -q -O {checkpoint_path} {config['url']}\n",
" print(\"Download complete!\")\n",
"else:\n",
" print(f\"Checkpoint exists: {checkpoint_path}\")\n",
"\n",
"print(f\"Checkpoint size: {os.path.getsize(checkpoint_path) / 1024 / 1024:.1f} MB\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load SAM2 Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sam2.build_sam import build_sam2\n",
"from sam2.sam2_image_predictor import SAM2ImagePredictor\n",
"from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator\n",
"\n",
"print(\"Loading SAM2 model...\")\n",
"\n",
"# Build model\n",
"sam2_model = build_sam2(\n",
" config['config'],\n",
" checkpoint_path,\n",
" device=DEVICE\n",
")\n",
"\n",
"# Create automatic mask generator\n",
"mask_generator = SAM2AutomaticMaskGenerator(\n",
" model=sam2_model,\n",
" points_per_side=32, # Grid density for point prompts\n",
" points_per_batch=64, # Batch size for processing\n",
" pred_iou_thresh=0.7, # IoU threshold for predictions\n",
" stability_score_thresh=0.92, # Stability threshold\n",
" stability_score_offset=1.0,\n",
" box_nms_thresh=0.7, # NMS threshold for boxes\n",
" crop_n_layers=1, # Crop layers for multi-scale\n",
" crop_nms_thresh=0.7,\n",
" crop_overlap_ratio=0.34,\n",
" crop_n_points_downscale_factor=2,\n",
" min_mask_region_area=100 # Minimum mask area in pixels\n",
")\n",
"\n",
"# Create predictor for interactive mode\n",
"predictor = SAM2ImagePredictor(sam2_model)\n",
"\n",
"print(\"Model loaded successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Video Frame Extraction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration - UPDATE THESE PATHS\n",
"VIDEO_PATH = '/kaggle/input/your-video-dataset/video.mp4' # Change to your video path\n",
"FRAMES_DIR = './frames'\n",
"OUTPUT_DIR = './annotations'\n",
"\n",
"# Frame extraction settings\n",
"SAMPLE_FPS = 2 # Extract 2 frames per second (adjust based on video)\n",
"MAX_FRAMES = 500 # Maximum frames to extract (None for all)\n",
"START_TIME = 0 # Start time in seconds\n",
"END_TIME = None # End time in seconds (None for entire video)\n",
"RESIZE = None # Resize frames: (width, height) or None\n",
"\n",
"os.makedirs(FRAMES_DIR, exist_ok=True)\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_video_info(video_path):\n",
" \"\"\"Get video information.\"\"\"\n",
" cap = cv2.VideoCapture(video_path)\n",
" if not cap.isOpened():\n",
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
" \n",
" info = {\n",
" 'fps': cap.get(cv2.CAP_PROP_FPS),\n",
" 'frame_count': int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),\n",
" 'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n",
" 'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
" }\n",
" info['duration'] = info['frame_count'] / info['fps']\n",
" cap.release()\n",
" return info\n",
"\n",
"# Display video info\n",
"if os.path.exists(VIDEO_PATH):\n",
" video_info = get_video_info(VIDEO_PATH)\n",
" print(\"Video Information:\")\n",
" print(f\" Resolution: {video_info['width']}x{video_info['height']}\")\n",
" print(f\" FPS: {video_info['fps']:.2f}\")\n",
" print(f\" Duration: {video_info['duration']:.2f}s\")\n",
" print(f\" Total frames: {video_info['frame_count']}\")\n",
"else:\n",
" print(f\"Video not found: {VIDEO_PATH}\")\n",
" print(\"Please upload your video or update VIDEO_PATH\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def extract_frames(video_path, output_dir, sample_fps=None, max_frames=None, \n",
" start_time=0, end_time=None, resize=None):\n",
" \"\"\"Extract frames from video.\"\"\"\n",
" cap = cv2.VideoCapture(video_path)\n",
" if not cap.isOpened():\n",
" raise ValueError(f\"Cannot open video: {video_path}\")\n",
" \n",
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
" \n",
" # Calculate frame interval\n",
" if sample_fps and sample_fps < fps:\n",
" frame_interval = int(fps / sample_fps)\n",
" else:\n",
" frame_interval = 1\n",
" \n",
" # Calculate frame range\n",
" start_frame = int(start_time * fps)\n",
" end_frame = int(end_time * fps) if end_time else frame_count\n",
" end_frame = min(end_frame, frame_count)\n",
" \n",
" cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n",
" \n",
" saved_paths = []\n",
" frame_idx = start_frame\n",
" extracted = 0\n",
" \n",
" total_to_extract = min(\n",
" (end_frame - start_frame) // frame_interval,\n",
" max_frames or float('inf')\n",
" )\n",
" \n",
" pbar = tqdm(total=int(total_to_extract), desc=\"Extracting frames\")\n",
" \n",
" while frame_idx < end_frame:\n",
" if max_frames and extracted >= max_frames:\n",
" break\n",
" \n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
" \n",
" if (frame_idx - start_frame) % frame_interval == 0:\n",
" if resize:\n",
" frame = cv2.resize(frame, resize)\n",
" \n",
" frame_name = f\"frame_{frame_idx:06d}.jpg\"\n",
" frame_path = os.path.join(output_dir, frame_name)\n",
" cv2.imwrite(frame_path, frame)\n",
" saved_paths.append(frame_path)\n",
" extracted += 1\n",
" pbar.update(1)\n",
" \n",
" frame_idx += 1\n",
" \n",
" pbar.close()\n",
" cap.release()\n",
" \n",
" print(f\"Extracted {len(saved_paths)} frames to {output_dir}\")\n",
" return saved_paths\n",
"\n",
"# Extract frames\n",
"if os.path.exists(VIDEO_PATH):\n",
" frame_paths = extract_frames(\n",
" VIDEO_PATH,\n",
" FRAMES_DIR,\n",
" sample_fps=SAMPLE_FPS,\n",
" max_frames=MAX_FRAMES,\n",
" start_time=START_TIME,\n",
" end_time=END_TIME,\n",
" resize=RESIZE\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preview extracted frames\n",
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))[:6]\n",
"\n",
"if frame_files:\n",
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
" for ax, frame_file in zip(axes.flat, frame_files):\n",
" img = cv2.imread(str(frame_file))\n",
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
" ax.imshow(img)\n",
" ax.set_title(frame_file.name)\n",
" ax.axis('off')\n",
" plt.tight_layout()\n",
" plt.show()\n",
"else:\n",
" print(\"No frames found. Please extract frames first.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Automatic Mask Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Annotation settings\n",
"MIN_MASK_AREA = 500 # Minimum mask area in pixels\n",
"MAX_MASK_AREA = None # Maximum mask area (None for no limit)\n",
"SAVE_MASKS = True # Save individual mask images\n",
"SAVE_VIS = True # Save visualization images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def visualize_masks(image, masks, alpha=0.5):\n",
" \"\"\"Create visualization with masks overlaid.\"\"\"\n",
" vis = image.copy()\n",
" \n",
" for mask_data in masks:\n",
" mask = mask_data['segmentation']\n",
" color = np.random.randint(0, 255, 3).tolist()\n",
" \n",
" overlay = vis.copy()\n",
" overlay[mask] = color\n",
" vis = cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)\n",
" \n",
" # Draw bbox\n",
" x, y, w, h = mask_data['bbox']\n",
" cv2.rectangle(vis, (x, y), (x + w, y + h), color, 2)\n",
" \n",
" return vis\n",
"\n",
"def annotate_single_image(image_path, min_area=100, max_area=None):\n",
" \"\"\"Generate masks for a single image.\"\"\"\n",
" image = cv2.imread(image_path)\n",
" image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" \n",
" # Generate masks\n",
" masks = mask_generator.generate(image_rgb)\n",
" \n",
" # Filter by area\n",
" filtered = []\n",
" for m in masks:\n",
" area = m['area']\n",
" if area >= min_area:\n",
" if max_area is None or area <= max_area:\n",
" filtered.append(m)\n",
" \n",
" # Sort by area (largest first)\n",
" filtered.sort(key=lambda x: x['area'], reverse=True)\n",
" \n",
" return filtered, image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test on a single frame\n",
"frame_files = sorted(Path(FRAMES_DIR).glob(\"*.jpg\"))\n",
"\n",
"if frame_files:\n",
" test_frame = str(frame_files[0])\n",
" print(f\"Testing on: {test_frame}\")\n",
" \n",
" masks, image = annotate_single_image(\n",
" test_frame,\n",
" min_area=MIN_MASK_AREA,\n",
" max_area=MAX_MASK_AREA\n",
" )\n",
" \n",
" print(f\"Found {len(masks)} objects\")\n",
" \n",
" # Visualize\n",
" vis = visualize_masks(image, masks)\n",
" vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)\n",
" \n",
" fig, axes = plt.subplots(1, 2, figsize=(15, 7))\n",
" axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
" axes[0].set_title('Original')\n",
" axes[0].axis('off')\n",
" \n",
" axes[1].imshow(vis_rgb)\n",
" axes[1].set_title(f'SAM2 Annotations ({len(masks)} objects)')\n",
" axes[1].axis('off')\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
" \n",
" # Show bbox details\n",
" print(\"\\nDetected objects:\")\n",
" for i, m in enumerate(masks[:10]):\n",
" print(f\" {i}: bbox={m['bbox']}, area={m['area']}, iou={m['predicted_iou']:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Process all frames\n",
"def annotate_all_frames(frames_dir, output_dir, min_area=100, max_area=None,\n",
" save_masks=True, save_vis=True):\n",
" \"\"\"Annotate all frames and save results.\"\"\"\n",
" frames_path = Path(frames_dir)\n",
" output_path = Path(output_dir)\n",
" output_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" if save_masks:\n",
" masks_dir = output_path / 'masks'\n",
" masks_dir.mkdir(exist_ok=True)\n",
" \n",
" if save_vis:\n",
" vis_dir = output_path / 'visualizations'\n",
" vis_dir.mkdir(exist_ok=True)\n",
" \n",
" frame_files = sorted(frames_path.glob(\"*.jpg\"))\n",
" print(f\"Processing {len(frame_files)} frames...\")\n",
" \n",
" all_annotations = {}\n",
" total_objects = 0\n",
" \n",
" for frame_file in tqdm(frame_files, desc=\"Annotating\"):\n",
" masks, image = annotate_single_image(\n",
" str(frame_file),\n",
" min_area=min_area,\n",
" max_area=max_area\n",
" )\n",
" \n",
" # Convert to serializable format\n",
" annotations = []\n",
" for i, m in enumerate(masks):\n",
" ann = {\n",
" 'id': i,\n",
" 'bbox': [int(x) for x in m['bbox']], # [x, y, w, h]\n",
" 'area': int(m['area']),\n",
" 'predicted_iou': float(m['predicted_iou']),\n",
" 'stability_score': float(m['stability_score'])\n",
" }\n",
" annotations.append(ann)\n",
" \n",
" # Save individual mask\n",
" if save_masks:\n",
" frame_masks_dir = masks_dir / frame_file.stem\n",
" frame_masks_dir.mkdir(exist_ok=True)\n",
" mask_path = frame_masks_dir / f\"mask_{i:03d}.png\"\n",
" cv2.imwrite(\n",
" str(mask_path),\n",
" m['segmentation'].astype(np.uint8) * 255\n",
" )\n",
" \n",
" all_annotations[frame_file.name] = annotations\n",
" total_objects += len(annotations)\n",
" \n",
" # Save visualization\n",
" if save_vis:\n",
" vis = visualize_masks(image, masks)\n",
" cv2.imwrite(str(vis_dir / frame_file.name), vis)\n",
" \n",
" # Save annotations JSON\n",
" annotations_file = output_path / 'annotations.json'\n",
" with open(annotations_file, 'w') as f:\n",
" json.dump(all_annotations, f, indent=2)\n",
" \n",
" print(f\"\\nAnnotation complete!\")\n",
" print(f\" Frames: {len(frame_files)}\")\n",
" print(f\" Total objects: {total_objects}\")\n",
" print(f\" Avg objects/frame: {total_objects/len(frame_files):.1f}\")\n",
" print(f\" Saved to: {output_path}\")\n",
" \n",
" return all_annotations\n",
"\n",
"# Run annotation\n",
"annotations = annotate_all_frames(\n",
" FRAMES_DIR,\n",
" OUTPUT_DIR,\n",
" min_area=MIN_MASK_AREA,\n",
" max_area=MAX_MASK_AREA,\n",
" save_masks=SAVE_MASKS,\n",
" save_vis=SAVE_VIS\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Review Annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preview some annotated frames\n",
"vis_dir = Path(OUTPUT_DIR) / 'visualizations'\n",
"vis_files = sorted(vis_dir.glob(\"*.jpg\"))[:6]\n",
"\n",
"if vis_files:\n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
" for ax, vis_file in zip(axes.flat, vis_files):\n",
" img = cv2.imread(str(vis_file))\n",
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
" ax.imshow(img)\n",
" \n",
" # Get object count\n",
" frame_name = vis_file.name\n",
" if frame_name in annotations:\n",
" count = len(annotations[frame_name])\n",
" else:\n",
" count = 0\n",
" ax.set_title(f\"{vis_file.name} ({count} objects)\")\n",
" ax.axis('off')\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Statistics\n",
"object_counts = [len(ann) for ann in annotations.values()]\n",
"\n",
"print(\"Annotation Statistics:\")\n",
"print(f\" Total frames: {len(annotations)}\")\n",
"print(f\" Total objects: {sum(object_counts)}\")\n",
"print(f\" Min objects/frame: {min(object_counts)}\")\n",
"print(f\" Max objects/frame: {max(object_counts)}\")\n",
"print(f\" Avg objects/frame: {np.mean(object_counts):.1f}\")\n",
"\n",
"# Histogram\n",
"plt.figure(figsize=(10, 4))\n",
"plt.hist(object_counts, bins=30, edgecolor='black')\n",
"plt.xlabel('Objects per frame')\n",
"plt.ylabel('Frequency')\n",
"plt.title('Distribution of Objects per Frame')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Save Output for Next Notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Summary of output files\n",
"print(\"Output files for next notebook:\")\n",
"print(f\" Frames: {FRAMES_DIR}/\")\n",
"print(f\" Annotations: {OUTPUT_DIR}/annotations.json\")\n",
"print(f\" Masks: {OUTPUT_DIR}/masks/\")\n",
"print(f\" Visualizations: {OUTPUT_DIR}/visualizations/\")\n",
"\n",
"# Check sizes\n",
"import subprocess\n",
"!du -sh {FRAMES_DIR} {OUTPUT_DIR}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: Create archive for download\n",
"ARCHIVE_NAME = 'sam2_annotations.zip'\n",
"\n",
"!zip -r {ARCHIVE_NAME} {FRAMES_DIR} {OUTPUT_DIR}/annotations.json {OUTPUT_DIR}/masks/\n",
"print(f\"\\nArchive created: {ARCHIVE_NAME}\")\n",
"!ls -lh {ARCHIVE_NAME}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"## Next Steps\n",
"\n",
"1. **Run `02_create_yolo_dataset.ipynb`** to convert annotations to YOLO format\n",
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t on the dataset\n",
"\n",
"### Tips\n",
"\n",
"- Adjust `MIN_MASK_AREA` to filter small detections\n",
"- Use `MAX_MASK_AREA` to filter very large objects (e.g., background)\n",
"- Lower `SAMPLE_FPS` for faster processing or if frames are similar\n",
"- Review visualizations to verify annotation quality"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@@ -0,0 +1,726 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create YOLO Dataset from SAM2 Annotations\n",
"\n",
"Convert SAM2 mask annotations to YOLO detection format (bounding boxes).\n",
"\n",
"## Input\n",
"- Frames from `01_sam2_video_annotation.ipynb`\n",
"- Annotations JSON file\n",
"\n",
"## Output\n",
"- YOLO format dataset ready for training\n",
"- `data.yaml` configuration file\n",
"\n",
"**Platform:** Kaggle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import json\n",
"import yaml\n",
"import shutil\n",
"import random\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from tqdm.notebook import tqdm\n",
"from collections import defaultdict\n",
"\n",
"print(\"Setup complete!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configuration - UPDATE THESE PATHS\n",
"FRAMES_DIR = './frames' # From notebook 01\n",
"ANNOTATIONS_FILE = './annotations/annotations.json' # From notebook 01\n",
"MASKS_DIR = './annotations/masks' # Optional: mask images\n",
"\n",
"# Output dataset\n",
"DATASET_DIR = './yolo_dataset'\n",
"\n",
"# Dataset settings\n",
"CLASS_NAMES = ['object'] # Single class for generic objects\n",
"VAL_SPLIT = 0.2 # 20% validation\n",
"SEED = 42 # Random seed for reproducibility\n",
"\n",
"# Filtering\n",
"MIN_BBOX_AREA = 100 # Minimum bbox area in pixels\n",
"MIN_BBOX_SIZE = 0.01 # Minimum bbox dimension (normalized, 0-1)\n",
"MAX_OBJECTS_PER_IMAGE = 100 # Maximum objects per image\n",
"MIN_IOU_SCORE = 0.5 # Minimum SAM2 IoU score\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Load Annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load annotations\n",
"with open(ANNOTATIONS_FILE, 'r') as f:\n",
" annotations = json.load(f)\n",
"\n",
"print(f\"Loaded annotations for {len(annotations)} frames\")\n",
"\n",
"# Show sample\n",
"sample_frame = list(annotations.keys())[0]\n",
"print(f\"\\nSample annotation ({sample_frame}):\")\n",
"print(json.dumps(annotations[sample_frame][:2], indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Verify frames exist\n",
"frames_path = Path(FRAMES_DIR)\n",
"frame_files = list(frames_path.glob(\"*.jpg\")) + list(frames_path.glob(\"*.png\"))\n",
"\n",
"print(f\"Found {len(frame_files)} frame images\")\n",
"\n",
"# Check matching\n",
"annotation_frames = set(annotations.keys())\n",
"image_frames = {f.name for f in frame_files}\n",
"\n",
"matched = annotation_frames & image_frames\n",
"print(f\"Matched frames: {len(matched)}\")\n",
"\n",
"if len(matched) < len(annotation_frames):\n",
" missing = annotation_frames - image_frames\n",
" print(f\"Warning: {len(missing)} annotated frames missing images\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Create YOLO Dataset Structure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create directory structure\n",
"dataset_path = Path(DATASET_DIR)\n",
"\n",
"images_train = dataset_path / 'images' / 'train'\n",
"images_val = dataset_path / 'images' / 'val'\n",
"labels_train = dataset_path / 'labels' / 'train'\n",
"labels_val = dataset_path / 'labels' / 'val'\n",
"\n",
"for dir_path in [images_train, images_val, labels_train, labels_val]:\n",
" dir_path.mkdir(parents=True, exist_ok=True)\n",
" print(f\"Created: {dir_path}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Convert Annotations to YOLO Format"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bbox_xywh_to_yolo(bbox, image_width, image_height):\n",
" \"\"\"\n",
" Convert [x, y, w, h] bbox to YOLO format [x_center, y_center, width, height] normalized.\n",
" \"\"\"\n",
" x, y, w, h = bbox\n",
" \n",
" x_center = (x + w / 2) / image_width\n",
" y_center = (y + h / 2) / image_height\n",
" width = w / image_width\n",
" height = h / image_height\n",
" \n",
" # Clamp to [0, 1]\n",
" x_center = max(0, min(1, x_center))\n",
" y_center = max(0, min(1, y_center))\n",
" width = max(0, min(1, width))\n",
" height = max(0, min(1, height))\n",
" \n",
" return x_center, y_center, width, height\n",
"\n",
"\n",
"def filter_annotations(anns, img_width, img_height, \n",
" min_area=100, min_size=0.01, \n",
" min_iou=0.5, max_objects=100):\n",
" \"\"\"\n",
" Filter annotations based on criteria.\n",
" \"\"\"\n",
" filtered = []\n",
" \n",
" for ann in anns:\n",
" bbox = ann.get('bbox', [])\n",
" area = ann.get('area', 0)\n",
" iou = ann.get('predicted_iou', 1.0)\n",
" \n",
" # Check area\n",
" if area < min_area:\n",
" continue\n",
" \n",
" # Check IoU score\n",
" if iou < min_iou:\n",
" continue\n",
" \n",
" # Check bbox dimensions\n",
" if len(bbox) == 4:\n",
" _, _, w, h = bbox\n",
" if w / img_width < min_size or h / img_height < min_size:\n",
" continue\n",
" \n",
" filtered.append(ann)\n",
" \n",
" # Limit number of objects (keep highest IoU)\n",
" if len(filtered) > max_objects:\n",
" filtered.sort(key=lambda x: x.get('predicted_iou', 0), reverse=True)\n",
" filtered = filtered[:max_objects]\n",
" \n",
" return filtered"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def process_frame(frame_name, frame_anns, frames_dir, images_dir, labels_dir, class_id=0):\n",
" \"\"\"\n",
" Process a single frame: copy image and create YOLO label file.\n",
" \"\"\"\n",
" frame_path = Path(frames_dir) / frame_name\n",
" \n",
" if not frame_path.exists():\n",
" return 0\n",
" \n",
" # Read image dimensions\n",
" image = cv2.imread(str(frame_path))\n",
" if image is None:\n",
" return 0\n",
" \n",
" height, width = image.shape[:2]\n",
" \n",
" # Filter annotations\n",
" filtered_anns = filter_annotations(\n",
" frame_anns, width, height,\n",
" min_area=MIN_BBOX_AREA,\n",
" min_size=MIN_BBOX_SIZE,\n",
" min_iou=MIN_IOU_SCORE,\n",
" max_objects=MAX_OBJECTS_PER_IMAGE\n",
" )\n",
" \n",
" # Copy image\n",
" dest_image = images_dir / frame_name\n",
" shutil.copy2(frame_path, dest_image)\n",
" \n",
" # Create YOLO labels\n",
" labels = []\n",
" for ann in filtered_anns:\n",
" bbox = ann.get('bbox', [])\n",
" if len(bbox) != 4:\n",
" continue\n",
" \n",
" x_center, y_center, w, h = bbox_xywh_to_yolo(bbox, width, height)\n",
" \n",
" # YOLO format: class x_center y_center width height\n",
" label_line = f\"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\"\n",
" labels.append(label_line)\n",
" \n",
" # Write label file\n",
" label_name = Path(frame_name).stem + '.txt'\n",
" label_path = labels_dir / label_name\n",
" \n",
" with open(label_path, 'w') as f:\n",
" f.write('\\n'.join(labels))\n",
" \n",
" return len(labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Split frames into train/val\n",
"frame_names = list(annotations.keys())\n",
"random.shuffle(frame_names)\n",
"\n",
"split_idx = int(len(frame_names) * (1 - VAL_SPLIT))\n",
"train_frames = frame_names[:split_idx]\n",
"val_frames = frame_names[split_idx:]\n",
"\n",
"print(f\"Train frames: {len(train_frames)}\")\n",
"print(f\"Val frames: {len(val_frames)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Process training frames\n",
"train_objects = 0\n",
"for frame_name in tqdm(train_frames, desc=\"Processing train\"):\n",
" count = process_frame(\n",
" frame_name,\n",
" annotations.get(frame_name, []),\n",
" FRAMES_DIR,\n",
" images_train,\n",
" labels_train,\n",
" class_id=0\n",
" )\n",
" train_objects += count\n",
"\n",
"print(f\"\\nTrain: {len(train_frames)} images, {train_objects} objects\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Process validation frames\n",
"val_objects = 0\n",
"for frame_name in tqdm(val_frames, desc=\"Processing val\"):\n",
" count = process_frame(\n",
" frame_name,\n",
" annotations.get(frame_name, []),\n",
" FRAMES_DIR,\n",
" images_val,\n",
" labels_val,\n",
" class_id=0\n",
" )\n",
" val_objects += count\n",
"\n",
"print(f\"\\nVal: {len(val_frames)} images, {val_objects} objects\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Create data.yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create YOLO data.yaml configuration\n",
"data_config = {\n",
" 'path': str(Path(DATASET_DIR).absolute()),\n",
" 'train': 'images/train',\n",
" 'val': 'images/val',\n",
" 'names': {i: name for i, name in enumerate(CLASS_NAMES)},\n",
" 'nc': len(CLASS_NAMES)\n",
"}\n",
"\n",
"yaml_path = dataset_path / 'data.yaml'\n",
"with open(yaml_path, 'w') as f:\n",
" yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)\n",
"\n",
"print(f\"Created: {yaml_path}\")\n",
"print(\"\\nContents:\")\n",
"with open(yaml_path) as f:\n",
" print(f.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Validate Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def validate_dataset(dataset_dir):\n",
" \"\"\"Validate YOLO dataset structure.\"\"\"\n",
" dataset_path = Path(dataset_dir)\n",
" results = {'valid': True, 'errors': [], 'warnings': [], 'stats': {}}\n",
" \n",
" # Check data.yaml\n",
" yaml_path = dataset_path / 'data.yaml'\n",
" if not yaml_path.exists():\n",
" results['errors'].append(\"Missing data.yaml\")\n",
" results['valid'] = False\n",
" else:\n",
" with open(yaml_path) as f:\n",
" config = yaml.safe_load(f)\n",
" results['stats']['num_classes'] = config.get('nc', 0)\n",
" results['stats']['class_names'] = config.get('names', {})\n",
" \n",
" # Check directories and count files\n",
" for split in ['train', 'val']:\n",
" images_dir = dataset_path / 'images' / split\n",
" labels_dir = dataset_path / 'labels' / split\n",
" \n",
" if not images_dir.exists():\n",
" results['errors'].append(f\"Missing images/{split}\")\n",
" results['valid'] = False\n",
" continue\n",
" \n",
" image_files = list(images_dir.glob(\"*.jpg\")) + list(images_dir.glob(\"*.png\"))\n",
" label_files = list(labels_dir.glob(\"*.txt\"))\n",
" \n",
" results['stats'][f'{split}_images'] = len(image_files)\n",
" results['stats'][f'{split}_labels'] = len(label_files)\n",
" \n",
" # Check for missing labels\n",
" image_stems = {f.stem for f in image_files}\n",
" label_stems = {f.stem for f in label_files}\n",
" missing = image_stems - label_stems\n",
" \n",
" if missing:\n",
" results['warnings'].append(f\"{len(missing)} {split} images missing labels\")\n",
" \n",
" # Count total objects\n",
" total_objects = 0\n",
" for label_file in label_files:\n",
" with open(label_file) as f:\n",
" lines = [l.strip() for l in f if l.strip()]\n",
" total_objects += len(lines)\n",
" results['stats'][f'{split}_objects'] = total_objects\n",
" \n",
" return results\n",
"\n",
"# Validate\n",
"validation = validate_dataset(DATASET_DIR)\n",
"\n",
"print(\"Dataset Validation:\")\n",
"print(f\" Valid: {validation['valid']}\")\n",
"print(f\"\\nStatistics:\")\n",
"for key, value in validation['stats'].items():\n",
" print(f\" {key}: {value}\")\n",
"\n",
"if validation['errors']:\n",
" print(f\"\\nErrors:\")\n",
" for err in validation['errors']:\n",
" print(f\" - {err}\")\n",
"\n",
"if validation['warnings']:\n",
" print(f\"\\nWarnings:\")\n",
" for warn in validation['warnings']:\n",
" print(f\" - {warn}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Visualize Dataset Samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def visualize_yolo_sample(image_path, label_path, class_names):\n",
" \"\"\"Visualize YOLO annotation on image.\"\"\"\n",
" image = cv2.imread(str(image_path))\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" height, width = image.shape[:2]\n",
" \n",
" # Read labels\n",
" if label_path.exists():\n",
" with open(label_path) as f:\n",
" labels = [l.strip().split() for l in f if l.strip()]\n",
" else:\n",
" labels = []\n",
" \n",
" # Draw bboxes\n",
" colors = plt.cm.tab10(np.linspace(0, 1, 10))\n",
" \n",
" for label in labels:\n",
" class_id = int(label[0])\n",
" x_center, y_center, w, h = map(float, label[1:5])\n",
" \n",
" # Convert to pixel coordinates\n",
" x1 = int((x_center - w/2) * width)\n",
" y1 = int((y_center - h/2) * height)\n",
" x2 = int((x_center + w/2) * width)\n",
" y2 = int((y_center + h/2) * height)\n",
" \n",
" color = tuple(int(c * 255) for c in colors[class_id % 10][:3])\n",
" cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)\n",
" \n",
" # Add label\n",
" class_name = class_names.get(class_id, str(class_id))\n",
" cv2.putText(image, class_name, (x1, y1-5), \n",
" cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
" \n",
" return image, len(labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize train samples\n",
"train_images = sorted(images_train.glob(\"*.jpg\"))[:6]\n",
"\n",
"fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
"class_names_dict = {i: name for i, name in enumerate(CLASS_NAMES)}\n",
"\n",
"for ax, img_path in zip(axes.flat, train_images):\n",
" label_path = labels_train / (img_path.stem + '.txt')\n",
" vis, count = visualize_yolo_sample(img_path, label_path, class_names_dict)\n",
" \n",
" ax.imshow(vis)\n",
" ax.set_title(f\"{img_path.name} ({count} objects)\")\n",
" ax.axis('off')\n",
"\n",
"plt.suptitle('Training Samples with YOLO Annotations', fontsize=14)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Object count distribution\n",
"train_label_files = list(labels_train.glob(\"*.txt\"))\n",
"val_label_files = list(labels_val.glob(\"*.txt\"))\n",
"\n",
"def count_objects_in_labels(label_files):\n",
" counts = []\n",
" for lf in label_files:\n",
" with open(lf) as f:\n",
" lines = [l.strip() for l in f if l.strip()]\n",
" counts.append(len(lines))\n",
" return counts\n",
"\n",
"train_counts = count_objects_in_labels(train_label_files)\n",
"val_counts = count_objects_in_labels(val_label_files)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n",
"\n",
"axes[0].hist(train_counts, bins=30, edgecolor='black', alpha=0.7, label='Train')\n",
"axes[0].hist(val_counts, bins=30, edgecolor='black', alpha=0.7, label='Val')\n",
"axes[0].set_xlabel('Objects per image')\n",
"axes[0].set_ylabel('Frequency')\n",
"axes[0].set_title('Objects per Image Distribution')\n",
"axes[0].legend()\n",
"\n",
"# Bbox size distribution\n",
"bbox_sizes = []\n",
"for lf in train_label_files:\n",
" with open(lf) as f:\n",
" for line in f:\n",
" parts = line.strip().split()\n",
" if len(parts) >= 5:\n",
" w, h = float(parts[3]), float(parts[4])\n",
" bbox_sizes.append(w * h)\n",
"\n",
"axes[1].hist(bbox_sizes, bins=50, edgecolor='black', alpha=0.7)\n",
"axes[1].set_xlabel('Bbox area (normalized)')\n",
"axes[1].set_ylabel('Frequency')\n",
"axes[1].set_title('Bounding Box Size Distribution')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"\\nBbox size stats:\")\n",
"print(f\" Min: {min(bbox_sizes):.4f}\")\n",
"print(f\" Max: {max(bbox_sizes):.4f}\")\n",
"print(f\" Mean: {np.mean(bbox_sizes):.4f}\")\n",
"print(f\" Median: {np.median(bbox_sizes):.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Export for Kaggle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create zip archive for Kaggle dataset\n",
"import zipfile\n",
"\n",
"EXPORT_ZIP = 'yolo_dataset.zip'\n",
"\n",
"print(f\"Creating {EXPORT_ZIP}...\")\n",
"\n",
"with zipfile.ZipFile(EXPORT_ZIP, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
" for root, dirs, files in os.walk(DATASET_DIR):\n",
" for file in files:\n",
" file_path = os.path.join(root, file)\n",
" arcname = os.path.relpath(file_path, os.path.dirname(DATASET_DIR))\n",
" zipf.write(file_path, arcname)\n",
"\n",
"zip_size = os.path.getsize(EXPORT_ZIP) / 1024 / 1024\n",
"print(f\"\\nExport complete!\")\n",
"print(f\" File: {EXPORT_ZIP}\")\n",
"print(f\" Size: {zip_size:.1f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Alternative: Create Kaggle dataset directly (if kaggle CLI available)\n",
"# Uncomment to use\n",
"\n",
"# KAGGLE_USERNAME = 'your-username'\n",
"# DATASET_NAME = 'sam2-yolo-custom'\n",
"\n",
"# # Create dataset metadata\n",
"# metadata = {\n",
"# 'title': 'SAM2 Auto-Annotated YOLO Dataset',\n",
"# 'id': f'{KAGGLE_USERNAME}/{DATASET_NAME}',\n",
"# 'licenses': [{'name': 'CC0-1.0'}]\n",
"# }\n",
"\n",
"# metadata_path = dataset_path / 'dataset-metadata.json'\n",
"# with open(metadata_path, 'w') as f:\n",
"# json.dump(metadata, f, indent=2)\n",
"\n",
"# # Upload to Kaggle\n",
"# !kaggle datasets create -p {DATASET_DIR}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Dataset Summary"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Final summary\n",
"print(\"=\" * 50)\n",
"print(\"YOLO DATASET SUMMARY\")\n",
"print(\"=\" * 50)\n",
"print(f\"\\nDataset location: {Path(DATASET_DIR).absolute()}\")\n",
"print(f\"\\nClasses ({len(CLASS_NAMES)}):\")\n",
"for i, name in enumerate(CLASS_NAMES):\n",
" print(f\" {i}: {name}\")\n",
"\n",
"print(f\"\\nSplit:\")\n",
"print(f\" Train: {validation['stats']['train_images']} images, {validation['stats']['train_objects']} objects\")\n",
"print(f\" Val: {validation['stats']['val_images']} images, {validation['stats']['val_objects']} objects\")\n",
"print(f\" Total: {validation['stats']['train_images'] + validation['stats']['val_images']} images\")\n",
"\n",
"print(f\"\\nFiles:\")\n",
"print(f\" data.yaml: {yaml_path}\")\n",
"print(f\" Export: {EXPORT_ZIP}\")\n",
"\n",
"print(\"\\n\" + \"=\" * 50)\n",
"print(\"Ready for YOLOv9t training!\")\n",
"print(\"=\" * 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"## Next Steps\n",
"\n",
"1. **Upload dataset to Kaggle** (if not already done)\n",
" - Go to kaggle.com/datasets/new\n",
" - Upload `yolo_dataset.zip`\n",
" \n",
"2. **Run `03_train_yolov9t.ipynb`** to train YOLOv9t\n",
"\n",
"### Dataset Structure\n",
"```\n",
"yolo_dataset/\n",
"├── data.yaml\n",
"├── images/\n",
"│ ├── train/\n",
"│ └── val/\n",
"└── labels/\n",
" ├── train/\n",
" └── val/\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
+729
View File
@@ -0,0 +1,729 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train YOLOv9t on Custom Dataset\n",
"\n",
"Train YOLOv9t (tiny) model on YOLO format dataset created from SAM2 annotations.\n",
"\n",
"## Input\n",
"- YOLO format dataset from `02_create_yolo_dataset.ipynb`\n",
"\n",
"## Output\n",
"- Trained YOLOv9t model weights\n",
"- Training metrics and visualizations\n",
"\n",
"**Platform:** Kaggle GPU (P100/T4) - Enable GPU in notebook settings!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup Environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check GPU\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install YOLOv9 (ultralytics fork with v9 support)\n",
"!pip install -q ultralytics\n",
"\n",
"# Alternative: Install official YOLOv9 repo\n",
"# !git clone https://github.com/WongKinYiu/yolov9.git\n",
"# %cd yolov9\n",
"# !pip install -q -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import torch\n",
"import yaml\n",
"import shutil\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from datetime import datetime\n",
"from IPython.display import Image, display\n",
"\n",
"print(f\"Python: {sys.version}\")\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
" print(f\"CUDA version: {torch.version.cuda}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"import ultralytics\n",
"print(f\"Ultralytics version: {ultralytics.__version__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Dataset configuration - UPDATE THIS PATH\n",
"# Option 1: From previous notebook (local)\n",
"DATASET_PATH = './yolo_dataset'\n",
"\n",
"# Option 2: From Kaggle dataset (uncomment and update)\n",
"# DATASET_PATH = '/kaggle/input/your-dataset-name/yolo_dataset'\n",
"\n",
"# Training configuration\n",
"CONFIG = {\n",
" # Model\n",
" 'model': 'yolov9t.pt', # Pretrained YOLOv9t (tiny)\n",
" \n",
" # Training parameters\n",
" 'epochs': 100, # Number of epochs\n",
" 'batch': 16, # Batch size (adjust based on GPU memory)\n",
" 'imgsz': 640, # Image size\n",
" 'patience': 20, # Early stopping patience\n",
" \n",
" # Optimizer\n",
" 'optimizer': 'AdamW', # Optimizer: SGD, Adam, AdamW\n",
" 'lr0': 0.001, # Initial learning rate\n",
" 'lrf': 0.01, # Final learning rate factor\n",
" 'momentum': 0.937, # SGD momentum\n",
" 'weight_decay': 0.0005, # Weight decay\n",
" \n",
" # Augmentation\n",
" 'hsv_h': 0.015, # HSV-Hue augmentation\n",
" 'hsv_s': 0.7, # HSV-Saturation\n",
" 'hsv_v': 0.4, # HSV-Value\n",
" 'degrees': 0.0, # Rotation\n",
" 'translate': 0.1, # Translation\n",
" 'scale': 0.5, # Scale\n",
" 'shear': 0.0, # Shear\n",
" 'flipud': 0.0, # Flip up-down\n",
" 'fliplr': 0.5, # Flip left-right\n",
" 'mosaic': 1.0, # Mosaic augmentation\n",
" 'mixup': 0.0, # Mixup augmentation\n",
" \n",
" # Other\n",
" 'workers': 4, # DataLoader workers\n",
" 'device': 0, # GPU device (0 for first GPU)\n",
" 'project': 'runs/train', # Output directory\n",
" 'name': 'yolov9t_custom', # Experiment name\n",
" 'exist_ok': True, # Overwrite existing\n",
" 'pretrained': True, # Use pretrained weights\n",
" 'verbose': True, # Verbose output\n",
"}\n",
"\n",
"print(\"Configuration loaded!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Prepare Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check dataset exists\n",
"dataset_path = Path(DATASET_PATH)\n",
"\n",
"if not dataset_path.exists():\n",
" print(f\"Dataset not found: {dataset_path}\")\n",
" print(\"Please update DATASET_PATH or upload your dataset.\")\n",
"else:\n",
" print(f\"Dataset found: {dataset_path}\")\n",
" \n",
" # List contents\n",
" print(\"\\nContents:\")\n",
" for item in dataset_path.iterdir():\n",
" if item.is_dir():\n",
" count = len(list(item.rglob('*')))\n",
" print(f\" {item.name}/ ({count} files)\")\n",
" else:\n",
" print(f\" {item.name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load and display data.yaml\n",
"data_yaml = dataset_path / 'data.yaml'\n",
"\n",
"if data_yaml.exists():\n",
" with open(data_yaml) as f:\n",
" data_config = yaml.safe_load(f)\n",
" \n",
" print(\"data.yaml contents:\")\n",
" print(yaml.dump(data_config, default_flow_style=False))\n",
" \n",
" # Update path to absolute if needed\n",
" if not Path(data_config.get('path', '')).is_absolute():\n",
" data_config['path'] = str(dataset_path.absolute())\n",
" \n",
" # Save updated config\n",
" with open(data_yaml, 'w') as f:\n",
" yaml.dump(data_config, f, default_flow_style=False)\n",
" print(\"\\nUpdated path to absolute.\")\n",
"else:\n",
" print(f\"data.yaml not found: {data_yaml}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Count images and labels\n",
"train_images = len(list((dataset_path / 'images' / 'train').glob('*')))\n",
"val_images = len(list((dataset_path / 'images' / 'val').glob('*')))\n",
"train_labels = len(list((dataset_path / 'labels' / 'train').glob('*.txt')))\n",
"val_labels = len(list((dataset_path / 'labels' / 'val').glob('*.txt')))\n",
"\n",
"print(\"Dataset Statistics:\")\n",
"print(f\" Train images: {train_images}\")\n",
"print(f\" Train labels: {train_labels}\")\n",
"print(f\" Val images: {val_images}\")\n",
"print(f\" Val labels: {val_labels}\")\n",
"print(f\" Classes: {data_config.get('nc', 'unknown')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Load YOLOv9t Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load pretrained YOLOv9t model\n",
"model = YOLO(CONFIG['model'])\n",
"\n",
"print(f\"Model: {CONFIG['model']}\")\n",
"print(f\"Task: {model.task}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Model information\n",
"model.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Train Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start training\n",
"print(\"Starting training...\")\n",
"print(f\" Dataset: {data_yaml}\")\n",
"print(f\" Epochs: {CONFIG['epochs']}\")\n",
"print(f\" Batch size: {CONFIG['batch']}\")\n",
"print(f\" Image size: {CONFIG['imgsz']}\")\n",
"print()\n",
"\n",
"results = model.train(\n",
" data=str(data_yaml),\n",
" epochs=CONFIG['epochs'],\n",
" batch=CONFIG['batch'],\n",
" imgsz=CONFIG['imgsz'],\n",
" patience=CONFIG['patience'],\n",
" optimizer=CONFIG['optimizer'],\n",
" lr0=CONFIG['lr0'],\n",
" lrf=CONFIG['lrf'],\n",
" momentum=CONFIG['momentum'],\n",
" weight_decay=CONFIG['weight_decay'],\n",
" hsv_h=CONFIG['hsv_h'],\n",
" hsv_s=CONFIG['hsv_s'],\n",
" hsv_v=CONFIG['hsv_v'],\n",
" degrees=CONFIG['degrees'],\n",
" translate=CONFIG['translate'],\n",
" scale=CONFIG['scale'],\n",
" shear=CONFIG['shear'],\n",
" flipud=CONFIG['flipud'],\n",
" fliplr=CONFIG['fliplr'],\n",
" mosaic=CONFIG['mosaic'],\n",
" mixup=CONFIG['mixup'],\n",
" workers=CONFIG['workers'],\n",
" device=CONFIG['device'],\n",
" project=CONFIG['project'],\n",
" name=CONFIG['name'],\n",
" exist_ok=CONFIG['exist_ok'],\n",
" pretrained=CONFIG['pretrained'],\n",
" verbose=CONFIG['verbose'],\n",
")\n",
"\n",
"print(\"\\nTraining complete!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Training Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Find training output directory\n",
"train_dir = Path(CONFIG['project']) / CONFIG['name']\n",
"\n",
"print(f\"Training output: {train_dir}\")\n",
"print(\"\\nContents:\")\n",
"for item in sorted(train_dir.iterdir()):\n",
" print(f\" {item.name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display training curves\n",
"results_png = train_dir / 'results.png'\n",
"if results_png.exists():\n",
" display(Image(filename=str(results_png), width=1000))\n",
"else:\n",
" print(\"results.png not found\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display confusion matrix\n",
"confusion_matrix = train_dir / 'confusion_matrix.png'\n",
"if confusion_matrix.exists():\n",
" display(Image(filename=str(confusion_matrix), width=600))\n",
"else:\n",
" print(\"confusion_matrix.png not found\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display F1 curve\n",
"f1_curve = train_dir / 'F1_curve.png'\n",
"if f1_curve.exists():\n",
" display(Image(filename=str(f1_curve), width=600))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display PR curve\n",
"pr_curve = train_dir / 'PR_curve.png'\n",
"if pr_curve.exists():\n",
" display(Image(filename=str(pr_curve), width=600))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display sample predictions\n",
"val_batch = train_dir / 'val_batch0_pred.jpg'\n",
"if val_batch.exists():\n",
" print(\"Validation batch predictions:\")\n",
" display(Image(filename=str(val_batch), width=1000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Evaluate Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load best weights\n",
"best_weights = train_dir / 'weights' / 'best.pt'\n",
"last_weights = train_dir / 'weights' / 'last.pt'\n",
"\n",
"print(f\"Best weights: {best_weights}\")\n",
"print(f\" Size: {best_weights.stat().st_size / 1024 / 1024:.1f} MB\")\n",
"\n",
"# Load best model\n",
"best_model = YOLO(str(best_weights))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Evaluate on validation set\n",
"print(\"Evaluating on validation set...\")\n",
"metrics = best_model.val(data=str(data_yaml))\n",
"\n",
"print(\"\\nValidation Metrics:\")\n",
"print(f\" mAP50: {metrics.box.map50:.4f}\")\n",
"print(f\" mAP50-95: {metrics.box.map:.4f}\")\n",
"print(f\" Precision: {metrics.box.mp:.4f}\")\n",
"print(f\" Recall: {metrics.box.mr:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Test Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test on validation images\n",
"val_images_dir = dataset_path / 'images' / 'val'\n",
"test_images = list(val_images_dir.glob('*.jpg'))[:4]\n",
"\n",
"if test_images:\n",
" print(f\"Testing on {len(test_images)} images...\")\n",
" \n",
" results = best_model.predict(\n",
" source=test_images,\n",
" conf=0.25,\n",
" save=True,\n",
" project='runs/predict',\n",
" name='test_inference'\n",
" )\n",
" \n",
" # Display results\n",
" predict_dir = Path('runs/predict/test_inference')\n",
" for img_path in sorted(predict_dir.glob('*.jpg'))[:4]:\n",
" print(f\"\\n{img_path.name}\")\n",
" display(Image(filename=str(img_path), width=600))\n",
"else:\n",
" print(\"No validation images found\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Inference speed test\n",
"import time\n",
"\n",
"if test_images:\n",
" test_img = str(test_images[0])\n",
" \n",
" # Warmup\n",
" for _ in range(3):\n",
" _ = best_model.predict(test_img, verbose=False)\n",
" \n",
" # Benchmark\n",
" times = []\n",
" for _ in range(10):\n",
" start = time.time()\n",
" _ = best_model.predict(test_img, verbose=False)\n",
" times.append(time.time() - start)\n",
" \n",
" avg_time = np.mean(times) * 1000\n",
" fps = 1000 / avg_time\n",
" \n",
" print(f\"Inference speed:\")\n",
" print(f\" Average: {avg_time:.1f} ms\")\n",
" print(f\" FPS: {fps:.1f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Export Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Export to ONNX\n",
"print(\"Exporting to ONNX...\")\n",
"onnx_path = best_model.export(format='onnx', simplify=True)\n",
"print(f\"ONNX exported: {onnx_path}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Export to TorchScript\n",
"print(\"Exporting to TorchScript...\")\n",
"torchscript_path = best_model.export(format='torchscript')\n",
"print(f\"TorchScript exported: {torchscript_path}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: Export to other formats\n",
"# TensorRT (requires TensorRT installation)\n",
"# engine_path = best_model.export(format='engine')\n",
"\n",
"# OpenVINO\n",
"# openvino_path = best_model.export(format='openvino')\n",
"\n",
"# CoreML (macOS)\n",
"# coreml_path = best_model.export(format='coreml')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Save and Download"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create output archive\n",
"import zipfile\n",
"\n",
"OUTPUT_ZIP = 'yolov9t_trained.zip'\n",
"\n",
"print(f\"Creating {OUTPUT_ZIP}...\")\n",
"\n",
"with zipfile.ZipFile(OUTPUT_ZIP, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
" # Add weights\n",
" zipf.write(best_weights, 'weights/best.pt')\n",
" zipf.write(last_weights, 'weights/last.pt')\n",
" \n",
" # Add ONNX if exists\n",
" onnx_file = best_weights.with_suffix('.onnx')\n",
" if onnx_file.exists():\n",
" zipf.write(onnx_file, 'weights/best.onnx')\n",
" \n",
" # Add results\n",
" for result_file in train_dir.glob('*.png'):\n",
" zipf.write(result_file, f'results/{result_file.name}')\n",
" \n",
" for result_file in train_dir.glob('*.csv'):\n",
" zipf.write(result_file, f'results/{result_file.name}')\n",
" \n",
" # Add args\n",
" args_file = train_dir / 'args.yaml'\n",
" if args_file.exists():\n",
" zipf.write(args_file, 'args.yaml')\n",
"\n",
"zip_size = os.path.getsize(OUTPUT_ZIP) / 1024 / 1024\n",
"print(f\"\\nExport complete!\")\n",
"print(f\" File: {OUTPUT_ZIP}\")\n",
"print(f\" Size: {zip_size:.1f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# List all output files\n",
"print(\"\\nAll output files:\")\n",
"print(f\"\\nTraining directory: {train_dir}\")\n",
"!ls -la {train_dir}/weights/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Final summary\n",
"print(\"=\" * 60)\n",
"print(\"TRAINING SUMMARY\")\n",
"print(\"=\" * 60)\n",
"print(f\"\\nModel: YOLOv9t\")\n",
"print(f\"Dataset: {dataset_path}\")\n",
"print(f\" Train images: {train_images}\")\n",
"print(f\" Val images: {val_images}\")\n",
"print(f\" Classes: {data_config.get('nc', 'unknown')}\")\n",
"\n",
"print(f\"\\nTraining:\")\n",
"print(f\" Epochs: {CONFIG['epochs']}\")\n",
"print(f\" Batch size: {CONFIG['batch']}\")\n",
"print(f\" Image size: {CONFIG['imgsz']}\")\n",
"\n",
"print(f\"\\nResults:\")\n",
"print(f\" mAP50: {metrics.box.map50:.4f}\")\n",
"print(f\" mAP50-95: {metrics.box.map:.4f}\")\n",
"print(f\" Precision: {metrics.box.mp:.4f}\")\n",
"print(f\" Recall: {metrics.box.mr:.4f}\")\n",
"\n",
"print(f\"\\nOutput files:\")\n",
"print(f\" Best weights: {best_weights}\")\n",
"print(f\" Export archive: {OUTPUT_ZIP}\")\n",
"\n",
"print(\"\\n\" + \"=\" * 60)\n",
"print(\"Training complete! Download weights for deployment.\")\n",
"print(\"=\" * 60)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"## Usage Example\n",
"\n",
"After training, use the model for inference:\n",
"\n",
"```python\n",
"from ultralytics import YOLO\n",
"\n",
"# Load trained model\n",
"model = YOLO('best.pt')\n",
"\n",
"# Inference on image\n",
"results = model.predict('image.jpg', conf=0.25)\n",
"\n",
"# Inference on video\n",
"results = model.predict('video.mp4', conf=0.25, save=True)\n",
"\n",
"# Access detections\n",
"for result in results:\n",
" boxes = result.boxes\n",
" for box in boxes:\n",
" x1, y1, x2, y2 = box.xyxy[0]\n",
" confidence = box.conf[0]\n",
" class_id = box.cls[0]\n",
"```\n",
"\n",
"## Tips\n",
"\n",
"- **Low mAP?** Try:\n",
" - More training epochs\n",
" - Data augmentation adjustments\n",
" - Lower learning rate\n",
" - More training data\n",
"\n",
"- **Overfitting?** Try:\n",
" - More augmentation\n",
" - Dropout/regularization\n",
" - Early stopping (patience)\n",
"\n",
"- **Slow training?** Try:\n",
" - Larger batch size (if GPU memory allows)\n",
" - Mixed precision (amp=True)\n",
" - Smaller image size"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}