103 lines
3.7 KiB
Markdown
103 lines
3.7 KiB
Markdown
|
|
# RoRD 模型训练损失函数详解
|
|||
|
|
|
|||
|
|
本文档详细描述了 RoRD(Robust Layout Representation and Detection)模型训练过程中使用的损失函数设计。
|
|||
|
|
|
|||
|
|
## 1. 检测损失(Detection Loss)
|
|||
|
|
|
|||
|
|
### 数学公式
|
|||
|
|
$$L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))$$
|
|||
|
|
|
|||
|
|
### 组成说明
|
|||
|
|
- **BCE损失**:二元交叉熵损失,适用于二分类检测任务
|
|||
|
|
- 衡量原始检测图与变换后检测图之间的差异
|
|||
|
|
- 公式:
|
|||
|
|
$$\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]$$
|
|||
|
|
|
|||
|
|
- **Smooth L1损失**:平滑L1损失,对异常值更鲁棒
|
|||
|
|
- 公式:
|
|||
|
|
$$\text{SmoothL1}(x) = \begin{cases}
|
|||
|
|
0.5x^2 & \text{if } |x| < 1 \\
|
|||
|
|
|x| - 0.5 & \text{otherwise}
|
|||
|
|
\end{cases}$$
|
|||
|
|
- 作为BCE损失的辅助正则项
|
|||
|
|
|
|||
|
|
- **权重比例**:
|
|||
|
|
- BCE损失:权重 1.0(主导损失)
|
|||
|
|
- Smooth L1损失:权重 0.1(辅助正则)
|
|||
|
|
|
|||
|
|
### 空间变换
|
|||
|
|
- **warp操作**:使用逆变换矩阵H⁻¹对特征图进行空间变换对齐
|
|||
|
|
- **实现**:通过`F.affine_grid`和`F.grid_sample`完成
|
|||
|
|
|
|||
|
|
## 2. 描述子损失(Descriptor Loss)
|
|||
|
|
|
|||
|
|
### Triplet Loss公式
|
|||
|
|
$$L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)$$
|
|||
|
|
|
|||
|
|
### 符号定义
|
|||
|
|
- **a** (anchor):原始图像的描述子特征
|
|||
|
|
- **p** (positive):变换后图像对应位置的描述子特征
|
|||
|
|
- **n** (negative):困难负样本的描述子特征
|
|||
|
|
- **margin**:边界参数,默认值为1.0
|
|||
|
|
- **f(·)**:描述子特征提取函数
|
|||
|
|
|
|||
|
|
### 采样策略
|
|||
|
|
|
|||
|
|
#### 正样本采样
|
|||
|
|
- **采样方法**:均匀网格采样
|
|||
|
|
- **采样点数**:200个点
|
|||
|
|
- **空间分布**:在特征图上均匀分布,确保训练稳定性
|
|||
|
|
|
|||
|
|
#### 困难负样本挖掘
|
|||
|
|
1. **候选生成**:随机生成负样本坐标点
|
|||
|
|
2. **距离计算**:计算anchor与所有负候选的距离
|
|||
|
|
3. **选择策略**:选择距离最近的负样本作为困难负样本
|
|||
|
|
4. **计算优化**:使用`torch.gather`高效选择
|
|||
|
|
|
|||
|
|
### 实现细节
|
|||
|
|
- **特征维度**:128维描述子向量
|
|||
|
|
- **归一化**:使用InstanceNorm进行特征归一化
|
|||
|
|
- **距离度量**:L2范数(欧氏距离)
|
|||
|
|
- **损失函数**:`nn.TripletMarginLoss(margin=1.0, p=2)`
|
|||
|
|
|
|||
|
|
## 3. 总损失函数
|
|||
|
|
|
|||
|
|
### 最终公式
|
|||
|
|
$$L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}$$
|
|||
|
|
|
|||
|
|
### 设计特点
|
|||
|
|
- **无权重平衡**:两个损失直接相加,依靠网络自动学习平衡
|
|||
|
|
- **端到端训练**:检测和描述任务联合优化
|
|||
|
|
- **多任务学习**:同时学习几何变换不变性和特征描述能力
|
|||
|
|
|
|||
|
|
## 4. 训练策略
|
|||
|
|
|
|||
|
|
### 损失优化
|
|||
|
|
- **优化器**:Adam优化器
|
|||
|
|
- **学习率**:初始1e-3,使用ReduceLROnPlateau调度
|
|||
|
|
- **梯度裁剪**:max_norm=1.0,防止梯度爆炸
|
|||
|
|
|
|||
|
|
### 验证指标
|
|||
|
|
- **检测损失**:验证集上的检测任务性能
|
|||
|
|
- **描述子损失**:验证集上的特征匹配性能
|
|||
|
|
- **总损失**:两个损失的加权和
|
|||
|
|
|
|||
|
|
## 5. 实现代码位置
|
|||
|
|
|
|||
|
|
- **检测损失**:`train.py::compute_detection_loss()`(第126-138行)
|
|||
|
|
- **描述子损失**:`train.py::compute_description_loss()`(第140-178行)
|
|||
|
|
- **总损失**:`train.py::main()`(第242行)
|
|||
|
|
|
|||
|
|
## 6. 数学符号对照表
|
|||
|
|
|
|||
|
|
| 符号 | 含义 | 维度 |
|
|||
|
|
|------|------|------|
|
|||
|
|
| det_original | 原始图像检测图 | (B, 1, H, W) |
|
|||
|
|
| det_rotated | 变换图像检测图 | (B, 1, H, W) |
|
|||
|
|
| desc_original | 原始图像描述子 | (B, 128, H, W) |
|
|||
|
|
| desc_rotated | 变换图像描述子 | (B, 128, H, W) |
|
|||
|
|
| H | 几何变换矩阵 | (B, 3, 3) |
|
|||
|
|
| margin | Triplet Loss边界 | 标量 |
|
|||
|
|
| B | 批次大小 | 标量 |
|
|||
|
|
| C | 特征维度 | 128 |
|
|||
|
|
| H, W | 特征图高宽 | 标量 |
|