Files

396 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# train.py
import argparse
import logging
import os
from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from data.ic_dataset import ICLayoutTrainingDataset
from losses import compute_detection_loss, compute_description_loss
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# 设置日志记录
def setup_logging(save_dir):
"""设置训练日志记录"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
# --- (已修改) 主函数与命令行接口 ---
def main(args):
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir))
save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir))
epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs)
batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size)
lr = args.lr if args.lr is not None else float(cfg.training.learning_rate)
patch_size = int(cfg.training.patch_size)
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
logging_cfg = cfg.get("logging", None)
use_tensorboard = False
log_dir = None
experiment_name = None
if logging_cfg is not None:
use_tensorboard = bool(logging_cfg.get("use_tensorboard", False))
log_dir = logging_cfg.get("log_dir", "runs")
experiment_name = logging_cfg.get("experiment_name", "default")
if args.disable_tensorboard:
use_tensorboard = False
if args.log_dir is not None:
log_dir = args.log_dir
if args.experiment_name is not None:
experiment_name = args.experiment_name
writer = None
if use_tensorboard and log_dir:
log_root = Path(log_dir).expanduser()
experiment_folder = experiment_name or "default"
tb_path = log_root / "train" / experiment_folder
tb_path.parent.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(tb_path.as_posix())
logger = setup_logging(save_dir)
logger.info("--- 开始训练 RoRD 模型 ---")
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
logger.info(f"数据目录: {data_dir}")
logger.info(f"保存目录: {save_dir}")
if writer:
logger.info(f"TensorBoard 日志目录: {tb_path}")
transform = get_transform()
# 读取增强与合成配置
augment_cfg = cfg.get("augment", {})
elastic_cfg = augment_cfg.get("elastic", {}) if augment_cfg else {}
use_albu = bool(elastic_cfg.get("enabled", False))
albu_params = {
"prob": elastic_cfg.get("prob", 0.3),
"alpha": elastic_cfg.get("alpha", 40),
"sigma": elastic_cfg.get("sigma", 6),
"alpha_affine": elastic_cfg.get("alpha_affine", 6),
"brightness_contrast": bool(augment_cfg.get("photometric", {}).get("brightness_contrast", True)) if augment_cfg else True,
"gauss_noise": bool(augment_cfg.get("photometric", {}).get("gauss_noise", True)) if augment_cfg else True,
}
# 构建真实数据集
real_dataset = ICLayoutTrainingDataset(
data_dir,
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
# 读取合成数据配置(程序化 + 扩散)
syn_cfg = cfg.get("synthetic", {})
syn_enabled = bool(syn_cfg.get("enabled", False))
syn_ratio = float(syn_cfg.get("ratio", 0.0))
syn_dir = syn_cfg.get("png_dir", None)
syn_dataset = None
if syn_enabled and syn_dir:
syn_dir_path = Path(to_absolute_path(syn_dir, config_dir))
if syn_dir_path.exists():
syn_dataset = ICLayoutTrainingDataset(
syn_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(syn_dataset) == 0:
syn_dataset = None
else:
logger.warning(f"合成数据目录不存在,忽略: {syn_dir_path}")
syn_enabled = False
# 扩散生成数据配置
diff_cfg = syn_cfg.get("diffusion", {}) if syn_cfg else {}
diff_enabled = bool(diff_cfg.get("enabled", False))
diff_ratio = float(diff_cfg.get("ratio", 0.0))
diff_dir = diff_cfg.get("png_dir", None)
diff_dataset = None
if diff_enabled and diff_dir:
diff_dir_path = Path(to_absolute_path(diff_dir, config_dir))
if diff_dir_path.exists():
diff_dataset = ICLayoutTrainingDataset(
diff_dir_path.as_posix(),
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
use_albu=use_albu,
albu_params=albu_params,
)
if len(diff_dataset) == 0:
diff_dataset = None
else:
logger.warning(f"扩散数据目录不存在,忽略: {diff_dir_path}")
diff_enabled = False
logger.info(
"真实数据集大小: %d%s%s" % (
len(real_dataset),
f", 合成(程序)数据集: {len(syn_dataset)}" if syn_dataset else "",
f", 合成(扩散)数据集: {len(diff_dataset)}" if diff_dataset else "",
)
)
# 验证集仅使用真实数据,避免评价受合成样本干扰
train_size = int(0.8 * len(real_dataset))
val_size = max(len(real_dataset) - train_size, 1)
real_train_dataset, val_dataset = torch.utils.data.random_split(real_dataset, [train_size, val_size])
# 训练集:可与合成数据集合并(程序合成 + 扩散)
datasets = [real_train_dataset]
weights = []
names = []
# 收集各源与期望比例
n_real = len(real_train_dataset)
n_real = max(n_real, 1)
names.append("real")
# 程序合成
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
datasets.append(syn_dataset)
names.append("synthetic")
# 扩散合成
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
datasets.append(diff_dataset)
names.append("diffusion")
if len(datasets) > 1:
mixed_train_dataset = ConcatDataset(datasets)
# 计算各源样本数
counts = [len(real_train_dataset)]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
counts.append(len(syn_dataset))
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
counts.append(len(diff_dataset))
# 期望比例real = 1 - (syn_ratio + diff_ratio)
target_real = max(0.0, 1.0 - (syn_ratio + diff_ratio))
target_ratios = [target_real]
if syn_dataset is not None and syn_enabled and syn_ratio > 0.0:
target_ratios.append(syn_ratio)
if diff_dataset is not None and diff_enabled and diff_ratio > 0.0:
target_ratios.append(diff_ratio)
# 构建每个样本的权重
per_source_weights = []
for count, ratio in zip(counts, target_ratios):
count = max(count, 1)
per_source_weights.append(ratio / count)
# 展开到每个样本
weights = []
idx = 0
for count, w in zip(counts, per_source_weights):
weights += [w] * count
idx += count
sampler = WeightedRandomSampler(weights, num_samples=len(mixed_train_dataset), replacement=True)
train_dataloader = DataLoader(mixed_train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4)
logger.info(
f"启用混采: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; 总样本={len(mixed_train_dataset)}"
)
if writer:
writer.add_text(
"dataset/mix",
f"enabled=true, ratios: real={target_real:.2f}, syn={syn_ratio:.2f}, diff={diff_ratio:.2f}; "
f"counts: real_train={len(real_train_dataset)}, syn={len(syn_dataset) if syn_dataset else 0}, diff={len(diff_dataset) if diff_dataset else 0}"
)
else:
train_dataloader = DataLoader(real_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
if writer:
writer.add_text("dataset/mix", f"enabled=false, real_train={len(real_train_dataset)}")
logger.info(f"训练集大小: {len(train_dataloader.dataset)}, 验证集大小: {len(val_dataset)}")
if writer:
writer.add_text("dataset/info", f"train={len(train_dataloader.dataset)}, val={len(val_dataset)}")
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 早停机制
best_val_loss = float('inf')
patience_counter = 0
patience = 10
for epoch in range(epochs):
# 训练阶段
model.train()
total_train_loss = 0
total_det_loss = 0
total_desc_loss = 0
for i, (original, rotated, H) in enumerate(train_dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
det_loss = compute_detection_loss(det_original, det_rotated, H)
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
loss = det_loss + desc_loss
optimizer.zero_grad()
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_train_loss += loss.item()
total_det_loss += det_loss.item()
total_desc_loss += desc_loss.item()
if writer:
num_batches = len(train_dataloader) if len(train_dataloader) > 0 else 1
global_step = epoch * num_batches + i
writer.add_scalar("train/loss_total", loss.item(), global_step)
writer.add_scalar("train/loss_det", det_loss.item(), global_step)
writer.add_scalar("train/loss_desc", desc_loss.item(), global_step)
writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], global_step)
if i % 10 == 0:
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_dataloader)
avg_det_loss = total_det_loss / len(train_dataloader)
avg_desc_loss = total_desc_loss / len(train_dataloader)
if writer:
writer.add_scalar("epoch/train_loss_total", avg_train_loss, epoch)
writer.add_scalar("epoch/train_loss_det", avg_det_loss, epoch)
writer.add_scalar("epoch/train_loss_desc", avg_desc_loss, epoch)
# 验证阶段
model.eval()
total_val_loss = 0
total_val_det_loss = 0
total_val_desc_loss = 0
with torch.no_grad():
for original, rotated, H in val_dataloader:
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
val_loss = val_det_loss + val_desc_loss
total_val_loss += val_loss.item()
total_val_det_loss += val_det_loss.item()
total_val_desc_loss += val_desc_loss.item()
avg_val_loss = total_val_loss / len(val_dataloader)
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
# 学习率调度
scheduler.step(avg_val_loss)
logger.info(f"--- Epoch {epoch+1} 完成 ---")
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
if writer:
writer.add_scalar("epoch/val_loss_total", avg_val_loss, epoch)
writer.add_scalar("epoch/val_loss_det", avg_val_det_loss, epoch)
writer.add_scalar("epoch/val_loss_desc", avg_val_desc_loss, epoch)
writer.add_scalar("epoch/lr", optimizer.param_groups[0]['lr'], epoch)
# 早停检查
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 保存最佳模型
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'rord_model_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'config': {
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
if writer:
writer.add_scalar("checkpoint/best_val_loss", best_val_loss, epoch)
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"早停触发!{patience} 个epoch没有改善")
break
# 保存最终模型
save_path = os.path.join(save_dir, 'rord_model_final.pth')
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_val_loss': avg_val_loss,
'config': {
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
logger.info("训练完成!")
if writer:
writer.add_scalar("final/val_loss", avg_val_loss, epochs - 1)
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径")
parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径")
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
parser.add_argument('--log_dir', type=str, default=None, help="TensorBoard 日志根目录,覆盖配置文件中的设置")
parser.add_argument('--experiment_name', type=str, default=None, help="TensorBoard 实验名称,覆盖配置文件中的设置")
parser.add_argument('--disable_tensorboard', action='store_true', help="禁用 TensorBoard 日志记录")
main(parser.parse_args())