一个目标实时检测的模型
This commit is contained in:
97
inference.py
97
inference.py
@@ -1,51 +1,70 @@
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
|
||||
from data_units import layout_to_tensor, tile_layout
|
||||
import cv2
|
||||
import numpy as np
|
||||
from models.superpoint_custom import SuperPointCustom
|
||||
|
||||
def main():
|
||||
# 配置参数(需根据实际调整)
|
||||
block_size = 64 # 分块尺寸
|
||||
target_module_path = "target.png"
|
||||
large_layout_path = "layout_large.png"
|
||||
def get_keypoints_from_heatmap(semi, threshold=0.015):
|
||||
semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8]
|
||||
prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8]
|
||||
prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2])
|
||||
prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W]
|
||||
keypoints = []
|
||||
for y in range(prob.shape[0]):
|
||||
for x in range(prob.shape[1]):
|
||||
if prob[y, x] > threshold:
|
||||
keypoints.append(cv2.KeyPoint(x, y, 1))
|
||||
return keypoints
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = RotationInvariantNet().to(device)
|
||||
model.load_state_dict(torch.load("rotation_cnn.pth"))
|
||||
def get_descriptors_from_map(desc, keypoints):
|
||||
desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8]
|
||||
descriptors = []
|
||||
scale = 8
|
||||
for kp in keypoints:
|
||||
x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale)
|
||||
if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]:
|
||||
descriptors.append(desc[:, y, x])
|
||||
return np.array(descriptors)
|
||||
|
||||
def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'):
|
||||
model = SuperPointCustom(num_channels=num_channels).to(device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
# 预处理目标模块与大版图
|
||||
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
|
||||
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
|
||||
layout = np.load(layout_path) # [C, H, W]
|
||||
module = np.load(module_path) # [C, H, W]
|
||||
layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device)
|
||||
module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device)
|
||||
|
||||
large_layout = layout_to_tensor(large_layout_path)
|
||||
tiles = tile_layout(large_layout)
|
||||
with torch.no_grad():
|
||||
semi_layout, desc_layout = model(layout_tensor)
|
||||
semi_module, desc_module = model(module_tensor)
|
||||
|
||||
kp_layout = get_keypoints_from_heatmap(semi_layout)
|
||||
desc_layout = get_descriptors_from_map(desc_layout, kp_layout)
|
||||
kp_module = get_keypoints_from_heatmap(semi_module)
|
||||
desc_module = get_descriptors_from_map(desc_module, kp_module)
|
||||
|
||||
# 构建特征索引(使用Faiss加速)
|
||||
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
|
||||
features_db = []
|
||||
for (x, y, tile) in tiles:
|
||||
feat = get_rotational_features(model, torch.tensor(tile).to(device))
|
||||
features_db.append(feat)
|
||||
index.add(np.stack(features_db))
|
||||
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
|
||||
matches = bf.match(desc_module, desc_layout)
|
||||
matches = sorted(matches, key=lambda x: x.distance)
|
||||
|
||||
# 检索相似区域
|
||||
D, I = index.search(target_feat[np.newaxis, :], k=10)
|
||||
src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
||||
|
||||
for idx in I[0]:
|
||||
x, y, _ = tiles[idx]
|
||||
h, w = module.shape[1], module.shape[2]
|
||||
corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
|
||||
transformed_corners = cv2.perspectiveTransform(corners, H)
|
||||
x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int)
|
||||
x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int)
|
||||
theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi
|
||||
|
||||
# 计算最佳匹配角度的显式计算
|
||||
min_angle, min_dist = 90, float('inf')
|
||||
target_vec = target_feat
|
||||
feat = features_db[idx]
|
||||
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
|
||||
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
|
||||
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
|
||||
if dist < min_dist:
|
||||
min_dist, min_angle = dist, a * 90
|
||||
print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees")
|
||||
return x_min, y_min, x_max, y_max, theta
|
||||
|
||||
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
layout_path = "data/large_layout.npy"
|
||||
module_path = "data/small_module.npy"
|
||||
model_path = "superpoint_custom_model.pth"
|
||||
num_channels = 3 # 替换为实际通道数
|
||||
match_and_estimate(layout_path, module_path, model_path, num_channels)
|
||||
Reference in New Issue
Block a user