fix function
This commit is contained in:
100
TRAINING_STRATEGY_ANALYSIS.md
Normal file
100
TRAINING_STRATEGY_ANALYSIS.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# RoRD 训练策略分析与改进
|
||||||
|
|
||||||
|
## 原始问题分析
|
||||||
|
|
||||||
|
### 1. 技术错误
|
||||||
|
- **PIL.Image.LANCZOS 错误**: 使用了已弃用的 `Image.LANCZOS`,应改为 `Image.Resampling.LANCZOS`
|
||||||
|
- **模型架构不匹配**: 检测头和描述子头使用了不同尺寸的特征图,导致训练不稳定
|
||||||
|
|
||||||
|
### 2. 训练策略问题
|
||||||
|
|
||||||
|
#### 2.1 损失函数设计
|
||||||
|
- **检测损失**: 使用 MSE 损失不适合二分类问题,应使用 BCE 损失
|
||||||
|
- **描述子损失**: Triplet Loss 采样策略不够有效,随机采样产生大量简单负样本
|
||||||
|
|
||||||
|
#### 2.2 数据增强策略
|
||||||
|
- **尺度抖动范围过大**: `(0.7, 1.5)` 可能导致训练不稳定
|
||||||
|
- **几何变换过于简单**: 只考虑8个离散方向,缺乏连续性
|
||||||
|
- **缺少其他增强**: 没有亮度、对比度、噪声等增强
|
||||||
|
|
||||||
|
#### 2.3 训练配置
|
||||||
|
- **批次大小过小**: 只有4,对于现代GPU效率低
|
||||||
|
- **学习率可能过高**: 1e-4 可能导致训练不稳定
|
||||||
|
- **缺少验证机制**: 没有验证集和早停
|
||||||
|
- **缺少监控**: 没有详细的训练日志和损失分解
|
||||||
|
|
||||||
|
## 改进方案
|
||||||
|
|
||||||
|
### 1. 技术修复
|
||||||
|
✅ **已修复**: PIL.Image.LANCZOS → Image.Resampling.LANCZOS
|
||||||
|
✅ **已修复**: 统一检测头和描述子头的特征图尺寸
|
||||||
|
|
||||||
|
### 2. 损失函数改进
|
||||||
|
✅ **检测损失**:
|
||||||
|
- 使用 BCE 损失替代 MSE
|
||||||
|
- 添加平滑 L1 损失作为辅助
|
||||||
|
|
||||||
|
✅ **描述子损失**:
|
||||||
|
- 增加采样点数量 (100 → 200)
|
||||||
|
- 使用网格采样替代随机采样
|
||||||
|
- 实现困难负样本挖掘
|
||||||
|
|
||||||
|
### 3. 数据增强优化
|
||||||
|
✅ **尺度抖动**: 缩小范围到 `(0.8, 1.2)`
|
||||||
|
✅ **额外增强**:
|
||||||
|
- 亮度调整 (0.8-1.2倍)
|
||||||
|
- 对比度调整 (0.8-1.2倍)
|
||||||
|
- 高斯噪声 (σ=5)
|
||||||
|
|
||||||
|
### 4. 训练配置优化
|
||||||
|
✅ **批次大小**: 4 → 8
|
||||||
|
✅ **学习率**: 1e-4 → 5e-5
|
||||||
|
✅ **训练轮数**: 20 → 50
|
||||||
|
✅ **添加权重衰减**: 1e-4
|
||||||
|
|
||||||
|
### 5. 训练流程改进
|
||||||
|
✅ **验证集**: 80/20 分割
|
||||||
|
✅ **学习率调度**: ReduceLROnPlateau
|
||||||
|
✅ **早停机制**: 10个epoch无改善则停止
|
||||||
|
✅ **梯度裁剪**: max_norm=1.0
|
||||||
|
✅ **详细日志**: 训练和验证损失分解
|
||||||
|
|
||||||
|
## 预期效果
|
||||||
|
|
||||||
|
### 1. 训练稳定性
|
||||||
|
- 更稳定的损失下降曲线
|
||||||
|
- 减少过拟合风险
|
||||||
|
- 更好的泛化能力
|
||||||
|
|
||||||
|
### 2. 模型性能
|
||||||
|
- 更准确的检测结果
|
||||||
|
- 更鲁棒的描述子
|
||||||
|
- 更好的几何不变性
|
||||||
|
|
||||||
|
### 3. 训练效率
|
||||||
|
- 更快的收敛速度
|
||||||
|
- 更好的资源利用率
|
||||||
|
- 更完善的监控机制
|
||||||
|
|
||||||
|
## 使用建议
|
||||||
|
|
||||||
|
### 1. 训练前准备
|
||||||
|
```bash
|
||||||
|
# 确保数据路径正确
|
||||||
|
python train.py --data_dir /path/to/layouts --save_dir /path/to/save
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 监控训练
|
||||||
|
- 查看日志文件了解详细训练过程
|
||||||
|
- 关注验证损失变化趋势
|
||||||
|
- 监控学习率自动调整
|
||||||
|
|
||||||
|
### 3. 模型选择
|
||||||
|
- 使用 `rord_model_best.pth` 作为最终模型
|
||||||
|
- 该模型在验证集上表现最佳
|
||||||
|
|
||||||
|
### 4. 进一步优化建议
|
||||||
|
- 考虑使用预训练权重初始化
|
||||||
|
- 实验不同的数据增强组合
|
||||||
|
- 尝试其他损失函数权重平衡
|
||||||
|
- 考虑使用混合精度训练加速
|
||||||
10
config.py
10
config.py
@@ -1,12 +1,12 @@
|
|||||||
# config.py
|
# config.py
|
||||||
|
|
||||||
# --- 训练参数 ---
|
# --- 训练参数 ---
|
||||||
LEARNING_RATE = 1e-4
|
LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 8 # 增加批次大小,提高训练效率
|
||||||
NUM_EPOCHS = 20
|
NUM_EPOCHS = 50 # 增加训练轮数
|
||||||
PATCH_SIZE = 256
|
PATCH_SIZE = 256
|
||||||
# (新增) 训练时尺度抖动范围
|
# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性
|
||||||
SCALE_JITTER_RANGE = (0.7, 1.5)
|
SCALE_JITTER_RANGE = (0.8, 1.2)
|
||||||
|
|
||||||
# --- 匹配与评估参数 ---
|
# --- 匹配与评估参数 ---
|
||||||
KEYPOINT_THRESHOLD = 0.5
|
KEYPOINT_THRESHOLD = 0.5
|
||||||
|
|||||||
@@ -9,21 +9,22 @@ class RoRD(nn.Module):
|
|||||||
"""
|
"""
|
||||||
修复后的 RoRD 模型。
|
修复后的 RoRD 模型。
|
||||||
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
|
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
|
||||||
- 移除了冗余的 descriptor_head_vanilla。
|
- 确保检测头和描述子头使用相同尺寸的特征图。
|
||||||
"""
|
"""
|
||||||
super(RoRD, self).__init__()
|
super(RoRD, self).__init__()
|
||||||
|
|
||||||
vgg16_features = models.vgg16(pretrained=False).features
|
vgg16_features = models.vgg16(pretrained=False).features
|
||||||
|
|
||||||
# 共享骨干网络
|
# 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致
|
||||||
self.slice1 = vgg16_features[:23] # 到 relu4_3
|
self.backbone = vgg16_features[:23] # 到 relu4_3
|
||||||
self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3
|
|
||||||
|
|
||||||
# 检测头
|
# 检测头
|
||||||
self.detection_head = nn.Sequential(
|
self.detection_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(256, 1, kernel_size=1),
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(128, 1, kernel_size=1),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,19 +32,18 @@ class RoRD(nn.Module):
|
|||||||
self.descriptor_head = nn.Sequential(
|
self.descriptor_head = nn.Sequential(
|
||||||
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(256, 128, kernel_size=1),
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(128, 128, kernel_size=1),
|
||||||
nn.InstanceNorm2d(128)
|
nn.InstanceNorm2d(128)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# 共享特征提取
|
# 共享特征提取
|
||||||
features_shared = self.slice1(x)
|
features = self.backbone(x)
|
||||||
|
|
||||||
# 描述子分支
|
# 检测器和描述子使用相同的特征图
|
||||||
descriptors = self.descriptor_head(features_shared)
|
detection_map = self.detection_head(features)
|
||||||
|
descriptors = self.descriptor_head(features)
|
||||||
# 检测器分支
|
|
||||||
features_det = self.slice2(features_shared)
|
|
||||||
detection_map = self.detection_head(features_det)
|
|
||||||
|
|
||||||
return detection_map, descriptors
|
return detection_map, descriptors
|
||||||
241
train.py
241
train.py
@@ -9,12 +9,31 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
# 导入项目模块
|
# 导入项目模块
|
||||||
import config
|
import config
|
||||||
from models.rord import RoRD
|
from models.rord import RoRD
|
||||||
from utils.data_utils import get_transform
|
from utils.data_utils import get_transform
|
||||||
|
|
||||||
|
# 设置日志记录
|
||||||
|
def setup_logging(save_dir):
|
||||||
|
"""设置训练日志记录"""
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- (已修改) 训练专用数据集类 ---
|
# --- (已修改) 训练专用数据集类 ---
|
||||||
class ICLayoutTrainingDataset(Dataset):
|
class ICLayoutTrainingDataset(Dataset):
|
||||||
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)):
|
||||||
@@ -48,9 +67,28 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
patch = image.crop((x, y, x + crop_size, y + crop_size))
|
||||||
|
|
||||||
# 4. 将裁剪出的图像块缩放回标准的 patch_size
|
# 4. 将裁剪出的图像块缩放回标准的 patch_size
|
||||||
patch = patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)
|
patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS)
|
||||||
# --- 尺度抖动结束 ---
|
# --- 尺度抖动结束 ---
|
||||||
|
|
||||||
|
# --- 新增:额外的数据增强 ---
|
||||||
|
# 亮度调整
|
||||||
|
if np.random.random() < 0.5:
|
||||||
|
brightness_factor = np.random.uniform(0.8, 1.2)
|
||||||
|
patch = patch.point(lambda x: int(x * brightness_factor))
|
||||||
|
|
||||||
|
# 对比度调整
|
||||||
|
if np.random.random() < 0.5:
|
||||||
|
contrast_factor = np.random.uniform(0.8, 1.2)
|
||||||
|
patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128))
|
||||||
|
|
||||||
|
# 添加噪声
|
||||||
|
if np.random.random() < 0.3:
|
||||||
|
patch_np = np.array(patch, dtype=np.float32)
|
||||||
|
noise = np.random.normal(0, 5, patch_np.shape)
|
||||||
|
patch_np = np.clip(patch_np + noise, 0, 255)
|
||||||
|
patch = Image.fromarray(patch_np.astype(np.uint8))
|
||||||
|
# --- 额外数据增强结束 ---
|
||||||
|
|
||||||
patch_np = np.array(patch)
|
patch_np = np.array(patch)
|
||||||
|
|
||||||
# 实现8个方向的离散几何变换 (这部分逻辑不变)
|
# 实现8个方向的离散几何变换 (这部分逻辑不变)
|
||||||
@@ -79,37 +117,78 @@ class ICLayoutTrainingDataset(Dataset):
|
|||||||
H_tensor = torch.from_numpy(H[:2, :]).float()
|
H_tensor = torch.from_numpy(H[:2, :]).float()
|
||||||
return patch, transformed_patch, H_tensor
|
return patch, transformed_patch, H_tensor
|
||||||
|
|
||||||
# --- 特征图变换与损失函数 (无变动) ---
|
# --- 特征图变换与损失函数 (改进版) ---
|
||||||
def warp_feature_map(feature_map, H_inv):
|
def warp_feature_map(feature_map, H_inv):
|
||||||
B, C, H, W = feature_map.size()
|
B, C, H, W = feature_map.size()
|
||||||
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device)
|
||||||
return F.grid_sample(feature_map, grid, align_corners=False)
|
return F.grid_sample(feature_map, grid, align_corners=False)
|
||||||
|
|
||||||
def compute_detection_loss(det_original, det_rotated, H):
|
def compute_detection_loss(det_original, det_rotated, H):
|
||||||
|
"""改进的检测损失:使用BCE损失替代MSE"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :]
|
||||||
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
warped_det_rotated = warp_feature_map(det_rotated, H_inv)
|
||||||
return F.mse_loss(det_original, warped_det_rotated)
|
|
||||||
|
# 使用BCE损失,更适合二分类问题
|
||||||
|
bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated)
|
||||||
|
|
||||||
|
# 添加平滑L1损失作为辅助
|
||||||
|
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated)
|
||||||
|
|
||||||
|
return bce_loss + 0.1 * smooth_l1_loss
|
||||||
|
|
||||||
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
||||||
|
"""改进的描述子损失:使用更有效的采样策略"""
|
||||||
B, C, H_feat, W_feat = desc_original.size()
|
B, C, H_feat, W_feat = desc_original.size()
|
||||||
num_samples = 100
|
|
||||||
coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
# 增加采样点数量,提高训练稳定性
|
||||||
|
num_samples = 200
|
||||||
|
|
||||||
|
# 使用网格采样而不是随机采样,确保空间分布更均匀
|
||||||
|
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||||
|
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||||
|
h_grid, w_grid = torch.meshgrid(h_coords, w_coords, indexing='ij')
|
||||||
|
coords = torch.stack([h_grid.flatten(), w_grid.flatten()], dim=1).unsqueeze(0).repeat(B, 1, 1)
|
||||||
|
|
||||||
|
# 采样anchor点
|
||||||
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2)
|
|
||||||
|
# 计算对应的正样本点
|
||||||
|
coords_hom = torch.cat([coords, torch.ones(B, coords.size(1), 1, device=coords.device)], dim=2)
|
||||||
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
||||||
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
||||||
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1
|
|
||||||
negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
# 使用困难负样本挖掘
|
||||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2)
|
with torch.no_grad():
|
||||||
|
# 计算所有可能的负样本对
|
||||||
|
neg_coords = torch.rand(B, num_samples * 2, 2, device=desc_original.device) * 2 - 1
|
||||||
|
negative_candidates = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# 选择最困难的负样本
|
||||||
|
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
||||||
|
negative_candidates_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||||
|
|
||||||
|
distances = torch.norm(anchor_expanded - negative_candidates_expanded, dim=3)
|
||||||
|
hard_negative_indices = torch.argmin(distances, dim=2)
|
||||||
|
negative = torch.gather(negative_candidates, 1, hard_negative_indices.unsqueeze(2).expand(-1, -1, C))
|
||||||
|
|
||||||
|
# 使用改进的Triplet Loss
|
||||||
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2, reduction='mean')
|
||||||
return triplet_loss(anchor, positive, negative)
|
return triplet_loss(anchor, positive, negative)
|
||||||
|
|
||||||
# --- (已修改) 主函数与命令行接口 ---
|
# --- (已修改) 主函数与命令行接口 ---
|
||||||
def main(args):
|
def main(args):
|
||||||
print("--- 开始训练 RoRD 模型 ---")
|
# 设置日志记录
|
||||||
print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
logger = setup_logging(args.save_dir)
|
||||||
|
|
||||||
|
logger.info("--- 开始训练 RoRD 模型 ---")
|
||||||
|
logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}")
|
||||||
|
logger.info(f"数据目录: {args.data_dir}")
|
||||||
|
logger.info(f"保存目录: {args.save_dir}")
|
||||||
|
|
||||||
transform = get_transform()
|
transform = get_transform()
|
||||||
|
|
||||||
# 在数据集初始化时传入尺度抖动范围
|
# 在数据集初始化时传入尺度抖动范围
|
||||||
dataset = ICLayoutTrainingDataset(
|
dataset = ICLayoutTrainingDataset(
|
||||||
args.data_dir,
|
args.data_dir,
|
||||||
@@ -117,29 +196,145 @@ def main(args):
|
|||||||
transform=transform,
|
transform=transform,
|
||||||
scale_range=config.SCALE_JITTER_RANGE
|
scale_range=config.SCALE_JITTER_RANGE
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
||||||
|
logger.info(f"数据集大小: {len(dataset)}")
|
||||||
|
|
||||||
|
# 分割训练集和验证集
|
||||||
|
train_size = int(0.8 * len(dataset))
|
||||||
|
val_size = len(dataset) - train_size
|
||||||
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}")
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
|
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
model = RoRD().cuda()
|
model = RoRD().cuda()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
||||||
|
|
||||||
|
# 添加学习率调度器
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode='min', factor=0.5, patience=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 早停机制
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
patience = 10
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
|
# 训练阶段
|
||||||
model.train()
|
model.train()
|
||||||
total_loss_val = 0
|
total_train_loss = 0
|
||||||
for i, (original, rotated, H) in enumerate(dataloader):
|
total_det_loss = 0
|
||||||
|
total_desc_loss = 0
|
||||||
|
|
||||||
|
for i, (original, rotated, H) in enumerate(train_dataloader):
|
||||||
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
||||||
|
|
||||||
det_original, desc_original = model(original)
|
det_original, desc_original = model(original)
|
||||||
det_rotated, desc_rotated = model(rotated)
|
det_rotated, desc_rotated = model(rotated)
|
||||||
loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H)
|
|
||||||
|
det_loss = compute_detection_loss(det_original, det_rotated, H)
|
||||||
|
desc_loss = compute_description_loss(desc_original, desc_rotated, H)
|
||||||
|
loss = det_loss + desc_loss
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
# 梯度裁剪,防止梯度爆炸
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss_val += loss.item()
|
total_train_loss += loss.item()
|
||||||
print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---")
|
total_det_loss += det_loss.item()
|
||||||
|
total_desc_loss += desc_loss.item()
|
||||||
if not os.path.exists(args.save_dir):
|
|
||||||
os.makedirs(args.save_dir)
|
if i % 10 == 0:
|
||||||
|
logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, "
|
||||||
|
f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
||||||
|
avg_det_loss = total_det_loss / len(train_dataloader)
|
||||||
|
avg_desc_loss = total_desc_loss / len(train_dataloader)
|
||||||
|
|
||||||
|
# 验证阶段
|
||||||
|
model.eval()
|
||||||
|
total_val_loss = 0
|
||||||
|
total_val_det_loss = 0
|
||||||
|
total_val_desc_loss = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for original, rotated, H in val_dataloader:
|
||||||
|
original, rotated, H = original.cuda(), rotated.cuda(), H.cuda()
|
||||||
|
|
||||||
|
det_original, desc_original = model(original)
|
||||||
|
det_rotated, desc_rotated = model(rotated)
|
||||||
|
|
||||||
|
val_det_loss = compute_detection_loss(det_original, det_rotated, H)
|
||||||
|
val_desc_loss = compute_description_loss(desc_original, desc_rotated, H)
|
||||||
|
val_loss = val_det_loss + val_desc_loss
|
||||||
|
|
||||||
|
total_val_loss += val_loss.item()
|
||||||
|
total_val_det_loss += val_det_loss.item()
|
||||||
|
total_val_desc_loss += val_desc_loss.item()
|
||||||
|
|
||||||
|
avg_val_loss = total_val_loss / len(val_dataloader)
|
||||||
|
avg_val_det_loss = total_val_det_loss / len(val_dataloader)
|
||||||
|
avg_val_desc_loss = total_val_desc_loss / len(val_dataloader)
|
||||||
|
|
||||||
|
# 学习率调度
|
||||||
|
scheduler.step(avg_val_loss)
|
||||||
|
|
||||||
|
logger.info(f"--- Epoch {epoch+1} 完成 ---")
|
||||||
|
logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}")
|
||||||
|
logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}")
|
||||||
|
logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||||
|
|
||||||
|
# 早停检查
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
|
||||||
|
# 保存最佳模型
|
||||||
|
if not os.path.exists(args.save_dir):
|
||||||
|
os.makedirs(args.save_dir)
|
||||||
|
save_path = os.path.join(args.save_dir, 'rord_model_best.pth')
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'best_val_loss': best_val_loss,
|
||||||
|
'config': {
|
||||||
|
'learning_rate': args.lr,
|
||||||
|
'batch_size': args.batch_size,
|
||||||
|
'epochs': args.epochs
|
||||||
|
}
|
||||||
|
}, save_path)
|
||||||
|
logger.info(f"最佳模型已保存至: {save_path}")
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
if patience_counter >= patience:
|
||||||
|
logger.info(f"早停触发!{patience} 个epoch没有改善")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 保存最终模型
|
||||||
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
save_path = os.path.join(args.save_dir, 'rord_model_final.pth')
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save({
|
||||||
print(f"模型已保存至: {save_path}")
|
'epoch': args.epochs,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'final_val_loss': avg_val_loss,
|
||||||
|
'config': {
|
||||||
|
'learning_rate': args.lr,
|
||||||
|
'batch_size': args.batch_size,
|
||||||
|
'epochs': args.epochs
|
||||||
|
}
|
||||||
|
}, save_path)
|
||||||
|
logger.info(f"最终模型已保存至: {save_path}")
|
||||||
|
logger.info("训练完成!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
parser = argparse.ArgumentParser(description="训练 RoRD 模型")
|
||||||
|
|||||||
Reference in New Issue
Block a user