Skip to content

Commit 3718637

Browse files
authored
feat: support embedding models for offline inference. (jd-opensource#318)
Signed-off-by: pengtao.156 <[email protected]>
1 parent 9bb0836 commit 3718637

File tree

10 files changed

+183
-9
lines changed

10 files changed

+183
-9
lines changed

examples/generate_embedding.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# python examples/generate_embedding.py --model='/path/models/Qwen3-8B' --devices='npu:0'
2+
# python generate_embedding.py --model='/path/models/Qwen3-8B' --devices='npu:0,npu:1'
3+
4+
from xllm import ArgumentParser, Embedding, RequestParams
5+
6+
# Create an EmbeddingLM.
7+
parser = ArgumentParser()
8+
emb = Embedding(**vars(parser.parse_args()))
9+
10+
# Create a reqeust params, include sampling params
11+
request_params = RequestParams()
12+
request_params.is_embeddings = True
13+
request_params.max_tokens = 1
14+
15+
inputs = [
16+
"Hello, my name is",
17+
"The president of the United States is",
18+
"The capital of France is",
19+
"The future of AI is",
20+
]
21+
22+
outputs = emb.embedding(inputs, request_params, True)
23+
24+
# Print the outputs.
25+
for i, output in enumerate(outputs):
26+
input_str = output.prompt
27+
generated_embedding = output.outputs[0].embeddings
28+
print(f"Input: {input_str!r}, Generated embedding: {generated_embedding!r}")
29+
30+
emb.finish()
31+

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,8 @@ def apply_patch():
611611
zip_safe=False,
612612
py_modules=["xllm/launch_xllm", "xllm/__init__",
613613
"xllm/pybind/llm", "xllm/pybind/vlm",
614-
"xllm/pybind/util", "xllm/pybind/args"],
614+
"xllm/pybind/embedding", "xllm/pybind/util",
615+
"xllm/pybind/args"],
615616
entry_points={
616617
'console_scripts': [
617618
'xllm = xllm.launch_xllm:launch_xllm'

xllm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
spec = importlib.util.spec_from_file_location("xllm_export", export_so_path)
1414
xllm_export = importlib.util.module_from_spec(spec)
1515

16+
from xllm.pybind.embedding import Embedding
1617
from xllm.pybind.llm import LLM
1718
from xllm.pybind.vlm import VLM
1819
from xllm.pybind.args import ArgumentParser
@@ -21,6 +22,7 @@
2122

2223
__all__ = [
2324
"ArgumentParser",
25+
"Embedding",
2426
"LLM",
2527
"LLMMaster",
2628
"VLM",

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
4040
int num_decoding_tokens,
4141
int block_size,
4242
bool enable_shm,
43-
bool is_local) {
43+
bool is_local,
44+
const std::string& task_type) {
4445
// TODO: pass whole xllm::runtime::Options here from main process.
4546
xllm::runtime::Options runner_options;
4647
runner_options.block_size(block_size)
@@ -49,7 +50,8 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
4950
.enable_offline_inference(true)
5051
.master_node_addr(master_node_addr)
5152
.enable_shm(enable_shm)
52-
.is_local(is_local);
53+
.is_local(is_local)
54+
.task_type(task_type);
5355
FLAGS_enable_schedule_overlap = false;
5456
FLAGS_master_node_addr = master_node_addr;
5557
FLAGS_block_size = block_size;

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class SpawnWorkerServer final {
2929
int num_decoding_tokens,
3030
int block_size,
3131
bool enable_shm,
32-
bool is_local);
32+
bool is_local,
33+
const std::string& task_type);
3334

3435
~SpawnWorkerServer() = default;
3536

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ limitations under the License.
3030
// @block_size
3131
// @enable_shm
3232
// @is_local
33+
// @task_type
3334
int main(int argc, char* argv[]) {
34-
if (argc < 9) {
35+
if (argc < 10) {
3536
LOG(ERROR)
36-
<< "Spwan worker process receive wrong args. Need 9 args, receive "
37+
<< "Spwan worker process receive wrong args. Need 10 args, receive "
3738
<< argc;
3839
return 1;
3940
}
@@ -54,16 +55,18 @@ int main(int argc, char* argv[]) {
5455
int block_size = atoi(argv[7]);
5556
int enable_shm = atoi(argv[8]);
5657
int is_local = atoi(argv[9]);
58+
std::string task_type = std::string(argv[10]);
5759

5860
LOG(INFO) << "Spwan worker: "
5961
<< "master_node_addr = " << master_node_addr
60-
<< ", is_local = " << is_local << ", local_rank = " << local_rank
62+
<< ", local_rank = " << local_rank
6163
<< ", world_size = " << world_size
6264
<< ", device_idx = " << device_idx
6365
<< ", num_decoding_tokens = " << num_decoding_tokens
6466
<< ", block_size = " << block_size
6567
<< ", enable_shm = " << (enable_shm > 0)
66-
<< ", enable_shm = " << (is_local > 0) << "\n";
68+
<< ", is_local = " << (is_local > 0)
69+
<< ", task_type = " << task_type << "\n";
6770

6871
xllm::SpawnWorkerServer worker(master_node_addr,
6972
local_rank,
@@ -73,7 +76,8 @@ int main(int argc, char* argv[]) {
7376
num_decoding_tokens,
7477
block_size,
7578
enable_shm > 0,
76-
is_local > 0);
79+
is_local > 0,
80+
task_type);
7781

7882
worker.run();
7983

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ void WorkerServer::create_spawn_server(int local_rank,
156156
block_size_ptr,
157157
enable_shm_ptr,
158158
is_local_ptr,
159+
options.task_type().c_str(),
159160
nullptr};
160161
pid_t pid;
161162
posix_spawn_file_actions_init(&file_actions_);

xllm/pybind/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ pybind_extension(
2020
gflags::gflags
2121
glog::glog
2222
Python::Module
23+
torch_python
24+
torch
25+
c10
2326
)
2427
target_link_libraries(common PRIVATE leveldb::leveldb ZLIB::ZLIB OpenSSL::SSL OpenSSL::Crypto protobuf::libprotobuf)
2528
add_dependencies(common brpc-static)

xllm/pybind/bind.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#include <pybind11/pybind11.h>
1818
#include <pybind11/stl.h>
1919
#include <pybind11/stl_bind.h>
20+
#include <torch/python.h>
2021

2122
#include "api_service/call.h"
2223
#include "core/common/options.h"

xllm/pybind/embedding.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
import signal
3+
import time
4+
from . import util
5+
from typing import List, Optional, Union
6+
7+
from xllm_export import (LLMMaster, Options, RequestOutput,
8+
RequestParams)
9+
10+
class Embedding:
11+
def __init__(
12+
self,
13+
model: str,
14+
devices: str = 'auto',
15+
block_size: int = 128,
16+
max_cache_size: int = 0,
17+
max_memory_utilization: float = 0.9,
18+
disable_prefix_cache: bool = False,
19+
max_tokens_per_batch: int = 20000,
20+
max_seqs_per_batch: int = 256,
21+
max_tokens_per_chunk_for_prefill: int = 512,
22+
num_request_handling_threads: int = 4,
23+
communication_backend: str = 'lccl',
24+
rank_tablefile: str = '',
25+
expert_parallel_degree: int = 0,
26+
enable_mla: bool = False,
27+
disable_chunked_prefill: bool = False,
28+
instance_role: str = 'DEFAULT',
29+
nnodes: int = 1,
30+
node_rank: int = 0,
31+
dp_size: int = 1,
32+
ep_size: int = 1,
33+
enable_shm: bool = False,
34+
is_local: bool = True,
35+
**kwargs,
36+
) -> None:
37+
if not os.path.exists(model):
38+
raise ValueError(f"model {model} not exists")
39+
40+
options = Options()
41+
options.model_path = model
42+
options.task_type = "embed"
43+
options.devices = devices
44+
options.draft_model_path = None
45+
options.draft_devices = None
46+
options.block_size = block_size
47+
options.max_cache_size = max_cache_size
48+
options.max_memory_utilization = max_memory_utilization
49+
if disable_prefix_cache:
50+
options.enable_prefix_cache = False
51+
else:
52+
options.enable_prefix_cache = True
53+
options.max_tokens_per_batch = max_tokens_per_batch
54+
options.max_seqs_per_batch = max_seqs_per_batch
55+
options.max_tokens_per_chunk_for_prefill = max_tokens_per_chunk_for_prefill
56+
options.num_request_handling_threads = num_request_handling_threads
57+
options.communication_backend = communication_backend
58+
options.rank_tablefile = rank_tablefile
59+
options.expert_parallel_degree = expert_parallel_degree
60+
options.enable_mla = enable_mla
61+
if disable_chunked_prefill:
62+
options.enable_chunked_prefill = False
63+
else:
64+
options.enable_chunked_prefill = True
65+
free_port = util.get_free_port()
66+
options.master_node_addr = "127.0.0.1:" + str(free_port)
67+
options.nnodes = nnodes
68+
options.node_rank = node_rank
69+
options.dp_size = dp_size
70+
options.ep_size = ep_size
71+
options.enable_disagg_pd = False
72+
options.enable_schedule_overlap = False
73+
options.enable_offline_inference = True
74+
options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
75+
options.enable_shm = enable_shm
76+
options.is_local = is_local
77+
self.master = LLMMaster(options)
78+
79+
def finish(self):
80+
try:
81+
#os.kill(os.getpid(), signal.SIGTERM)
82+
os.kill(os.getpid(), signal.SIGKILL)
83+
except Exception as e:
84+
pass
85+
86+
def embedding(
87+
self,
88+
inputs: Union[str, List[str]],
89+
request_params: Optional[Union[RequestParams, List[RequestParams]]] = None,
90+
wait_schedule_done: bool = True,
91+
) -> List[RequestOutput]:
92+
if request_params is None:
93+
request_params = RequestParams()
94+
if isinstance(inputs, str):
95+
inputs = [inputs]
96+
if isinstance(request_params, RequestParams):
97+
request_params.is_embeddings = True
98+
request_params = [request_params]
99+
else:
100+
for i in range(len(request_params)):
101+
request_params[i].is_embeddings = True
102+
103+
outputs = [None] * len(inputs)
104+
def callback(index: int, output: RequestOutput) -> bool:
105+
outputs[index] = output
106+
return True
107+
108+
# schedule all requests
109+
self.master.handle_batch_request(
110+
inputs, request_params, callback
111+
)
112+
113+
# TODO: add wait later
114+
if wait_schedule_done:
115+
pass
116+
117+
# generate
118+
self.master.generate()
119+
120+
# wait async output
121+
for i in range(len(outputs)):
122+
while outputs[i] is None:
123+
time.sleep(0.01)
124+
if outputs[i].status is not None and not outputs[i].status.ok:
125+
raise ValidationError(outputs[i].status.code, outputs[i].status.message)
126+
outputs[i].prompt = inputs[i]
127+
128+
return outputs

0 commit comments

Comments
 (0)