增加角度显示计算,优化CNN架构

This commit is contained in:
jiao77
2025-03-26 22:33:36 +08:00
parent 88ca482d5d
commit 79cec17a50
6 changed files with 27 additions and 23 deletions

View File

@@ -1,15 +1,8 @@
import faiss
import numpy as np
import torch
# 导入 models.rotation_cnn 模块中的 RotationInvariantNet
from models.rotation_cnn import RotationInvariantNet
from models.rotation_cnn import get_rotational_features
# 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在)
from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数
from data_units import tile_layout
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
from data_units import layout_to_tensor, tile_layout
def main():
# 配置参数(需根据实际调整)
@@ -17,7 +10,6 @@ def main():
target_module_path = "target.png"
large_layout_path = "layout_large.png"
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotationInvariantNet().to(device)
model.load_state_dict(torch.load("rotation_cnn.pth"))
@@ -33,16 +25,27 @@ def main():
# 构建特征索引使用Faiss加速
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
features_db = []
for (x,y,tile) in tiles:
for (x, y, tile) in tiles:
feat = get_rotational_features(model, torch.tensor(tile).to(device))
features_db.append(feat)
index.add(np.stack(features_db))
# 检索相似区域
D, I = index.search(target_feat[np.newaxis,:], k=10)
for idx in I[0]:
x,y,_ = tiles[idx]
print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}")
D, I = index.search(target_feat[np.newaxis, :], k=10)
for idx in I[0]:
x, y, _ = tiles[idx]
# 计算最佳匹配角度的显式计算
min_angle, min_dist = 90, float('inf')
target_vec = target_feat
feat = features_db[idx]
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
if dist < min_dist:
min_dist, min_angle = dist, a * 90
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
if __name__ == "__main__":
main()