Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
exclusive_parser.add_argument(
exclusive_parser.add_argument(
"--output-snapshot-path",
type=str,
default=None,
Expand Down Expand Up @@ -266,7 +266,7 @@ def _add_exported_input_path_args(parser) -> None:
default=None,
help="Use the specified torchchat snaphot .tc model file",
)


# Add CLI Args related to JIT downloading of model artifacts
def _add_jit_downloading_args(parser) -> None:
Expand Down Expand Up @@ -582,10 +582,8 @@ def arg_init(args):
if "mps" in args.device:
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):
print(
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
"Warning: STOP. Compilation on MPS is experimental! Don't use it yet!"
)
vars(args)["compile"] = False
vars(args)["compile_prefill"] = False

if hasattr(args, "seed") and args.seed:
# Localized import to minimize expensive imports
Expand Down
16 changes: 8 additions & 8 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(
draft_quantize: bool,
):
torch._inductor.config.coordinate_descent_tuning = (
builder_args.device != "cpu"
builder_args.device not in ["cpu", "mps"]
)
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
Expand Down Expand Up @@ -1315,7 +1315,7 @@ def __init__(
quantize: bool,
draft_quantize: bool,
):

is_speculative = speculative_builder_args.checkpoint_path is not None
assert is_speculative == False, "Distributed inference with pp > 1 does not support speculative inference yet."
super().__init__(
Expand All @@ -1336,7 +1336,7 @@ def distributed_input(prompt: str) -> str:
text = [input(prompt)]
else:
text = [None]

dist.broadcast_object_list(text)
return text[0]

Expand Down Expand Up @@ -1491,7 +1491,7 @@ def prefill(
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
lane = 0
kwargs = {"input_pos": input_pos, "cache_lane": lane}

if self.pp_rank == self.first_pp_rank:
logits = self.prefiller.step(padded_seq, **kwargs)
elif self.pp_rank == self.last_pp_rank:
Expand Down Expand Up @@ -1592,7 +1592,7 @@ def sample(
return (idx_next, None)
probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = self.multinomial_sample_one_no_sync(probs)

return idx_next, probs


Expand All @@ -1601,12 +1601,12 @@ def run_generator(
rank: Optional[int] =None
):
"""
This function creates and executes a generator
This function creates and executes a generator
"""
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
#Setup rank 1 and up to suppress log messages and print messages
if builder_args.distributed and rank != 0:
logger.setLevel(logging.CRITICAL)
Expand Down Expand Up @@ -1636,7 +1636,7 @@ def run_generator(

def main(args):
builder_args = BuilderArgs.from_args(args)

if builder_args.distributed:
world_size = builder_args.tp * builder_args.pp

Expand Down