Skip to content

Commit 7e25430

Browse files
authored
feat: update vlm offline interface. (jd-opensource#356)
1 parent a995889 commit 7e25430

File tree

5 files changed

+194
-120
lines changed

5 files changed

+194
-120
lines changed

examples/generate_vlm.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,62 @@
1-
# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0'
2-
# python generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0,npu:1'
1+
# python generate_vlm.py --model /path/to/Qwen2.5-VL-7B-Instruct/ --disable_prefix_cache --disable_chunked_prefill --max_seqs_per_batch 4
32

43
import os
54
import signal
6-
from xllm import ArgumentParser, VLM, RequestParams, MMChatMessage, MMInputData
5+
6+
from xllm import ArgumentParser, VLM, RequestParams
7+
from xllm_export import MMType, MMData
8+
9+
from PIL import Image
10+
from transformers import AutoImageProcessor
711

812
# Create an VLM.
913
parser = ArgumentParser()
10-
vlm = VLM(**vars(parser.parse_args()))
14+
args = parser.parse_args()
15+
16+
vlm = VLM(**vars(args))
17+
processor = AutoImageProcessor.from_pretrained(args.model, trust_remote_code=True)
18+
19+
questions = ["简单介绍下图片"]
20+
prompts = [
21+
(
22+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
23+
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
24+
f"{question}<|im_end|>\n"
25+
"<|im_start|>assistant\n"
26+
)
27+
for question in questions
28+
]
29+
30+
paths = ["00307664d4ce393b.png"]
31+
images = []
32+
for path in paths:
33+
images.append(Image.open(path).convert("RGB"))
34+
35+
multi_modal_datas = []
36+
for idx in range(len(images)):
37+
print(f"Processing image: {paths[idx]}")
38+
image = images[idx]
39+
40+
data = processor.preprocess([image], return_tensors="pt").data
41+
mm_data = {
42+
"pixel_values": data['pixel_values'],
43+
"image_grid_thw": data['image_grid_thw'],
44+
}
45+
multi_modal_datas.append(MMData(MMType.IMAGE, mm_data))
46+
1147

1248
# Create a reqeust params, include sampling params
1349
request_params = RequestParams()
14-
request_params.temperature = 0.8
15-
request_params.top_p = 0.95
16-
request_params.max_tokens = 100
17-
18-
# input_data
19-
mm_input_data1 = MMInputData()
20-
mm_input_data1.type = 'text'
21-
mm_input_data1.text = 'Please briefly introduce this picture.'
22-
mm_input_data2 = MMInputData()
23-
mm_input_data2.type = 'image_url'
24-
mm_input_data2.image_url = 'https://img2.baidu.com/it/u=2376489989,3127732063&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=657'
25-
mm_chat_msg = MMChatMessage()
26-
mm_chat_msg.role = 'user'
27-
mm_chat_msg.content = [mm_input_data1, mm_input_data2]
28-
29-
output = vlm.generate(mm_chat_msg, request_params, True)
30-
31-
prompt = output.prompt
32-
generated_text = output.outputs[0].text
33-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
50+
request_params.temperature = 0
51+
request_params.max_tokens = 1024
52+
53+
outputs = vlm.generate(prompts, multi_modal_datas, request_params, True)
54+
55+
for output in outputs:
56+
prompt = output.prompt
57+
generated_text = output.outputs[0].text
58+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
3459

3560
vlm.finish()
3661

62+

xllm/core/runtime/vlm_master.cpp

Lines changed: 68 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -124,74 +124,6 @@ VLMMaster::~VLMMaster() {
124124
}
125125
}
126126

