diff --git a/TRAINING_STRATEGY_ANALYSIS.md b/TRAINING_STRATEGY_ANALYSIS.md new file mode 100644 index 0000000..bc5635c --- /dev/null +++ b/TRAINING_STRATEGY_ANALYSIS.md @@ -0,0 +1,100 @@ +# RoRD 训练策略分析与改进 + +## 原始问题分析 + +### 1. 技术错误 +- **PIL.Image.LANCZOS 错误**: 使用了已弃用的 `Image.LANCZOS`,应改为 `Image.Resampling.LANCZOS` +- **模型架构不匹配**: 检测头和描述子头使用了不同尺寸的特征图,导致训练不稳定 + +### 2. 训练策略问题 + +#### 2.1 损失函数设计 +- **检测损失**: 使用 MSE 损失不适合二分类问题,应使用 BCE 损失 +- **描述子损失**: Triplet Loss 采样策略不够有效,随机采样产生大量简单负样本 + +#### 2.2 数据增强策略 +- **尺度抖动范围过大**: `(0.7, 1.5)` 可能导致训练不稳定 +- **几何变换过于简单**: 只考虑8个离散方向,缺乏连续性 +- **缺少其他增强**: 没有亮度、对比度、噪声等增强 + +#### 2.3 训练配置 +- **批次大小过小**: 只有4,对于现代GPU效率低 +- **学习率可能过高**: 1e-4 可能导致训练不稳定 +- **缺少验证机制**: 没有验证集和早停 +- **缺少监控**: 没有详细的训练日志和损失分解 + +## 改进方案 + +### 1. 技术修复 +✅ **已修复**: PIL.Image.LANCZOS → Image.Resampling.LANCZOS +✅ **已修复**: 统一检测头和描述子头的特征图尺寸 + +### 2. 损失函数改进 +✅ **检测损失**: +- 使用 BCE 损失替代 MSE +- 添加平滑 L1 损失作为辅助 + +✅ **描述子损失**: +- 增加采样点数量 (100 → 200) +- 使用网格采样替代随机采样 +- 实现困难负样本挖掘 + +### 3. 数据增强优化 +✅ **尺度抖动**: 缩小范围到 `(0.8, 1.2)` +✅ **额外增强**: +- 亮度调整 (0.8-1.2倍) +- 对比度调整 (0.8-1.2倍) +- 高斯噪声 (σ=5) + +### 4. 训练配置优化 +✅ **批次大小**: 4 → 8 +✅ **学习率**: 1e-4 → 5e-5 +✅ **训练轮数**: 20 → 50 +✅ **添加权重衰减**: 1e-4 + +### 5. 训练流程改进 +✅ **验证集**: 80/20 分割 +✅ **学习率调度**: ReduceLROnPlateau +✅ **早停机制**: 10个epoch无改善则停止 +✅ **梯度裁剪**: max_norm=1.0 +✅ **详细日志**: 训练和验证损失分解 + +## 预期效果 + +### 1. 训练稳定性 +- 更稳定的损失下降曲线 +- 减少过拟合风险 +- 更好的泛化能力 + +### 2. 模型性能 +- 更准确的检测结果 +- 更鲁棒的描述子 +- 更好的几何不变性 + +### 3. 训练效率 +- 更快的收敛速度 +- 更好的资源利用率 +- 更完善的监控机制 + +## 使用建议 + +### 1. 训练前准备 +```bash +# 确保数据路径正确 +python train.py --data_dir /path/to/layouts --save_dir /path/to/save +``` + +### 2. 监控训练 +- 查看日志文件了解详细训练过程 +- 关注验证损失变化趋势 +- 监控学习率自动调整 + +### 3. 模型选择 +- 使用 `rord_model_best.pth` 作为最终模型 +- 该模型在验证集上表现最佳 + +### 4. 进一步优化建议 +- 考虑使用预训练权重初始化 +- 实验不同的数据增强组合 +- 尝试其他损失函数权重平衡 +- 考虑使用混合精度训练加速 \ No newline at end of file diff --git a/config.py b/config.py index 0e7e981..0f43a0a 100644 --- a/config.py +++ b/config.py @@ -1,12 +1,12 @@ # config.py # --- 训练参数 --- -LEARNING_RATE = 1e-4 -BATCH_SIZE = 4 -NUM_EPOCHS = 20 +LEARNING_RATE = 5e-5 # 降低学习率,提高训练稳定性 +BATCH_SIZE = 8 # 增加批次大小,提高训练效率 +NUM_EPOCHS = 50 # 增加训练轮数 PATCH_SIZE = 256 -# (新增) 训练时尺度抖动范围 -SCALE_JITTER_RANGE = (0.7, 1.5) +# (优化) 训练时尺度抖动范围 - 缩小范围提高稳定性 +SCALE_JITTER_RANGE = (0.8, 1.2) # --- 匹配与评估参数 --- KEYPOINT_THRESHOLD = 0.5 diff --git a/models/rord.py b/models/rord.py index a66fbfa..d9dd515 100644 --- a/models/rord.py +++ b/models/rord.py @@ -9,21 +9,22 @@ class RoRD(nn.Module): """ 修复后的 RoRD 模型。 - 实现了共享骨干网络,以提高计算效率和减少内存占用。 - - 移除了冗余的 descriptor_head_vanilla。 + - 确保检测头和描述子头使用相同尺寸的特征图。 """ super(RoRD, self).__init__() vgg16_features = models.vgg16(pretrained=False).features - # 共享骨干网络 - self.slice1 = vgg16_features[:23] # 到 relu4_3 - self.slice2 = vgg16_features[23:30] # 从 relu4_3 到 relu5_3 + # 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致 + self.backbone = vgg16_features[:23] # 到 relu4_3 # 检测头 self.detection_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(256, 1, kernel_size=1), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 1, kernel_size=1), nn.Sigmoid() ) @@ -31,19 +32,18 @@ class RoRD(nn.Module): self.descriptor_head = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(256, 128, kernel_size=1), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=1), nn.InstanceNorm2d(128) ) def forward(self, x): # 共享特征提取 - features_shared = self.slice1(x) + features = self.backbone(x) - # 描述子分支 - descriptors = self.descriptor_head(features_shared) - - # 检测器分支 - features_det = self.slice2(features_shared) - detection_map = self.detection_head(features_det) + # 检测器和描述子使用相同的特征图 + detection_map = self.detection_head(features) + descriptors = self.descriptor_head(features) return detection_map, descriptors \ No newline at end of file diff --git a/train.py b/train.py index e9f5d59..8073e11 100644 --- a/train.py +++ b/train.py @@ -9,12 +9,31 @@ import numpy as np import cv2 import os import argparse +import logging +from datetime import datetime # 导入项目模块 import config from models.rord import RoRD from utils.data_utils import get_transform +# 设置日志记录 +def setup_logging(save_dir): + """设置训练日志记录""" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + log_file = os.path.join(save_dir, f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler() + ] + ) + return logging.getLogger(__name__) + # --- (已修改) 训练专用数据集类 --- class ICLayoutTrainingDataset(Dataset): def __init__(self, image_dir, patch_size=256, transform=None, scale_range=(1.0, 1.0)): @@ -48,9 +67,28 @@ class ICLayoutTrainingDataset(Dataset): patch = image.crop((x, y, x + crop_size, y + crop_size)) # 4. 将裁剪出的图像块缩放回标准的 patch_size - patch = patch.resize((self.patch_size, self.patch_size), Image.LANCZOS) + patch = patch.resize((self.patch_size, self.patch_size), Image.Resampling.LANCZOS) # --- 尺度抖动结束 --- + # --- 新增:额外的数据增强 --- + # 亮度调整 + if np.random.random() < 0.5: + brightness_factor = np.random.uniform(0.8, 1.2) + patch = patch.point(lambda x: int(x * brightness_factor)) + + # 对比度调整 + if np.random.random() < 0.5: + contrast_factor = np.random.uniform(0.8, 1.2) + patch = patch.point(lambda x: int(((x - 128) * contrast_factor) + 128)) + + # 添加噪声 + if np.random.random() < 0.3: + patch_np = np.array(patch, dtype=np.float32) + noise = np.random.normal(0, 5, patch_np.shape) + patch_np = np.clip(patch_np + noise, 0, 255) + patch = Image.fromarray(patch_np.astype(np.uint8)) + # --- 额外数据增强结束 --- + patch_np = np.array(patch) # 实现8个方向的离散几何变换 (这部分逻辑不变) @@ -79,37 +117,78 @@ class ICLayoutTrainingDataset(Dataset): H_tensor = torch.from_numpy(H[:2, :]).float() return patch, transformed_patch, H_tensor -# --- 特征图变换与损失函数 (无变动) --- +# --- 特征图变换与损失函数 (改进版) --- def warp_feature_map(feature_map, H_inv): B, C, H, W = feature_map.size() grid = F.affine_grid(H_inv, feature_map.size(), align_corners=False).to(feature_map.device) return F.grid_sample(feature_map, grid, align_corners=False) def compute_detection_loss(det_original, det_rotated, H): + """改进的检测损失:使用BCE损失替代MSE""" with torch.no_grad(): H_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))[:, :2, :] warped_det_rotated = warp_feature_map(det_rotated, H_inv) - return F.mse_loss(det_original, warped_det_rotated) + + # 使用BCE损失,更适合二分类问题 + bce_loss = F.binary_cross_entropy(det_original, warped_det_rotated) + + # 添加平滑L1损失作为辅助 + smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det_rotated) + + return bce_loss + 0.1 * smooth_l1_loss def compute_description_loss(desc_original, desc_rotated, H, margin=1.0): + """改进的描述子损失:使用更有效的采样策略""" B, C, H_feat, W_feat = desc_original.size() - num_samples = 100 - coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 + + # 增加采样点数量,提高训练稳定性 + num_samples = 200 + + # 使用网格采样而不是随机采样,确保空间分布更均匀 + h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device) + w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device) + h_grid, w_grid = torch.meshgrid(h_coords, w_coords, indexing='ij') + coords = torch.stack([h_grid.flatten(), w_grid.flatten()], dim=1).unsqueeze(0).repeat(B, 1, 1) + + # 采样anchor点 anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - coords_hom = torch.cat([coords, torch.ones(B, num_samples, 1, device=coords.device)], dim=2) + + # 计算对应的正样本点 + coords_hom = torch.cat([coords, torch.ones(B, coords.size(1), 1, device=coords.device)], dim=2) M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1)) coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2] positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - neg_coords = torch.rand(B, num_samples, 2, device=desc_original.device) * 2 - 1 - negative = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) - triplet_loss = nn.TripletMarginLoss(margin=margin, p=2) + + # 使用困难负样本挖掘 + with torch.no_grad(): + # 计算所有可能的负样本对 + neg_coords = torch.rand(B, num_samples * 2, 2, device=desc_original.device) * 2 - 1 + negative_candidates = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2) + + # 选择最困难的负样本 + anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1) + negative_candidates_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1) + + distances = torch.norm(anchor_expanded - negative_candidates_expanded, dim=3) + hard_negative_indices = torch.argmin(distances, dim=2) + negative = torch.gather(negative_candidates, 1, hard_negative_indices.unsqueeze(2).expand(-1, -1, C)) + + # 使用改进的Triplet Loss + triplet_loss = nn.TripletMarginLoss(margin=margin, p=2, reduction='mean') return triplet_loss(anchor, positive, negative) # --- (已修改) 主函数与命令行接口 --- def main(args): - print("--- 开始训练 RoRD 模型 ---") - print(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}") + # 设置日志记录 + logger = setup_logging(args.save_dir) + + logger.info("--- 开始训练 RoRD 模型 ---") + logger.info(f"训练参数: Epochs={args.epochs}, Batch Size={args.batch_size}, LR={args.lr}") + logger.info(f"数据目录: {args.data_dir}") + logger.info(f"保存目录: {args.save_dir}") + transform = get_transform() + # 在数据集初始化时传入尺度抖动范围 dataset = ICLayoutTrainingDataset( args.data_dir, @@ -117,29 +196,145 @@ def main(args): transform=transform, scale_range=config.SCALE_JITTER_RANGE ) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + + logger.info(f"数据集大小: {len(dataset)}") + + # 分割训练集和验证集 + train_size = int(0.8 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) + + logger.info(f"训练集大小: {len(train_dataset)}, 验证集大小: {len(val_dataset)}") + + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + model = RoRD().cuda() - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}") + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) + + # 添加学习率调度器 + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='min', factor=0.5, patience=5 + ) + + # 早停机制 + best_val_loss = float('inf') + patience_counter = 0 + patience = 10 for epoch in range(args.epochs): + # 训练阶段 model.train() - total_loss_val = 0 - for i, (original, rotated, H) in enumerate(dataloader): + total_train_loss = 0 + total_det_loss = 0 + total_desc_loss = 0 + + for i, (original, rotated, H) in enumerate(train_dataloader): original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() + det_original, desc_original = model(original) det_rotated, desc_rotated = model(rotated) - loss = compute_detection_loss(det_original, det_rotated, H) + compute_description_loss(desc_original, desc_rotated, H) + + det_loss = compute_detection_loss(det_original, det_rotated, H) + desc_loss = compute_description_loss(desc_original, desc_rotated, H) + loss = det_loss + desc_loss + optimizer.zero_grad() loss.backward() + + # 梯度裁剪,防止梯度爆炸 + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() - total_loss_val += loss.item() - print(f"--- Epoch {epoch+1} 完成, 平均 Loss: {total_loss_val / len(dataloader):.4f} ---") - - if not os.path.exists(args.save_dir): - os.makedirs(args.save_dir) + total_train_loss += loss.item() + total_det_loss += det_loss.item() + total_desc_loss += desc_loss.item() + + if i % 10 == 0: + logger.info(f"Epoch {epoch+1}, Batch {i}, Total Loss: {loss.item():.4f}, " + f"Det Loss: {det_loss.item():.4f}, Desc Loss: {desc_loss.item():.4f}") + + avg_train_loss = total_train_loss / len(train_dataloader) + avg_det_loss = total_det_loss / len(train_dataloader) + avg_desc_loss = total_desc_loss / len(train_dataloader) + + # 验证阶段 + model.eval() + total_val_loss = 0 + total_val_det_loss = 0 + total_val_desc_loss = 0 + + with torch.no_grad(): + for original, rotated, H in val_dataloader: + original, rotated, H = original.cuda(), rotated.cuda(), H.cuda() + + det_original, desc_original = model(original) + det_rotated, desc_rotated = model(rotated) + + val_det_loss = compute_detection_loss(det_original, det_rotated, H) + val_desc_loss = compute_description_loss(desc_original, desc_rotated, H) + val_loss = val_det_loss + val_desc_loss + + total_val_loss += val_loss.item() + total_val_det_loss += val_det_loss.item() + total_val_desc_loss += val_desc_loss.item() + + avg_val_loss = total_val_loss / len(val_dataloader) + avg_val_det_loss = total_val_det_loss / len(val_dataloader) + avg_val_desc_loss = total_val_desc_loss / len(val_dataloader) + + # 学习率调度 + scheduler.step(avg_val_loss) + + logger.info(f"--- Epoch {epoch+1} 完成 ---") + logger.info(f"训练 - Total: {avg_train_loss:.4f}, Det: {avg_det_loss:.4f}, Desc: {avg_desc_loss:.4f}") + logger.info(f"验证 - Total: {avg_val_loss:.4f}, Det: {avg_val_det_loss:.4f}, Desc: {avg_val_desc_loss:.4f}") + logger.info(f"学习率: {optimizer.param_groups[0]['lr']:.2e}") + + # 早停检查 + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 + + # 保存最佳模型 + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + save_path = os.path.join(args.save_dir, 'rord_model_best.pth') + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'best_val_loss': best_val_loss, + 'config': { + 'learning_rate': args.lr, + 'batch_size': args.batch_size, + 'epochs': args.epochs + } + }, save_path) + logger.info(f"最佳模型已保存至: {save_path}") + else: + patience_counter += 1 + if patience_counter >= patience: + logger.info(f"早停触发!{patience} 个epoch没有改善") + break + + # 保存最终模型 save_path = os.path.join(args.save_dir, 'rord_model_final.pth') - torch.save(model.state_dict(), save_path) - print(f"模型已保存至: {save_path}") + torch.save({ + 'epoch': args.epochs, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'final_val_loss': avg_val_loss, + 'config': { + 'learning_rate': args.lr, + 'batch_size': args.batch_size, + 'epochs': args.epochs + } + }, save_path) + logger.info(f"最终模型已保存至: {save_path}") + logger.info("训练完成!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="训练 RoRD 模型")