diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 8b616ca9161..ddbaf4b6cbe 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -250,7 +250,8 @@ def roi_align( rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): if ( - not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)) + not _has_ops() + or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps or input.is_xpu)) ) and is_compile_supported(input.device.type): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops()