Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4cbd6b6

Browse files
authored
Merge branch 'main' into patch-2
2 parents 4378b26 + 11dcbeb commit 4cbd6b6

File tree

11 files changed

+83
-43
lines changed

11 files changed

+83
-43
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ add_subdirectory(tokenizer)
1414
# include et_run executable
1515
include(runner/et.cmake)
1616
if(TARGET et_run)
17-
target_link_libraries(et_run PUBLIC tokenizer)
17+
target_link_libraries(et_run PUBLIC tokenizer microkernels-prod)
1818
endif()
1919

2020
# include aoti_run executable

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ The following assumes you've completed the steps for [Setting up ExecuTorch](#se
477477

478478
1. Download the AAR file, which contains the Java library and corresponding JNI library, to build and run the app.
479479

480-
- [executorch-240919.aar](https://ossci-android.s3.amazonaws.com/executorch/main/executorch-240919.aar) (SHASUM: c8a5d38ead03bfa28ee8469f6355840ad0d182ba)
480+
- [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-241002/executorch.aar) ([sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-241002/executorch.aar.sha256sums))
481481

482-
2. Rename the downloaded AAR file to `executorch.aar` and move the file to `torchchat/edge/android/torchchat/app/libs/`. You may need to create directory `torchchat/edge/android/torchchat/app/libs/` if it does not exist.
482+
2. Move the downloaded AAR file to `torchchat/edge/android/torchchat/app/libs/`. You may need to create directory `torchchat/edge/android/torchchat/app/libs/` if it does not exist.
483483

484484
3. Push the model and tokenizer file to your device. You can find the model file called `llama3.1.pte` in the current `torchchat` directory and the tokenizer file at `$(python3 torchchat.py where llama3.1)/tokenizer.model` path.
485485
```

install/.pins/et-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
286799c9c844ce6427b8eca260f9b2f28be03291
1+
72b3bb3194c611f7c4861e6f3b24af5de868af72

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ae3e7c68eae7085e13241cb3d6b39481868dd162
1+
49b1fb61c8b8eceda755579a2fd92c756d822de2

install/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ fi
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4848
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
PYTORCH_NIGHTLY_VERSION=dev20240901
50+
PYTORCH_NIGHTLY_VERSION=dev20241002
5151

5252
# Nightly version for torchvision
53-
VISION_NIGHTLY_VERSION=dev20240901
53+
VISION_NIGHTLY_VERSION=dev20241002
5454

5555
# Nightly version for torchtune
5656
TUNE_NIGHTLY_VERSION=dev20240928
@@ -76,7 +76,7 @@ fi
7676

7777
# pip packages needed by exir.
7878
REQUIREMENTS_TO_INSTALL=(
79-
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
79+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
8080
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
8181
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
8282
)

runner/et.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ if(executorch_FOUND)
9494
optimized_native_cpu_ops_lib
9595
quantized_ops_lib
9696
xnnpack_backend
97+
microkernels-prod
9798
XNNPACK
9899
pthreadpool
99100
cpuinfo

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,17 @@ def convert_hf_checkpoint(
8181
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
8282
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
8383
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
84+
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
85+
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
86+
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
87+
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
8488
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
8589
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
8690
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
8791
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
92+
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
93+
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
94+
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
8895
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
8996
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
9097
"model.norm.weight": "norm.weight",
@@ -93,11 +100,10 @@ def convert_hf_checkpoint(
93100
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
94101

95102
def permute(w, n_heads):
96-
dim = config.dim
97103
return (
98-
w.view(n_heads, 2, config.head_dim // 2, dim)
104+
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
99105
.transpose(1, 2)
100-
.reshape(config.head_dim * n_heads, dim)
106+
.reshape(w.shape)
101107
)
102108

103109
merged_result = {}
@@ -130,6 +136,7 @@ def load_safetensors():
130136
continue
131137
assert state_dict is not None, f"Unable to load tensors from {file}"
132138
merged_result.update(state_dict)
139+
133140
final_result = {}
134141
for key, value in merged_result.items():
135142
if "layers" in key:
@@ -145,16 +152,18 @@ def load_safetensors():
145152
final_result[new_key] = value
146153

147154
for key in tuple(final_result.keys()):
148-
if "wq" in key:
155+
if "wq.weight" in key or "wq.bias" in key:
156+
wk_key = key.replace("wq", "wk")
157+
wv_key = key.replace("wq", "wv")
149158
q = final_result[key]
150-
k = final_result[key.replace("wq", "wk")]
151-
v = final_result[key.replace("wq", "wv")]
159+
k = final_result[wk_key]
160+
v = final_result[wv_key]
152161
q = permute(q, config.n_heads)
153162
k = permute(k, config.n_local_heads)
154163
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
155164
del final_result[key]
156-
del final_result[key.replace("wq", "wk")]
157-
del final_result[key.replace("wq", "wv")]
165+
del final_result[wk_key]
166+
del final_result[wv_key]
158167
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
159168
torch.save(final_result, model_dir / "model.pth")
160169
print("Done.")

torchchat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def export_for_server(
9494
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
9595
XnnpackDynamicallyQuantizedPartitioner,
9696
)
97-
from executorch.backends.xnnpack.passes.convert_to_linear import (
97+
from executorch.backends.xnnpack._passes.convert_to_linear import (
9898
ConvertToLinearPass,
9999
)
100100
from executorch.exir import EdgeProgramManager, to_edge

torchchat/generate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,22 @@ def chat(
928928
self.model_forward, fullgraph=True, **kwargs
929929
)
930930

931-
self.decode_one_token = torch.compile(
932-
self.decode_one_token, fullgraph=True, **kwargs
933-
)
931+
if self.model.config.model_type == ModelType.Flamingo:
932+
# Based on https://github.com/pytorch/torchtune/blob/57ab583c84c4a9dcacac23aeabc81f2a679670fe/torchtune/training/_compile.py#L42-L52
933+
from torchtune.modules import (
934+
TransformerCrossAttentionLayer,
935+
TransformerSelfAttentionLayer,
936+
)
937+
decoder = self.model.model.decoder
938+
for m in reversed(list(decoder.modules())):
939+
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
940+
m, TransformerCrossAttentionLayer
941+
):
942+
m.compile()
943+
else:
944+
self.decode_one_token = torch.compile(
945+
self.decode_one_token, fullgraph=True, **kwargs
946+
)
934947

935948
if generator_args.compile_prefill:
936949
self.prefill = torch.compile(

torchchat/model.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
try:
3535
# TODO: remove this after we figure out where in torchtune an `evaluate` module
3636
# is being imported, which is being confused with huggingface's `evaluate``.
37-
import lm_eval # noqa
37+
import lm_eval # noqa
3838
except Exception:
3939
pass
4040

@@ -278,6 +278,11 @@ class TransformerArgs:
278278
# For pipeline parallel
279279
n_stages: int = 1
280280
stage_idx: int = 0
281+
# Optional biases
282+
attention_bias: bool = False
283+
feed_forward_bias: bool = False
284+
# Whether or not to tie the input word embeddings to the output
285+
tie_word_embeddings: bool = False
281286

282287
def __post_init__(self):
283288
if self.n_local_heads == -1:
@@ -394,7 +399,7 @@ def from_name(cls, name: str):
394399
config = [
395400
config
396401
for config in known_model_params
397-
if config in str(name).upper() or config in str(name)
402+
if config.upper() in str(name).upper() or config in str(name)
398403
]
399404

400405
# We may have two or more configs matched (e.g., "7B" and
@@ -471,7 +476,7 @@ def build_model(self) -> nn.Module:
471476
modules[name] = module_class(TransformerArgs.from_params(config_args))
472477
else:
473478
modules[name] = module_class(**config_args)
474-
479+
475480
# Temporary add extra params to the DeepFusionModel.
476481
# TODO: Remove it once we can make fusion model configurable in model_param.
477482
if recipe.fusion_class == DeepFusionModel:
@@ -629,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
629634
if config.stage_idx == config.n_stages - 1:
630635
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
631636
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
637+
if config.tie_word_embeddings:
638+
self.output.weight = self.tok_embeddings.weight
632639
else:
633640
self.norm = None
634641
self.output = None
635642

636643
self.max_batch_size = -1
637644
self.max_seq_length = -1
645+
self._register_load_state_dict_pre_hook(self.load_hook)
646+
647+
def load_hook(self, state_dict, prefix, *args):
648+
"""Handle tied embeddings at load time"""
649+
if self.config.tie_word_embeddings:
650+
state_dict.setdefault("model.output.weight", state_dict["model.tok_embeddings.weight"])
638651

639652
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
640653
if (
@@ -730,16 +743,16 @@ def __init__(self, config: TransformerArgs):
730743

731744
# key, query, value projections for all heads, but in a batch
732745
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
733-
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
734-
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
746+
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias)
747+
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias)
735748
self.wk = nn.Linear(
736-
config.dim, config.n_local_heads * config.head_dim, bias=False
749+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
737750
)
738751
self.wv = nn.Linear(
739-
config.dim, config.n_local_heads * config.head_dim, bias=False
752+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
740753
)
741754

