add some change twice

This commit is contained in:
Jiao77
2025-09-25 20:30:31 +08:00
parent 09f513686d
commit 8c6c5592b6
2 changed files with 161 additions and 0 deletions

138
losses.py Normal file
View File

@@ -0,0 +1,138 @@
"""Loss utilities for RoRD training."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor:
"""Append the third row [0, 0, 1] to build a full 3x3 homography."""
if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3:
raise ValueError("Expected homography with shape (B, 2, 3)")
batch_size = h_2x3.size(0)
device = h_2x3.device
bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype)
bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1)
return torch.cat([h_2x3, bottom_row], dim=1)
def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor:
"""Warp feature map according to inverse homography."""
return F.grid_sample(
feature_map,
F.affine_grid(h_inv, feature_map.size(), align_corners=False),
align_corners=False,
)
def compute_detection_loss(
det_original: torch.Tensor,
det_rotated: torch.Tensor,
h: torch.Tensor,
) -> torch.Tensor:
"""Binary cross-entropy + smooth L1 detection loss."""
h_full = _augment_homography_matrix(h)
h_inv = torch.inverse(h_full)[:, :2, :]
warped_det = warp_feature_map(det_rotated, h_inv)
bce_loss = F.binary_cross_entropy(det_original, warped_det)
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det)
return bce_loss + 0.1 * smooth_l1_loss
def compute_description_loss(
desc_original: torch.Tensor,
desc_rotated: torch.Tensor,
h: torch.Tensor,
margin: float = 1.0,
) -> torch.Tensor:
"""Triplet-style descriptor loss with Manhattan-aware sampling."""
batch_size, channels, height, width = desc_original.size()
num_samples = 200
grid_side = int(math.sqrt(num_samples))
h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1)
manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1)
anchor = F.grid_sample(
desc_original,
manhattan_coords.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
coords_hom = torch.cat(
[manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)],
dim=2,
)
h_full = _augment_homography_matrix(h)
h_inv = torch.inverse(h_full)
coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2]
positive = F.grid_sample(
desc_rotated,
coords_transformed.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
negative_list = []
if manhattan_coords.size(1) > 0:
angles = [0, 90, 180, 270]
for angle in angles:
if angle == 0:
continue
theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device)
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
rot = torch.stack(
[
torch.stack([cos_t, -sin_t]),
torch.stack([sin_t, cos_t]),
]
)
rotated_coords = manhattan_coords @ rot.T
negative_list.append(rotated_coords)
if negative_list:
neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2)
negative_candidates = F.grid_sample(
desc_rotated,
neg_coords.unsqueeze(1),
align_corners=False,
).squeeze(2).transpose(1, 2)
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
k = max(anchor.size(1) // 2, 1)
hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1]
idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2))
negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand)
negative = negative.mean(dim=2)
else:
negative = torch.zeros_like(anchor)
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean')
geometric_triplet = triplet_loss(anchor, positive, negative)
manhattan_loss = 0.0
for i in range(anchor.size(1)):
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
manhattan_loss += torch.mean(1 - cos_sim)
manhattan_loss = manhattan_loss / max(anchor.size(1), 1)
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss