import faiss import numpy as np import torch # 导入 models.rotation_cnn 模块中的 RotationInvariantNet 类 from models.rotation_cnn import RotationInvariantNet from models.rotation_cnn import get_rotational_features # 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在) from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数 from data_units import tile_layout def main(): # 配置参数(需根据实际调整) block_size = 64 # 分块尺寸 target_module_path = "target.png" large_layout_path = "layout_large.png" # 加载模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RotationInvariantNet().to(device) model.load_state_dict(torch.load("rotation_cnn.pth")) model.eval() # 预处理目标模块与大版图 target_tensor = layout_to_tensor(target_module_path, (block_size, block_size)) target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device)) large_layout = layout_to_tensor(large_layout_path) tiles = tile_layout(large_layout) # 构建特征索引(使用Faiss加速) index = faiss.IndexFlatL2(64) # 特征维度由模型决定 features_db = [] for (x,y,tile) in tiles: feat = get_rotational_features(model, torch.tensor(tile).to(device)) features_db.append(feat) index.add(np.stack(features_db)) # 检索相似区域 D, I = index.search(target_feat[np.newaxis,:], k=10) for idx in I[0]: x,y,_ = tiles[idx] print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}") if __name__ == "__main__": main()