diff --git a/install/.pins/et-pin.txt b/install/.pins/et-pin.txt index 01c77f102..ceb4852bf 100644 --- a/install/.pins/et-pin.txt +++ b/install/.pins/et-pin.txt @@ -1 +1 @@ -af098c31b6f8d5f38e40a5cf35784b0969d97df8 +286799c9c844ce6427b8eca260f9b2f28be03291 diff --git a/torchchat/export.py b/torchchat/export.py index 21e7fcaa8..6b06f1df1 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -28,7 +28,7 @@ """ -Export for Server +Export for Server """ @@ -78,7 +78,7 @@ def export_for_server( """ Export for ExecuTorch -TODO (https://github.com/pytorch/torchchat/issues/1058): Replace +TODO (https://github.com/pytorch/torchchat/issues/1058): Replace replace_attention_with_custom_sdpa_attention with ET's implementation """ @@ -94,6 +94,9 @@ def export_for_server( from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, ) + from executorch.backends.xnnpack.passes.convert_to_linear import ( + ConvertToLinearPass, + ) from executorch.exir import EdgeProgramManager, to_edge from executorch.exir.capture._config import ( @@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str: _skip_type_promotion=bool(target_precision == torch.float16), ) - if target_precision == torch.float16 or target_precision == torch.bfloat16: - if state_dict_dtype != torch.float16: - print("model.to torch.float16") - model = model.to(dtype=torch.float16) - state_dict_dtype = torch.float16 - elif target_precision == torch.float32: - if state_dict_dtype != torch.float32: - print("model.to torch.float32") - model = model.to(dtype=torch.float32) - elif target_precision == torch.bfloat16: - print("model.to torch.bfloat16") - model = model.to(dtype=torch.bfloat16) - else: + if target_precision not in (torch.float16, torch.float32, torch.bfloat16): raise ValueError(f"Unsupported dtype for ET export: {target_precision}") - replace_attention_with_custom_sdpa_attention(model) + if state_dict_dtype != target_precision: + print(f"model.to {target_precision}") + model = model.to(dtype=target_precision) + state_dict_dtype = target_precision + + # Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't + # support anything but bfloat32, and our attempt to use it anyway by converting + # to and from float causes other errors.) + if target_precision != torch.bfloat16: + replace_attention_with_custom_sdpa_attention(model) + with torch.nn.attention.sdpa_kernel( [torch.nn.attention.SDPBackend.MATH] ), torch.no_grad(): @@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str: ExecutorchBackendConfig( extract_delegate_segments=True, passes=[ + ConvertToLinearPass(), QuantFusionPass(), ], sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), diff --git a/torchchat/generate.py b/torchchat/generate.py index 9e60f9494..14c4832e3 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -818,6 +818,7 @@ def chat( if text_transformer_args is not None else 2048 ), + max_seq_length ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index 3300ebee9..336eb864c 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -127,10 +127,10 @@ def forward( ) return self.decoder(decoder_input, input_pos=input_pos) - + def setup_caches(self, batch_size, max_seq_len) -> None: self.decoder.setup_caches(batch_size, max_seq_len) - + def _encoder_feature_select(self, encoder_output) -> Tensor: selected_image_feature = encoder_output[1][0].view( *encoder_output[1][0].shape[2:] @@ -154,7 +154,7 @@ def _get_decoder_input( image_embeds = self.mm_projector(encoder_output) if post_tokens is None: return torch.cat((pre_img_embed, image_embeds), dim=1) - + post_img_embed = self.tok_embeddings(post_tokens) return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) @@ -227,7 +227,7 @@ def _llava(cls): }, fusion_class=ConcateFusion, ) - + @classmethod def get_recipe(cls, model_type): match model_type: @@ -338,7 +338,7 @@ def _sanity_check( def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - + if (model_type_name := loaded_params.get("model_type", None)) is None: # The model params is in the transformer_args format # set the model_type to TextOnly and reformat the params @@ -460,14 +460,14 @@ def build_model(self) -> nn.Module: modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) - + def _replace_known_params(self, params): patterns = {"QuickGELUActivation()": QuickGELUActivation()} for key, value in params.items(): if isinstance(value, Hashable) and value in patterns: params[key] = patterns[value] return params - + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") @@ -939,7 +939,15 @@ def __init__(self, config, path) -> None: self.model_ = exec_lib._load_for_executorch(str(path)) self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"]) - + # TODO: attempt to use "get_max_seq_len" method on the model after + # ExecuTorch bug is fixed. + max_seq_len = 128 + # try: + # max_seq_len = self.model_.run_method("get_max_seq_len", []) + # except Exception as e: + # pass + self.text_transformer_args.max_seq_length = max_seq_len + def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple forward_inputs = (x.to(torch.long), input_pos.to(torch.long)) @@ -958,6 +966,6 @@ def forward(self, x, input_pos): def setup_caches(self, max_batch_size, max_seq_length): pass - + except: pass diff --git a/torchchat/utils/scripts/install_utils.sh b/torchchat/utils/scripts/install_utils.sh index 2da3d044c..0ff4608c6 100644 --- a/torchchat/utils/scripts/install_utils.sh +++ b/torchchat/utils/scripts/install_utils.sh @@ -150,7 +150,7 @@ install_executorch_cpp_libs() { -DEXECUTORCH_BUILD_KERNELS_CUSTOM=${EXECUTORCH_BUILD_KERNELS_CUSTOM_VAR} \ ${CROSS_COMPILE_ARGS} \ -S . -B ${CMAKE_OUT_DIR} -G Ninja - cmake --build ${CMAKE_OUT_DIR} + cmake --build ${CMAKE_OUT_DIR} -j16 cmake --install ${CMAKE_OUT_DIR} --prefix ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install popd }