Files
augment_dataset/rotate_dataset.py
2026-02-09 10:10:11 +07:00

90 lines
3.0 KiB
Python

import argparse
from copy import deepcopy
import datumaro as dm
from datumaro.components.transformer import ItemTransform
from datumaro.components.media import Image
from datumaro.components.annotation import Bbox
import numpy as np
class RotateTransform(ItemTransform):
def __init__(self, extractor, angle):
super().__init__(extractor)
self.angle = angle
# k is number of 90-degree CCW rotations
self.k = (angle // 90) % 4
def transform_item(self, item):
if not item.media or not isinstance(item.media, Image):
return item
h, w = item.media.data.shape[:2]
# 1. Rotate Image
rotated_image = np.rot90(item.media.data, k=-self.k) # CCW to CW adjustment
# 2. Rotate Bounding Boxes
new_annotations = []
for ann in item.annotations:
if isinstance(ann, Bbox):
x, y, bw, bh = ann.points
# Calculate new coordinates based on rotation angle
if self.angle == 90:
new_bbox = [h - (y + bh), x, bh, bw]
elif self.angle == 180:
new_bbox = [w - (x + bw), h - (y + bh), bw, bh]
elif self.angle == 270:
new_bbox = [y, w - (x + bw), bh, bw]
else:
new_bbox = [x, y, bw, bh]
new_annotations.append(ann.wrap(points=new_bbox))
else:
new_annotations.append(ann)
print(f"{item.id}_r{self.angle}")
return item.wrap(
id=f"{item.id}_r{self.angle}",
media=Image.from_numpy(rotated_image),
annotations=new_annotations,
)
def main():
parser = argparse.ArgumentParser(description="Rotate Datumaro datasets.")
parser.add_argument("--input", required=True, help="Path to input dataset")
parser.add_argument("--output", default="output_yolo", help="Output directory")
parser.add_argument(
"--angles",
nargs="+",
type=int,
default=[90, 180, 270],
help="Space separated angles (e.g. 90 180)",
)
args = parser.parse_args()
# 1. Load Original
dataset = dm.Dataset.import_from(args.input, "coco")
cloned_dataset = deepcopy(dataset)
# 2. Generate list of rotated datasets from args
rotated_sets = []
for angle in args.angles:
ds_rt = cloned_dataset.transform(RotateTransform, angle=angle)
rotated_sets.append(ds_rt)
cloned_dataset = deepcopy(dataset)
# rotated_sets = [dataset.transform(RotateTransform, angle=a) for a in args.angles]
# 3. Combine using splat operator on the array
combined_ds = dm.Dataset.from_extractors(dataset, *rotated_sets)
# combined_ds.transform("reindex", start=0)
# 4. Export to Ultralytics YOLO
# combined_ds.export(args.output, format="coco", save_media=True)
combined_ds.export(args.output, format="coco", save_media=True)
print(f"✅ Success! Augmented dataset saved to: {args.output}")
if __name__ == "__main__":
main()