Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 61d1e0e

Browse files
authored
Merge branch 'main' into multiturn-mm-single-image
2 parents 25beb26 + dc3d35e commit 61d1e0e

File tree

7 files changed

+59
-61
lines changed

7 files changed

+59
-61
lines changed

install/.lintrunner.toml renamed to .lintrunner.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ init_command = [
1919
'run',
2020
'pip_init',
2121
'--dry-run={{DRYRUN}}',
22-
'--requirement=requirements-lintrunner.txt',
22+
'--requirement=install/requirements-lintrunner.txt',
2323
]
2424

2525
# Black + usort
@@ -46,7 +46,7 @@ init_command = [
4646
'pip_init',
4747
'--dry-run={{DRYRUN}}',
4848
'--no-black-binary',
49-
'--requirement=requirements-lintrunner.txt',
49+
'--requirement=install/requirements-lintrunner.txt',
5050
]
5151
is_formatter = true
5252

@@ -75,6 +75,6 @@ init_command = [
7575
'run',
7676
'pip_init',
7777
'--dry-run={{DRYRUN}}',
78-
'--requirement=requirements-lintrunner.txt',
78+
'--requirement=install/requirements-lintrunner.txt',
7979
]
8080
is_formatter = true

CONTRIBUTING.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,23 @@ We actively welcome your pull requests.
1010
2. If you've added code that should be tested, add tests.
1111
3. If you've changed APIs, update the documentation.
1212
4. Ensure the test suite passes.
13-
5. Make sure your code lints.
13+
5. Make sure your code is well-formatted using the repo linter. See "Linting" for details.
1414
6. If you haven't already, complete the Contributor License Agreement ("CLA").
1515

16+
17+
### Linting
18+
Install the lintrunner dependencies from the requirements file.
19+
```
20+
pip3 install -r install/requirements-lintrunner.txt
21+
```
22+
23+
After making your changes locally, run the lintrunner and apply all suggestions to your changes.
24+
You can do this from the top-level torchchat directory - it will apply suggestions only to files that
25+
you have touched.
26+
```
27+
lintrunner -a
28+
```
29+
1630
## Contributor License Agreement ("CLA")
1731
In order to accept your pull request, we need you to submit a CLA. You only need
1832
to do this once to work on any of Meta's open source projects.

dist_run.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209209
batch_size, seq_len, vocab_size = output.shape
210210

211211
if step != -1:
212+
# `pos` is not provided, so we can use the first token
212213
next_token_logits = output[:, 0, :]
213214
else:
214215
# get the logits for each prompt at the specified positions
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228229
).squeeze(-1)
229230
else:
230231
# Argmax (deterministic)
231-
next_tokens = torch.argmax(next_token_logits, dim=-1)
232+
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
232233

233-
logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
234+
# Token ids in int tensor form
234235
return next_tokens
235236

236237

