2025-06-08 15:38:56 +08:00
|
|
|
# train.py
|
|
|
|
|
|
|
|
|
|
import argparse
|
2025-06-30 03:27:18 +08:00
|
|
|
import logging
|
2025-09-25 20:20:24 +08:00
|
|
|
import os
|
2025-06-30 03:27:18 +08:00
|
|
|
from datetime import datetime
|
2025-09-25 20:20:24 +08:00
|
|
|
from pathlib import Path
|
2025-06-08 15:38:56 +08:00
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
import torch
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
from data.ic_dataset import ICLayoutTrainingDataset
|
|
|
|
|
from losses import compute_detection_loss, compute_description_loss
|
2025-06-07 23:45:32 +08:00
|
|
|
from models.rord import RoRD
|
2025-09-25 20:20:24 +08:00
|
|
|
from utils.config_loader import load_config, to_absolute_path
|
2025-06-08 15:38:56 +08:00
|
|
|
from utils.data_utils import get_transform
|
2025-06-07 23:45:32 +08:00
|
|
|
|
2025-06-30 03:27:18 +08:00
|
|
|
# 设置日志记录
|
|
|
|
|
def setup_logging(save_dir):
|
|
|
|
|
"""设置训练日志记录"""
|
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
|
|
|
|
|
|
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
|
|
|
handlers=[
|
|
|
|
|
logging.FileHandler(log_file),
|
|
|
|
|
logging.StreamHandler()
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
# --- (已修改) 主函数与命令行接口 ---
|
|
|
|
|
def main(args):
|
|
|
|
|
cfg = load_config(args.config)
|
|
|
|
|
config_dir = Path(args.config).resolve().parent
|
2025-06-08 15:38:56 +08:00
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
data_dir = args.data_dir or str(to_absolute_path(cfg.paths.layout_dir, config_dir))
|
|
|
|
|
save_dir = args.save_dir or str(to_absolute_path(cfg.paths.save_dir, config_dir))
|
|
|
|
|
epochs = args.epochs if args.epochs is not None else int(cfg.training.num_epochs)
|
|
|
|
|
batch_size = args.batch_size if args.batch_size is not None else int(cfg.training.batch_size)
|
|
|
|
|
lr = args.lr if args.lr is not None else float(cfg.training.learning_rate)
|
|
|
|
|
patch_size = int(cfg.training.patch_size)
|
|
|
|
|
scale_range = tuple(float(x) for x in cfg.training.scale_jitter_range)
|
2025-06-07 23:45:32 +08:00
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
logger = setup_logging(save_dir)
|
2025-06-08 15:38:56 +08:00
|
|
|
|
2025-06-30 03:27:18 +08:00
|
|
|
logger.info("--- 开始训练 RoRD 模型 ---")
|
2025-09-25 20:20:24 +08:00
|
|
|
logger.info(f"训练参数: Epochs={epochs}, Batch Size={batch_size}, LR={lr}")
|
|
|
|
|
logger.info(f"数据目录: {data_dir}")
|
|
|
|
|
logger.info(f"保存目录: {save_dir}")
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
transform = get_transform()
|
2025-09-25 20:20:24 +08:00
|
|
|
|
2025-06-09 01:49:13 +08:00
|
|
|
dataset = ICLayoutTrainingDataset(
|
2025-09-25 20:20:24 +08:00
|
|
|
data_dir,
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
transform=transform,
|
|
|
|
|
scale_range=scale_range,
|
2025-06-09 01:49:13 +08:00
|
|
|
)
|
2025-06-30 03:27:18 +08:00
|
|
|
|
|
|
|
|
logger.info(f"数据集大小: {len(dataset)}")
|
|
|
|
|
|
|
|
|
|
# 分割训练集和验证集
|
|
|
|
|
train_size = int(0.8 * len(dataset))
|
|
|
|
|
val_size = len(dataset) - train_size
|
|
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
|
|
|
|
|
|
|
|
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
|
|
|
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
2025-06-30 03:27:18 +08:00
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
model = RoRD().cuda()
|
2025-06-30 03:27:18 +08:00
|
|
|
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
2025-06-30 03:27:18 +08:00
|
|
|
|
|
|
|
|
# 添加学习率调度器
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
|
|
|
optimizer, mode='min', factor=0.5, patience=5
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 早停机制
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
patience_counter = 0
|
|
|
|
|
patience = 10
|
2025-06-08 15:38:56 +08:00
|
|
|
|
2025-09-25 20:20:24 +08:00
|
|
|
for epoch in range(epochs):
|
2025-06-30 03:27:18 +08:00
|
|
|
# 训练阶段
|
2025-06-08 15:38:56 +08:00
|
|
|
model.train()
|
2025-06-30 03:27:18 +08:00
|
|
|
total_train_loss = 0
|
|
|
|
|
total_det_loss = 0
|
|
|
|
|
total_desc_loss = 0
|
|
|
|
|
|
|
|
|
|
for i, (original, rotated, H) in enumerate(train_dataloader):
|
2025-06-08 15:38:56 +08:00
|
|
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
2025-06-30 03:27:18 +08:00
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
det_original, desc_original = model(original)
|
|
|
|
|
det_rotated, desc_rotated = model(rotated)
|
2025-06-30 03:27:18 +08:00
|
|
|
|
|
|
|
|
det_loss = compute_detection_loss(det_original, det_rotated, H)
|
|
|
|
|
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
|
|
|
|
|
loss = det_loss + desc_loss
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
2025-06-30 03:27:18 +08:00
|
|
|
|
|
|
|
|
# 梯度裁剪,防止梯度爆炸
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
optimizer.step()
|
2025-06-30 03:27:18 +08:00
|
|
|
total_train_loss += loss.item()
|
|
|
|
|
total_det_loss += det_loss.item()
|
|
|
|
|
total_desc_loss += desc_loss.item()
|
|
|
|
|
|
|
|
|
|
if i % 10 == 0:
|
|
|
|
|
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
|
|
|
|
|
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
|
|
|
|
|
|
|
|
|
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
|
|
|
|
avg_det_loss = total_det_loss / len(train_dataloader)
|
|
|
|
|
avg_desc_loss = total_desc_loss / len(train_dataloader)
|
|
|
|
|
|
|
|
|
|
# 验证阶段
|
|
|
|
|
model.eval()
|
|
|
|
|
total_val_loss = 0
|
|
|
|
|
total_val_det_loss = 0
|
|
|
|
|
total_val_desc_loss = 0
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for original, rotated, H in val_dataloader:
|
|
|
|
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
|
|
|
|
|
|
|
|
|
det_original, desc_original = model(original)
|
|
|
|
|
det_rotated, desc_rotated = model(rotated)
|
|
|
|
|
|
|
|
|
|
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
|
|
|
|
|
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
|
|
|
|
|
val_loss = val_det_loss + val_desc_loss
|
|
|
|
|
|
|
|
|
|
total_val_loss += val_loss.item()
|
|
|
|
|
total_val_det_loss += val_det_loss.item()
|
|
|
|
|
total_val_desc_loss += val_desc_loss.item()
|
|
|
|
|
|
|
|
|
|
avg_val_loss = total_val_loss / len(val_dataloader)
|
|
|
|
|
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
|
|
|
|
|
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
|
|
|
|
|
|
|
|
|
|
# 学习率调度
|
|
|
|
|
scheduler.step(avg_val_loss)
|
|
|
|
|
|
|
|
|
|
logger.info(f"--- Epoch {epoch+1} 完成 ---")
|
|
|
|
|
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
|
|
|
|
|
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
|
|
|
|
|
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
|
|
|
|
|
|
|
|
|
# 早停检查
|
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
|
|
|
best_val_loss = avg_val_loss
|
|
|
|
|
patience_counter = 0
|
|
|
|
|
|
|
|
|
|
# 保存最佳模型
|
2025-09-25 20:20:24 +08:00
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.makedirs(save_dir)
|
|
|
|
|
save_path = os.path.join(save_dir, 'rord_model_best.pth')
|
2025-06-30 03:27:18 +08:00
|
|
|
torch.save({
|
|
|
|
|
'epoch': epoch,
|
|
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
|
|
'best_val_loss': best_val_loss,
|
|
|
|
|
'config': {
|
2025-09-25 20:20:24 +08:00
|
|
|
'learning_rate': lr,
|
|
|
|
|
'batch_size': batch_size,
|
|
|
|
|
'epochs': epochs,
|
|
|
|
|
'config_path': str(Path(args.config).resolve()),
|
2025-06-30 03:27:18 +08:00
|
|
|
}
|
|
|
|
|
}, save_path)
|
|
|
|
|
logger.info(f"最佳模型已保存至: {save_path}")
|
|
|
|
|
else:
|
|
|
|
|
patience_counter += 1
|
|
|
|
|
if patience_counter >= patience:
|
|
|
|
|
logger.info(f"早停触发!{patience} 个epoch没有改善")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# 保存最终模型
|
2025-09-25 20:20:24 +08:00
|
|
|
save_path = os.path.join(save_dir, 'rord_model_final.pth')
|
2025-06-30 03:27:18 +08:00
|
|
|
torch.save({
|
2025-09-25 20:20:24 +08:00
|
|
|
'epoch': epochs,
|
2025-06-30 03:27:18 +08:00
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
|
|
'final_val_loss': avg_val_loss,
|
|
|
|
|
'config': {
|
2025-09-25 20:20:24 +08:00
|
|
|
'learning_rate': lr,
|
|
|
|
|
'batch_size': batch_size,
|
|
|
|
|
'epochs': epochs,
|
|
|
|
|
'config_path': str(Path(args.config).resolve()),
|
2025-06-30 03:27:18 +08:00
|
|
|
}
|
|
|
|
|
}, save_path)
|
|
|
|
|
logger.info(f"最终模型已保存至: {save_path}")
|
|
|
|
|
logger.info("训练完成!")
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
2025-09-25 20:20:24 +08:00
|
|
|
parser.add_argument('--config', type=str, default="configs/base_config.yaml", help="YAML 配置文件路径")
|
|
|
|
|
parser.add_argument('--data_dir', type=str, default=None, help="训练数据目录,若未提供则使用配置文件中的路径")
|
|
|
|
|
parser.add_argument('--save_dir', type=str, default=None, help="模型保存目录,若未提供则使用配置文件中的路径")
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=None, help="训练轮数,若未提供则使用配置文件中的值")
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=None, help="批次大小,若未提供则使用配置文件中的值")
|
|
|
|
|
parser.add_argument('--lr', type=float, default=None, help="学习率,若未提供则使用配置文件中的值")
|
2025-06-08 15:38:56 +08:00
|
|
|
main(parser.parse_args())
|