Skip to content

Commit 0b6544e

Browse files
authored
feat: add vlm python interface to support offline inference. (jd-opensource#236)
Signed-off-by: pengtao.156 <[email protected]>
1 parent a9e34b1 commit 0b6544e

File tree

14 files changed

+412
-30
lines changed

14 files changed

+412
-30
lines changed

examples/__init__.py

Whitespace-only changes.

examples/generate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0'
2+
3+
from xllm import ArgumentParser, LLM, RequestParams
4+
5+
# Create an LLM.
6+
parser = ArgumentParser()
7+
llm = LLM(**vars(parser.parse_args()))
8+
9+
# Create a reqeust params, include sampling params
10+
request_params = RequestParams()
11+
request_params.temperature = 0.8
12+
request_params.top_p = 0.95
13+
request_params.max_tokens = 10
14+
15+
# Generate texts from the prompts. The output is a list of RequestOutput
16+
# objects that contain the prompt, generated text, and other information.
17+
prompts = [
18+
"Hello, my name is",
19+
"The president of the United States is",
20+
"The capital of France is",
21+
"The future of AI is",
22+
]
23+
24+
outputs = llm.generate(prompts, request_params, True)
25+
26+
# Print the outputs.
27+
for i, output in enumerate(outputs):
28+
prompt = output.prompt
29+
generated_text = output.outputs[0].text
30+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
31+
32+
llm.finish()
33+

examples/generate_vlm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0' --master_node_addr=127.0.0.1:8888
2+
3+
import os
4+
import signal
5+
from xllm import ArgumentParser, VLM, RequestParams, MMChatMessage, MMInputData
6+
7+
# Create an VLM.
8+
parser = ArgumentParser()
9+
vlm = VLM(**vars(parser.parse_args()))
10+
11+
# Create a reqeust params, include sampling params
12+
request_params = RequestParams()
13+
request_params.temperature = 0.8
14+
request_params.top_p = 0.95
15+
request_params.max_tokens = 100
16+
17+
# input_data
18+
mm_input_data1 = MMInputData()
19+
mm_input_data1.type = 'text'
20+
mm_input_data1.text = 'Please briefly introduce this picture.'
21+
mm_input_data2 = MMInputData()
22+
mm_input_data2.type = 'image_url'
23+
mm_input_data2.image_url = 'https://img2.baidu.com/it/u=2376489989,3127732063&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=657'
24+
mm_chat_msg = MMChatMessage()
25+
mm_chat_msg.role = 'user'
26+
mm_chat_msg.content = [mm_input_data1, mm_input_data2]
27+
28+
output = vlm.generate(mm_chat_msg, request_params, True)
29+
30+
prompt = output.prompt
31+
generated_text = output.outputs[0].text
32+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
33+
34+
vlm.finish()
35+

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def apply_patch():
593593
},
594594
zip_safe=False,
595595
py_modules=["xllm/launch_xllm", "xllm/__init__",
596-
"xllm/pybind/llm", "xllm/pybind/args"],
596+
"xllm/pybind/llm", "xllm/pybind/vlm", "xllm/pybind/args"],
597597
entry_points={
598598
'console_scripts': [
599599
'xllm = xllm.launch_xllm:launch_xllm'

xllm/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
xllm_export = importlib.util.module_from_spec(spec)
1515

1616
from xllm.pybind.llm import LLM
17+
from xllm.pybind.vlm import VLM
1718
from xllm.pybind.args import ArgumentParser
1819
from xllm_export import (LLMMaster, Options, RequestParams, RequestOutput,
19-
SequenceOutput, Status, StatusCode)
20+
SequenceOutput, Status, StatusCode, MMChatMessage, MMInputData)
2021

2122
__all__ = [
2223
"ArgumentParser",
2324
"LLM",
2425
"LLMMaster",
26+
"VLM",
27+
"VLMMaster"
2528
"Options",
2629
"RequestParams",
2730
"RequestOutput",
2831
"SequenceOutput",
2932
"Status",
3033
"StatusCode",
3134
]
32-

xllm/core/common/types.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,17 @@ struct EplbInfo {
274274
int32_t update_layer_id = -1;
275275
};
276276

277+
struct MMInputData {
278+
std::string type = "";
279+
std::string text = "";
280+
std::string image_url = "";
281+
std::string video_url = "";
282+
std::string audio_url = "";
283+
};
284+
285+
struct MMChatMessage {
286+
std::string role = "";
287+
std::vector<MMInputData> content;
288+
};
289+
277290
} // namespace xllm

xllm/core/framework/request/mm_input_helper.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,22 @@ class Handler {
147147
return true;
148148
}
149149

150+
bool process(const MMInputData& msg, MMInputItem& input) {
151+
if (!this->load(msg, input)) {
152+
LOG(ERROR) << " load mm data failed";
153+
return false;
154+
}
155+
156+
if (!this->decode(input)) {
157+
LOG(ERROR) << " decode mm data failed";
158+
return false;
159+
}
160+
161+
return true;
162+
}
163+
150164
virtual bool load(const proto::MMInputData& msg, MMInputItem& input) = 0;
165+
virtual bool load(const MMInputData& msg, MMInputItem& input) = 0;
151166
virtual bool decode(MMInputItem& input) = 0;
152167

153168
protected:
@@ -197,6 +212,23 @@ class ImageHandler : public Handler {
197212
}
198213
}
199214

215+
virtual bool load(const MMInputData& msg, MMInputItem& input) {
216+
input.clear();
217+
218+
const auto& url = msg.image_url;
219+
if (url.compare(0, dataurl_prefix_.size(), dataurl_prefix_) ==
220+
0) { // data url
221+
222+
input.type_ = MMType::IMAGE;
223+
return this->load_from_dataurl(url, input.raw_data_);
224+
} else if (url.compare(0, httpurl_prefix_.size(), httpurl_prefix_) ==
225+
0) { // http url
226+
227+
input.type_ = MMType::IMAGE;
228+
return this->load_from_http(url, input.raw_data_);
229+
}
230+
}
231+
200232
virtual bool decode(MMInputItem& input) {
201233
OpenCVImageDecoder decoder;
202234
return decoder.decode(input.raw_data_, input.decode_data_);
@@ -223,6 +255,18 @@ class MMHandlerSet {
223255
return handler->process(msg, input);
224256
}
225257

258+
bool process(const std::string& type,
259+
const MMInputData& msg,
260+
MMInputItem& input) {
261+
auto itor = handlers_.find(type);
262+
if (itor == handlers_.end()) {
263+
return false;
264+
}
265+
266+
auto& handler = itor->second;
267+
return handler->process(msg, input);
268+
}
269+
226270
private:
227271
std::unordered_map<std::string, std::unique_ptr<Handler>> handlers_;
228272
};
@@ -259,6 +303,32 @@ bool MMInputHelper::trans(const MMChatMessageVec& vec,
259303
return true;
260304
}
261305

306+
bool MMInputHelper::trans(const std::vector<MMChatMessage>& raw_input_data,
307+
std::vector<Message>& messages,
308+
MMInputItemVec& inputs) {
309+
messages.clear();
310+
inputs.clear();
311+
messages.reserve(raw_input_data.size());
312+
inputs.reserve(raw_input_data.size());
313+
314+
for (int idx = 0; idx < raw_input_data.size(); ++idx) {
315+
const auto& chat = raw_input_data[idx];
316+
const auto& role = chat.role;
317+
const auto& content = chat.content;
318+
319+
Message::MMContentVec mmc;
320+
MMInputItemVec ins;
321+
if (!this->trans(content, mmc, ins)) {
322+
return false;
323+
}
324+
325+
messages.emplace_back(role, mmc);
326+
inputs.insert(inputs.end(), ins.begin(), ins.end());
327+
}
328+
329+
return true;
330+
}
331+
262332
bool MMInputHelper::trans(const MMInputDataVec& vec,
263333
Message::MMContentVec& mmc,
264334
MMInputItemVec& inputs) {
@@ -285,4 +355,30 @@ bool MMInputHelper::trans(const MMInputDataVec& vec,
285355
return true;
286356
}
287357

358+
bool MMInputHelper::trans(const std::vector<MMInputData>& vec,
359+
Message::MMContentVec& mmc,
360+
MMInputItemVec& inputs) {
361+
mmc.clear();
362+
inputs.clear();
363+
364+
for (int idx = 0; idx < vec.size(); ++idx) {
365+
const auto& item = vec[idx];
366+
const auto& type = item.type;
367+
368+
if (type == "text") {
369+
mmc.emplace_back(type, item.text);
370+
} else {
371+
MMInputItem input;
372+
if (!mm_handlers_->process(type, item, input)) {
373+
return false;
374+
}
375+
376+
mmc.emplace_back(type);
377+
inputs.emplace_back(input);
378+
}
379+
}
380+
381+
return true;
382+
}
383+
288384
} // namespace xllm

