diff --git a/torchchat/generate.py b/torchchat/generate.py index 5eb946f7d..fcbe5513b 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -928,9 +928,22 @@ def chat( self.model_forward, fullgraph=True, **kwargs ) - self.decode_one_token = torch.compile( - self.decode_one_token, fullgraph=True, **kwargs - ) + if self.model.config.model_type == ModelType.Flamingo: + # Based on https://github.com/pytorch/torchtune/blob/57ab583c84c4a9dcacac23aeabc81f2a679670fe/torchtune/training/_compile.py#L42-L52 + from torchtune.modules import ( + TransformerCrossAttentionLayer, + TransformerSelfAttentionLayer, + ) + decoder = self.model.model.decoder + for m in reversed(list(decoder.modules())): + if isinstance(m, TransformerSelfAttentionLayer) or isinstance( + m, TransformerCrossAttentionLayer + ): + m.compile() + else: + self.decode_one_token = torch.compile( + self.decode_one_token, fullgraph=True, **kwargs + ) if generator_args.compile_prefill: self.prefill = torch.compile(