Skip to content

Commit 52c1399

Browse files
authored
DetectMultiBackend() return device update (#6958)
Fixes ONNX validation that returns outputs on CPU.
1 parent c84dd27 commit 52c1399

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

models/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,8 @@ def forward(self, im, augment=False, visualize=False, val=False):
458458
y = (y.astype(np.float32) - zero_point) * scale # re-scale
459459
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
460460

461-
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
461+
if isinstance(y, np.ndarray):
462+
y = torch.tensor(y, device=self.device)
462463
return (y, []) if val else y
463464

464465
def warmup(self, imgsz=(1, 3, 640, 640)):

0 commit comments

Comments
 (0)