2025-03-25 01:42:26 +08:00
|
|
|
|
import faiss
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
2025-03-26 22:33:36 +08:00
|
|
|
|
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
|
|
|
|
|
|
from data_units import layout_to_tensor, tile_layout
|
2025-03-25 01:42:26 +08:00
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
# 配置参数(需根据实际调整)
|
|
|
|
|
|
block_size = 64 # 分块尺寸
|
|
|
|
|
|
target_module_path = "target.png"
|
|
|
|
|
|
large_layout_path = "layout_large.png"
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
model = RotationInvariantNet().to(device)
|
|
|
|
|
|
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
# 预处理目标模块与大版图
|
|
|
|
|
|
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
|
|
|
|
|
|
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
|
|
|
|
|
|
|
|
|
|
|
|
large_layout = layout_to_tensor(large_layout_path)
|
|
|
|
|
|
tiles = tile_layout(large_layout)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建特征索引(使用Faiss加速)
|
|
|
|
|
|
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
|
|
|
|
|
features_db = []
|
2025-03-26 22:33:36 +08:00
|
|
|
|
for (x, y, tile) in tiles:
|
2025-03-25 01:42:26 +08:00
|
|
|
|
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
|
|
|
|
|
features_db.append(feat)
|
|
|
|
|
|
index.add(np.stack(features_db))
|
|
|
|
|
|
|
|
|
|
|
|
# 检索相似区域
|
2025-03-26 22:33:36 +08:00
|
|
|
|
D, I = index.search(target_feat[np.newaxis, :], k=10)
|
2025-03-25 01:42:26 +08:00
|
|
|
|
|
2025-03-26 22:33:36 +08:00
|
|
|
|
for idx in I[0]:
|
|
|
|
|
|
x, y, _ = tiles[idx]
|
|
|
|
|
|
|
|
|
|
|
|
# 计算最佳匹配角度的显式计算
|
|
|
|
|
|
min_angle, min_dist = 90, float('inf')
|
|
|
|
|
|
target_vec = target_feat
|
|
|
|
|
|
feat = features_db[idx]
|
|
|
|
|
|
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
|
|
|
|
|
|
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
|
|
|
|
|
|
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
|
|
|
|
|
|
if dist < min_dist:
|
|
|
|
|
|
min_dist, min_angle = dist, a * 90
|
|
|
|
|
|
|
|
|
|
|
|
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
|
2025-03-25 01:42:26 +08:00
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|