|
16 | 16 |
|
17 | 17 | from concurrent import futures |
18 | 18 | import datetime |
| 19 | +import functools |
19 | 20 | import os |
20 | 21 | import re |
21 | 22 | from typing import Optional |
|
24 | 25 | from flax import jax_utils |
25 | 26 | from flax.training import checkpoints |
26 | 27 | import jax |
| 28 | +from jax import sharding |
27 | 29 | from jax.experimental import maps |
28 | 30 | from jax.experimental import multihost_utils |
29 | 31 | from jax.experimental.gda_serialization import serialization as gda_serialization |
@@ -440,6 +442,11 @@ def _save_checkpoint_gda(train_state: train_states.TrainState, |
440 | 442 | checkpoint_step_dir) |
441 | 443 |
|
442 | 444 |
|
| 445 | +@functools.lru_cache() |
| 446 | +def _cached_mesh_pspec_sharding(mesh, pspec): |
| 447 | + return sharding.MeshPspecSharding(mesh, pspec) |
| 448 | + |
| 449 | + |
443 | 450 | def _restore_checkpoint_gda( |
444 | 451 | train_state: Optional[train_states.TrainState], |
445 | 452 | checkpoint_dir: str, |
@@ -500,11 +507,11 @@ def _restore_checkpoint_gda( |
500 | 507 | ] |
501 | 508 | tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths) |
502 | 509 |
|
| 510 | + shardings = [ |
| 511 | + _cached_mesh_pspec_sharding(global_mesh, s) for s in partition_spec_leaves |
| 512 | + ] |
503 | 513 | train_state_gda = gda_serialization.run_deserialization( |
504 | | - [global_mesh] * len(tspecs), |
505 | | - partition_spec_leaves, |
506 | | - tspecs, |
507 | | - global_shapes=global_shapes) |
| 514 | + shardings, tspecs, global_shapes=global_shapes) |
508 | 515 |
|
509 | 516 | restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda) |
510 | 517 | # Barrier across all processes to ensure all restore finish. |
|
0 commit comments