Skip to content

Commit e80063b

Browse files
authored
Optimize decode_outputs for OpenVINO (Megvii-BaseDetection#1535)
Avoid ScatterND ops that will cause error in openvino model optimizer.
1 parent d942239 commit e80063b

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

yolox/models/yolo_head.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ def decode_outputs(self, outputs, dtype):
246246
grids = torch.cat(grids, dim=1).type(dtype)
247247
strides = torch.cat(strides, dim=1).type(dtype)
248248

249-
outputs[..., :2] = (outputs[..., :2] + grids) * strides
250-
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
249+
outputs = torch.cat([(outputs[..., 0:2] + grids) * strides, torch.exp(outputs[..., 2:4]) * strides, outputs[..., 4:]], dim=-1)
251250
return outputs
252251

253252
def get_losses(

0 commit comments

Comments
 (0)