import torch import torch.nn as nn class RotationInvariantNet(nn.Module): """轻量级旋转不变特征提取网络""" def __init__(self, input_channels=1): super().__init__() self.cnn = nn.Sequential( # 基础卷积层 nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 下采样 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野 nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征,调整输出尺寸为4x4 nn.Flatten(), # 展平为一维向量 nn.Linear(64*16, 128) # 增加全连接层以降低维度到128 ) def forward(self, x): return self.cnn(x) def get_rotational_features(model, input_image): """计算输入图像所有旋转角度的特征平均值""" rotations = [0, 90, 180, 270] features_list = [] for angle in rotations: rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3]) feat = model(rotated_img.unsqueeze(0)) features_list.append(feat) return torch.mean(torch.stack(features_list), dim=0).detach().numpy()