Files
RoRD-Layout-Recognation/train.py
2025-09-25 20:20:24 +08:00

213 lines
8.5 KiB
Python

# train.py
import argparse
import logging
import os
from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from data.ic_dataset import ICLayoutTrainingDataset
from losses import compute_detection_loss, compute_description_loss
from models.rord import RoRD
from utils.config_loader import load_config, to_absolute_path
from utils.data_utils import get_transform
# 设置日志记录
def setup_logging(save_dir):
"""设置训练日志记录"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
# --- (已修改) 主函数与命令行接口 ---
def main(args):
cfg = load_config(args.config)
config_dir = Path(args.config).resolve().parent
data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir))
save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir))
epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs)
batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size)
lr = args.lr if args.lr is not None else float(cfg.training.learning_rate)
patch_size = int(cfg.training.patch_size)
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
logger = setup_logging(save_dir)
logger.info("--- 开始训练 RoRD 模型 ---")
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
logger.info(f"数据目录: {data_dir}")
logger.info(f"保存目录: {save_dir}")
transform = get_transform()
dataset = ICLayoutTrainingDataset(
data_dir,
patch_size=patch_size,
transform=transform,
scale_range=scale_range,
)
logger.info(f"数据集大小: {len(dataset)}")
# 分割训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model = RoRD().cuda()
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 早停机制
best_val_loss = float('inf')
patience_counter = 0
patience = 10
for epoch in range(epochs):
# 训练阶段
model.train()
total_train_loss = 0
total_det_loss = 0
total_desc_loss = 0
for i, (original, rotated, H) in enumerate(train_dataloader):
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
det_loss = compute_detection_loss(det_original, det_rotated, H)
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
loss = det_loss + desc_loss
optimizer.zero_grad()
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_train_loss += loss.item()
total_det_loss += det_loss.item()
total_desc_loss += desc_loss.item()
if i % 10 == 0:
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_dataloader)
avg_det_loss = total_det_loss / len(train_dataloader)
avg_desc_loss = total_desc_loss / len(train_dataloader)
# 验证阶段
model.eval()
total_val_loss = 0
total_val_det_loss = 0
total_val_desc_loss = 0
with torch.no_grad():
for original, rotated, H in val_dataloader:
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
det_original, desc_original = model(original)
det_rotated, desc_rotated = model(rotated)
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
val_loss = val_det_loss + val_desc_loss
total_val_loss += val_loss.item()
total_val_det_loss += val_det_loss.item()
total_val_desc_loss += val_desc_loss.item()
avg_val_loss = total_val_loss / len(val_dataloader)
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
# 学习率调度
scheduler.step(avg_val_loss)
logger.info(f"--- Epoch {epoch+1} 完成 ---")
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
# 早停检查
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 保存最佳模型
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'rord_model_best.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'config': {
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最佳模型已保存至: {save_path}")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"早停触发!{patience} 个epoch没有改善")
break
# 保存最终模型
save_path = os.path.join(save_dir, 'rord_model_final.pth')
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'final_val_loss': avg_val_loss,
'config': {
'learning_rate': lr,
'batch_size': batch_size,
'epochs': epochs,
'config_path': str(Path(args.config).resolve()),
}
}, save_path)
logger.info(f"最终模型已保存至: {save_path}")
logger.info("训练完成!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径")
parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径")
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
main(parser.parse_args())