update rord.py
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -8,3 +8,6 @@ wheels/
|
|||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
|
|
||||||
|
# Model Path File
|
||||||
|
model.path/
|
||||||
@@ -16,7 +16,7 @@ class RoRD(nn.Module):
|
|||||||
vgg16_features = models.vgg16(pretrained=False).features
|
vgg16_features = models.vgg16(pretrained=False).features
|
||||||
|
|
||||||
# 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致
|
# 共享骨干网络 - 只使用到 relu4_3,确保特征图尺寸一致
|
||||||
self.backbone = vgg16_features[:23] # 到 relu4_3
|
self.backbone = nn.Sequential(*list(vgg16_features.children())[:23])
|
||||||
|
|
||||||
# 检测头
|
# 检测头
|
||||||
self.detection_head = nn.Sequential(
|
self.detection_head = nn.Sequential(
|
||||||
|
|||||||
Reference in New Issue
Block a user