52
52
# Testing only
53
53
'0x0056' ,
54
54
'0x0062' ,
55
+ # TPU 7x
56
+ '0x0076'
55
57
]
56
58
57
59
@@ -188,7 +190,10 @@ def version() -> int:
188
190
except requests .HTTPError as e :
189
191
raise EnvironmentError ('Failed to get TPU metadata' ) from e
190
192
191
- match = re .match (r'^v(\d)([A-Za-z]?){7}-(\d+)$' , env [xenv .ACCELERATOR_TYPE ])
193
+ match = re .match (r'^(?:v|tpu)(\d)([A-Za-z]?){7}-(\d+)$' ,
194
+ env [xenv .ACCELERATOR_TYPE ])
195
+ if not match :
196
+ raise EnvironmentError ('Failed to parse TPU version from metadata' )
192
197
return int (match .groups ()[0 ])
193
198
194
199
@@ -254,7 +259,8 @@ def configure_topology(local_rank: int,
254
259
tpu_env = get_tpu_env ()
255
260
256
261
accelerator_type = tpu_env [xenv .ACCELERATOR_TYPE ]
257
- if version () >= 4 :
262
+ tpu_version = version ()
263
+ if tpu_version >= 4 :
258
264
# Process bounds with 4 chips per process
259
265
default_process_bounds = MeshShape .from_string (
260
266
tpu_env [xenv .TPU_PROCESS_BOUNDS ])
@@ -270,8 +276,11 @@ def configure_topology(local_rank: int,
270
276
process_bounds = default_process_bounds * chips_per_process
271
277
272
278
os .environ .setdefault (xenv .TPU_CHIPS_PER_PROCESS_BOUNDS , '1,1,1' )
273
- os .environ .setdefault (xenv .TPU_PROCESS_BOUNDS ,
274
- ',' .join (str (dim ) for dim in process_bounds ))
279
+ process_bounds_str = ',' .join (str (dim ) for dim in process_bounds )
280
+ if tpu_version == 7 :
281
+ process_bounds_str += ',2'
282
+
283
+ os .environ .setdefault (xenv .TPU_PROCESS_BOUNDS , process_bounds_str )
275
284
276
285
# Assume each TPU has the same number of local processes with the same ports
277
286
worker_id = int (tpu_env [xenv .WORKER_ID ])
0 commit comments