Skip to content

Commit 163193e

Browse files
authored
Set environment variables for tpu7x (#9586)
1 parent 5522c69 commit 163193e

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

torch_xla/_internal/tpu.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
# Testing only
5353
'0x0056',
5454
'0x0062',
55+
# TPU 7x
56+
'0x0076'
5557
]
5658

5759

@@ -188,7 +190,10 @@ def version() -> int:
188190
except requests.HTTPError as e:
189191
raise EnvironmentError('Failed to get TPU metadata') from e
190192

191-
match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE])
193+
match = re.match(r'^(?:v|tpu)(\d)([A-Za-z]?){7}-(\d+)$',
194+
env[xenv.ACCELERATOR_TYPE])
195+
if not match:
196+
raise EnvironmentError('Failed to parse TPU version from metadata')
192197
return int(match.groups()[0])
193198

194199

@@ -254,7 +259,8 @@ def configure_topology(local_rank: int,
254259
tpu_env = get_tpu_env()
255260

256261
accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE]
257-
if version() >= 4:
262+
tpu_version = version()
263+
if tpu_version >= 4:
258264
# Process bounds with 4 chips per process
259265
default_process_bounds = MeshShape.from_string(
260266
tpu_env[xenv.TPU_PROCESS_BOUNDS])
@@ -270,8 +276,11 @@ def configure_topology(local_rank: int,
270276
process_bounds = default_process_bounds * chips_per_process
271277

272278
os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1')
273-
os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS,
274-
','.join(str(dim) for dim in process_bounds))
279+
process_bounds_str = ','.join(str(dim) for dim in process_bounds)
280+
if tpu_version == 7:
281+
process_bounds_str += ',2'
282+
283+
os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, process_bounds_str)
275284

276285
# Assume each TPU has the same number of local processes with the same ports
277286
worker_id = int(tpu_env[xenv.WORKER_ID])

0 commit comments

Comments
 (0)