import os import torch from torch import nn, optim from torch.utils.data import DataLoader, random_split import numpy as np from datetime import datetime import argparse # 导入项目模块(根据你的路径调整) from models.rotation_cnn import RotationInvariantCNN # 模型实现 from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数 # 设置随机种子(可选) torch.manual_seed(42) np.random.seed(42) def main(): """训练流程""" # 解析命令行参数 parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher") parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录") parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例") parser.add_argument("--batch_size", type=int, default=16, help="批量大小") parser.add_argument("--epochs", type=int, default=50, help="训练轮次") parser.add_argument("--lr", type=float, default=1e-3, help="学习率") parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径") args = parser.parse_args() # 创建输出目录 os.makedirs(args.model_save_dir, exist_ok=True) # 数据加载 dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms()) total_samples = len(dataset) val_size = int(total_samples * args.val_split) train_size = total_samples - val_size # 划分训练集和验证集 train_dataset, val_dataset = random_split( dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # 初始化模型、损失函数和优化器 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数 criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数 optimizer = optim.Adam(model.parameters(), lr=args.lr) # 训练循环 best_val_loss = float("inf") for epoch in range(1, args.epochs + 1): model.train() train_loss = 0.0 for batch_idx, (data, targets) in enumerate(train_loader): data, targets = data.to(device), targets.to(device) # 前向传播 outputs = model(data) loss = criterion(outputs, targets) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() if (batch_idx + 1) % 10 == 0: print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}") # 验证 model.eval() val_loss = 0.0 with torch.no_grad(): for data, targets in val_loader: data, targets = data.to(device), targets.to(device) outputs = model(data) loss = criterion(outputs, targets) val_loss += loss.item() avg_train_loss = train_loss / len(train_loader) avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") # 保存最佳模型 if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth")) print("训练完成!") if __name__ == "__main__": main()