add some change twice
This commit is contained in:
138
losses.py
Normal file
138
losses.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Loss utilities for RoRD training."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def _augment_homography_matrix(h_2x3: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Append the third row [0, 0, 1] to build a full 3x3 homography."""
|
||||||
|
if h_2x3.dim() != 3 or h_2x3.size(1) != 2 or h_2x3.size(2) != 3:
|
||||||
|
raise ValueError("Expected homography with shape (B, 2, 3)")
|
||||||
|
|
||||||
|
batch_size = h_2x3.size(0)
|
||||||
|
device = h_2x3.device
|
||||||
|
bottom_row = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=h_2x3.dtype)
|
||||||
|
bottom_row = bottom_row.view(1, 1, 3).expand(batch_size, -1, -1)
|
||||||
|
return torch.cat([h_2x3, bottom_row], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_feature_map(feature_map: torch.Tensor, h_inv: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Warp feature map according to inverse homography."""
|
||||||
|
return F.grid_sample(
|
||||||
|
feature_map,
|
||||||
|
F.affine_grid(h_inv, feature_map.size(), align_corners=False),
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_detection_loss(
|
||||||
|
det_original: torch.Tensor,
|
||||||
|
det_rotated: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Binary cross-entropy + smooth L1 detection loss."""
|
||||||
|
h_full = _augment_homography_matrix(h)
|
||||||
|
h_inv = torch.inverse(h_full)[:, :2, :]
|
||||||
|
warped_det = warp_feature_map(det_rotated, h_inv)
|
||||||
|
|
||||||
|
bce_loss = F.binary_cross_entropy(det_original, warped_det)
|
||||||
|
smooth_l1_loss = F.smooth_l1_loss(det_original, warped_det)
|
||||||
|
return bce_loss + 0.1 * smooth_l1_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_description_loss(
|
||||||
|
desc_original: torch.Tensor,
|
||||||
|
desc_rotated: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
margin: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Triplet-style descriptor loss with Manhattan-aware sampling."""
|
||||||
|
batch_size, channels, height, width = desc_original.size()
|
||||||
|
num_samples = 200
|
||||||
|
|
||||||
|
grid_side = int(math.sqrt(num_samples))
|
||||||
|
h_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
|
||||||
|
w_coords = torch.linspace(-1, 1, grid_side, device=desc_original.device)
|
||||||
|
|
||||||
|
manhattan_h = torch.cat([h_coords, torch.zeros_like(h_coords)])
|
||||||
|
manhattan_w = torch.cat([torch.zeros_like(w_coords), w_coords])
|
||||||
|
manhattan_coords = torch.stack([manhattan_h, manhattan_w], dim=1)
|
||||||
|
manhattan_coords = manhattan_coords.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
anchor = F.grid_sample(
|
||||||
|
desc_original,
|
||||||
|
manhattan_coords.unsqueeze(1),
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
|
coords_hom = torch.cat(
|
||||||
|
[manhattan_coords, torch.ones(batch_size, manhattan_coords.size(1), 1, device=desc_original.device)],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
h_full = _augment_homography_matrix(h)
|
||||||
|
h_inv = torch.inverse(h_full)
|
||||||
|
coords_transformed = (coords_hom @ h_inv.transpose(1, 2))[:, :, :2]
|
||||||
|
|
||||||
|
positive = F.grid_sample(
|
||||||
|
desc_rotated,
|
||||||
|
coords_transformed.unsqueeze(1),
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
|
negative_list = []
|
||||||
|
if manhattan_coords.size(1) > 0:
|
||||||
|
angles = [0, 90, 180, 270]
|
||||||
|
for angle in angles:
|
||||||
|
if angle == 0:
|
||||||
|
continue
|
||||||
|
theta = torch.tensor(angle * math.pi / 180.0, device=desc_original.device)
|
||||||
|
cos_t = torch.cos(theta)
|
||||||
|
sin_t = torch.sin(theta)
|
||||||
|
rot = torch.stack(
|
||||||
|
[
|
||||||
|
torch.stack([cos_t, -sin_t]),
|
||||||
|
torch.stack([sin_t, cos_t]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rotated_coords = manhattan_coords @ rot.T
|
||||||
|
negative_list.append(rotated_coords)
|
||||||
|
|
||||||
|
if negative_list:
|
||||||
|
neg_coords = torch.stack(negative_list, dim=1).reshape(batch_size, -1, 2)
|
||||||
|
negative_candidates = F.grid_sample(
|
||||||
|
desc_rotated,
|
||||||
|
neg_coords.unsqueeze(1),
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(2).transpose(1, 2)
|
||||||
|
|
||||||
|
anchor_expanded = anchor.unsqueeze(2).expand(-1, -1, negative_candidates.size(1), -1)
|
||||||
|
negative_expanded = negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1)
|
||||||
|
manhattan_dist = torch.sum(torch.abs(anchor_expanded - negative_expanded), dim=3)
|
||||||
|
|
||||||
|
k = max(anchor.size(1) // 2, 1)
|
||||||
|
hard_indices = torch.topk(manhattan_dist, k=k, largest=False)[1]
|
||||||
|
idx_expand = hard_indices.unsqueeze(-1).expand(-1, -1, -1, negative_candidates.size(2))
|
||||||
|
negative = torch.gather(negative_candidates.unsqueeze(1).expand(-1, anchor.size(1), -1, -1), 2, idx_expand)
|
||||||
|
negative = negative.mean(dim=2)
|
||||||
|
else:
|
||||||
|
negative = torch.zeros_like(anchor)
|
||||||
|
|
||||||
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=1, reduction='mean')
|
||||||
|
geometric_triplet = triplet_loss(anchor, positive, negative)
|
||||||
|
|
||||||
|
manhattan_loss = 0.0
|
||||||
|
for i in range(anchor.size(1)):
|
||||||
|
anchor_norm = F.normalize(anchor[:, i], p=2, dim=1)
|
||||||
|
positive_norm = F.normalize(positive[:, i], p=2, dim=1)
|
||||||
|
cos_sim = torch.sum(anchor_norm * positive_norm, dim=1)
|
||||||
|
manhattan_loss += torch.mean(1 - cos_sim)
|
||||||
|
|
||||||
|
manhattan_loss = manhattan_loss / max(anchor.size(1), 1)
|
||||||
|
sparsity_loss = torch.mean(torch.abs(anchor)) + torch.mean(torch.abs(positive))
|
||||||
|
binary_loss = torch.mean(torch.abs(torch.sign(anchor) - torch.sign(positive)))
|
||||||
|
|
||||||
|
return geometric_triplet + 0.1 * manhattan_loss + 0.01 * sparsity_loss + 0.05 * binary_loss
|
||||||
23
utils/config_loader.py
Normal file
23
utils/config_loader.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Configuration loading utilities using OmegaConf."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: Union[str, Path]) -> DictConfig:
|
||||||
|
"""Load a YAML configuration file into a DictConfig."""
|
||||||
|
path = Path(config_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {path}")
|
||||||
|
return OmegaConf.load(path)
|
||||||
|
|
||||||
|
|
||||||
|
def to_absolute_path(path_str: str, base_dir: Union[str, Path]) -> Path:
|
||||||
|
"""Resolve a possibly relative path against the configuration file directory."""
|
||||||
|
path = Path(path_str).expanduser()
|
||||||
|
if path.is_absolute():
|
||||||
|
return path.resolve()
|
||||||
|
return (Path(base_dir) / path).resolve()
|
||||||
Reference in New Issue
Block a user