添加数据增强方案以及扩散生成模型的想法

This commit is contained in:
Jiao77
2025-10-20 21:14:03 +08:00
parent d6d00cf088
commit 08f488f0d8
22 changed files with 1903 additions and 190 deletions

155
train.py
View File

@@ -7,7 +7,7 @@ from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from data.ic_dataset import ICLayoutTrainingDataset
@@ -82,25 +82,152 @@ def main(args):
transform = get_transform()
dataset = ICLayoutTrainingDataset(
# 读取增强与合成配置
augment_cfg = cfg.get("augment", {})
elastic_cfg = augment_cfg.get("elastic", {}) if augment_cfg else {}
use_albu = bool(elastic_cfg.get("enabled", False))
albu_params = {
"prob": elastic_cfg.get("prob", 0.3),
"alpha": elastic_cfg.get("alpha", 40),
"sigma": elastic_cfg.get("sigma", 6),
"alpha_affine": elastic_cfg.get("alpha_affine", 6),
"brightness_contrast": bool(augment_cfg.get("photometric", {}).get("brightness_contrast", True)) if augment_cfg else True,
"gauss_noise": bool(augment_cfg.get("photometric", {}).get("gauss_noise", True)) if augment_cfg else True,
}
# 构建真实数据集
real_dataset = ICLayoutTrainingDataset(
data_dir,
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
logger.info(f"数据集大小: {len(dataset)}")
# 分割训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
# 读取合成数据配置(程序化 + 扩散)
syn_cfg = cfg.get("synthetic", {})
syn_enabled = bool(syn_cfg.get("enabled", False))
syn_ratio = float(syn_cfg.get("ratio", 0.0))
syn_dir = syn_cfg.get("png_dir", None)
syn_dataset = None
if syn_enabled and syn_dir:
syn_dir_path = Path(to_absolute_path(syn_dir, config_dir))
if syn_dir_path.exists():
syn_dataset = ICLayoutTrainingDataset(
syn_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(syn_dataset) == 0:
syn_dataset = None
else:
logger.warning(f"合成数据目录不存在,忽略: {syn_dir_path}")
syn_enabled = False
# 扩散生成数据配置
diff_cfg = syn_cfg.get("diffusion", {}) if syn_cfg else {}
diff_enabled = bool(diff_cfg.get("enabled", False))
diff_ratio = float(diff_cfg.get("ratio", 0.0))
diff_dir = diff_cfg.get("png_dir", None)
diff_dataset = None
if diff_enabled and diff_dir:
diff_dir_path = Path(to_absolute_path(diff_dir, config_dir))
if diff_dir_path.exists():
diff_dataset = ICLayoutTrainingDataset(
diff_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(diff_dataset) == 0:
diff_dataset = None
else:
logger.warning(f"扩散数据目录不存在,忽略: {diff_dir_path}")
diff_enabled = False
logger.info(
"真实数据集大小: %d%s%s" % (
len(real_dataset),
f", 合成(程序)数据集: {len(syn_dataset)}" if syn_dataset else "",
f", 合成(扩散)数据集: {len(diff_dataset)}" if diff_dataset else "",
)
)
# 验证集仅使用真实数据,避免评价受合成样本干扰
train_size = int(0.8 * len(real_dataset))
val_size = max(len(real_dataset) - train_size, 1)
real_train_dataset, val_dataset = torch.utils.data.random_split(real_dataset, [train_size, val_size])
# 训练集:可与合成数据集合并(程序合成 + 扩散)
datasets = [real_train_dataset]
weights = []
names = []
# 收集各源与期望比例
n_real = len(real_train_dataset)
n_real = max(n_real, 1)
names.append("real")
# 程序合成
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
datasets.append(syn_dataset)
names.append("synthetic")
# 扩散合成
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
datasets.append(diff_dataset)
names.append("diffusion")
if len(datasets) > 1:
mixed_train_dataset = ConcatDataset(datasets)
# 计算各源样本数
counts = [len(real_train_dataset)]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
counts.append(len(syn_dataset))
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
counts.append(len(diff_dataset))
# 期望比例real = 1 - (syn_ratio + diff_ratio)
target_real = max(0.0, 1.0 - (syn_ratio + diff_ratio))
target_ratios = [target_real]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
target_ratios.append(syn_ratio)
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
target_ratios.append(diff_ratio)
# 构建每个样本的权重
per_source_weights = []
for count, ratio in zip(counts, target_ratios):
count = max(count, 1)
per_source_weights.append(ratio / count)
# 展开到每个样本
weights = []
idx = 0
for count, w in zip(counts, per_source_weights):
weights += [w] * count
idx += count
sampler = WeightedRandomSampler(weights, num_samples=len(mixed_train_dataset), replacement=True)
train_dataloader = DataLoader(mixed_train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4)
logger.info(
f"启用混采: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; 总样本={len(mixed_train_dataset)}"
)
if writer:
writer.add_text(
"dataset/mix",
f"enabled=true, ratios: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; "
f"counts: real_train={len(real_train_dataset)}, syn={len(syn_dataset) if syn_dataset else 0}, diff={len(diff_dataset) if diff_dataset else 0}"
)
else:
train_dataloader = DataLoader(real_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
if writer:
writer.add_text("dataset/mix", f"enabled=false, real_train={len(real_train_dataset)}")
logger.info(f"训练集大小: {len(train_dataloader.dataset)}, 验证集大小: {len(val_dataset)}")
if writer:
writer.add_text("dataset/info", f"train={len(train_dataset)}, val={len(val_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
writer.add_text("dataset/info", f"train={len(train_dataloader.dataset)}, val={len(val_dataset)}")
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()