2025-03-25 01:42:26 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
class RotationInvariantNet(nn.Module):
|
|
|
|
|
|
"""轻量级旋转不变特征提取网络"""
|
2025-03-26 22:33:36 +08:00
|
|
|
|
def __init__(self, input_channels=1):
|
2025-03-25 01:42:26 +08:00
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.cnn = nn.Sequential(
|
|
|
|
|
|
# 基础卷积层
|
|
|
|
|
|
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
|
nn.MaxPool2d(2), # 下采样
|
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
|
|
|
|
nn.ReLU(),
|
2025-03-26 22:33:36 +08:00
|
|
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
|
|
|
|
|
|
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4
|
|
|
|
|
|
nn.Flatten(), # 展平为一维向量
|
|
|
|
|
|
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
|
2025-03-25 01:42:26 +08:00
|
|
|
|
)
|
2025-03-26 22:33:36 +08:00
|
|
|
|
|
2025-03-25 01:42:26 +08:00
|
|
|
|
def forward(self, x):
|
2025-03-26 22:33:36 +08:00
|
|
|
|
return self.cnn(x)
|
2025-03-25 01:42:26 +08:00
|
|
|
|
def get_rotational_features(model, input_image):
|
|
|
|
|
|
"""计算输入图像所有旋转角度的特征平均值"""
|
|
|
|
|
|
rotations = [0, 90, 180, 270]
|
|
|
|
|
|
features_list = []
|
|
|
|
|
|
for angle in rotations:
|
|
|
|
|
|
rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3])
|
|
|
|
|
|
feat = model(rotated_img.unsqueeze(0))
|
|
|
|
|
|
features_list.append(feat)
|
|
|
|
|
|
return torch.mean(torch.stack(features_list), dim=0).detach().numpy()
|