一个目标实时检测的模型

This commit is contained in:
jiao77
2025-03-31 14:49:04 +08:00
parent 956805997e
commit a5c63ad0de
6 changed files with 188 additions and 151 deletions

131
train.py
View File

@@ -1,99 +1,68 @@
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from datetime import datetime
import argparse
from models.superpoint_custom import SuperPointCustom
from utils.data_augmentation import generate_training_pair
# 导入项目模块(根据你的路径调整)
from models.rotation_cnn import RotationInvariantCNN # 模型实现
from data_units import LayoutDataset, layout_transforms # 数据集和预处理函数
class BinaryDataset(Dataset):
def __init__(self, image_dir, patch_size, num_channels):
self.image_dir = image_dir
self.patch_size = patch_size
self.num_channels = num_channels
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.endswith('.npy')]
# 设置随机种子(可选)
torch.manual_seed(42)
np.random.seed(42)
def __len__(self):
return len(self.image_paths)
def main():
"""训练流程"""
# 解析命令行参数
parser = argparse.ArgumentParser(description="Train Rotation-Invariant Layout Matcher")
parser.add_argument("--data_dir", type=str, default="./data/train/", help="训练数据目录")
parser.add_argument("--val_split", type=float, default=0.2, help="验证集比例")
parser.add_argument("--batch_size", type=int, default=16, help="批量大小")
parser.add_argument("--epochs", type=int, default=50, help="训练轮次")
parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
parser.add_argument("--model_save_dir", type=str, default="./models/", help="模型保存路径")
args = parser.parse_args()
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = np.load(img_path) # [C, H, W]
patch, warped_patch, H = generate_training_pair(image, self.patch_size)
patch = torch.from_numpy(patch).float()
warped_patch = torch.from_numpy(warped_patch).float()
return patch, warped_patch, torch.from_numpy(H).float()
# 创建输出目录
os.makedirs(args.model_save_dir, exist_ok=True)
def simple_detector_loss(semi, semi_w, H, device):
return F.mse_loss(semi, semi_w) # 简化版,实际需更复杂实现
# 数据加载
dataset = LayoutDataset(root_dir=args.data_dir, transform=layout_transforms())
total_samples = len(dataset)
val_size = int(total_samples * args.val_split)
train_size = total_samples - val_size
def simple_descriptor_loss(desc, desc_w, H, device):
return F.mse_loss(desc, desc_w) # 简化版,实际需更复杂实现
# 划分训练集和验证集
train_dataset, val_dataset = random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
learning_rate = 0.001
num_epochs = 10
image_dir = 'data/train_images' # 替换为实际路径
patch_size = 256
num_channels = 3 # 替换为实际通道数
# 初始化模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotationInvariantCNN().to(device) # 根据你的模型结构调整参数
criterion = nn.CrossEntropyLoss() # 分类任务示例,根据任务类型选择损失函数
optimizer = optim.Adam(model.parameters(), lr=args.lr)
dataset = BinaryDataset(image_dir, patch_size, num_channels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 训练循环
best_val_loss = float("inf")
for epoch in range(1, args.epochs + 1):
model = SuperPointCustom(num_channels=num_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for batch_idx, (data, targets) in enumerate(train_loader):
data, targets = data.to(device), targets.to(device)
# 前向传播
outputs = model(data)
loss = criterion(outputs, targets)
# 反向传播和优化
total_loss = 0
for patch, warped_patch, H in dataloader:
patch, warped_patch, H = patch.to(device), warped_patch.to(device), H.to(device)
semi, desc = model(patch)
semi_w, desc_w = model(warped_patch)
det_loss = simple_detector_loss(semi, semi_w, H, device)
desc_loss = simple_descriptor_loss(desc, desc_w, H, device)
loss = det_loss + desc_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
print(f"Epoch [{epoch}/{args.epochs}] Batch {batch_idx+1}/{len(train_loader)} Loss: {loss.item():.4f}")
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for data, targets in val_loader:
data, targets = data.to(device), targets.to(device)
outputs = model(data)
loss = criterion(outputs, targets)
val_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
# 保存最佳模型
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), os.path.join(args.model_save_dir, f"best_model_{datetime.now().strftime('%Y%m%d%H%M')}.pth"))
total_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
print("训练完成!")
torch.save(model.state_dict(), 'superpoint_custom_model.pth')
if __name__ == "__main__":
main()
train()