127-
void VLMMaster::handle_request(const std::vector<Message>& messages,
128-
const MMInput& mm_inputs,
129-
RequestParams sp,
130-
OutputCallback callback) {
131-
MMData mm_data;
132-
if (!mm_inputs.empty() && !image_processor_->process(mm_inputs, mm_data)) {
133-
LOG(ERROR) << " image processor process failed";
134-
}
135-
136-
this->handle_request(messages, mm_data, sp, callback);
137-
}
138-
139-
void VLMMaster::handle_batch_request(const std::vector<std::string>& prompts,
140-
const std::vector<MMData>& mm_datas,
141-
std::vector<RequestParams> sps,
142-
BatchOutputCallback callback) {
143-
CHECK(prompts.size() == sps.size() || sps.size() == 1)
144-
<< "Number of prompts and sampling parameters should be the same";
145-
146-
const size_t num_requests = prompts.size();
147-
for (size_t i = 0; i < num_requests; ++i) {
148-
handle_request(std::move(prompts[i]),
149-
std::move(mm_datas[i]),
150-
// the sampling parameter may be shared
151-
sps.size() == 1 ? sps[0] : std::move(sps[i]),
152-
[i, callback](const RequestOutput& output) {
153-
output.log_request_status();
154-
return callback(i, output);
155-
});
156-
}
157-
}
158-
159-
void VLMMaster::handle_batch_request(
160-
const std::vector<std::vector<Message>>& conversations,
161-
const std::vector<MMData>& mm_datas,
162-
std::vector<RequestParams> sps,
163-
BatchOutputCallback callback) {
164-
CHECK(conversations.size() == sps.size() || sps.size() == 1)
165-
<< "Number of conversations and sampling parameters should be the same";
166-
167-
const size_t num_requests = conversations.size();
168-
for (size_t i = 0; i < num_requests; ++i) {
169-
handle_request(std::move(conversations[i]),
170-
std::move(mm_datas[i]),
171-
// the sampling parameter may be shared
172-
sps.size() == 1 ? sps[0] : std::move(sps[i]),
173-
[i, callback](const RequestOutput& output) {
174-
output.log_request_status();
175-
return callback(i, output);
176-
});
177-
}
178-
}
179-
180-
void VLMMaster::handle_request(const std::vector<MMChatMessage>& raw_input_data,
181-
RequestParams sp,
182-
OutputCallback callback) {
183-
static MMInputHelper helper;
184-
std::vector<Message> messages;
185-
MMInput mm_inputs;
186-
187-
if (!helper.trans(raw_input_data, messages, mm_inputs.items_)) {
188-
LOG(ERROR) << "MMInputHelper trans failed, ingnore this input.";
189-
return;
190-
}
191-
192-
handle_request(std::move(messages), std::move(mm_inputs), sp, callback);
193-
}
194-
195127
void VLMMaster::handle_request(const std::string& prompt,
196128
const MMData& mm_data,
197129
RequestParams sp,
@@ -232,6 +164,18 @@ void VLMMaster::handle_request(const std::string& prompt,
232164
});
233165
}
234166

167+
void VLMMaster::handle_request(const std::vector<Message>& messages,
168+
const MMInput& mm_inputs,
169+
RequestParams sp,
170+
OutputCallback callback) {
171+
MMData mm_data;
172+
if (!mm_inputs.empty() && !image_processor_->process(mm_inputs, mm_data)) {
173+
LOG(ERROR) << " image processor process failed";
174+
}
175+
176+
this->handle_request(messages, mm_data, sp, callback);
177+
}
178+
235179
void VLMMaster::handle_request(const std::vector<Message>& messages,
236180
const MMData& mm_data,
237181
RequestParams sp,
@@ -270,6 +214,62 @@ void VLMMaster::handle_request(const std::vector<Message>& messages,
270214
});
271215
}
272216

