Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install/.pins/et-pin.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
af098c31b6f8d5f38e40a5cf35784b0969d97df8
286799c9c844ce6427b8eca260f9b2f28be03291
34 changes: 18 additions & 16 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


"""
Export for Server
Export for Server
"""


Expand Down Expand Up @@ -78,7 +78,7 @@ def export_for_server(
"""
Export for ExecuTorch

TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
replace_attention_with_custom_sdpa_attention with ET's implementation
"""

Expand All @@ -94,6 +94,9 @@ def export_for_server(
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack.passes.convert_to_linear import (
ConvertToLinearPass,
)
from executorch.exir import EdgeProgramManager, to_edge

from executorch.exir.capture._config import (
Expand Down Expand Up @@ -274,22 +277,20 @@ def export_for_et(model, device, output_path) -> str:
_skip_type_promotion=bool(target_precision == torch.float16),
)

if target_precision == torch.float16 or target_precision == torch.bfloat16:
if state_dict_dtype != torch.float16:
print("model.to torch.float16")
model = model.to(dtype=torch.float16)
state_dict_dtype = torch.float16
elif target_precision == torch.float32:
if state_dict_dtype != torch.float32:
print("model.to torch.float32")
model = model.to(dtype=torch.float32)
elif target_precision == torch.bfloat16:
print("model.to torch.bfloat16")
model = model.to(dtype=torch.bfloat16)
else:
if target_precision not in (torch.float16, torch.float32, torch.bfloat16):
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")

replace_attention_with_custom_sdpa_attention(model)
if state_dict_dtype != target_precision:
print(f"model.to {target_precision}")
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
), torch.no_grad():
Expand All @@ -306,6 +307,7 @@ def export_for_et(model, device, output_path) -> str:
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
ConvertToLinearPass(),
QuantFusionPass(),
],
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
Expand Down
1 change: 1 addition & 0 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def chat(
if text_transformer_args is not None
else 2048
),
max_seq_length
)

max_seq_length = (
Expand Down
26 changes: 17 additions & 9 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def forward(
)

return self.decoder(decoder_input, input_pos=input_pos)

def setup_caches(self, batch_size, max_seq_len) -> None:
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output) -> Tensor:
selected_image_feature = encoder_output[1][0].view(
*encoder_output[1][0].shape[2:]
Expand All @@ -154,7 +154,7 @@ def _get_decoder_input(
image_embeds = self.mm_projector(encoder_output)
if post_tokens is None:
return torch.cat((pre_img_embed, image_embeds), dim=1)

post_img_embed = self.tok_embeddings(post_tokens)
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)

Expand Down Expand Up @@ -227,7 +227,7 @@ def _llava(cls):
},
fusion_class=ConcateFusion,
)

@classmethod
def get_recipe(cls, model_type):
match model_type:
Expand Down Expand Up @@ -338,7 +338,7 @@ def _sanity_check(
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

if (model_type_name := loaded_params.get("model_type", None)) is None:
# The model params is in the transformer_args format
# set the model_type to TextOnly and reformat the params
Expand Down Expand Up @@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)

def _replace_known_params(self, params):
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
for key, value in params.items():
if isinstance(value, Hashable) and value in patterns:
params[key] = patterns[value]
return params

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down Expand Up @@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
self.model_ = exec_lib._load_for_executorch(str(path))

self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])

# TODO: attempt to use "get_max_seq_len" method on the model after
# ExecuTorch bug is fixed.
max_seq_len = 128
# try:
# max_seq_len = self.model_.run_method("get_max_seq_len", [])
# except Exception as e:
# pass
self.text_transformer_args.max_seq_length = max_seq_len

def forward(self, x, input_pos):
# model_.forward expects inputs to be wrapped in a tuple
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
Expand All @@ -958,6 +966,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass
2 changes: 1 addition & 1 deletion torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ install_executorch_cpp_libs() {
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=${EXECUTORCH_BUILD_KERNELS_CUSTOM_VAR} \
${CROSS_COMPILE_ARGS} \
-S . -B ${CMAKE_OUT_DIR} -G Ninja
cmake --build ${CMAKE_OUT_DIR}
cmake --build ${CMAKE_OUT_DIR} -j16
cmake --install ${CMAKE_OUT_DIR} --prefix ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install
popd
}
Expand Down
Loading