xllm/core/framework/request/mm_input_helper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <string>
1919
#include <vector>
2020

21+
#include "core/common/types.h"
2122
#include "core/framework/chat_template/jinja_chat_template.h"
2223
#include "mm_data.h"
2324
#include "multimodal.pb.h"
@@ -69,11 +70,19 @@ class MMInputHelper {
6970
std::vector<Message>& messages,
7071
MMInputItemVec& inputs);
7172

73+
bool trans(const std::vector<MMChatMessage>& raw_input_data,
74+
std::vector<Message>& messages,
75+
MMInputItemVec& inputs);
76+
7277
private:
7378
bool trans(const MMInputDataVec& vec,
7479
Message::MMContentVec& mmc,
7580
MMInputItemVec& input);
7681

82+
bool trans(const std::vector<MMInputData>& vec,
83+
Message::MMContentVec& mmc,
84+
MMInputItemVec& input);
85+
7786
std::unique_ptr<MMHandlerSet> mm_handlers_;
7887
};
7988

xllm/core/runtime/llm_master.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ void LLMMaster::handle_batch_request(std::vector<std::string> prompts,
131131
<< "Number of prompts and sampling parameters should be the same";
132132

133133
const size_t num_requests = prompts.size();
134-
scheduler_->incr_pending_requests(num_requests);
135134
for (size_t i = 0; i < num_requests; ++i) {
136135
handle_request(std::move(prompts[i]),
137136
std::nullopt,
@@ -153,7 +152,6 @@ void LLMMaster::handle_batch_request(
153152
<< "Number of conversations and sampling parameters should be the same";
154153

155154
const size_t num_requests = conversations.size();
156-
scheduler_->incr_pending_requests(num_requests);
157155
for (size_t i = 0; i < num_requests; ++i) {
158156
handle_request(std::move(conversations[i]),
159157
std::nullopt,
@@ -173,8 +171,10 @@ void LLMMaster::handle_request(std::string prompt,
173171
std::optional<Call*> call,
174172
OutputCallback callback) {
175173
scheduler_->incr_pending_requests(1);
176-
auto cb = [callback = std::move(callback)](const RequestOutput& output) {
174+
auto cb = [callback = std::move(callback),
175+
scheduler = scheduler_.get()](const RequestOutput& output) {
177176
output.log_request_status();
177+
scheduler->decr_pending_requests();
178178
return callback(output);
179179
};
180180
// add into the queue
@@ -186,9 +186,6 @@ void LLMMaster::handle_request(std::string prompt,
186186
call]() mutable {
187187
AUTO_COUNTER(request_handling_latency_seconds_completion);
188188

189-
// remove the pending request after scheduling
190-
SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); });
191-
192189
Timer timer;
193190
// verify the prompt
194191
if (!sp.verify_params(callback)) {
@@ -214,8 +211,10 @@ void LLMMaster::handle_request(std::vector<Message> messages,
214211
std::optional<Call*> call,
215212
OutputCallback callback) {
216213
scheduler_->incr_pending_requests(1);
217-
auto cb = [callback = std::move(callback)](const RequestOutput& output) {
214+
auto cb = [callback = std::move(callback),
215+
scheduler = scheduler_.get()](const RequestOutput& output) {
218216
output.log_request_status();
217+
scheduler->decr_pending_requests();
219218
return callback(output);
220219
};
221220
// add into the queue
@@ -226,8 +225,6 @@ void LLMMaster::handle_request(std::vector<Message> messages,
226225
callback = std::move(cb),
227226
call]() mutable {
228227
AUTO_COUNTER(request_handling_latency_seconds_chat);
229-
// remove the pending request after scheduling
230-
SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); });
231228

232229
// verify the prompt
233230
if (!sp.verify_params(callback)) {

0 commit comments

Comments
 (0)