一个目标实时检测的模型

This commit is contained in:
jiao77
2025-03-31 14:49:04 +08:00
parent 956805997e
commit a5c63ad0de
6 changed files with 188 additions and 151 deletions

View File

@@ -1,51 +1,70 @@
import faiss
import numpy as np
import torch
from models.rotation_cnn import RotationInvariantNet, get_rotational_features
from data_units import layout_to_tensor, tile_layout
import cv2
import numpy as np
from models.superpoint_custom import SuperPointCustom
def main():
# 配置参数(需根据实际调整)
block_size = 64 # 分块尺寸
target_module_path = "target.png"
large_layout_path = "layout_large.png"
def get_keypoints_from_heatmap(semi, threshold=0.015):
semi = semi.squeeze().cpu().numpy() # [65, H/8, W/8]
prob = cv2.softmax(semi, axis=0)[:-1] # [64, H/8, W/8]
prob = prob.reshape(8, 8, semi.shape[1], semi.shape[2])
prob = prob.transpose(0, 2, 1, 3).reshape(8*semi.shape[1], 8*semi.shape[2]) # [H, W]
keypoints = []
for y in range(prob.shape[0]):
for x in range(prob.shape[1]):
if prob[y, x] > threshold:
keypoints.append(cv2.KeyPoint(x, y, 1))
return keypoints
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotationInvariantNet().to(device)
model.load_state_dict(torch.load("rotation_cnn.pth"))
def get_descriptors_from_map(desc, keypoints):
desc = desc.squeeze().cpu().numpy() # [256, H/8, W/8]
descriptors = []
scale = 8
for kp in keypoints:
x, y = int(kp.pt[0] / scale), int(kp.pt[1] / scale)
if 0 <= x < desc.shape[2] and 0 <= y < desc.shape[1]:
descriptors.append(desc[:, y, x])
return np.array(descriptors)
def match_and_estimate(layout_path, module_path, model_path, num_channels, device='cuda'):
model = SuperPointCustom(num_channels=num_channels).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# 预处理目标模块与大版图
target_tensor = layout_to_tensor(target_module_path, (block_size, block_size))
target_feat = get_rotational_features(model, torch.tensor(target_tensor).to(device))
layout = np.load(layout_path) # [C, H, W]
module = np.load(module_path) # [C, H, W]
layout_tensor = torch.from_numpy(layout).float().unsqueeze(0).to(device)
module_tensor = torch.from_numpy(module).float().unsqueeze(0).to(device)
large_layout = layout_to_tensor(large_layout_path)
tiles = tile_layout(large_layout)
with torch.no_grad():
semi_layout, desc_layout = model(layout_tensor)
semi_module, desc_module = model(module_tensor)
kp_layout = get_keypoints_from_heatmap(semi_layout)
desc_layout = get_descriptors_from_map(desc_layout, kp_layout)
kp_module = get_keypoints_from_heatmap(semi_module)
desc_module = get_descriptors_from_map(desc_module, kp_module)
# 构建特征索引使用Faiss加速
index = faiss.IndexFlatL2(64) # 特征维度由模型决定
features_db = []
for (x, y, tile) in tiles:
feat = get_rotational_features(model, torch.tensor(tile).to(device))
features_db.append(feat)
index.add(np.stack(features_db))
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
matches = bf.match(desc_module, desc_layout)
matches = sorted(matches, key=lambda x: x.distance)
# 检索相似区域
D, I = index.search(target_feat[np.newaxis, :], k=10)
src_pts = np.float32([kp_module[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp_layout[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
for idx in I[0]:
x, y, _ = tiles[idx]
h, w = module.shape[1], module.shape[2]
corners = np.float32([[0, 0], [w, 0], [w, h], [0, h]]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, H)
x_min, y_min = np.min(transformed_corners, axis=0).ravel().astype(int)
x_max, y_max = np.max(transformed_corners, axis=0).ravel().astype(int)
theta = np.arctan2(H[1, 0], H[0, 0]) * 180 / np.pi
# 计算最佳匹配角度的显式计算
min_angle, min_dist = 90, float('inf')
target_vec = target_feat
feat = features_db[idx]
for a in [0, 1, 2, 3]: # 代表0°、90°、180°、270°
rotated_feat = np.rot90(feat.reshape(block_size, block_size), k=a)
dist = np.linalg.norm(target_vec - rotated_feat.flatten())
if dist < min_dist:
min_dist, min_angle = dist, a * 90
print(f"Matched region: [{x_min}, {y_min}, {x_max}, {y_max}], Rotation: {theta:.2f} degrees")
return x_min, y_min, x_max, y_max, theta
print(f"坐标({x},{y}), 最佳旋转方向{min_angle}度,距离: {min_dist}")
if __name__ == "__main__":
main()
layout_path = "data/large_layout.npy"
module_path = "data/small_module.npy"
model_path = "superpoint_custom_model.pth"
num_channels = 3 # 替换为实际通道数
match_and_estimate(layout_path, module_path, model_path, num_channels)