Files

396 lines
17 KiB
Python
Raw Permalink Normal View History

2025-06-08 15:38:56 +08:00
# train.py
import argparse
2025-06-30 03:27:18 +08:00
import logging
2025-09-25 20:20:24 +08:00
import os
2025-06-30 03:27:18 +08:00
from datetime import datetime
2025-09-25 20:20:24 +08:00
from pathlib import Path
2025-06-08 15:38:56 +08:00
2025-09-25 20:20:24 +08:00
import torch
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
2025-09-25 20:20:24 +08:00
from data.ic_dataset import ICLayoutTrainingDataset
from losses import compute_detection_loss, compute_description_loss
2025-06-07 23:45:32 +08:00
from models.rord import RoRD
2025-09-25 20:20:24 +08:00
from utils.config_loader import load_config, to_absolute_path
2025-06-08 15:38:56 +08:00
from utils.data_utils import get_transform
2025-06-07 23:45:32 +08:00
2025-06-30 03:27:18 +08:00
# 设置日志记录
def setup_logging(save_dir):
"""设置训练日志记录"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
2025-09-25 20:20:24 +08:00
# --- (已修改) 主函数与命令行接口 ---
def main(args):
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
2025-06-08 15:38:56 +08:00
2025-09-25 20:20:24 +08:00
data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir))
save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir))
epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs)
batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size)
lr = args.lr if args.lr is not None else float(cfg.training.learning_rate)
patch_size = int(cfg.training.patch_size)
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
2025-06-07 23:45:32 +08:00
logging_cfg = cfg.get("logging", None)
use_tensorboard = False
log_dir = None
experiment_name = None
if logging_cfg is not None:
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
log_dir = logging_cfg.get("log_dir", "runs")
experiment_name = logging_cfg.get("experiment_name", "default")
if args.disable_tensorboard:
use_tensorboard = False
if args.log_dir is not None:
log_dir = args.log_dir
if args.experiment_name is not None:
experiment_name = args.experiment_name
writer = None
if use_tensorboard and log_dir:
log_root = Path(log_dir).expanduser()
experiment_folder = experiment_name or "default"
tb_path = log_root / "train" / experiment_folder
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
2025-09-25 20:20:24 +08:00
logger = setup_logging(save_dir)
2025-06-08 15:38:56 +08:00
2025-06-30 03:27:18 +08:00
logger.info("--- 开始训练 RoRD 模型 ---")
2025-09-25 20:20:24 +08:00
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
logger.info(f"数据目录: {data_dir}")
logger.info(f"保存目录: {save_dir}")
if writer:
logger.info(f"TensorBoard 日志目录: {tb_path}")
2025-09-25 20:20:24 +08:00
2025-06-08 15:38:56 +08:00
transform = get_transform()
2025-09-25 20:20:24 +08:00
# 读取增强与合成配置
augment_cfg = cfg.get("augment", {})
elastic_cfg = augment_cfg.get("elastic", {}) if augment_cfg else {}
use_albu = bool(elastic_cfg.get("enabled", False))
albu_params = {
"prob": elastic_cfg.get("prob", 0.3),
"alpha": elastic_cfg.get("alpha", 40),
"sigma": elastic_cfg.get("sigma", 6),
"alpha_affine": elastic_cfg.get("alpha_affine", 6),
"brightness_contrast": bool(augment_cfg.get("photometric", {}).get("brightness_contrast", True)) if augment_cfg else True,
"gauss_noise": bool(augment_cfg.get("photometric", {}).get("gauss_noise", True)) if augment_cfg else True,
}
# 构建真实数据集
real_dataset = ICLayoutTrainingDataset(
2025-09-25 20:20:24 +08:00
data_dir,
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
2025-06-09 01:49:13 +08:00
)
# 读取合成数据配置(程序化 + 扩散)
syn_cfg = cfg.get("synthetic", {})
syn_enabled = bool(syn_cfg.get("enabled", False))
syn_ratio = float(syn_cfg.get("ratio", 0.0))
syn_dir = syn_cfg.get("png_dir", None)
syn_dataset = None
if syn_enabled and syn_dir:
syn_dir_path = Path(to_absolute_path(syn_dir, config_dir))
if syn_dir_path.exists():
syn_dataset = ICLayoutTrainingDataset(
syn_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(syn_dataset) == 0:
syn_dataset = None
else:
logger.warning(f"合成数据目录不存在,忽略: {syn_dir_path}")
syn_enabled = False
# 扩散生成数据配置
diff_cfg = syn_cfg.get("diffusion", {}) if syn_cfg else {}
diff_enabled = bool(diff_cfg.get("enabled", False))
diff_ratio = float(diff_cfg.get("ratio", 0.0))
diff_dir = diff_cfg.get("png_dir", None)
diff_dataset = None
if diff_enabled and diff_dir:
diff_dir_path = Path(to_absolute_path(diff_dir, config_dir))
if diff_dir_path.exists():
diff_dataset = ICLayoutTrainingDataset(
diff_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(diff_dataset) == 0:
diff_dataset = None
else:
logger.warning(f"扩散数据目录不存在,忽略: {diff_dir_path}")
diff_enabled = False
logger.info(
"真实数据集大小: %d%s%s" % (
len(real_dataset),
f", 合成(程序)数据集: {len(syn_dataset)}" if syn_dataset else "",
f", 合成(扩散)数据集: {len(diff_dataset)}" if diff_dataset else "",
)
)
# 验证集仅使用真实数据,避免评价受合成样本干扰
train_size = int(0.8 * len(real_dataset))
val_size = max(len(real_dataset) - train_size, 1)
real_train_dataset, val_dataset = torch.utils.data.random_split(real_dataset, [train_size, val_size])
# 训练集:可与合成数据集合并(程序合成 + 扩散)
datasets = [real_train_dataset]
weights = []
names = []
# 收集各源与期望比例
n_real = len(real_train_dataset)
n_real = max(n_real, 1)
names.append("real")
# 程序合成
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
datasets.append(syn_dataset)
names.append("synthetic")
# 扩散合成
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
datasets.append(diff_dataset)
names.append("diffusion")
if len(datasets) > 1:
mixed_train_dataset = ConcatDataset(datasets)
# 计算各源样本数
counts = [len(real_train_dataset)]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
counts.append(len(syn_dataset))
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
counts.append(len(diff_dataset))
# 期望比例real = 1 - (syn_ratio + diff_ratio)
target_real = max(0.0, 1.0 - (syn_ratio + diff_ratio))
target_ratios = [target_real]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
target_ratios.append(syn_ratio)
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
target_ratios.append(diff_ratio)
# 构建每个样本的权重
per_source_weights = []
for count, ratio in zip(counts, target_ratios):
count = max(count, 1)
per_source_weights.append(ratio / count)
# 展开到每个样本
weights = []
idx = 0
for count, w in zip(counts, per_source_weights):
weights += [w] * count
idx += count
sampler = WeightedRandomSampler(weights, num_samples=len(mixed_train_dataset), replacement=True)
train_dataloader = DataLoader(mixed_train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4)
logger.info(
f"启用混采: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; 总样本={len(mixed_train_dataset)}"
)
if writer:
writer.add_text(
"dataset/mix",
f"enabled=true, ratios: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; "
f"counts: real_train={len(real_train_dataset)}, syn={len(syn_dataset) if syn_dataset else 0}, diff={len(diff_dataset) if diff_dataset else 0}"
)
else:
train_dataloader = DataLoader(real_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
if writer:
writer.add_text("dataset/mix", f"enabled=false, real_train={len(real_train_dataset)}")
logger.info(f"训练集大小: {len(train_dataloader.dataset)}, 验证集大小: {len(val_dataset)}")
if writer:
writer.add_text("dataset/info", f"train={len(train_dataloader.dataset)}, val={len(val_dataset)}")
2025-09-25 20:20:24 +08:00
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
2025-06-30 03:27:18 +08:00
2025-06-08 15:38:56 +08:00
model = RoRD().cuda()
2025-06-30 03:27:18 +08:00
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
2025-09-25 20:20:24 +08:00
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
2025-06-30 03:27:18 +08:00
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 早停机制
best_val_loss = float('inf')
patience_counter = 0
patience = 10
2025-06-08 15:38:56 +08:00
2025-09-25 20:20:24 +08:00
for epoch in range(epochs):
2025-06-30 03:27:18 +08:00
# 训练阶段
2025-06-08 15:38:56 +08:00
model.train()
2025-06-30 03:27:18 +08:00
total_train_loss = 0
total_det_loss = 0
total_desc_loss = 0
for i, (original, rotated, H) in enumerate(train_dataloader):
2025-06-08 15:38:56 +08:00
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
2025-06-30 03:27:18 +08:00
2025-06-08 15:38:56 +08:00
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
2025-06-30 03:27:18 +08:00
det_loss = compute_detection_loss(det_original, det_rotated, H)
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
loss = det_loss + desc_loss
2025-06-08 15:38:56 +08:00
optimizer.zero_grad()
loss.backward()
2025-06-30 03:27:18 +08:00
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
2025-06-08 15:38:56 +08:00
optimizer.step()
2025-06-30 03:27:18 +08:00
total_train_loss += loss.item()
total_det_loss += det_loss.item()
total_desc_loss += desc_loss.item()
if writer:
num_batches = len(train_dataloader) if len(train_dataloader) > 0 else 1
global_step = epoch * num_batches + i
writer.add_scalar("train/loss_total", loss.item(), global_step)
writer.add_scalar("train/loss_det", det_loss.item(), global_step)
writer.add_scalar("train/loss_desc", desc_loss.item(), global_step)
writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step)
2025-06-30 03:27:18 +08:00
if i % 10 == 0:
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_dataloader)
avg_det_loss = total_det_loss / len(train_dataloader)
avg_desc_loss = total_desc_loss / len(train_dataloader)
if writer:
writer.add_scalar("epoch/train_loss_total", avg_train_loss, epoch)
writer.add_scalar("epoch/train_loss_det", avg_det_loss, epoch)
writer.add_scalar("epoch/train_loss_desc", avg_desc_loss, epoch)
2025-06-30 03:27:18 +08:00
# 验证阶段
model.eval()
total_val_loss = 0
total_val_det_loss = 0
total_val_desc_loss = 0
with torch.no_grad():
for original, rotated, H in val_dataloader:
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
val_loss = val_det_loss + val_desc_loss
total_val_loss += val_loss.item()
total_val_det_loss += val_det_loss.item()
total_val_desc_loss += val_desc_loss.item()
avg_val_loss = total_val_loss / len(val_dataloader)
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
# 学习率调度
scheduler.step(avg_val_loss)
logger.info(f"--- Epoch {epoch+1} 完成 ---")
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
if writer:
writer.add_scalar("epoch/val_loss_total", avg_val_loss, epoch)
writer.add_scalar("epoch/val_loss_det", avg_val_det_loss, epoch)
writer.add_scalar("epoch/val_loss_desc", avg_val_desc_loss, epoch)
writer.add_scalar("epoch/lr", optimizer.param_groups[0]['lr'], epoch)
2025-06-30 03:27:18 +08:00
# 早停检查
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 保存最佳模型
2025-09-25 20:20:24 +08:00
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'rord_model_best.pth')
2025-06-30 03:27:18 +08:00
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'config': {
2025-09-25 20:20:24 +08:00
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
2025-06-30 03:27:18 +08:00
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
if writer:
writer.add_scalar("checkpoint/best_val_loss", best_val_loss, epoch)
2025-06-30 03:27:18 +08:00
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"早停触发!{patience} 个epoch没有改善")
break
# 保存最终模型
2025-09-25 20:20:24 +08:00
save_path = os.path.join(save_dir, 'rord_model_final.pth')
2025-06-30 03:27:18 +08:00
torch.save({
2025-09-25 20:20:24 +08:00
'epoch': epochs,
2025-06-30 03:27:18 +08:00
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_val_loss': avg_val_loss,
'config': {
2025-09-25 20:20:24 +08:00
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
2025-06-30 03:27:18 +08:00
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
logger.info("训练完成!")
2025-06-08 15:38:56 +08:00
if writer:
writer.add_scalar("final/val_loss", avg_val_loss, epochs - 1)
writer.close()
2025-06-08 15:38:56 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
2025-09-25 20:20:24 +08:00
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径")
parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径")
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件中的设置")
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件中的设置")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 日志记录")
2025-06-08 15:38:56 +08:00
main(parser.parse_args())