Files
RoRD-Layout-Recognation/models/rord.py

49 lines
1.7 KiB
Python
Raw Normal View History

2025-06-08 15:38:56 +08:00
# models/rord.py
2025-06-07 23:45:32 +08:00
import torch
import torch.nn as nn
from torchvision import models
class RoRD(nn.Module):
def __init__(self):
2025-06-08 15:38:56 +08:00
"""
2025-07-22 23:43:35 +08:00
Repaired RoRD model.
- Implements shared backbone network to improve computational efficiency and reduce memory usage.
- Ensures detection head and descriptor head use feature maps of the same size.
2025-06-08 15:38:56 +08:00
"""
2025-06-07 23:45:32 +08:00
super(RoRD, self).__init__()
2025-06-09 00:55:28 +08:00
vgg16_features = models.vgg16(pretrained=False).features
2025-06-08 15:38:56 +08:00
2025-07-22 23:43:35 +08:00
# Shared backbone network - only uses up to relu4_3 to ensure consistent feature map dimensions
2025-07-20 15:37:42 +08:00
self.backbone = nn.Sequential(*list(vgg16_features.children())[:23])
2025-06-08 15:38:56 +08:00
2025-07-22 23:43:35 +08:00
# Detection head
2025-06-07 23:45:32 +08:00
self.detection_head = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
2025-06-30 03:27:18 +08:00
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, kernel_size=1),
2025-06-07 23:45:32 +08:00
nn.Sigmoid()
)
2025-07-22 23:43:35 +08:00
# Descriptor head
2025-06-08 15:38:56 +08:00
self.descriptor_head = nn.Sequential(
2025-06-07 23:45:32 +08:00
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
2025-06-30 03:27:18 +08:00
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=1),
2025-06-07 23:45:32 +08:00
nn.InstanceNorm2d(128)
)
def forward(self, x):
2025-07-22 23:43:35 +08:00
# Shared feature extraction
2025-06-30 03:27:18 +08:00
features = self.backbone(x)
2025-06-07 23:45:32 +08:00
2025-07-22 23:43:35 +08:00
# Detector and descriptor use the same feature maps
2025-06-30 03:27:18 +08:00
detection_map = self.detection_head(features)
descriptors = self.descriptor_head(features)
2025-06-07 23:45:32 +08:00
2025-06-08 15:38:56 +08:00
return detection_map, descriptors