Files
RoRD-Layout-Recognation/models/rord.py
2025-06-08 00:05:19 +08:00

47 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from torchvision import models
class RoRD(nn.Module):
def __init__(self):
super(RoRD, self).__init__()
# 检测骨干网络VGG-16 直到 relu5_3层 0 到 29
self.backbone_det = models.vgg16(pretrained=True).features[:30]
# 描述骨干网络VGG-16 直到 relu4_3层 0 到 22
self.backbone_desc = models.vgg16(pretrained=True).features[:23]
# 检测头:输出关键点概率图
self.detection_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 1, kernel_size=1),
nn.Sigmoid()
)
# 普通描述子头D2-Net 风格)
self.descriptor_head_vanilla = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.InstanceNorm2d(128)
)
# RoRD 描述子头(旋转鲁棒)
self.descriptor_head_rord = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1),
nn.InstanceNorm2d(128)
)
def forward(self, x):
# 检测分支
features_det = self.backbone_det(x)
detection = self.detection_head(features_det)
# 描述分支
features_desc = self.backbone_desc(x)
desc_vanilla = self.descriptor_head_vanilla(features_desc)
desc_rord = self.descriptor_head_rord(features_desc)
return detection, desc_vanilla, desc_rord