Skip to content

Commit 21d93c1

Browse files
authored
Optimize Mixtral with expert parallelism (#2090)
1 parent f1c8520 commit 21d93c1

File tree

6 files changed

+221
-334
lines changed

6 files changed

+221
-334
lines changed

Dockerfile

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads
4141

4242
RUN python3 setup.py build_ext --inplace
4343

44-
# Build the megablocks library as wheel because it doesn't publish pre-built wheels.
45-
# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78
46-
RUN apt-get install -y git && \
47-
git clone https://github.com/stanford-futuredata/megablocks.git && \
48-
cd megablocks && \
49-
git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \
50-
MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel
51-
5244
# image to run unit testing suite
5345
FROM dev AS test
5446

@@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai
8577
RUN --mount=type=cache,target=/root/.cache/pip \
8678
pip install accelerate
8779

88-
COPY vllm vllm
8980
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
90-
COPY --from=build /workspace/megablocks/dist/*.whl /tmp/
91-
RUN --mount=type=cache,target=/root/.cache/pip \
92-
pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \
93-
rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl
81+
COPY vllm vllm
9482

9583
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
9684

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
7272
```bash
7373
pip install vllm
7474
```
75-
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks):
76-
```bash
77-
pip install megablocks
78-
```
7975

8076
## Getting Started
8177

docs/source/models/supported_models.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for in
7474
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
7575

7676
.. note::
77-
Currently, the ROCm version of vLLM does not support Mixtral.
78-
Additionally, it only supports Mistral for context lengths up to 4096.
77+
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
7978

8079
.. tip::
8180
The easiest way to check if your model is supported is to run the program below:

vllm/config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,16 @@ def _verify_load_format(self) -> None:
120120
if load_format == "auto":
121121
load_format = "pt"
122122

123-
# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
123+
# TODO: Remove this check once HF updates the pt weights of Mixtral.
124124
architectures = getattr(self.hf_config, "architectures", [])
125-
if "MixtralForCausalLM" in architectures and load_format != "pt":
126-
logger.info(
127-
"Currently, only 'pt' format is supported for Mixtral. "
128-
"Changing the format to 'pt'. This may re-download the "
129-
"weights if you have downloaded the safetensor weights.")
130-
load_format = "pt"
125+
if "MixtralForCausalLM" in architectures:
126+
if load_format == "pt":
127+
raise ValueError(
128+
"Currently, the 'pt' format is not supported for Mixtral. "
129+
"Please use the 'safetensors' format instead. ")
130+
elif load_format == "auto":
131+
# Do not fall back to pt weights.
132+
load_format = "safetensors"
131133

132134
self.load_format = load_format
133135

vllm/model_executor/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939
}
4040

4141
# Models not supported by ROCm.
42-
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"]
42+
_ROCM_UNSUPPORTED_MODELS = []
4343

4444
# Models partially supported by ROCm.
4545
# Architecture -> Reason.
4646
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
4747
"MistralForCausalLM":
4848
"Sliding window attention is not yet supported in ROCm's flash attention",
49+
"MixtralForCausalLM":
50+
"Sliding window attention is not yet supported in ROCm's flash attention",
4951
}
5052

5153

0 commit comments

Comments
 (0)