Skip to content

Commit 2fbcfa4

Browse files
committed
feat: refine dit modules.
1 parent ae7eff1 commit 2fbcfa4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+598
-600
lines changed

xllm/api_service/api_service_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ limitations under the License.
1919
#include <memory>
2020

2121
#include "call.h"
22-
#include "core/runtime/dit_master.h"
2322
#include "core/runtime/llm_master.h"
23+
2424
namespace xllm {
2525

2626
template <typename T>

xllm/api_service/image_generation_service_impl.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ bool send_result_to_client_brpc(std::shared_ptr<ImageGenerationCall> call,
4444
proto_output->mutable_results()->Reserve(outputs.size());
4545
for (const auto& output : outputs) {
4646
auto* proto_result = proto_output->add_results();
47-
// proto_result->set_base64(output.image_tensor); // TODO proto tensor to
48-
// base64
47+
48+
// proto_result->set_image(output.image);
4949
proto_result->set_width(output.width);
5050
proto_result->set_height(output.height);
5151
proto_result->set_seed(output.seed);
@@ -73,12 +73,11 @@ void ImageGenerationServiceImpl::process_async(
7373
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
7474
return;
7575
}
76-
// create RequestParams for image generation request
77-
// set is_image_generation and max_tokens = 1 to control engine step once.
76+
77+
// create DiTRequestParams for image generation request
7878
DiTRequestParams request_params(
7979
rpc_request, call->get_x_request_id(), call->get_x_request_time());
80-
// TODO only support input_str for now
81-
auto& input = rpc_request.input().prompt();
80+
8281
// schedule the request
8382
master_->handle_request(
8483
std::move(request_params),
@@ -88,16 +87,13 @@ void ImageGenerationServiceImpl::process_async(
8887
request_id = request_params.request_id,
8988
created_time = absl::ToUnixSeconds(absl::Now())](
9089
const DiTRequestOutput& req_output) -> bool {
91-
LOG(INFO) << "into callback before request finished";
92-
LOG(INFO) << req_output.outputs.size();
93-
LOG(INFO) << req_output.outputs[0].image_tensor;
9490
if (req_output.status.has_value()) {
9591
const auto& status = req_output.status.value();
9692
if (!status.ok()) {
9793
return call->finish_with_error(status.code(), status.message());
9894
}
9995
}
100-
LOG(INFO) << "into callback after request finished";
96+
10197
return send_result_to_client_brpc(
10298
call, request_id, created_time, model, req_output);
10399
});

xllm/api_service/image_generation_service_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ limitations under the License.
1616
#pragma once
1717
#include <absl/container/flat_hash_set.h>
1818

19-
#include "api_service/api_service_impl.h"
20-
#include "api_service/call.h"
2119
#include "api_service/non_stream_call.h"
2220
#include "image_generation.pb.h"
2321

xllm/core/framework/batch/batch.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License.
2121
#include <limits>
2222
#include <vector>
2323

24-
#include "framework/request/dit_request_params.h"
2524
#include "framework/request/mm_data.h"
2625
#include "framework/request/request.h"
2726
#include "framework/request/sequence.h"

xllm/core/framework/batch/dit_batch.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ limitations under the License.
2424
namespace xllm {
2525

2626
DiTForwardInput xllm::DiTBatch::prepare_forward_input() {
27+
CHECK(!dit_request_vec_.empty());
28+
2729
DiTForwardInput forward_input;
28-
if (dit_request_data_vec_.empty()) {
29-
return forward_input;
30-
}
31-
forward_input.input_params = dit_request_data_vec_[0].input_params;
32-
forward_input.generation_params = dit_request_data_vec_[0].generation_params;
30+
forward_input.input_params = dit_request_vec_[0]->state().input_params();
31+
forward_input.generation_params =
32+
dit_request_vec_[0]->state().generation_params();
33+
3334
return forward_input;
3435
}
3536

36-
} // namespace xllm
37+
} // namespace xllm

xllm/core/framework/batch/dit_batch.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,25 @@ limitations under the License.
2121
#include <limits>
2222
#include <vector>
2323

24-
#include "framework/request/dit_request_params.h"
24+
#include "framework/request/dit_request.h"
2525
#include "runtime/dit_forward_params.h"
2626

2727
namespace xllm {
2828

2929
struct DiTBatch {
3030
public:
3131
DiTBatch() = default;
32-
void add(const DiTRequestParams& dit_request_state) {
33-
dit_request_data_vec_.emplace_back(dit_request_state);
32+
void add(const std::shared_ptr<DiTRequest>& request) {
33+
dit_request_vec_.emplace_back(request);
3434
}
35-
size_t size() const { return dit_request_data_vec_.size(); }
36-
bool empty() const { return dit_request_data_vec_.empty(); }
35+
size_t size() const { return dit_request_vec_.size(); }
36+
bool empty() const { return dit_request_vec_.empty(); }
3737

3838
// prepare forward input
3939
DiTForwardInput prepare_forward_input();
4040

4141
private:
42-
std::vector<DiTRequestParams> dit_request_data_vec_;
42+
std::vector<std::shared_ptr<DiTRequest>> dit_request_vec_;
4343
};
4444

45-
} // namespace xllm
45+
} // namespace xllm

xllm/core/framework/dit_model_loader.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,13 @@ DiTModelLoader::DiTModelLoader(const std::string& model_root_path)
237237

238238
const nlohmann::json root_json = model_index_reader.data();
239239
if (!root_json.is_object()) {
240-
LOG(FATAL) << "DITModelLoader: model_index.json root is not an object!";
240+
LOG(FATAL) << "DiTModelLoader: model_index.json root is not an object!";
241241
}
242242

243243
// parse model_index.json & initialize model_loader
244244
for (const auto& [json_key, json_value] : root_json.items()) {
245245
if (!json_value.is_array() || json_value.size() != 2) {
246-
LOG(WARNING) << "DITModelLoader: Invalid format for component! "
246+
LOG(WARNING) << "DiTModelLoader: Invalid format for component! "
247247
<< "JsonKey=" << json_key
248248
<< ", Expected [library, class_name] array";
249249
continue;
@@ -254,13 +254,13 @@ DiTModelLoader::DiTModelLoader(const std::string& model_root_path)
254254
std::filesystem::path(model_root_path_) / json_key;
255255
const std::string component_folder = component_folder_path.string();
256256
if (!std::filesystem::exists(component_folder)) {
257-
LOG(WARNING) << "DITModelLoader: Component folder not found! "
257+
LOG(WARNING) << "DiTModelLoader: Component folder not found! "
258258
<< "ComponentName=" << component_name
259259
<< ", Folder=" << component_folder;
260260
continue;
261261
}
262262
if (!std::filesystem::is_directory(component_folder)) {
263-
LOG(WARNING) << "DITModelLoader: Component path is not a directory! "
263+
LOG(WARNING) << "DiTModelLoader: Component path is not a directory! "
264264
<< "ComponentName=" << component_name
265265
<< ", Path=" << component_folder;
266266
continue;

xllm/core/framework/model/dit_model.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class DiTModel : public torch::nn::Module {
2929
public:
3030
~DiTModel() override = default;
3131

32-
virtual torch::Tensor forward(const InputParams& input_params,
33-
const GenerationParams& gen_params) = 0;
32+
virtual torch::Tensor forward(const DiTInputParams& input_params,
33+
const DiTGenerationParams& gen_params) = 0;
3434
virtual torch::Device device() const = 0;
3535
virtual const torch::TensorOptions& options() const = 0;
3636
virtual void load_model(std::unique_ptr<DiTModelLoader> loader) = 0;
@@ -43,8 +43,8 @@ class DiTModelImpl : public DiTModel {
4343
: model_(std::move(model)), options_(options) {
4444
LOG(INFO) << "DiTModelImpl created.";
4545
}
46-
torch::Tensor forward(const InputParams& input_params,
47-
const GenerationParams& gen_params) override {
46+
torch::Tensor forward(const DiTInputParams& input_params,
47+
const DiTGenerationParams& gen_params) override {
4848
return model_->forward(input_params, gen_params);
4949
}
5050
torch::Device device() const override { return options_.device(); }

xllm/core/framework/request/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ cc_library(
55
NAME
66
request
77
HDRS
8+
dit_request.h
9+
dit_request_params.h
810
finish_reason.h
911
incremental_decoder.h
1012
mm_data.h
1113
mm_input_helper.h
14+
request_base.h
1215
request.h
1316
dit_request.h
1417
request_output.h
@@ -23,6 +26,7 @@ cc_library(
2326
stopping_checker.h
2427
priority_comparator.h
2528
SRCS
29+
dit_request.cpp
2630
finish_reason.cpp
2731
incremental_decoder.cpp
2832
mm_data.cpp

xllm/core/framework/request/dit_request.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,9 @@ DiTRequest::DiTRequest(const std::string& request_id,
3131
const std::string& x_request_id,
3232
const std::string& x_request_time,
3333
const DiTRequestState& state,
34-
const std::string& service_request_id,
35-
bool offline,
36-
int32_t slo_ms,
37-
RequestPriority priority)
38-
: created_time_(absl::Now()),
39-
request_id_(request_id),
40-
service_request_id_(service_request_id),
41-
x_request_id_(x_request_id),
42-
x_request_time_(x_request_time),
43-
state_(state),
44-
offline_(offline),
45-
slo_ms_(slo_ms),
46-
priority_(priority) {}
34+
const std::string& service_request_id)
35+
: RequestBase(request_id, x_request_id, x_request_time, service_request_id),
36+
state_(state) {}
4737

4838
bool DiTRequest::finished() const { return true; }
4939

@@ -54,16 +44,23 @@ void DiTRequest::log_statistic(double total_latency) {
5444
<< "total_latency: " << total_latency * 1000 << "ms";
5545
}
5646

57-
DiTRequestOutput DiTRequest::generate_dit_output(DiTForwardOutput dit_output) {
47+
const DiTRequestOutput DiTRequest::generate_output(
48+
DiTForwardOutput dit_output) {
5849
DiTRequestOutput output;
5950
output.request_id = request_id_;
6051
output.service_request_id = service_request_id_;
6152
output.status = Status(StatusCode::OK);
6253
output.finished = finished();
6354
output.cancelled = false;
55+
6456
DiTGenerationOutput result;
65-
result.image_tensor = dit_output.image;
57+
result.image = dit_output.image;
58+
result.height = state_.generation_params().height;
59+
result.width = state_.generation_params().width;
60+
result.seed = state_.generation_params().seed.value();
6661
output.outputs.push_back(result);
62+
6763
return output;
6864
}
69-
} // namespace xllm
65+
66+
} // namespace xllm

0 commit comments

Comments
 (0)