Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3ef1296

Browse files
committed
Address some nits from PR review
1 parent 76895cc commit 3ef1296

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

torchchat/usages/openai_api.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dataclasses import dataclass
1414
from io import BytesIO
1515
from pwd import getpwuid
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Dict, List, Optional, Union, Type
1717

1818
import torch
1919

@@ -486,18 +486,13 @@ def _callback(self, x, *, buffer, done_generating):
486486
pass
487487

488488

489-
def create_openai_api_generator(distributed):
489+
def create_openai_api_generator(distributed: bool) -> Type:
490490
"""
491491
Factory method to create an OpenAiApiGenerator
492492
"""
493493

494-
if distributed:
495-
# Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator
496-
return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator), {})
497-
else:
498-
return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, LocalGenerator), {})
499-
500-
494+
# Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator
495+
return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator if distributed else LocalGenerator), {})
501496

502497

503498
"""

torchchat/usages/server.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,11 @@ def create_app(args): # noqa: C901
6767

6868
builder_args = BuilderArgs.from_args(args)
6969
procs = []
70+
queue = None
7071
if builder_args.distributed:
7172
world_size = builder_args.tp * builder_args.pp
7273
mp_context = mp.get_context('spawn')
7374
queue = mp_context.Queue()
74-
else:
75-
world_size = 1
76-
queue = None
77-
7875

7976
if builder_args.distributed:
8077
for i in range(1, world_size):

0 commit comments

Comments
 (0)