import torch import torch.nn as nn class RotationInvariantNet(nn.Module): """轻量级旋转不变特征提取网络""" def __init__(self, input_channels=1, num_features=64): super().__init__() self.cnn = nn.Sequential( # 基础卷积层 nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), # 下采样 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)) # 全局池化获取全局特征 ) def forward(self, x): features = self.cnn(x) return torch.flatten(features, 1) # 展平为特征向量 def get_rotational_features(model, input_image): """计算输入图像所有旋转角度的特征平均值""" rotations = [0, 90, 180, 270] features_list = [] for angle in rotations: rotated_img = torch.rot90(input_image, k=angle//90, dims=[2,3]) feat = model(rotated_img.unsqueeze(0)) features_list.append(feat) return torch.mean(torch.stack(features_list), dim=0).detach().numpy()