2025-06-08 15:38:56 +08:00
|
|
|
|
# models/rord.py
|
|
|
|
|
|
|
2025-06-07 23:45:32 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from torchvision import models
|
|
|
|
|
|
|
|
|
|
|
|
class RoRD(nn.Module):
|
|
|
|
|
|
def __init__(self):
|
2025-06-08 15:38:56 +08:00
|
|
|
|
"""
|
|
|
|
|
|
修复后的 RoRD 模型。
|
|
|
|
|
|
- 实现了共享骨干网络,以提高计算效率和减少内存占用。
|
2025-06-30 03:27:18 +08:00
|
|
|
|
- 确保检测头和描述子头使用相同尺寸的特征图。
|
2025-06-08 15:38:56 +08:00
|
|
|
|
"""
|
2025-06-07 23:45:32 +08:00
|
|
|
|
super(RoRD, self).__init__()
|
|
|
|
|
|
|
2025-06-09 00:55:28 +08:00
|
|
|
|
vgg16_features = models.vgg16(pretrained=False).features
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
2025-06-30 03:27:18 +08:00
|
|
|
|
# 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致
|
2025-07-20 15:37:42 +08:00
|
|
|
|
self.backbone = nn.Sequential(*list(vgg16_features.children())[:23])
|
2025-06-08 15:38:56 +08:00
|
|
|
|
|
|
|
|
|
|
# 检测头
|
2025-06-07 23:45:32 +08:00
|
|
|
|
self.detection_head = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(inplace=True),
|
2025-06-30 03:27:18 +08:00
|
|
|
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(128, 1, kernel_size=1),
|
2025-06-07 23:45:32 +08:00
|
|
|
|
nn.Sigmoid()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
# 描述子头
|
|
|
|
|
|
self.descriptor_head = nn.Sequential(
|
2025-06-07 23:45:32 +08:00
|
|
|
|
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(inplace=True),
|
2025-06-30 03:27:18 +08:00
|
|
|
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(128, 128, kernel_size=1),
|
2025-06-07 23:45:32 +08:00
|
|
|
|
nn.InstanceNorm2d(128)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
2025-06-08 15:38:56 +08:00
|
|
|
|
# 共享特征提取
|
2025-06-30 03:27:18 +08:00
|
|
|
|
features = self.backbone(x)
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-30 03:27:18 +08:00
|
|
|
|
# 检测器和描述子使用相同的特征图
|
|
|
|
|
|
detection_map = self.detection_head(features)
|
|
|
|
|
|
descriptors = self.descriptor_head(features)
|
2025-06-07 23:45:32 +08:00
|
|
|
|
|
2025-06-08 15:38:56 +08:00
|
|
|
|
return detection_map, descriptors
|