initial commit
This commit is contained in:
99
train.py
Normal file
99
train.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
# 导入项目模块(根据你的路径调整)
|
||||
from models.rotation_cnn import RotationInvariantCNN # 模型实现
|
||||
from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数
|
||||
|
||||
# 设置随机种子(可选)
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def main():
|
||||
"""训练流程"""
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher")
|
||||
parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录")
|
||||
parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="批量大小")
|
||||
parser.add_argument("--epochs", type=int, default=50, help="训练轮次")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
|
||||
parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.model_save_dir, exist_ok=True)
|
||||
|
||||
# 数据加载
|
||||
dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms())
|
||||
total_samples = len(dataset)
|
||||
val_size = int(total_samples * args.val_split)
|
||||
train_size = total_samples - val_size
|
||||
|
||||
# 划分训练集和验证集
|
||||
train_dataset, val_dataset = random_split(
|
||||
dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
# 初始化模型、损失函数和优化器
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数
|
||||
criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# 训练循环
|
||||
best_val_loss = float("inf")
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
||||
data, targets = data.to(device), targets.to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(data)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}")
|
||||
|
||||
# 验证
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for data, targets in val_loader:
|
||||
data, targets = data.to(device), targets.to(device)
|
||||
outputs = model(data)
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_train_loss = train_loss / len(train_loader)
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# 保存最佳模型
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth"))
|
||||
|
||||
print("训练完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user