initial commit
This commit is contained in:
48
inference.py
Normal file
48
inference.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
# 导入 models.rotation_cnn 模块中的 RotationInvariantNet 类
|
||||
from models.rotation_cnn import RotationInvariantNet
|
||||
|
||||
from models.rotation_cnn import get_rotational_features
|
||||
|
||||
# 导入 data_utils 中的 layout_to_tensor 函数(假设该函数存在)
|
||||
from data_units import layout_to_tensor # 如果 data_utils.py 存在此函数
|
||||
|
||||
from data_units import tile_layout
|
||||
|
||||
def main():
|
||||
# 配置参数(需根据实际调整)
|
||||
block_size = 64 # 分块尺寸
|
||||
target_module_path = "target.png"
|
||||
large_layout_path = "layout_large.png"
|
||||
|
||||
# 加载模型
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = RotationInvariantNet().to(device)
|
||||
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
||||
model.eval()
|
||||
|
||||
# 预处理目标模块与大版图
|
||||
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
|
||||
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
|
||||
|
||||
large_layout = layout_to_tensor(large_layout_path)
|
||||
tiles = tile_layout(large_layout)
|
||||
|
||||
# 构建特征索引(使用Faiss加速)
|
||||
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
||||
features_db = []
|
||||
for (x,y,tile) in tiles:
|
||||
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
||||
features_db.append(feat)
|
||||
index.add(np.stack(features_db))
|
||||
|
||||
# 检索相似区域
|
||||
D, I = index.search(target_feat[np.newaxis,:], k=10)
|
||||
for idx in I[0]:
|
||||
x,y,_ = tiles[idx]
|
||||
print(f"匹配区域坐标: ({x}, {y}), 相似度: {D[0][idx]}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user