Files
dataset-yolo-script/augment_yolov9_dataset.py
T
2026-02-04 15:22:28 +07:00

394 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Augment a YOLOv9-format dataset by creating new image and label files for:
horizontal flip, vertical flip, +10% hue, +30% contrast, and grayscale.
"""
from __future__ import annotations
import argparse
import logging
import random
import time
from pathlib import Path
import cv2
# Augmentation strength constants (tune as needed)
HUE_DELTA = 0.1 # 10% hue shift in [0, 1] scale
CONTRAST_FACTOR = 1.3 # 30% contrast increase
# Suffix used for each augmentation type -> (suffix, applies to labels)
SUFFIX_HFLIP = "hflip"
SUFFIX_VFLIP = "vflip"
SUFFIX_HUE = "hue"
SUFFIX_CONTRAST = "contrast"
SUFFIX_GRAY = "gray"
LOG = logging.getLogger(__name__)
# Default image extensions to discover (case-insensitive)
DEFAULT_IMAGE_EXTS = (".jpg", ".jpeg", ".png")
def read_yolo_labels(path: Path) -> list[tuple[int, float, float, float, float]]:
"""Read YOLO label file; return list of (class_id, x_center, y_center, width, height)."""
rows = []
with path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
LOG.warning(
"Skipping malformed line in %s (expected 5 values): %s",
path,
line[:80],
)
continue
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
rows.append((class_id, x_center, y_center, width, height))
return rows
def write_yolo_labels(path: Path, rows: list[tuple[int, float, float, float, float]]) -> None:
"""Write YOLO label file in one-line-per-object format."""
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w") as f:
for class_id, x_center, y_center, width, height in rows:
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
def flip_labels_horizontal(
rows: list[tuple[int, float, float, float, float]],
) -> list[tuple[int, float, float, float, float]]:
"""Return new rows with x_center replaced by 1 - x_center."""
return [(c, 1.0 - x, y, w, h) for c, x, y, w, h in rows]
def flip_labels_vertical(
rows: list[tuple[int, float, float, float, float]],
) -> list[tuple[int, float, float, float, float]]:
"""Return new rows with y_center replaced by 1 - y_center."""
return [(c, x, 1.0 - y, w, h) for c, x, y, w, h in rows]
def _load_image(path: Path):
"""Load image as BGR; raise on failure."""
img = cv2.imread(str(path))
if img is None:
raise OSError(f"Failed to load image: {path}. Check path and format (e.g. .jpg, .png).")
return img
def _ensure_parent(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def apply_horizontal_flip(
image_path: Path,
labels_path: Path,
out_image_path: Path,
out_labels_path: Path,
dry_run: bool = False,
) -> None:
"""Flip image horizontally and transform labels (x_center -> 1 - x_center)."""
if dry_run:
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
return
img = _load_image(image_path)
flipped = cv2.flip(img, 1)
_ensure_parent(out_image_path)
if not cv2.imwrite(str(out_image_path), flipped):
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
rows = read_yolo_labels(labels_path)
write_yolo_labels(out_labels_path, flip_labels_horizontal(rows))
def apply_vertical_flip(
image_path: Path,
labels_path: Path,
out_image_path: Path,
out_labels_path: Path,
dry_run: bool = False,
) -> None:
"""Flip image vertically and transform labels (y_center -> 1 - y_center)."""
if dry_run:
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
return
img = _load_image(image_path)
flipped = cv2.flip(img, 0)
_ensure_parent(out_image_path)
if not cv2.imwrite(str(out_image_path), flipped):
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
rows = read_yolo_labels(labels_path)
write_yolo_labels(out_labels_path, flip_labels_vertical(rows))
def apply_hue_shift(
image_path: Path,
labels_path: Path,
out_image_path: Path,
out_labels_path: Path,
delta: float = HUE_DELTA,
dry_run: bool = False,
) -> None:
"""Shift hue by delta (01 scale); copy labels unchanged."""
if dry_run:
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
return
img = _load_image(image_path)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype("float32")
h, s, v = cv2.split(hsv)
# OpenCV H is 0180; treat delta as fraction of full circle
h = (h + delta * 180) % 180
hsv = cv2.merge([h, s, v]).astype("uint8")
out = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
_ensure_parent(out_image_path)
if not cv2.imwrite(str(out_image_path), out):
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
rows = read_yolo_labels(labels_path)
write_yolo_labels(out_labels_path, rows)
def apply_contrast(
image_path: Path,
labels_path: Path,
out_image_path: Path,
out_labels_path: Path,
factor: float = CONTRAST_FACTOR,
dry_run: bool = False,
) -> None:
"""Apply contrast: (pixel - mean) * factor + mean, clip to [0, 255]; copy labels."""
if dry_run:
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
return
img = _load_image(image_path).astype("float32")
mean = img.mean()
out = (img - mean) * factor + mean
out = out.clip(0, 255).astype("uint8")
_ensure_parent(out_image_path)
if not cv2.imwrite(str(out_image_path), out):
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
rows = read_yolo_labels(labels_path)
write_yolo_labels(out_labels_path, rows)
def apply_grayscale(
image_path: Path,
labels_path: Path,
out_image_path: Path,
out_labels_path: Path,
dry_run: bool = False,
) -> None:
"""Convert to grayscale and broadcast to 3 channels; copy labels."""
if dry_run:
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
return
img = _load_image(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
out = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
_ensure_parent(out_image_path)
if not cv2.imwrite(str(out_image_path), out):
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
rows = read_yolo_labels(labels_path)
write_yolo_labels(out_labels_path, rows)
def discover_images_and_labels(
dataset_dir: Path,
image_ext: str,
) -> list[tuple[Path, Path]]:
"""
Find (image_path, label_path) pairs.
Prefer dataset_dir/images/ and dataset_dir/labels/; else use dataset_dir for both.
"""
images_dir = dataset_dir / "images"
labels_dir = dataset_dir / "labels"
if not images_dir.is_dir():
images_dir = dataset_dir
labels_dir = dataset_dir
if not images_dir.is_dir():
raise FileNotFoundError(
f"Dataset directory not found or has no 'images' subdir: {dataset_dir}. "
"Provide a path that contains an 'images' folder or is the folder with image files."
)
if not labels_dir.is_dir():
raise FileNotFoundError(
f"Labels directory not found: {labels_dir}. "
"Expected a 'labels' folder next to 'images', or the same folder for flat layout."
)
pairs = []
raw = (image_ext or "").strip()
if raw.lower() in {"*", "any", "all", "auto"}:
allowed_exts = {e.lower() for e in DEFAULT_IMAGE_EXTS}
else:
parts = [p.strip() for p in raw.split(",") if p.strip()]
if not parts:
allowed_exts = {e.lower() for e in DEFAULT_IMAGE_EXTS}
else:
allowed_exts = {(p if p.startswith(".") else f".{p}").lower() for p in parts}
for img_path in images_dir.iterdir():
if not img_path.is_file():
continue
if img_path.suffix.lower() not in allowed_exts:
continue
base = img_path.stem
label_path = labels_dir / f"{base}.txt"
if not label_path.is_file():
LOG.warning("No label file for image %s, skipping: %s", img_path.name, label_path)
continue
pairs.append((img_path, label_path))
return pairs
def run_augmentations(
dataset_dir: Path,
output_dir: Path | None,
image_ext: str,
enabled: set[str],
max_per_image: int,
dry_run: bool,
) -> None:
"""Discover image/label pairs and apply up to max_per_image random augmentations per image."""
pairs = discover_images_and_labels(dataset_dir, image_ext)
if not pairs:
LOG.warning("No image/label pairs found in %s with image-ext %s.", dataset_dir, image_ext)
return
enabled_list = list(enabled)
if not enabled_list:
LOG.warning("No augmentation types enabled.")
return
out_root = output_dir if output_dir is not None else dataset_dir
if output_dir is None:
out_images = dataset_dir / "images" if (dataset_dir / "images").is_dir() else dataset_dir
out_labels = dataset_dir / "labels" if (dataset_dir / "labels").is_dir() else dataset_dir
else:
out_images = out_root / "images"
out_labels = out_root / "labels"
total_images = len(pairs)
total_augmentations = 0
start_time = time.perf_counter()
LOG.info("Starting augmentation: %d images, max %d per image.", total_images, max_per_image)
for idx, (img_path, label_path) in enumerate(pairs, start=1):
base = img_path.stem
ext = img_path.suffix
k = min(max_per_image, len(enabled_list))
chosen = random.sample(enabled_list, k)
LOG.info(
"Processing image %d/%d: %s (%s)",
idx,
total_images,
img_path.name,
", ".join(chosen),
)
for suffix in chosen:
out_img = out_images / f"{base}_{suffix}{ext}"
out_lbl = out_labels / f"{base}_{suffix}.txt"
try:
if suffix == SUFFIX_HFLIP:
apply_horizontal_flip(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
elif suffix == SUFFIX_VFLIP:
apply_vertical_flip(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
elif suffix == SUFFIX_HUE:
apply_hue_shift(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
elif suffix == SUFFIX_CONTRAST:
apply_contrast(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
elif suffix == SUFFIX_GRAY:
apply_grayscale(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
total_augmentations += 1
except OSError as e:
LOG.error("Skipping %s %s: %s", suffix, img_path.name, e)
elapsed = time.perf_counter() - start_time
LOG.info(
"Completed: %d images, %d augmentations in %.1f s.",
total_images,
total_augmentations,
elapsed,
)
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(
description="Augment YOLOv9 dataset with flips, hue, contrast, and grayscale.",
)
parser.add_argument(
"--dataset-dir",
type=Path,
required=True,
help="Root of the dataset (containing images/ and labels/ or flat image+label files).",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Where to write augmented files (default: same as dataset-dir).",
)
parser.add_argument(
"--image-ext",
type=str,
default=",".join(DEFAULT_IMAGE_EXTS),
help=(
"Image extension(s) to look for. Provide a single ext (e.g. .jpg) or a comma-separated list "
"(e.g. .jpg,.jpeg,.png). Use 'all'/'any' to use the defaults."
),
)
parser.add_argument(
"--suffixes",
type=str,
nargs="+",
default=[SUFFIX_HFLIP, SUFFIX_VFLIP, SUFFIX_HUE, SUFFIX_CONTRAST, SUFFIX_GRAY],
choices=[SUFFIX_HFLIP, SUFFIX_VFLIP, SUFFIX_HUE, SUFFIX_CONTRAST, SUFFIX_GRAY],
help="Which augmentations can be applied (default: all).",
)
parser.add_argument(
"--max-per-image",
type=int,
default=2,
metavar="N",
help="Maximum number of augmentation types to apply per image (default: 2).",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducible augmentation selection.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Only print which files would be created.",
)
args = parser.parse_args()
if args.max_per_image < 1:
LOG.error("--max-per-image must be at least 1.")
raise SystemExit(1)
if args.seed is not None:
random.seed(args.seed)
if not args.dataset_dir.is_dir():
LOG.error(
"Dataset directory does not exist: %s. Create it and add images/ and labels/ (or image + label files).",
args.dataset_dir,
)
raise SystemExit(1)
run_augmentations(
dataset_dir=args.dataset_dir,
output_dir=args.output_dir,
image_ext=args.image_ext,
enabled=set(args.suffixes),
max_per_image=args.max_per_image,
dry_run=args.dry_run,
)
if __name__ == "__main__":
main()