Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f3cbd53

Browse files
committed
merge with main
2 parents 128566c + f730056 commit f3cbd53

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

dist_run.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
5656
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5757
}
58-
CACHE_PRECISION = torch.bfloat16
5958

6059

6160
def _init_distributed():
@@ -245,8 +244,8 @@ def main(args):
245244

246245
tokenizer = _build_chat_tokenizer(model_name)
247246

248-
set_precision(CACHE_PRECISION)
249-
logger.info(f"Using cache precision {CACHE_PRECISION}")
247+
set_precision(model_dtype)
248+
logger.info(f"Using cache precision {model_dtype}")
250249

251250
hf_config = get_hf_config_file(distribution)
252251
if hf_config is None:
@@ -285,8 +284,6 @@ def main(args):
285284
with device:
286285
model = Transformer(config)
287286

288-
model.setup_caches(1, 4096)
289-
290287
# Distribute model on TP mesh
291288
model.distribute(tp_mesh)
292289
if rank == 0:
@@ -300,6 +297,12 @@ def main(args):
300297
dim = 4096 # embedding dimension
301298
assert seqlen % sp_degree == 0
302299

300+
# Setup KV caches (after model distribution)
301+
# TODO: the setting below only works for 1 micro-batch case. To support
302+
# multiple micro-batches, we need the KV cache in the model to be aware of
303+
# the number of micro-batches and the current micro-batch index.
304+
model.setup_caches(mb_size, seqlen)
305+
303306
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
304307
activation = torch.rand(
305308
mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype

install/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ PYTORCH_NIGHTLY_VERSION=dev20240814
5353
VISION_NIGHTLY_VERSION=dev20240814
5454

5555
# Nightly version for torchtune
56-
TUNE_NIGHTLY_VERSION=dev20240910
56+
TUNE_NIGHTLY_VERSION=dev20240916
5757

5858

5959
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same

torchchat/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,22 +727,22 @@ def chat(
727727
if generator_args.image_prompts is not None:
728728
print("Image prompts", generator_args.image_prompts)
729729

730+
# Support for just the first image prompt for now
731+
images = [Image.open(generator_args.image_prompts[0])]
730732
messages = [
731733
Message(
732734
role="user",
733735
content=[
734-
{"type": "image"},
736+
{"type": "image", "content": images[0]},
735737
{"type": "text", "content": generator_args.prompt},
736738
],
737739
eot=True,
738740
),
739741
Message(role="assistant", content=""),
740742
]
741743

742-
images = [Image.open(generator_args.image_prompts[0])]
743744
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
744-
745-
data = transform({"images": images, "messages": messages}, inference=True)
745+
data = transform({"messages": messages}, inference=True)
746746
batch = padded_collate([data], self.builder_args.device)
747747
batch.pop("mask")
748748
encoded = batch["tokens"]

torchchat/model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -623,13 +623,15 @@ def setup_caches(self, max_batch_size, max_seq_length):
623623
and self.max_batch_size >= max_batch_size
624624
):
625625
return
626-
head_dim = self.config.dim // self.config.n_heads
627626
max_seq_length = find_multiple(max_seq_length, 8)
628627
self.max_seq_length = max_seq_length
629628
self.max_batch_size = max_batch_size
630629
for b in self.layers.values():
631-
b.attention.kv_cache = KVCache(
632-
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
630+
# Lower the setup_cache call to the attention module because tensor
631+
# parallelism may have been applied there and the `n_local_heads``
632+
# value being adjusted.
633+
b.attention.setup_cache(
634+
max_batch_size, max_seq_length,
633635
)
634636

635637
freqs_cis = precompute_freqs_cis(
@@ -745,6 +747,16 @@ def __init__(self, config: TransformerArgs):
745747
self.dim = config.dim
746748
self._register_load_state_dict_pre_hook(self.load_hook)
747749

750+
def setup_cache(self, max_batch_size, max_seq_length):
751+
n_local_heads = self.n_local_heads
752+
# If TP is enabled, the heads would be divided and assigned to different ranks
753+
if hasattr(self, "tp_degree"):
754+
n_local_heads = self.n_local_heads // self.tp_degree
755+
756+
self.kv_cache = KVCache(
757+
max_batch_size, max_seq_length, n_local_heads, self.head_dim
758+
)
759+
748760
def load_hook(self, state_dict, prefix, *args):
749761
# if prefix + "wq.weight" in state_dict:
750762
# wq = state_dict.pop(prefix + "wq.weight")
@@ -785,14 +797,13 @@ def _unfuse_wqkv_state_dict(
785797

786798
def distribute(self, device_mesh: DeviceMesh):
787799
self.device_mesh = device_mesh
800+
self.tp_degree = device_mesh.size()
788801
parallelize_module(self.wq, device_mesh, ColwiseParallel())
789802
parallelize_module(self.wk, device_mesh, ColwiseParallel())
790803
parallelize_module(self.wv, device_mesh, ColwiseParallel())
791804
parallelize_module(
792805
self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1))
793806
)
794-
# TODO: enable kv cache in distributed case
795-
self.kv_cache = None
796807

797808
def forward(
798809
self,

0 commit comments

Comments
 (0)