Skip to content

Commit 385cae5

Browse files
authored
Support VLM serving with MLX (#10)
1 parent 4105a45 commit 385cae5

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ python -m paroquant.cli.chat --model $MODEL
4747
python -m paroquant.cli.serve --model $MODEL --port 8000
4848
```
4949

50+
Add `--llm-only` if you do not wish to load the VLM components.
51+
5052
### Agent with Tool Calling
5153

5254
Start the API server first, then install the agent dependencies and run:

assets/model_card.jinja

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ python -m paroquant.cli.chat --model {{ paro_model_path }}
4949
python -m paroquant.cli.serve --model {{ paro_model_path }} --port 8000
5050
```
5151

52+
Add `--llm-only` if you do not wish to load the VLM components.
53+
5254
{% if supports_tool_call -%}
5355
### Agent with Tool Calling
5456

paroquant/cli/serve.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,73 @@ def _serve_vllm():
2020

2121

2222
def _serve_mlx():
23-
import mlx_lm.server
24-
from mlx_lm.utils import load_tokenizer
23+
import os
24+
import sys
2525

2626
from paroquant.inference.backends.mlx.load import load as paro_load
2727

28-
def _patched_load(path_or_hf_repo, tokenizer_config=None, adapter_path=None, **kwargs):
29-
model, _, _ = paro_load(path_or_hf_repo, force_text=True)
30-
tokenizer = load_tokenizer(path_or_hf_repo, tokenizer_config_extra=tokenizer_config)
31-
tokenizer._tool_call_start = None
32-
tokenizer._tool_call_end = None
33-
return model, tokenizer
28+
original_argv = list(sys.argv)
29+
model_arg = None
30+
llm_only = False
31+
stripped_argv = [original_argv[0]]
32+
i = 1
33+
while i < len(original_argv):
34+
arg = original_argv[i]
35+
if arg == "--model":
36+
if i + 1 >= len(original_argv):
37+
raise ValueError("--model expects a value")
38+
model_arg = original_argv[i + 1]
39+
i += 2
40+
continue
41+
if arg.startswith("--model="):
42+
model_arg = arg.split("=", 1)[1]
43+
i += 1
44+
continue
45+
if arg == "--llm-only":
46+
llm_only = True
47+
i += 1
48+
continue
49+
stripped_argv.append(arg)
50+
i += 1
51+
52+
if not model_arg:
53+
model_arg = os.environ.get("MODEL")
54+
if not model_arg:
55+
raise ValueError("Model path is required (use --model or MODEL environment variable).")
56+
57+
model, processor, is_vlm = paro_load(model_arg, force_text=llm_only)
58+
59+
if is_vlm:
60+
import mlx_vlm.server as mlx_server
61+
62+
os.environ["MODEL"] = model_arg
63+
sys.argv = stripped_argv
64+
65+
def _patched_load(path_or_hf_repo, *args, **kwargs):
66+
return model, processor
67+
68+
_uvicorn_run = mlx_server.uvicorn.run
69+
70+
def _run_no_reload(*args, **kwargs):
71+
kwargs["reload"] = False
72+
return _uvicorn_run(*args, **kwargs)
73+
74+
mlx_server.uvicorn.run = _run_no_reload
75+
else:
76+
import mlx_lm.server as mlx_server
77+
78+
tokenizer = getattr(processor, "tokenizer", processor)
79+
if hasattr(tokenizer, "_tool_call_start"):
80+
tokenizer._tool_call_start = None
81+
if hasattr(tokenizer, "_tool_call_end"):
82+
tokenizer._tool_call_end = None
83+
sys.argv = stripped_argv
84+
85+
def _patched_load(path_or_hf_repo, tokenizer_config=None, adapter_path=None, **kwargs):
86+
return model, tokenizer
3487

35-
mlx_lm.server.load = _patched_load
36-
mlx_lm.server.main()
88+
mlx_server.load = _patched_load
89+
mlx_server.main()
3790

3891

3992
def main():

paroquant/inference/backends/mlx/load.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def load(model_path: str, lazy: bool = False, force_text: bool = False) -> tuple
176176
weights = _convert_autoawq(weights, group_size)
177177
if hasattr(model, "sanitize"):
178178
weights = model.sanitize(weights)
179+
if is_vlm and hasattr(model, "vision_tower") and hasattr(model.vision_tower, "sanitize"):
180+
weights = model.vision_tower.sanitize(weights)
179181

180182
_patch_rotation_layers(model, weights, bits, group_size)
181183
model.load_weights(list(weights.items()), strict=False)

0 commit comments

Comments
 (0)