update loss function.
This commit is contained in:
78
train.py
78
train.py
@@ -138,44 +138,84 @@ def compute_detection_loss(det_original, det_rotated, H):
|
||||
return bce_loss + 0.1 * smooth_l1_loss
|
||||
|
||||
def compute_description_loss(desc_original, desc_rotated, H, margin=1.0):
|
||||
"""改进的描述子损失:使用更有效的采样策略"""
|
||||
"""IC版图专用几何感知描述子损失:编码曼哈顿几何特征"""
|
||||
B, C, H_feat, W_feat = desc_original.size()
|
||||
|
||||
# 增加采样点数量,提高训练稳定性
|
||||
# 曼哈顿几何感知采样:重点采样边缘和角点区域
|
||||
num_samples = 200
|
||||
|
||||
# 使用网格采样而不是随机采样,确保空间分布更均匀
|
||||
# 生成曼哈顿对齐的采样网格(水平和垂直优先)
|
||||
h_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||
w_coords = torch.linspace(-1, 1, int(np.sqrt(num_samples)), device=desc_original.device)
|
||||
h_grid, w_grid = torch.meshgrid(h_coords, w_coords, indexing='ij')
|
||||
coords = torch.stack([h_grid.flatten(), w_grid.flatten()], dim=1).unsqueeze(0).repeat(B, 1, 1)
|
||||
|
||||
# 增加曼哈顿方向的采样密度
|
||||
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
|
||||
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
|
||||
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1).unsqueeze(0).repeat(B, 1, 1)
|
||||
|
||||
# 采样anchor点
|
||||
anchor = F.grid_sample(desc_original, coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||
anchor = F.grid_sample(desc_original, manhattan_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||
|
||||
# 计算对应的正样本点
|
||||
coords_hom = torch.cat([coords, torch.ones(B, coords.size(1), 1, device=coords.device)], dim=2)
|
||||
coords_hom = torch.cat([manhattan_coords, torch.ones(B, manhattan_coords.size(1), 1, device=manhattan_coords.device)], dim=2)
|
||||
M_inv = torch.inverse(torch.cat([H, torch.tensor([0.0, 0.0, 1.0]).view(1, 1, 3).repeat(H.shape[0], 1, 1)], dim=1))
|
||||
coords_transformed = (coords_hom @ M_inv.transpose(1, 2))[:, :, :2]
|
||||
positive = F.grid_sample(desc_rotated, coords_transformed.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||
|
||||
# 使用困难负样本挖掘
|
||||
# IC版图专用负样本策略:考虑重复结构
|
||||
with torch.no_grad():
|
||||
# 计算所有可能的负样本对
|
||||
neg_coords = torch.rand(B, num_samples * 2, 2, device=desc_original.device) * 2 - 1
|
||||
negative_candidates = F.grid_sample(desc_rotated, neg_coords.unsqueeze(1), align_corners=False).squeeze(2).transpose(1, 2)
|
||||
# 1. 几何感知的负样本:曼哈顿变换后的不同区域
|
||||
neg_coords = []
|
||||
for b in range(B):
|
||||
# 生成曼哈顿变换后的坐标(90度旋转等)
|
||||
angles = [0, 90, 180, 270]
|
||||
for angle in angles:
|
||||
if angle != 0:
|
||||
theta = torch.tensor([angle * np.pi / 180])
|
||||
rot_matrix = torch.tensor([
|
||||
[torch.cos(theta), -torch.sin(theta), 0],
|
||||
[torch.sin(theta), torch.cos(theta), 0]
|
||||
])
|
||||
rotated_coords = manhattan_coords[b] @ rot_matrix[:2, :2].T
|
||||
neg_coords.append(rotated_coords)
|
||||
|
||||
# 选择最困难的负样本
|
||||
neg_coords = torch.stack(neg_coords[:B*num_samples//2]).reshape(B, -1, 2)
|
||||
|
||||
# 2. 特征空间困难负样本
|
||||
negative_candidates = F.grid_sample(desc_rotated, neg_coords, align_corners=False).squeeze(2).transpose(1, 2)
|
||||
|
||||
# 3. 曼哈顿距离约束的困难样本选择
|
||||
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
||||
negative_candidates_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||
|
||||
distances = torch.norm(anchor_expanded - negative_candidates_expanded, dim=3)
|
||||
hard_negative_indices = torch.argmin(distances, dim=2)
|
||||
negative = torch.gather(negative_candidates, 1, hard_negative_indices.unsqueeze(2).expand(-1, -1, C))
|
||||
# 使用曼哈顿距离而非欧氏距离
|
||||
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
|
||||
hard_indices = torch.topk(manhattan_dist, k=anchor.size(1)//2, largest=False)[1]
|
||||
negative = torch.gather(negative_candidates, 1, hard_indices)
|
||||
|
||||
# 使用改进的Triplet Loss
|
||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=2, reduction='mean')
|
||||
return triplet_loss(anchor, positive, negative)
|
||||
# IC版图专用的几何一致性损失
|
||||
# 1. 曼哈顿方向一致性损失
|
||||
manhattan_loss = 0
|
||||
for i in range(anchor.size(1)):
|
||||
# 计算水平和垂直方向的几何一致性
|
||||
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
|
||||
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
|
||||
|
||||
# 鼓励描述子对曼哈顿变换不变
|
||||
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
|
||||
manhattan_loss += torch.mean(1 - cos_sim)
|
||||
|
||||
# 2. 稀疏性正则化(IC版图特征稀疏)
|
||||
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
|
||||
|
||||
# 3. 二值化特征距离(处理二值化输入)
|
||||
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
|
||||
|
||||
# 综合损失
|
||||
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean') # 使用L1距离
|
||||
geometric_triplet = triplet_loss(anchor, positive, negative)
|
||||
|
||||
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
|
||||
|
||||
# --- (已修改) 主函数与命令行接口 ---
|
||||
def main(args):
|
||||
|
||||
Reference in New Issue
Block a user