@@ -247,6 +248,11 @@ def _update_padded_sequence(
247248
# Decode token id into string and print it
248249
def _decode_in_flight(token, tokenizer, tp_rank):
249250
"""decode token ids for all prompts in the batch and log them"""
251+
# `token` is a tensor of shape (batch_size, 1).
252+
# For TiktokenTokenizer, we need to squeeze it to 1D.
253+
# For SentencePieceProcessor, we don't.
254+
if isinstance(tokenizer, TiktokenTokenizer):
255+
token = torch.squeeze(token, dim=1)
250256
token_str = tokenizer.decode(token.tolist())
251257
# print the token string on tp rank 0
252258
if tp_rank == 0:
@@ -328,15 +334,26 @@ def main(args):
328334
config.stage_idx = pp_rank
329335
config.n_stages = pp_degree
330336

331-
with device:
337+
with torch.device("meta"):
332338
# TODO: we should create model instead of Transformer
333339
model = Transformer(config)
334340

335341
# Distribute model on TP mesh
342+
# (Surprisingly, this works even though model is on meta device and mesh is of
343+
# cuda devices)
336344
model.distribute(tp_mesh)
337345
if rank == 0:
338346
logger.info(f"Model: {model}")
339347

348+
# Load weights
349+
logger.info(f"Loading weights for {pp_rank=} on {device=}")
350+
with CUDATrackTime() as timer:
351+
_load_model_weights(model, distribution, device=device, model_config=config)
352+
353+
logger.info(
354+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
355+
)
356+
340357
# Batch size. Since we push batches dynamically through the pipeline rather
341358
# than chunking them, this is effectively micro-batch size in pipeline
342359
# sense. Thus it is interchangeable with micro-batch size below.
@@ -352,17 +369,8 @@ def main(args):
352369
# lanes.
353370
# TODO: bump up the lane count
354371
pipeline_lanes = 1
355-
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
356-
357-
# Load weights
358-
logger.info(f"Loading weights for {pp_rank=} on {device=}")
359-
with CUDATrackTime() as timer:
360-
_load_model_weights(model, distribution, device=device, model_config=config)
361-
model.to(device)
362-
363-
logger.info(
364-
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
365-
)
372+
with device:
373+
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
366374

367375
# info on stage size and params
368376
stage_size = get_module_size(model)
@@ -528,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
528536

529537
# output formatted response via last pp group and tp rank 0
530538
if pp_rank == last_pp_rank and tp_rank == 0:
531-
# `res` is a list of tensors, each being a batch of generated token ids
532-
533-
res_stacked = torch.stack(res, dim=1)
534-
res_list = res_stacked.tolist()
535-
536-
# Decode the output as comprehension instead of loop
537-
responses = [tokenizer.decode(sequence) for sequence in res_list]
538-
539+
# `res` is a list of tensors, each being a batch of generated token ids.
540+
# We need to concatenate them to get the full sequence of generated
541+
# token ids. Thus cat'ing along dim 1.
542+
res = torch.cat(res, dim=1)
543+
res_list = res.tolist()
544+
responses = tokenizer.decode(res_list)
539545
# Show prompts and responses
540546
for prompt_text, response_text in zip(prompt, responses):
541547
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

docs/multimodal.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ While we strongly encourage you to use the Hugging Face checkpoint (which is the
1919
```
2020

2121
## Generation
22-
23-
**We are currently debugging Multimodal Inference on MPS and will have updates soon. In the meantime, when testing on Mac, please set `--device cpu`**
24-
2522
This generates text output based on a text prompt and (optional) image prompt.
2623

2724
```

torchchat/distributed/dtensor_utils.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,17 @@
88
logger = SingletonLogger.get_logger()
99

1010

11-
12-
def is_dtensor(tensor):
13-
"""Check if a tensor is a DTensor by class or has a placements attribute (not sure if we want to use attr check)"""
14-
return isinstance(tensor, DTensor) or hasattr(tensor, "placements")
15-
16-
17-
def load_into_dtensor(weight_tensor, model_dtensor):
11+
def convert_to_dtensor(weight_tensor, dtensor_template):
1812
"""Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it"""
19-
weight_tensor = weight_tensor.to(model_dtensor.device)
2013

21-
if weight_tensor.shape != model_dtensor.shape:
14+
if weight_tensor.shape != dtensor_template.shape:
2215
raise ValueError(
2316
f"Shape mismatch: weight tensor shape {weight_tensor.shape} "
24-
f"doesn't match DTensor shape {model_dtensor.shape}"
17+
f"doesn't match DTensor shape {dtensor_template.shape}"
2518
)
2619

27-
placements = model_dtensor.placements
28-
mesh = model_dtensor.device_mesh
20+
placements = dtensor_template.placements
21+
mesh = dtensor_template.device_mesh
2922
mesh_dims = mesh.ndim
3023

3124
for placement in placements:

torchchat/distributed/safetensor_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from torch.nn import Module
1414
from typing import Dict, Tuple, Set, Optional
1515

16-
17-
from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor
16+
from torch.distributed._tensor import DTensor
17+
from torchchat.distributed.dtensor_utils import convert_to_dtensor
1818

1919

2020
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
@@ -284,9 +284,7 @@ def update_state_dict(
284284
continue
285285

286286
checkpoint_tensor = checkpoint[old_param]
287-
stage_tensor = state_dict[param]
288-
289-
stage_is_dtensor = is_dtensor(stage_tensor)
287+
model_tensor = state_dict[param]
290288

291289
if "wq" in param:
292290
checkpoint_tensor = permute_weight_to_attn_heads(
@@ -297,17 +295,16 @@ def update_state_dict(
297295
checkpoint_tensor, num_local_heads, head_dim, dim
298296
)
299297

298+
# Move checkpoint tensor to desired device
299+
checkpoint_tensor = checkpoint_tensor.to(device)
300+
300301
# here we need to check if the tensor is a DTensor and if so, adjust the
301302
# shape and placement to match the model DTensor.
302-
if stage_is_dtensor:
303-
model_tensor = load_into_dtensor(checkpoint_tensor, stage_tensor)
304-
# logger.info(f"DTensor: Loaded {param} into {model_tensor=}")
305-
state_dict[param] = model_tensor
303+
if isinstance(model_tensor, DTensor):
304+
state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor)
306305
count_dtensors_loaded += 1
307-
308306
else:
309307
# regular tensor, just update directly
310-
checkpoint_tensor = checkpoint_tensor.to(device)
311308
state_dict[param] = checkpoint_tensor
312309

313310
# ensure matching dtypes

torchchat/generate.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def __init__(
264264
"""
265265
))
266266
# fmt: on
267-
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
268267
self.system_prompt = generator_args.prompt
269268
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)
270269

@@ -503,7 +502,6 @@ def decode_n_tokens(
503502
next_prob.clone() if next_prob is not None else None
504503
)
505504

506-
# return new_tokens, new_probs
507505

508506
def model_forward(self, model, x, input_pos):
509507
return model(x, input_pos)
@@ -603,8 +601,6 @@ def generate(
603601
is_speculative = draft_model is not None
604602
device, dtype = prompt.device, prompt.dtype
605603

606-
# create an empty tensor of the expected final shape and
607-
# fill in the current tokens
608604
if len(prompt.shape) > 1:
609605
prompt = prompt.squeeze(0)
610606
prompt_length = prompt.size(0)
@@ -633,11 +629,6 @@ def generate(
633629
if model.config.model_type == ModelType.Flamingo:
634630
model.reset_caches()
635631

636-
# create an empty tensor of the expected final shape and
637-
# fill in the current tokens
638-
empty = torch.empty(max_seq_length, dtype=dtype, device=device)
639-
empty[:prompt_length] = prompt
640-
641632
input_pos = torch.arange(
642633
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
643634
)

0 commit comments

Comments
 (0)