217+
void VLMMaster::handle_request(const std::vector<MMChatMessage>& raw_input_data,
218+
RequestParams sp,
219+
OutputCallback callback) {
220+
static MMInputHelper helper;
221+
std::vector<Message> messages;
222+
MMInput mm_inputs;
223+
224+
if (!helper.trans(raw_input_data, messages, mm_inputs.items_)) {
225+
LOG(ERROR) << "MMInputHelper trans failed, ingnore this input.";
226+
return;
227+
}
228+
229+
handle_request(std::move(messages), std::move(mm_inputs), sp, callback);
230+
}
231+
232+
void VLMMaster::handle_batch_request(const std::vector<std::string>& prompts,
233+
const std::vector<MMData>& mm_datas,
234+
const std::vector<RequestParams>& sps,
235+
BatchOutputCallback callback) {
236+
CHECK(prompts.size() == sps.size() || sps.size() == 1)
237+
<< "Number of prompts and sampling parameters should be the same";
238+
239+
const size_t num_requests = prompts.size();
240+
for (size_t i = 0; i < num_requests; ++i) {
241+
handle_request(std::move(prompts[i]),
242+
std::move(mm_datas[i]),
243+
// the sampling parameter may be shared
244+
sps.size() == 1 ? sps[0] : std::move(sps[i]),
245+
[i, callback](const RequestOutput& output) {
246+
output.log_request_status();
247+
return callback(i, output);
248+
});
249+
}
250+
}
251+
252+
void VLMMaster::handle_batch_request(
253+
const std::vector<std::vector<Message>>& conversations,
254+
const std::vector<MMData>& mm_datas,
255+
const std::vector<RequestParams>& sps,
256+
BatchOutputCallback callback) {
257+
CHECK(conversations.size() == sps.size() || sps.size() == 1)
258+
<< "Number of conversations and sampling parameters should be the same";
259+
260+
const size_t num_requests = conversations.size();
261+
for (size_t i = 0; i < num_requests; ++i) {
262+
handle_request(std::move(conversations[i]),
263+
std::move(mm_datas[i]),
264+
// the sampling parameter may be shared
265+
sps.size() == 1 ? sps[0] : std::move(sps[i]),
266+
[i, callback](const RequestOutput& output) {
267+
output.log_request_status();
268+
return callback(i, output);
269+
});
270+
}
271+
}
272+
273273
void VLMMaster::run() {
274274
const bool already_running = running_.load(std::memory_order_relaxed);
275275
if (already_running) {

xllm/core/runtime/vlm_master.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class VLMMaster : public Master {
4646
explicit VLMMaster(const Options& options);
4747
~VLMMaster();
4848

49+
// completion
4950
void handle_request(const std::string& prompt,
5051
const MMData& mm_data,
5152
RequestParams sp,
@@ -71,14 +72,14 @@ class VLMMaster : public Master {
7172
// batch completion
7273
void handle_batch_request(const std::vector<std::string>& prompts,
7374
const std::vector<MMData>& mm_datas,
74-
std::vector<RequestParams> sps,
75+
const std::vector<RequestParams>& sps,
7576
BatchOutputCallback callback);
7677

7778
// batch chat
7879
void handle_batch_request(
7980
const std::vector<std::vector<Message>>& conversations,
8081
const std::vector<MMData>& mm_datas,
81-
std::vector<RequestParams> sps,
82+
const std::vector<RequestParams>& sps,
8283
BatchOutputCallback callback);
8384

8485
// start the handling loop

xllm/pybind/bind.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "api_service/call.h"
2323
#include "core/common/options.h"
2424
#include "core/common/types.h"
25+
#include "core/framework/request/mm_data.h"
2526
#include "core/framework/request/request_output.h"
2627
#include "core/framework/request/request_params.h"
2728
#include "core/runtime/llm_master.h"
@@ -232,6 +233,43 @@ PYBIND11_MODULE(xllm_export, m) {
232233
.def_readwrite("role", &MMChatMessage::role)
233234
.def_readwrite("content", &MMChatMessage::content);
234235

236+
// 10. export MMType
237+
py::enum_<MMType::Value>(m, "MMType")
238+
.value("NONE", MMType::Value::NONE)
239+
.value("IMAGE", MMType::Value::IMAGE)
240+
.value("VIDEO", MMType::Value::VIDEO)
241+
.value("AUDIO", MMType::Value::AUDIO)
242+
.value("EMBEDDING", MMType::EMBEDDING)
243+
.export_values();
244+
245+
// 11. export MMData
246+
py::class_<MMData>(m, "MMData")
247+
.def(py::init<int, const MMDict&>(), py::arg("ty"), py::arg("data"))
248+
.def("get",
249+
[](const MMData& self, const MMKey& key) -> py::object {
250+
auto value = self.get<torch::Tensor>(key);
251+
if (value.has_value()) {
252+
return py::cast(value.value());
253+
}
254+
return py::none();
255+
})
256+
.def("get_list",
257+
[](const MMData& self, const MMKey& key) -> py::object {
258+
auto value = self.get<std::vector<torch::Tensor>>(key);
259+
if (value.has_value()) {
260+
return py::cast(value.value());
261+
}
262+
return py::none();
263+
})
264+
.def_readwrite("ty", &MMData::ty_)
265+
.def_readwrite("data", &MMData::data_)
266+
.def("__repr__", [](const MMData& self) {
267+
std::stringstream ss;
268+
ss << "MMData(" << static_cast<int>(self.ty_) << ": "
269+
<< self.data_.size() << " items)";
270+
return ss.str();
271+
});
272+
235273
// 10. export VLMMaster
236274
py::class_<VLMMaster>(m, "VLMMaster")
237275
.def(py::init<const Options&>(),
@@ -242,6 +280,13 @@ PYBIND11_MODULE(xllm_export, m) {
242280
RequestParams,
243281
OutputCallback>(&VLMMaster::handle_request),
244282
py::call_guard<py::gil_scoped_release>())
283+
.def("handle_batch_request",
284+
py::overload_cast<const std::vector<std::string>&,
285+
const std::vector<MMData>&,
286+
const std::vector<RequestParams>&,
287+
BatchOutputCallback>(
288+
&VLMMaster::handle_batch_request),
289+
py::call_guard<py::gil_scoped_release>())
245290
.def("generate",
246291
&VLMMaster::generate,
247292
py::call_guard<py::gil_scoped_release>())

0 commit comments

Comments
 (0)