394 lines
14 KiB
Python
394 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Augment a YOLOv9-format dataset by creating new image and label files for:
|
||
horizontal flip, vertical flip, +10% hue, +30% contrast, and grayscale.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import logging
|
||
import random
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
|
||
# Augmentation strength constants (tune as needed)
|
||
HUE_DELTA = 0.1 # 10% hue shift in [0, 1] scale
|
||
CONTRAST_FACTOR = 1.3 # 30% contrast increase
|
||
|
||
# Suffix used for each augmentation type -> (suffix, applies to labels)
|
||
SUFFIX_HFLIP = "hflip"
|
||
SUFFIX_VFLIP = "vflip"
|
||
SUFFIX_HUE = "hue"
|
||
SUFFIX_CONTRAST = "contrast"
|
||
SUFFIX_GRAY = "gray"
|
||
|
||
LOG = logging.getLogger(__name__)
|
||
|
||
# Default image extensions to discover (case-insensitive)
|
||
DEFAULT_IMAGE_EXTS = (".jpg", ".jpeg", ".png")
|
||
|
||
|
||
def read_yolo_labels(path: Path) -> list[tuple[int, float, float, float, float]]:
|
||
"""Read YOLO label file; return list of (class_id, x_center, y_center, width, height)."""
|
||
rows = []
|
||
with path.open() as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split()
|
||
if len(parts) != 5:
|
||
LOG.warning(
|
||
"Skipping malformed line in %s (expected 5 values): %s",
|
||
path,
|
||
line[:80],
|
||
)
|
||
continue
|
||
class_id = int(parts[0])
|
||
x_center = float(parts[1])
|
||
y_center = float(parts[2])
|
||
width = float(parts[3])
|
||
height = float(parts[4])
|
||
rows.append((class_id, x_center, y_center, width, height))
|
||
return rows
|
||
|
||
|
||
def write_yolo_labels(path: Path, rows: list[tuple[int, float, float, float, float]]) -> None:
|
||
"""Write YOLO label file in one-line-per-object format."""
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with path.open("w") as f:
|
||
for class_id, x_center, y_center, width, height in rows:
|
||
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
|
||
|
||
|
||
def flip_labels_horizontal(
|
||
rows: list[tuple[int, float, float, float, float]],
|
||
) -> list[tuple[int, float, float, float, float]]:
|
||
"""Return new rows with x_center replaced by 1 - x_center."""
|
||
return [(c, 1.0 - x, y, w, h) for c, x, y, w, h in rows]
|
||
|
||
|
||
def flip_labels_vertical(
|
||
rows: list[tuple[int, float, float, float, float]],
|
||
) -> list[tuple[int, float, float, float, float]]:
|
||
"""Return new rows with y_center replaced by 1 - y_center."""
|
||
return [(c, x, 1.0 - y, w, h) for c, x, y, w, h in rows]
|
||
|
||
|
||
def _load_image(path: Path):
|
||
"""Load image as BGR; raise on failure."""
|
||
img = cv2.imread(str(path))
|
||
if img is None:
|
||
raise OSError(f"Failed to load image: {path}. Check path and format (e.g. .jpg, .png).")
|
||
return img
|
||
|
||
|
||
def _ensure_parent(path: Path) -> None:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
def apply_horizontal_flip(
|
||
image_path: Path,
|
||
labels_path: Path,
|
||
out_image_path: Path,
|
||
out_labels_path: Path,
|
||
dry_run: bool = False,
|
||
) -> None:
|
||
"""Flip image horizontally and transform labels (x_center -> 1 - x_center)."""
|
||
if dry_run:
|
||
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
|
||
return
|
||
img = _load_image(image_path)
|
||
flipped = cv2.flip(img, 1)
|
||
_ensure_parent(out_image_path)
|
||
if not cv2.imwrite(str(out_image_path), flipped):
|
||
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
|
||
rows = read_yolo_labels(labels_path)
|
||
write_yolo_labels(out_labels_path, flip_labels_horizontal(rows))
|
||
|
||
|
||
def apply_vertical_flip(
|
||
image_path: Path,
|
||
labels_path: Path,
|
||
out_image_path: Path,
|
||
out_labels_path: Path,
|
||
dry_run: bool = False,
|
||
) -> None:
|
||
"""Flip image vertically and transform labels (y_center -> 1 - y_center)."""
|
||
if dry_run:
|
||
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
|
||
return
|
||
img = _load_image(image_path)
|
||
flipped = cv2.flip(img, 0)
|
||
_ensure_parent(out_image_path)
|
||
if not cv2.imwrite(str(out_image_path), flipped):
|
||
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
|
||
rows = read_yolo_labels(labels_path)
|
||
write_yolo_labels(out_labels_path, flip_labels_vertical(rows))
|
||
|
||
|
||
def apply_hue_shift(
|
||
image_path: Path,
|
||
labels_path: Path,
|
||
out_image_path: Path,
|
||
out_labels_path: Path,
|
||
delta: float = HUE_DELTA,
|
||
dry_run: bool = False,
|
||
) -> None:
|
||
"""Shift hue by delta (0–1 scale); copy labels unchanged."""
|
||
if dry_run:
|
||
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
|
||
return
|
||
img = _load_image(image_path)
|
||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype("float32")
|
||
h, s, v = cv2.split(hsv)
|
||
# OpenCV H is 0–180; treat delta as fraction of full circle
|
||
h = (h + delta * 180) % 180
|
||
hsv = cv2.merge([h, s, v]).astype("uint8")
|
||
out = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
||
_ensure_parent(out_image_path)
|
||
if not cv2.imwrite(str(out_image_path), out):
|
||
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
|
||
rows = read_yolo_labels(labels_path)
|
||
write_yolo_labels(out_labels_path, rows)
|
||
|
||
|
||
def apply_contrast(
|
||
image_path: Path,
|
||
labels_path: Path,
|
||
out_image_path: Path,
|
||
out_labels_path: Path,
|
||
factor: float = CONTRAST_FACTOR,
|
||
dry_run: bool = False,
|
||
) -> None:
|
||
"""Apply contrast: (pixel - mean) * factor + mean, clip to [0, 255]; copy labels."""
|
||
if dry_run:
|
||
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
|
||
return
|
||
img = _load_image(image_path).astype("float32")
|
||
mean = img.mean()
|
||
out = (img - mean) * factor + mean
|
||
out = out.clip(0, 255).astype("uint8")
|
||
_ensure_parent(out_image_path)
|
||
if not cv2.imwrite(str(out_image_path), out):
|
||
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
|
||
rows = read_yolo_labels(labels_path)
|
||
write_yolo_labels(out_labels_path, rows)
|
||
|
||
|
||
def apply_grayscale(
|
||
image_path: Path,
|
||
labels_path: Path,
|
||
out_image_path: Path,
|
||
out_labels_path: Path,
|
||
dry_run: bool = False,
|
||
) -> None:
|
||
"""Convert to grayscale and broadcast to 3 channels; copy labels."""
|
||
if dry_run:
|
||
LOG.info("Would create: %s, %s", out_image_path, out_labels_path)
|
||
return
|
||
img = _load_image(image_path)
|
||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
out = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
||
_ensure_parent(out_image_path)
|
||
if not cv2.imwrite(str(out_image_path), out):
|
||
raise OSError(f"Failed to write image: {out_image_path}. Check permissions and disk space.")
|
||
rows = read_yolo_labels(labels_path)
|
||
write_yolo_labels(out_labels_path, rows)
|
||
|
||
|
||
def discover_images_and_labels(
|
||
dataset_dir: Path,
|
||
image_ext: str,
|
||
) -> list[tuple[Path, Path]]:
|
||
"""
|
||
Find (image_path, label_path) pairs.
|
||
Prefer dataset_dir/images/ and dataset_dir/labels/; else use dataset_dir for both.
|
||
"""
|
||
images_dir = dataset_dir / "images"
|
||
labels_dir = dataset_dir / "labels"
|
||
if not images_dir.is_dir():
|
||
images_dir = dataset_dir
|
||
labels_dir = dataset_dir
|
||
if not images_dir.is_dir():
|
||
raise FileNotFoundError(
|
||
f"Dataset directory not found or has no 'images' subdir: {dataset_dir}. "
|
||
"Provide a path that contains an 'images' folder or is the folder with image files."
|
||
)
|
||
if not labels_dir.is_dir():
|
||
raise FileNotFoundError(
|
||
f"Labels directory not found: {labels_dir}. "
|
||
"Expected a 'labels' folder next to 'images', or the same folder for flat layout."
|
||
)
|
||
pairs = []
|
||
|
||
raw = (image_ext or "").strip()
|
||
if raw.lower() in {"*", "any", "all", "auto"}:
|
||
allowed_exts = {e.lower() for e in DEFAULT_IMAGE_EXTS}
|
||
else:
|
||
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
||
if not parts:
|
||
allowed_exts = {e.lower() for e in DEFAULT_IMAGE_EXTS}
|
||
else:
|
||
allowed_exts = {(p if p.startswith(".") else f".{p}").lower() for p in parts}
|
||
|
||
for img_path in images_dir.iterdir():
|
||
if not img_path.is_file():
|
||
continue
|
||
if img_path.suffix.lower() not in allowed_exts:
|
||
continue
|
||
base = img_path.stem
|
||
label_path = labels_dir / f"{base}.txt"
|
||
if not label_path.is_file():
|
||
LOG.warning("No label file for image %s, skipping: %s", img_path.name, label_path)
|
||
continue
|
||
pairs.append((img_path, label_path))
|
||
return pairs
|
||
|
||
|
||
def run_augmentations(
|
||
dataset_dir: Path,
|
||
output_dir: Path | None,
|
||
image_ext: str,
|
||
enabled: set[str],
|
||
max_per_image: int,
|
||
dry_run: bool,
|
||
) -> None:
|
||
"""Discover image/label pairs and apply up to max_per_image random augmentations per image."""
|
||
pairs = discover_images_and_labels(dataset_dir, image_ext)
|
||
if not pairs:
|
||
LOG.warning("No image/label pairs found in %s with image-ext %s.", dataset_dir, image_ext)
|
||
return
|
||
enabled_list = list(enabled)
|
||
if not enabled_list:
|
||
LOG.warning("No augmentation types enabled.")
|
||
return
|
||
out_root = output_dir if output_dir is not None else dataset_dir
|
||
if output_dir is None:
|
||
out_images = dataset_dir / "images" if (dataset_dir / "images").is_dir() else dataset_dir
|
||
out_labels = dataset_dir / "labels" if (dataset_dir / "labels").is_dir() else dataset_dir
|
||
else:
|
||
out_images = out_root / "images"
|
||
out_labels = out_root / "labels"
|
||
total_images = len(pairs)
|
||
total_augmentations = 0
|
||
start_time = time.perf_counter()
|
||
LOG.info("Starting augmentation: %d images, max %d per image.", total_images, max_per_image)
|
||
for idx, (img_path, label_path) in enumerate(pairs, start=1):
|
||
base = img_path.stem
|
||
ext = img_path.suffix
|
||
k = min(max_per_image, len(enabled_list))
|
||
chosen = random.sample(enabled_list, k)
|
||
LOG.info(
|
||
"Processing image %d/%d: %s (%s)",
|
||
idx,
|
||
total_images,
|
||
img_path.name,
|
||
", ".join(chosen),
|
||
)
|
||
for suffix in chosen:
|
||
out_img = out_images / f"{base}_{suffix}{ext}"
|
||
out_lbl = out_labels / f"{base}_{suffix}.txt"
|
||
try:
|
||
if suffix == SUFFIX_HFLIP:
|
||
apply_horizontal_flip(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
|
||
elif suffix == SUFFIX_VFLIP:
|
||
apply_vertical_flip(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
|
||
elif suffix == SUFFIX_HUE:
|
||
apply_hue_shift(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
|
||
elif suffix == SUFFIX_CONTRAST:
|
||
apply_contrast(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
|
||
elif suffix == SUFFIX_GRAY:
|
||
apply_grayscale(img_path, label_path, out_img, out_lbl, dry_run=dry_run)
|
||
total_augmentations += 1
|
||
except OSError as e:
|
||
LOG.error("Skipping %s %s: %s", suffix, img_path.name, e)
|
||
elapsed = time.perf_counter() - start_time
|
||
LOG.info(
|
||
"Completed: %d images, %d augmentations in %.1f s.",
|
||
total_images,
|
||
total_augmentations,
|
||
elapsed,
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||
parser = argparse.ArgumentParser(
|
||
description="Augment YOLOv9 dataset with flips, hue, contrast, and grayscale.",
|
||
)
|
||
parser.add_argument(
|
||
"--dataset-dir",
|
||
type=Path,
|
||
required=True,
|
||
help="Root of the dataset (containing images/ and labels/ or flat image+label files).",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
type=Path,
|
||
default=None,
|
||
help="Where to write augmented files (default: same as dataset-dir).",
|
||
)
|
||
parser.add_argument(
|
||
"--image-ext",
|
||
type=str,
|
||
default=",".join(DEFAULT_IMAGE_EXTS),
|
||
help=(
|
||
"Image extension(s) to look for. Provide a single ext (e.g. .jpg) or a comma-separated list "
|
||
"(e.g. .jpg,.jpeg,.png). Use 'all'/'any' to use the defaults."
|
||
),
|
||
)
|
||
parser.add_argument(
|
||
"--suffixes",
|
||
type=str,
|
||
nargs="+",
|
||
default=[SUFFIX_HFLIP, SUFFIX_VFLIP, SUFFIX_HUE, SUFFIX_CONTRAST, SUFFIX_GRAY],
|
||
choices=[SUFFIX_HFLIP, SUFFIX_VFLIP, SUFFIX_HUE, SUFFIX_CONTRAST, SUFFIX_GRAY],
|
||
help="Which augmentations can be applied (default: all).",
|
||
)
|
||
parser.add_argument(
|
||
"--max-per-image",
|
||
type=int,
|
||
default=2,
|
||
metavar="N",
|
||
help="Maximum number of augmentation types to apply per image (default: 2).",
|
||
)
|
||
parser.add_argument(
|
||
"--seed",
|
||
type=int,
|
||
default=None,
|
||
help="Random seed for reproducible augmentation selection.",
|
||
)
|
||
parser.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="Only print which files would be created.",
|
||
)
|
||
args = parser.parse_args()
|
||
if args.max_per_image < 1:
|
||
LOG.error("--max-per-image must be at least 1.")
|
||
raise SystemExit(1)
|
||
if args.seed is not None:
|
||
random.seed(args.seed)
|
||
if not args.dataset_dir.is_dir():
|
||
LOG.error(
|
||
"Dataset directory does not exist: %s. Create it and add images/ and labels/ (or image + label files).",
|
||
args.dataset_dir,
|
||
)
|
||
raise SystemExit(1)
|
||
run_augmentations(
|
||
dataset_dir=args.dataset_dir,
|
||
output_dir=args.output_dir,
|
||
image_ext=args.image_ext,
|
||
enabled=set(args.suffixes),
|
||
max_per_image=args.max_per_image,
|
||
dry_run=args.dry_run,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|