diff --git a/tests/utils/test_device_mesh.py b/tests/utils/test_device_mesh.py index 1c486bdaac..2969e9d06d 100644 --- a/tests/utils/test_device_mesh.py +++ b/tests/utils/test_device_mesh.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest from torchtnt.utils.device_mesh import ( diff --git a/torchtnt/utils/device_mesh.py b/torchtnt/utils/device_mesh.py index 46e224ff23..8f343cffad 100644 --- a/torchtnt/utils/device_mesh.py +++ b/torchtnt/utils/device_mesh.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Optional from torch.distributed.device_mesh import DeviceMesh, init_device_mesh