We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2d87993 commit cb965c9Copy full SHA for cb965c9
src/llmcompressor/utils/dev.py
@@ -116,6 +116,14 @@ def patch_transformers_logger_level(level: int = logging.ERROR):
116
117
118
def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
119
+ """
120
+ Dispatch a model autoregressive generation. This means that modules are dispatched
121
+ evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks
122
+ that may have existed previously.
123
+
124
+ :param model: model to dispatch
125
+ :return: model which is dispatched
126
127
remove_hook_from_module(model, recurse=True)
128
max_memory = get_balanced_memory(
129
model,
0 commit comments