Skip to content

Commit fae9a9f

Browse files
laurentescopybara-github
authored andcommitted
Update GDA serializer to make use of the new API.
PiperOrigin-RevId: 479408922
1 parent d31fb9e commit fae9a9f

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

lingvo/jax/checkpoints.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from concurrent import futures
1818
import datetime
19+
import functools
1920
import os
2021
import re
2122
from typing import Optional
@@ -24,6 +25,7 @@
2425
from flax import jax_utils
2526
from flax.training import checkpoints
2627
import jax
28+
from jax import sharding
2729
from jax.experimental import maps
2830
from jax.experimental import multihost_utils
2931
from jax.experimental.gda_serialization import serialization as gda_serialization
@@ -440,6 +442,11 @@ def _save_checkpoint_gda(train_state: train_states.TrainState,
440442
checkpoint_step_dir)
441443

442444

445+
@functools.lru_cache()
446+
def _cached_mesh_pspec_sharding(mesh, pspec):
447+
return sharding.MeshPspecSharding(mesh, pspec)
448+
449+
443450
def _restore_checkpoint_gda(
444451
train_state: Optional[train_states.TrainState],
445452
checkpoint_dir: str,
@@ -500,11 +507,11 @@ def _restore_checkpoint_gda(
500507
]
501508
tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths)
502509

510+
shardings = [
511+
_cached_mesh_pspec_sharding(global_mesh, s) for s in partition_spec_leaves
512+
]
503513
train_state_gda = gda_serialization.run_deserialization(
504-
[global_mesh] * len(tspecs),
505-
partition_spec_leaves,
506-
tspecs,
507-
global_shapes=global_shapes)
514+
shardings, tspecs, global_shapes=global_shapes)
508515

509516
restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda)
510517
# Barrier across all processes to ensure all restore finish.

0 commit comments

Comments
 (0)