import os
import cv2
import albumentations as A
import random
import shutil
from tqdm import tqdm
def clean_unmatched_files(img_dir, label_dir):
"""
清理YOLO数据集中不匹配的图片和标签文件:
- 删除没有对应.txt标签的图片
- 删除没有对应.jpg/.png图片的标签
"""
print("正在清理不匹配的图片和标签文件...")
img_extensions = ('.jpg', '.jpeg', '.png')
img_files = {os.path.splitext(f)[0] for f in os.listdir(img_dir) if f.lower().endswith(img_extensions)}
label_files = {os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith('.txt')}
# 找出不匹配的文件
imgs_without_labels = img_files - label_files
labels_without_imgs = label_files - img_files
# 删除没有标签的图片
for stem in imgs_without_labels:
for ext in img_extensions:
img_path = os.path.join(img_dir, stem + ext)
if os.path.exists(img_path):
os.remove(img_path)
print(f"已删除无标签图片: {img_path}")
# 删除没有图片的标签
for stem in labels_without_imgs:
label_path = os.path.join(label_dir, stem + '.txt')
if os.path.exists(label_path):
os.remove(label_path)
print(f"已删除无图片标签: {label_path}")
print(f"清理完成!共删除 {len(imgs_without_labels)} 张无标签图片 和 {len(labels_without_imgs)} 个无图片标签。")
def augment_images_and_labels(img_dir, label_dir, output_img_dir, output_label_dir, augment_times=3, view_dir=None, view_ratio=0.1):
"""
对YOLO数据进行安全的数据增强,仅针对训练集
:param img_dir: 原始图片的目录路径
:param label_dir: YOLO标签目录路径
:param output_img_dir: 增强后的图片保存目录(可与输入相同)
:param output_label_dir: 增强后的标签保存目录(可与输入相同)
:param augment_times: 每张图片的增强次数
:param view_dir: 查看增强效果的目录路径
:param view_ratio: 查看增强效果的图片比例
"""
# 确保输出目录存在
os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_label_dir, exist_ok=True)
if view_dir:
os.makedirs(view_dir, exist_ok=True)
# 获取所有原始图片(排除已增强的)
img_extensions = ('.jpg', '.jpeg', '.png')
img_files = [f for f in os.listdir(img_dir)
if f.lower().endswith(img_extensions) and '_aug_' not in f]
if not img_files:
print("警告:未找到任何原始图片(可能已被清理或路径错误)")
return
num_view_imgs = max(1, int(len(img_files) * view_ratio))
view_indices = set(random.sample(range(len(img_files)), num_view_imgs))
for idx, img_file in enumerate(tqdm(img_files)):
img_path = os.path.join(img_dir, img_file)
label_path = os.path.join(label_dir, os.path.splitext(img_file)[0] + '.txt')
# 读取图像
image = cv2.imread(img_path)
if image is None:
print(f"无法读取图像,跳过: {img_path}")
continue
height, width = image.shape[:2]
# 读取标签
bboxes = []
class_labels = []
if os.path.exists(label_path):
with open(label_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) != 5:
continue # 跳过格式错误的行
class_id = int(parts[0])
x_center, y_center, w, h = map(float, parts[1:])
# 验证坐标合法性(YOLO格式应在 [0,1])
if not (0 <= x_center <= 1 and 0 <= y_center <= 1 and 0 < w <= 1 and 0 < h <= 1):
continue
bboxes.append([x_center, y_center, w, h])
class_labels.append(class_id)
if not bboxes:
print(f"有效标签为空,跳过: {label_path}")
continue
# 动态设置裁剪尺寸(防止 crop > image)
min_dim = min(height, width)
crop_h = min(500, min_dim)
crop_w = min(500, min_dim)
augmentations = A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
A.GaussianBlur(blur_limit=(3, 7), p=0.2),
A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
A.Resize(width=640, height=640, p=0.5),
A.RandomCrop(width=crop_w, height=crop_h, p=0.5),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
A.RandomScale(scale_limit=0.2, p=0.2),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'], min_visibility=0.1, min_area=10))
for i in range(augment_times):
try:
augmented = augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
except Exception as e:
print(f"增强失败({img_file} 第{i}次): {e}")
continue
aug_image = augmented['image']
aug_bboxes = augmented['bboxes']
aug_labels = augmented['class_labels']
if not aug_bboxes:
continue # 跳过无有效框的增强结果
# 保存增强结果
base_name = os.path.splitext(img_file)[0]
out_img_path = os.path.join(output_img_dir, f"{base_name}_aug_{i}.jpg")
out_label_path = os.path.join(output_label_dir, f"{base_name}_aug_{i}.txt")
cv2.imwrite(out_img_path, aug_image)
with open(out_label_path, 'w') as f:
for bbox, cls in zip(aug_bboxes, aug_labels):
x_center, y_center, w, h = bbox
# 再次确保数值合法(防止浮点误差)
x_center = max(0.0, min(1.0, x_center))
y_center = max(0.0, min(1.0, y_center))
w = max(0.0, min(1.0, w))
h = max(0.0, min(1.0, h))
f.write(f"{int(cls)} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n")
# 可视化样本
if view_dir and idx in view_indices:
view_img = aug_image.copy()
h_img, w_img = view_img.shape[:2]
for bbox, cls in zip(aug_bboxes, aug_labels):
xc, yc, bw, bh = bbox
x1 = int((xc - bw / 2) * w_img)
y1 = int((yc - bh / 2) * h_img)
x2 = int((xc + bw / 2) * w_img)
y2 = int((yc + bh / 2) * h_img)
cv2.rectangle(view_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(view_img, str(cls), (x1, max(0, y1 - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
cv2.imwrite(os.path.join(view_dir, f"{base_name}_aug_{i}.jpg"), view_img)
print("✅ 数据增强完成!")
if __name__ == "__main__":
# 配置路径(必须为全英文路径)
img_dir = "./train/images"
label_dir = "./train/labels"
output_img_dir = "./train/images" # 增强图可追加到原目录
output_label_dir = "./train/labels" # 增强标签同理
# 第一步:清理不匹配文件
clean_unmatched_files(img_dir, label_dir)
# 第二步:执行增强
augment_images_and_labels(
img_dir=img_dir,
label_dir=label_dir,
output_img_dir=output_img_dir,
output_label_dir=output_label_dir,
augment_times=4,
view_dir="view",
view_ratio=0.1
)
python