742-
self.wo = nn.Linear(config.dim, config.dim, bias=False)
755+
self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias)
743756
self.kv_cache = None
744757

745758
self.n_heads = config.n_heads
@@ -766,14 +779,16 @@ def load_hook(self, state_dict, prefix, *args):
766779
# wv = state_dict.pop(prefix + "wv.weight")
767780
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
768781

769-
if prefix + "wqkv.weight" in state_dict:
770-
wqkv = state_dict.pop(prefix + "wqkv.weight")
771-
q_size = self.n_heads * self.head_dim
772-
kv_size = self.n_local_heads * self.head_dim
773-
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
774-
state_dict[prefix + "wq.weight"] = wq
775-
state_dict[prefix + "wk.weight"] = wk
776-
state_dict[prefix + "wv.weight"] = wv
782+
for tensor_suffix in ["weight", "bias"]:
783+
wqkv_key = f"{prefix}wqkv.{tensor_suffix}"
784+
if wqkv_key in state_dict:
785+
wqkv = state_dict.pop(wqkv_key)
786+
q_size = self.n_heads * self.head_dim
787+
kv_size = self.n_local_heads * self.head_dim
788+
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
789+
state_dict[f"{prefix}wq.{tensor_suffix}"] = wq
790+
state_dict[f"{prefix}wk.{tensor_suffix}"] = wk
791+
state_dict[f"{prefix}wv.{tensor_suffix}"] = wv
777792

778793
return
779794

@@ -852,9 +867,9 @@ def forward(
852867
class FeedForward(nn.Module):
853868
def __init__(self, config: TransformerArgs) -> None:
854869
super().__init__()
855-
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
856-
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
857-
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
870+
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
871+
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias)
872+
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
858873

859874
def distribute(self, device_mesh: DeviceMesh):
860875
parallelize_module(self.w1, device_mesh, ColwiseParallel())

0 commit comments

Comments
 (0)