增加角度显示计算,优化CNN架构

This commit is contained in:
jiao77
2025-03-26 22:33:36 +08:00
parent 88ca482d5d
commit 79cec17a50
6 changed files with 27 additions and 23 deletions

View File

@@ -3,7 +3,7 @@ import torch.nn as nn
class RotationInvariantNet(nn.Module):
"""轻量级旋转不变特征提取网络"""
def __init__(self, input_channels=1, num_features=64):
def __init__(self, input_channels=1):
super().__init__()
self.cnn = nn.Sequential(
# 基础卷积层
@@ -12,13 +12,14 @@ class RotationInvariantNet(nn.Module):
nn.MaxPool2d(2), # 下采样
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1)) # 全局池化获取全局特征
nn.Conv2d(64, 64, kernel_size=3, stride=2), # 更大感受野
nn.AdaptiveAvgPool2d((4,4)), # 全局池化获取全局特征调整输出尺寸为4x4
nn.Flatten(), # 展平为一维向量
nn.Linear(64*16, 128) # 增加全连接层以降低维度到128
)
def forward(self, x):
features = self.cnn(x)
return torch.flatten(features, 1) # 展平为特征向量
return self.cnn(x)
def get_rotational_features(model, input_image):
"""计算输入图像所有旋转角度的特征平均值"""
rotations = [0, 90, 180, 270]