Skip to content

Commit 7bdcb49

Browse files
tdoublepnjhill
authored andcommitted
Fixes to enable PT2C for ibm/mpt-7b-instruct2
There were multiple issues using PT2C with this model: 1. PT2C fails with error `'NoneType' object has no attribute 'node'` when compiling the first kernel. This is an issue on PT side and I have opened [an issue](pytorch/pytorch#107721) accordingly. There is a simple workaround which is just to call forward once before compiling, so this does not block us for now. 2. There are issues using PT2C dynamic shapes together with the `einops` package. Fixed by updating einops to latest rc version. 3. The other models that we've tried with PT2C until now return the `past_key_values` (pkv) tensor as a tuple of tuples. The exception is when we concatenate a batch, after which the `past_key_values` tensors are a list of lists. Since type changes breaks the PT2C guards, we had some logic in the code to detect check whether the first dimension is a list, and if so convert the pkvs to a tuple of tuples. This logic breaks down for this model because the forward function returns the pkvs as a list of tuples, and starts erroring out if we try to pass them as a tuple of tuples. To solve, this I've added logic to detect what are the types expected by the model, and ensure that in the case of concatenation we always convert to the expected types. 4. There is one line in the modelling code that PT2C does not play well with. It seems to compare the shape of the `attention_mask` to the actual values inside and creates complete chaos leading to guards breaking whenever we concatenate batches that were started at different times. The solution is the trivial change below in the modelling code. Ideally we could get this change applied in the model on HF side, until then it can be patched in our local versions. Note that without this change in the model, everything still works, but we will be falling back to eager more often than if this change is applied.
1 parent 62de0ed commit 7bdcb49

File tree

6 files changed

+42
-8
lines changed

6 files changed

+42
-8
lines changed

server/poetry.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ optimum = { version = "1.11.0", extras = ["onnxruntime-gpu"], optional = true }
2323
onnxruntime = { version = "1.15.1", optional = true }
2424
onnxruntime-gpu = { version = "1.15.1", optional = true }
2525
onnx = { version = "1.14.0", optional = true }
26-
einops = "^0.6.1"
26+
einops = "^0.7.0rc2"
2727

2828
# Explicitly install some transitive dependencies to avoid CVEs
2929
mpmath = ">=1.3.0"

server/text_generation_server/inference_engine/hf_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def __init__(
2323
"trust_remote_code": TRUST_REMOTE_CODE,
2424
}
2525

26+
if model_config.model_type == "mpt":
27+
model_config.init_device = str(self.device)
28+
kwargs["config"] = model_config
29+
2630
if dtype == torch.int8:
2731
# using LLM.int8()
2832
kwargs["load_in_8bit"] = True

server/text_generation_server/models/causal_lm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,14 @@ def batch_type(self) -> Type[CausalLMBatch]:
481481
def batch_type(self, value):
482482
self._batch_type = value
483483

484+
def determine_pkv_types(self) -> Tuple[Type, Type]:
485+
one_token = torch.tensor([[1]], device=self.device)
486+
_, pkv, _ = self.forward(
487+
input_ids=one_token,
488+
attention_mask=one_token,
489+
)
490+
return type(pkv), type(pkv[0])
491+
484492
def forward(
485493
self,
486494
input_ids: torch.Tensor,

server/text_generation_server/models/model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
if PT2_COMPILE:
2828
import torch._dynamo
2929
from torch._inductor.compile_fx import compile_fx
30+
from einops._torch_specific import allow_ops_in_compiled_graph
31+
allow_ops_in_compiled_graph()
3032

3133

3234
class Model(ABC):
@@ -70,6 +72,12 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
7072
if not PT2_COMPILE:
7173
self.compiled = False
7274
else:
75+
76+
# Perform a forward pass using a single token. This serves 2 purposes:
77+
# (1) work-around for PT2C issue #107721
78+
# (2) determine types of past_key_value output
79+
type_pkv_dim0, type_pkv_dim1 = self.determine_pkv_types()
80+
7381
torch._dynamo.config.cache_size_limit = 512
7482
self.n_kernels = 0
7583

@@ -93,8 +101,13 @@ def count_kernels(guard):
93101
run_forward = torch._dynamo.run(compiled_forward)
94102

95103
def parse_kwargs(kwargs):
96-
if "past_key_values" in kwargs and type(kwargs["past_key_values"]) is list:
97-
kwargs["past_key_values"] = tuple(tuple(t) for t in kwargs["past_key_values"])
104+
# after batch concatentation the past_key_value tensor is a list of lists.
105+
# this will lead to guard failures unless we convert them to the typical
106+
# types that we expect to be returned by forward.
107+
pkv = kwargs.get("past_key_values")
108+
if pkv is not None:
109+
if type(pkv) != type_pkv_dim0 or type(pkv[0]) != type_pkv_dim1:
110+
kwargs["past_key_values"] = type_pkv_dim0(type_pkv_dim1(t) for t in pkv)
98111
return kwargs
99112

100113
def override_forward_with_compile(self, *args, **kwargs):

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,15 @@ def batch_type(self) -> Type[Seq2SeqLMBatch]:
510510
def batch_type(self, value):
511511
self._batch_type = value
512512

513+
def determine_pkv_types(self) -> Tuple[Type, Type]:
514+
one_token = torch.tensor([[1]], device=self.device)
515+
_, _, pkv, _ = self.forward(
516+
input_ids=one_token,
517+
attention_mask=one_token,
518+
decoder_input_ids=one_token,
519+
)
520+
return type(pkv), type(pkv[0])
521+
513522
def forward(
514523
self,
515524
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)