Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2e08c8f

Browse files
committed
Move generator base class into its own module so distributed can pick it up
1 parent e780f6c commit 2e08c8f

File tree

3 files changed

+407
-404
lines changed

3 files changed

+407
-404
lines changed

torchchat/distributed/generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
2222
from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE
2323
from torchchat.distributed.logging_utils import SingletonLogger
24+
from torchchat.utils.generator import Generator, GeneratorArgs
2425

2526
logger = SingletonLogger.get_logger()
2627

@@ -194,19 +195,19 @@ def step(self) -> List[Output]:
194195
return outputs
195196

196197

197-
class DistributedGenerator(object):
198+
class DistributedGenerator(Generator):
198199
def __init__(
199200
self,
200201
# TODO: switch this to torchchat method
201202
model_name: str,
202203
builder_args: BuilderArgs,
203204
tokenizer_args: TokenizerArgs,
204-
# TODO: move GeneratorArgs into a different module
205-
generator_args,
205+
generator_args: GeneratorArgs,
206206
profile: Optional[Path],
207207
quantize: bool,
208208
draft_quantize: bool,
209209
):
210+
super().__init__(builder_args, tokenizer_args, generator_args)
210211
self.model_name = model_name
211212
self.builder_args = builder_args
212213
self.generate_args = generator_args

0 commit comments

Comments
 (0)