添加数据增强方案以及扩散生成模型的想法
This commit is contained in:
155
train.py
155
train.py
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from data.ic_dataset import ICLayoutTrainingDataset
|
||||
@@ -82,25 +82,152 @@ def main(args):
|
||||
|
||||
transform = get_transform()
|
||||
|
||||
dataset = ICLayoutTrainingDataset(
|
||||
# 读取增强与合成配置
|
||||
augment_cfg = cfg.get("augment", {})
|
||||
elastic_cfg = augment_cfg.get("elastic", {}) if augment_cfg else {}
|
||||
use_albu = bool(elastic_cfg.get("enabled", False))
|
||||
albu_params = {
|
||||
"prob": elastic_cfg.get("prob", 0.3),
|
||||
"alpha": elastic_cfg.get("alpha", 40),
|
||||
"sigma": elastic_cfg.get("sigma", 6),
|
||||
"alpha_affine": elastic_cfg.get("alpha_affine", 6),
|
||||
"brightness_contrast": bool(augment_cfg.get("photometric", {}).get("brightness_contrast", True)) if augment_cfg else True,
|
||||
"gauss_noise": bool(augment_cfg.get("photometric", {}).get("gauss_noise", True)) if augment_cfg else True,
|
||||
}
|
||||
|
||||
# 构建真实数据集
|
||||
real_dataset = ICLayoutTrainingDataset(
|
||||
data_dir,
|
||||
patch_size=patch_size,
|
||||
transform=transform,
|
||||
scale_range=scale_range,
|
||||
use_albu=use_albu,
|
||||
albu_params=albu_params,
|
||||
)
|
||||
|
||||
logger.info(f"数据集大小: {len(dataset)}")
|
||||
|
||||
# 分割训练集和验证集
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
||||
|
||||
# 读取合成数据配置(程序化 + 扩散)
|
||||
syn_cfg = cfg.get("synthetic", {})
|
||||
syn_enabled = bool(syn_cfg.get("enabled", False))
|
||||
syn_ratio = float(syn_cfg.get("ratio", 0.0))
|
||||
syn_dir = syn_cfg.get("png_dir", None)
|
||||
|
||||
syn_dataset = None
|
||||
if syn_enabled and syn_dir:
|
||||
syn_dir_path = Path(to_absolute_path(syn_dir, config_dir))
|
||||
if syn_dir_path.exists():
|
||||
syn_dataset = ICLayoutTrainingDataset(
|
||||
syn_dir_path.as_posix(),
|
||||
patch_size=patch_size,
|
||||
transform=transform,
|
||||
scale_range=scale_range,
|
||||
use_albu=use_albu,
|
||||
albu_params=albu_params,
|
||||
)
|
||||
if len(syn_dataset) == 0:
|
||||
syn_dataset = None
|
||||
else:
|
||||
logger.warning(f"合成数据目录不存在,忽略: {syn_dir_path}")
|
||||
syn_enabled = False
|
||||
|
||||
# 扩散生成数据配置
|
||||
diff_cfg = syn_cfg.get("diffusion", {}) if syn_cfg else {}
|
||||
diff_enabled = bool(diff_cfg.get("enabled", False))
|
||||
diff_ratio = float(diff_cfg.get("ratio", 0.0))
|
||||
diff_dir = diff_cfg.get("png_dir", None)
|
||||
diff_dataset = None
|
||||
if diff_enabled and diff_dir:
|
||||
diff_dir_path = Path(to_absolute_path(diff_dir, config_dir))
|
||||
if diff_dir_path.exists():
|
||||
diff_dataset = ICLayoutTrainingDataset(
|
||||
diff_dir_path.as_posix(),
|
||||
patch_size=patch_size,
|
||||
transform=transform,
|
||||
scale_range=scale_range,
|
||||
use_albu=use_albu,
|
||||
albu_params=albu_params,
|
||||
)
|
||||
if len(diff_dataset) == 0:
|
||||
diff_dataset = None
|
||||
else:
|
||||
logger.warning(f"扩散数据目录不存在,忽略: {diff_dir_path}")
|
||||
diff_enabled = False
|
||||
|
||||
logger.info(
|
||||
"真实数据集大小: %d%s%s" % (
|
||||
len(real_dataset),
|
||||
f", 合成(程序)数据集: {len(syn_dataset)}" if syn_dataset else "",
|
||||
f", 合成(扩散)数据集: {len(diff_dataset)}" if diff_dataset else "",
|
||||
)
|
||||
)
|
||||
|
||||
# 验证集仅使用真实数据,避免评价受合成样本干扰
|
||||
train_size = int(0.8 * len(real_dataset))
|
||||
val_size = max(len(real_dataset) - train_size, 1)
|
||||
real_train_dataset, val_dataset = torch.utils.data.random_split(real_dataset, [train_size, val_size])
|
||||
|
||||
# 训练集:可与合成数据集合并(程序合成 + 扩散)
|
||||
datasets = [real_train_dataset]
|
||||
weights = []
|
||||
names = []
|
||||
# 收集各源与期望比例
|
||||
n_real = len(real_train_dataset)
|
||||
n_real = max(n_real, 1)
|
||||
names.append("real")
|
||||
# 程序合成
|
||||
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
|
||||
datasets.append(syn_dataset)
|
||||
names.append("synthetic")
|
||||
# 扩散合成
|
||||
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
|
||||
datasets.append(diff_dataset)
|
||||
names.append("diffusion")
|
||||
|
||||
if len(datasets) > 1:
|
||||
mixed_train_dataset = ConcatDataset(datasets)
|
||||
# 计算各源样本数
|
||||
counts = [len(real_train_dataset)]
|
||||
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
|
||||
counts.append(len(syn_dataset))
|
||||
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
|
||||
counts.append(len(diff_dataset))
|
||||
# 期望比例:real = 1 - (syn_ratio + diff_ratio)
|
||||
target_real = max(0.0, 1.0 - (syn_ratio + diff_ratio))
|
||||
target_ratios = [target_real]
|
||||
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
|
||||
target_ratios.append(syn_ratio)
|
||||
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
|
||||
target_ratios.append(diff_ratio)
|
||||
# 构建每个样本的权重
|
||||
per_source_weights = []
|
||||
for count, ratio in zip(counts, target_ratios):
|
||||
count = max(count, 1)
|
||||
per_source_weights.append(ratio / count)
|
||||
# 展开到每个样本
|
||||
weights = []
|
||||
idx = 0
|
||||
for count, w in zip(counts, per_source_weights):
|
||||
weights += [w] * count
|
||||
idx += count
|
||||
sampler = WeightedRandomSampler(weights, num_samples=len(mixed_train_dataset), replacement=True)
|
||||
train_dataloader = DataLoader(mixed_train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4)
|
||||
logger.info(
|
||||
f"启用混采: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; 总样本={len(mixed_train_dataset)}"
|
||||
)
|
||||
if writer:
|
||||
writer.add_text(
|
||||
"dataset/mix",
|
||||
f"enabled=true, ratios: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; "
|
||||
f"counts: real_train={len(real_train_dataset)}, syn={len(syn_dataset) if syn_dataset else 0}, diff={len(diff_dataset) if diff_dataset else 0}"
|
||||
)
|
||||
else:
|
||||
train_dataloader = DataLoader(real_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
if writer:
|
||||
writer.add_text("dataset/mix", f"enabled=false, real_train={len(real_train_dataset)}")
|
||||
|
||||
logger.info(f"训练集大小: {len(train_dataloader.dataset)}, 验证集大小: {len(val_dataset)}")
|
||||
if writer:
|
||||
writer.add_text("dataset/info", f"train={len(train_dataset)}, val={len(val_dataset)}")
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
writer.add_text("dataset/info", f"train={len(train_dataloader.dataset)}, val={len(val_dataset)}")
|
||||
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
model = RoRD().cuda()
|
||||
|
||||
Reference in New Issue
Block a user