You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/design/plugin_system.md
+96-2Lines changed: 96 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,7 +4,7 @@ The community frequently requests the ability to extend vLLM with custom feature
4
4
5
5
## How Plugins Work in vLLM
6
6
7
-
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview](arch_overview.md)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work.
7
+
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview](arch_overview.md)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_plugins_by_group][vllm.plugins.load_plugins_by_group] function in the `vllm.plugins` module.
8
8
9
9
## How vLLM Discovers Plugins
10
10
@@ -57,6 +57,100 @@ Every plugin has three parts:
57
57
58
58
-**Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
59
59
60
+
### Platform plugins guidelines
61
+
62
+
1. Create a platform plugin project, for example, `vllm_add_dummy_platform`. The project structure should look like this:
63
+
64
+
```shell
65
+
vllm_add_dummy_platform/
66
+
├── vllm_add_dummy_platform/
67
+
│ ├── __init__.py
68
+
│ ├── my_dummy_platform.py
69
+
│ ├── my_dummy_worker.py
70
+
│ ├── my_dummy_attention.py
71
+
│ ├── my_dummy_device_communicator.py
72
+
│ ├── my_dummy_custom_ops.py
73
+
├── setup.py
74
+
```
75
+
76
+
2. In the `setup.py` file, add the following entry point:
3. Implement the platform class `MyDummyPlatform` in `my_dummy_platform.py`. The platform class should inherit from `vllm.platforms.interface.Platform`. Please follow the interface to implement the functions one by one. There are some important functions and properties that should be implemented at least:
99
+
100
+
- `_enum`: This property is the device enumeration from [PlatformEnum][vllm.platforms.interface.PlatformEnum]. Usually, it should be `PlatformEnum.OOT`, which means the platform is out-of-tree.
101
+
- `device_type`: This property should return the type of the device which pytorch uses. For example, `"cpu"`, `"cuda"`, etc.
102
+
- `device_name`: This property is set the same as `device_type` usually. It's mainly used for logging purposes.
103
+
- `check_and_update_config`: This functionis called very early in the vLLM's initialization process. It's used forplugins to update the vllm configuration. For example, the block size, graph mode config, etc, can be updatedin this function. The most important thing is that the **worker_cls** should be setin this functiontolet vLLM know which worker class to use for the worker process.
104
+
- `get_attn_backend_cls`: This functionshouldreturn the attention backend class's fully qualified name.
105
+
- `get_device_communicator_cls`: This function should return the device communicator class's fully qualified name.
106
+
107
+
4. Implement the worker class `MyDummyWorker`in`my_dummy_worker.py`. The worker class should inherit from [WorkerBase][vllm.v1.worker.worker_base.WorkerBase]. Please follow the interface to implement the functions one by one. Basically, all interfaces in the base class should be implemented, since they are called here and there in vLLM. To make sure a model can be executed, the basic functions should be implemented are:
108
+
109
+
- `init_device`: This functionis called to set up the device for the worker.
110
+
- `initialize_cache`: This functionis called to set cache config for the worker.
111
+
- `load_model`: This functionis called to load the model weights to device.
112
+
- `get_kv_cache_spaces`: This functionis called to generate the kv cache spaces for the model.
113
+
- `determine_available_memory`: This functionis called to profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs.
114
+
- `initialize_from_config`: This functionis called to allocate device KV cache with the specified kv_cache_config
115
+
- `execute_model`: This functionis called every step to inference the model.
116
+
117
+
Additional functions that can be implemented are:
118
+
119
+
- If the plugin wants to support sleep mode feature, please implement the `sleep` and `wakeup` functions.
120
+
- If the plugin wants to support graph mode feature, please implement the `compile_or_warm_up_model` function.
121
+
- If the plugin wants to support speculative decoding feature, please implement the `take_draft_token_ids` function.
122
+
- If the plugin wants to support lora feature, please implement the `add_lora`,`remove_lora`,`list_loras` and `pin_lora` functions.
123
+
- If the plugin wants to support data parallelism feature, please implement the `execute_dummy_batch` functions.
124
+
125
+
Please look at the worker base class [WorkerBase][vllm.v1.worker.worker_base.WorkerBase] for more functions that can be implemented.
126
+
127
+
5. Implement the attention backend class `MyDummyAttention`in`my_dummy_attention.py`. The attention backend class should inherit from [AttentionBackend][vllm.attention.backends.abstract.AttentionBackend]. It's used to calculate attentions with your device. Take `vllm.v1.attention.backends` as examples, it contains many attention backend implementations.
128
+
129
+
6. Implement custom ops for high performance. Most ops can be ran by pytorch native implementation, while the performance may not be good. In this case, you can implement specific custom ops for your plugins. Currently, there are kinds of custom ops vLLM supports:
130
+
131
+
- pytorch ops
132
+
there are 3 kinds of pytorch ops:
133
+
134
+
- `communicator ops`: Device communicator op. Such as all-reduce, all-gather, etc.
135
+
Please implement the device communicator class `MyDummyDeviceCommunicator` in `my_dummy_device_communicator.py`. The device communicator class should inherit from [DeviceCommunicatorBase][vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase].
136
+
- `common ops`: Common ops. Such as matmul, softmax, etc.
137
+
Please implement the common ops by register oot way. See more detail in [CustomOp][vllm.model_executor.custom_op.CustomOp] class.
138
+
- `csrc ops`: C++ ops. This kind of ops are implemented in C++ and are registered as torch custom ops.
139
+
Following csrc module and `vllm._custom_ops` to implement your ops.
140
+
141
+
- triton ops
142
+
Custom way doesn't work for triton ops now.
143
+
144
+
7. (optional) Implement other plugable modules, such as lora, graph backend, quantization, mamba attention backend, etc.
145
+
60
146
## Compatibility Guarantee
61
147
62
-
vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
148
+
vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets.
149
+
150
+
The interface for the model/module may change during vLLM's development. If you see any deprecation log info, please upgrade your plugin to the latest version.
151
+
152
+
## Deprecation announcement
153
+
154
+
!!! warning "Deprecations"
155
+
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
156
+
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
0 commit comments