Skip to content

Commit d980f99

Browse files
committed
Simplify call to torch_dynamo.optimize()
config_patches arg is no longer required with latest pt2 nightlies, per @aviros
1 parent 9ef5b19 commit d980f99

File tree

1 file changed

+5
-10
lines changed
  • server/text_generation_server/models

1 file changed

+5
-10
lines changed

server/text_generation_server/models/model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,7 @@ def count_kernels(guard):
8686
self.n_kernels += 1
8787

8888
compiled_forward = torch._dynamo.optimize(
89-
lambda model, inputs: compile_fx(
90-
model,
91-
inputs,
92-
config_patches={
93-
"triton.cudagraphs": False,
94-
"size_asserts": False,
95-
},
96-
),
89+
compile_fx,
9790
dynamic=True,
9891
guard_fail_fn=count_kernels,
9992
)(self.model.forward)
@@ -155,12 +148,14 @@ def get_indices_to_keep(
155148
return next_batch_keep_indices
156149

157150
def _setup_prompt_encoder(self) -> bool:
158-
if hasattr(self.model, "named_children"):
151+
vocab_size = getattr(self.model.config, "vocab_size", None)
152+
153+
if vocab_size is not None and hasattr(self.model, "named_children"):
159154
# Logic derived from https://github.com/huggingface/peft/blob/75925b1aaee47fe483a3fd0322d86df3d3eb8d22/src/peft/peft_model.py#L185
160155
for name, module in self.model.named_children():
161156
if isinstance(module, PreTrainedModel):
162157
for named_param, value in list(module.named_parameters()):
163-
if value.shape[0] == self.model.config.vocab_size:
158+
if value.shape[0] == vocab_size:
164159
self.word_embeddings = module.get_submodule(named_param.replace(".weight", ""))
165160
return True
166161

0 commit comments

Comments
 (0)