import torch import torch.nn as nn import torch.nn.functional as F class SuperPointCustom(nn.Module): def __init__(self, num_channels=3): # num_channels 为版图通道数 super(SuperPointCustom, self).__init__() self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256 # 编码器 self.conv1a = nn.Conv2d(num_channels, c1, kernel_size=3, stride=1, padding=1) self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) # 检测头 self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) # 65 = 8x8 + dustbin # 描述符头 self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0) def forward(self, x): # 编码器 x = self.relu(self.conv1a(x)) x = self.relu(self.conv1b(x)) x = self.pool(x) x = self.relu(self.conv2a(x)) x = self.relu(self.conv2b(x)) x = self.pool(x) x = self.relu(self.conv3a(x)) x = self.relu(self.conv3b(x)) x = self.pool(x) x = self.relu(self.conv4a(x)) x = self.relu(self.conv4b(x)) # 检测头 cPa = self.relu(self.convPa(x)) semi = self.convPb(cPa) # [B, 65, H/8, W/8] # 描述符头 cDa = self.relu(self.convDa(x)) desc = self.convDb(cDa) # [B, 256, H/8, W/8] desc = F.normalize(desc, p=2, dim=1) # L2归一化 return semi, desc