Skip to content

Commit b8af35e

Browse files
No public description
PiperOrigin-RevId: 713639384
1 parent 2ca4212 commit b8af35e

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

official/projects/pix2seq/modeling/pix2seq_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,17 @@ def checkpoint_items(
327327
# For backward-compatibility with prior checkpoints, the first backbone
328328
# should be named "backbone" and the second one should be named
329329
# "backbone_2", etc.
330-
items = dict(backbone=self.backbones[0], transformer=self.transformer)
330+
items = dict(
331+
backbone=self.backbones[0],
332+
transformer=self.transformer,
333+
stem_projection=self._stem_projections[0],
334+
stem_ln=self._stem_lns[0],
335+
)
331336
for i in range(1, len(self.backbones)):
332337
items[f"backbone_{i+1}"] = self.backbones[i]
338+
items[f"stem_projection_{i+1}"] = self._stem_projections[i]
339+
items[f"stem_ln_{i+1}"] = self._stem_lns[i]
340+
333341
return items
334342

335343
def _generate_image_mask(

0 commit comments

Comments
 (0)