Files
RoRD-Layout-Recognation/docs/loss_function.md
2025-07-20 16:06:33 +08:00

103 lines
3.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# RoRD 模型训练损失函数详解
本文档详细描述了 RoRDRobust Layout Representation and Detection模型训练过程中使用的损失函数设计。
## 1. 检测损失Detection Loss
### 数学公式
$$L_{\text{det}} = \text{BCE}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1})) + 0.1 \times \text{SmoothL1}(\text{det}_{\text{original}}, \text{warp}(\text{det}_{\text{rotated}}, H^{-1}))$$
### 组成说明
- **BCE损失**:二元交叉熵损失,适用于二分类检测任务
- 衡量原始检测图与变换后检测图之间的差异
- 公式:
$$\text{BCE}(y, \hat{y}) = -[y \cdot \log(\hat{y}) + (1-y) \cdot \log(1-\hat{y})]$$
- **Smooth L1损失**平滑L1损失对异常值更鲁棒
- 公式:
$$\text{SmoothL1}(x) = \begin{cases}
0.5x^2 & \text{if } |x| < 1 \\
|x| - 0.5 & \text{otherwise}
\end{cases}$$
- 作为BCE损失的辅助正则项
- **权重比例**
- BCE损失权重 1.0主导损失
- Smooth L1损失权重 0.1辅助正则
### 空间变换
- **warp操作**使用逆变换矩阵H⁻¹对特征图进行空间变换对齐
- **实现**通过`F.affine_grid``F.grid_sample`完成
## 2. 描述子损失Descriptor Loss
### Triplet Loss公式
$$L_{\text{desc}} = \max\left(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + \text{margin}\right)$$
### 符号定义
- **a** (anchor)原始图像的描述子特征
- **p** (positive)变换后图像对应位置的描述子特征
- **n** (negative)困难负样本的描述子特征
- **margin**边界参数默认值为1.0
- **f(·)**描述子特征提取函数
### 采样策略
#### 正样本采样
- **采样方法**均匀网格采样
- **采样点数**200个点
- **空间分布**在特征图上均匀分布确保训练稳定性
#### 困难负样本挖掘
1. **候选生成**随机生成负样本坐标点
2. **距离计算**计算anchor与所有负候选的距离
3. **选择策略**选择距离最近的负样本作为困难负样本
4. **计算优化**使用`torch.gather`高效选择
### 实现细节
- **特征维度**128维描述子向量
- **归一化**使用InstanceNorm进行特征归一化
- **距离度量**L2范数欧氏距离
- **损失函数**`nn.TripletMarginLoss(margin=1.0, p=2)`
## 3. 总损失函数
### 最终公式
$$L_{\text{total}} = L_{\text{det}} + L_{\text{desc}}$$
### 设计特点
- **无权重平衡**两个损失直接相加依靠网络自动学习平衡
- **端到端训练**检测和描述任务联合优化
- **多任务学习**同时学习几何变换不变性和特征描述能力
## 4. 训练策略
### 损失优化
- **优化器**Adam优化器
- **学习率**初始1e-3使用ReduceLROnPlateau调度
- **梯度裁剪**max_norm=1.0,防止梯度爆炸
### 验证指标
- **检测损失**验证集上的检测任务性能
- **描述子损失**验证集上的特征匹配性能
- **总损失**两个损失的加权和
## 5. 实现代码位置
- **检测损失**`train.py::compute_detection_loss()`第126-138行
- **描述子损失**`train.py::compute_description_loss()`第140-178行
- **总损失**`train.py::main()`第242行
## 6. 数学符号对照表
| 符号 | 含义 | 维度 |
|------|------|------|
| det_original | 原始图像检测图 | (B, 1, H, W) |
| det_rotated | 变换图像检测图 | (B, 1, H, W) |
| desc_original | 原始图像描述子 | (B, 128, H, W) |
| desc_rotated | 变换图像描述子 | (B, 128, H, W) |
| H | 几何变换矩阵 | (B, 3, 3) |
| margin | Triplet Loss边界 | 标量 |
| B | 批次大小 | 标量 |
| C | 特征维度 | 128 |
| H, W | 特征图高宽 | 标量 |