一个目标实时检测的模型
This commit is contained in:
@@ -1,31 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class RotationInvariantNet(nn.Module):
|
||||
"""轻量级旋转不变特征提取网络"""
|
||||
def __init__(self, input_channels=1):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
# 基础卷积层
|
||||
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2), # 下采样
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
|
||||
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4
|
||||
nn.Flatten(), # 展平为一维向量
|
||||
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
def get_rotational_features(model, input_image):
|
||||
"""计算输入图像所有旋转角度的特征平均值"""
|
||||
rotations = [0, 90, 180, 270]
|
||||
features_list = []
|
||||
for angle in rotations:
|
||||
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
|
||||
feat = model(rotated_img.unsqueeze(0))
|
||||
features_list.append(feat)
|
||||
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()
|
||||
47
models/superpoint_custom.py
Normal file
47
models/superpoint_custom.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class SuperPointCustom(nn.Module):
|
||||
def __init__(self, num_channels=3): # num_channels 为版图通道数
|
||||
super(SuperPointCustom, self).__init__()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
|
||||
# 编码器
|
||||
self.conv1a = nn.Conv2d(num_channels, c1, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
||||
# 检测头
|
||||
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
||||
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) # 65 = 8x8 + dustbin
|
||||
# 描述符头
|
||||
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
||||
self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
# 编码器
|
||||
x = self.relu(self.conv1a(x))
|
||||
x = self.relu(self.conv1b(x))
|
||||
x = self.pool(x)
|
||||
x = self.relu(self.conv2a(x))
|
||||
x = self.relu(self.conv2b(x))
|
||||
x = self.pool(x)
|
||||
x = self.relu(self.conv3a(x))
|
||||
x = self.relu(self.conv3b(x))
|
||||
x = self.pool(x)
|
||||
x = self.relu(self.conv4a(x))
|
||||
x = self.relu(self.conv4b(x))
|
||||
# 检测头
|
||||
cPa = self.relu(self.convPa(x))
|
||||
semi = self.convPb(cPa) # [B, 65, H/8, W/8]
|
||||
# 描述符头
|
||||
cDa = self.relu(self.convDa(x))
|
||||
desc = self.convDb(cDa) # [B, 256, H/8, W/8]
|
||||
desc = F.normalize(desc, p=2, dim=1) # L2归一化
|
||||
return semi, desc
|
||||
Reference in New Issue
Block a user