Skip to content

Commit 06bb661

Browse files
committed
add train dispatch
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b2367ce commit 06bb661

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/llmcompressor/entrypoints/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from llmcompressor.core.session_functions import active_session
99
from llmcompressor.datasets.utils import get_processed_dataset
1010
from llmcompressor.transformers.finetune.trainer import Trainer
11+
from llmcompressor.utils.dev import dispatch_for_generation
1112

1213
from .utils import post_process, pre_process
1314

@@ -63,6 +64,7 @@ def train(**kwargs) -> PreTrainedModel:
6364
)
6465

6566
pre_process(model_args)
67+
dispatch_for_generation(model_args.model) # train is dispatched same as generation
6668

6769
processed_dataset = get_processed_dataset(
6870
dataset_args=dataset_args,

0 commit comments

Comments
 (0)