We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b2367ce commit 06bb661Copy full SHA for 06bb661
src/llmcompressor/entrypoints/train.py
@@ -8,6 +8,7 @@
8
from llmcompressor.core.session_functions import active_session
9
from llmcompressor.datasets.utils import get_processed_dataset
10
from llmcompressor.transformers.finetune.trainer import Trainer
11
+from llmcompressor.utils.dev import dispatch_for_generation
12
13
from .utils import post_process, pre_process
14
@@ -63,6 +64,7 @@ def train(**kwargs) -> PreTrainedModel:
63
64
)
65
66
pre_process(model_args)
67
+ dispatch_for_generation(model_args.model) # train is dispatched same as generation
68
69
processed_dataset = get_processed_dataset(
70
dataset_args=dataset_args,
0 commit comments