Skip to content

Commit f7cf5b7

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support non-sharded llm inference (#1007)
Summary: Pull Request resolved: #1007 # Context Small llms (1b) don't need to be sharded for inference. Currently we assume sharded is needed in our recipes # This Diff 1) enables materialization of params at TNT level when moving model from meta device to cuda 2) makes sharding / global mesh coordinator optional 3) Adds llama3pt2_1b_inference.yaml which loads full model per rank Reviewed By: rshakoor Differential Revision: D75823535 fbshipit-source-id: 9d8bd09ca50d032d9144eb2c8ea2f486dfc50dfe
1 parent f5bf0c1 commit f7cf5b7

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchtnt/utils/prepare_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ def _prepare_module_1d(
728728
elif isinstance(strategy, FSDP2Strategy):
729729
module = prepare_fsdp2(module, device, strategy, global_mesh=global_mesh)
730730
else:
731+
# materialize any meta device params
732+
materialize_meta_params(module=module, device=device)
733+
# then move entire module to device
731734
module = module.to(device)
732735

733736
if activation_checkpoint_params:

0 commit comments

Comments
 (0)