Skip to content

Commit ae7eff1

Browse files
yiming-l21xiao-yu-chen
authored andcommitted
feat: add DiT data structures.
1 parent 2b7384d commit ae7eff1

File tree

11 files changed

+242
-6
lines changed

11 files changed

+242
-6
lines changed

xllm/api_service/api_service_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ limitations under the License.
1919
#include <memory>
2020

2121
#include "call.h"
22+
#include "core/runtime/dit_master.h"
2223
#include "core/runtime/llm_master.h"
23-
2424
namespace xllm {
2525

2626
template <typename T>

xllm/core/framework/batch/batch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <limits>
2222
#include <vector>
2323

24+
#include "framework/request/dit_request_params.h"
2425
#include "framework/request/mm_data.h"
2526
#include "framework/request/request.h"
2627
#include "framework/request/sequence.h"

xllm/core/framework/request/request.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,23 @@ Request::Request(const std::string& request_id,
4848
slo_ms_(slo_ms) {
4949
create_sequences_group();
5050
}
51+
Request::Request(const std::string& request_id,
52+
const std::string& x_request_id,
53+
const std::string& x_request_time,
54+
const DITRequestState& state,
55+
const std::string& service_request_id,
56+
bool offline,
57+
int32_t slo_ms,
58+
RequestPriority priority)
59+
: request_id_(request_id),
60+
service_request_id_(service_request_id),
61+
x_request_id_(x_request_id),
62+
x_request_time_(x_request_time),
63+
dit_state_(std::move(state)),
64+
created_time_(absl::Now()),
65+
offline_(offline),
66+
priority_(priority),
67+
slo_ms_(slo_ms) {}
5168

5269
void Request::create_sequences_group() {
5370
SequenceParams sequence_params;

xllm/core/framework/request/request.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ class Request {
4343
bool offline = false,
4444
int32_t slo_ms = 0,
4545
RequestPriority priority = RequestPriority::NORMAL);
46+
Request(const std::string& request_id,
47+
const std::string& x_request_id,
48+
const std::string& x_request_time,
49+
const DITRequestState& state,
50+
const std::string& service_request_id = "",
51+
bool offline = false,
52+
int32_t slo_ms = 0,
53+
RequestPriority priority = RequestPriority::NORMAL);
4654

4755
bool finished() const;
4856

@@ -93,7 +101,7 @@ class Request {
93101
const RequestPriority priority() const { return priority_; }
94102

95103
RequestState& state() { return state_; }
96-
104+
DITRequestState& dit_state() { return dit_state_; }
97105
void update_connection_status();
98106

99107
bool check_beam_search() const {
@@ -115,7 +123,7 @@ class Request {
115123
std::string x_request_time_;
116124

117125
RequestState state_;
118-
126+
DITRequestState dit_state_;
119127
// list of sequences to generate completions for the prompt
120128
// use deque instead of vector to avoid no-copy move for Sequence
121129
// std::deque<Sequence> sequences;

xllm/core/framework/request/request_output.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,34 @@ void RequestOutput::log_request_status() const {
5151
}
5252
}
5353

54+
void ImageRequestOutput::log_request_status() const {
55+
if (!status.has_value()) {
56+
return;
57+
}
58+
59+
auto code = status.value().code();
60+
switch (code) {
61+
case StatusCode::OK:
62+
COUNTER_INC(request_status_total_ok);
63+
break;
64+
case StatusCode::CANCELLED:
65+
COUNTER_INC(request_status_total_cancelled);
66+
break;
67+
case StatusCode::UNKNOWN:
68+
COUNTER_INC(request_status_total_unknown);
69+
break;
70+
case StatusCode::INVALID_ARGUMENT:
71+
COUNTER_INC(request_status_total_invalid_argument);
72+
break;
73+
case StatusCode::DEADLINE_EXCEEDED:
74+
COUNTER_INC(request_status_total_deadline_exceeded);
75+
break;
76+
case StatusCode::RESOURCE_EXHAUSTED:
77+
COUNTER_INC(request_status_total_resource_exhausted);
78+
break;
79+
default:
80+
COUNTER_INC(request_status_total_unknown);
81+
break;
82+
}
83+
}
5484
} // namespace xllm

xllm/core/framework/request/request_output.h

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <vector>
2525

2626
#include "core/common/types.h"
27+
#include "image_generation.pb.h"
2728

2829
namespace xllm {
2930
struct Usage {
@@ -75,6 +76,22 @@ struct SequenceOutput {
7576
// the embeddings of the prompt token
7677
std::optional<std::vector<float>> embeddings;
7778
};
79+
struct ImageGenerationOutput {
80+
// the index of the sequence in the request.
81+
size_t index;
82+
83+
// the generated image in proto tensor format.
84+
proto::Tensor image_tensor;
85+
86+
// the height of the generated image.
87+
int32_t height;
88+
89+
// the width of the generated image.
90+
int32_t width;
91+
92+
// seed used for image generation.
93+
int64_t seed;
94+
};
7895

7996
struct RequestOutput {
8097
RequestOutput() = default;
@@ -108,10 +125,44 @@ struct RequestOutput {
108125
bool cancelled = false;
109126
};
110127

128+
struct ImageRequestOutput {
129+
ImageRequestOutput() = default;
130+
131+
ImageRequestOutput(Status&& _status) : status(std::move(_status)) {}
132+
133+
void log_request_status() const;
134+
135+
// the id of the request.
136+
std::string request_id;
137+
138+
// the id of the request which generated in xllm service.
139+
std::string service_request_id;
140+
141+
// the status of the request.
142+
std::optional<Status> status;
143+
144+
// the output for each sequence in the request.
145+
std::vector<ImageGenerationOutput> outputs;
146+
147+
// whether the request is finished.
148+
bool finished = false;
149+
150+
// whether the request is cancelled.
151+
bool cancelled = false;
152+
};
153+
111154
// callback function for output, return true to continue, false to stop/cancel
112155
using OutputCallback = std::function<bool(RequestOutput output)>;
113-
156+
// callback function for image request output, return true to continue, false to
157+
// stop/cancel
158+
using ImageOutputCallback = std::function<bool(ImageRequestOutput output)>;
159+
// callback function for batch output, return true to continue, false to
160+
// stop/cancel
114161
using BatchOutputCallback =
115162
std::function<bool(size_t index, RequestOutput output)>;
163+
// callback function for batch image output, return true to continue, false to
164+
// stop/cancel
165+
using BatchImageOutputCallback =
166+
std::function<bool(size_t index, ImageRequestOutput output)>;
116167

117168
} // namespace xllm

xllm/core/framework/request/request_params.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ std::string generate_chat_request_id() {
3939
short_uuid.random();
4040
}
4141

42+
std::string generate_image_generation_request_id() {
43+
return "imggen-" + InstanceName::name()->get_name_hash() + "-" +
44+
short_uuid.random();
45+
}
46+
4247
} // namespace
4348

4449
RequestParams::RequestParams(const proto::CompletionRequest& request,
@@ -332,6 +337,76 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request,
332337
streaming = false;
333338
}
334339

340+
ImageRequestParams::ImageRequestParams(
341+
const proto::ImageGenerationRequest& request,
342+
const std::string& x_rid,
343+
const std::string& x_rtime) {
344+
request_id = generate_image_generation_request_id();
345+
x_request_id = x_rid;
346+
x_request_time = x_rtime;
347+
model = request.model();
348+
if (request.has_service_request_id()) {
349+
service_request_id = request.service_request_id();
350+
}
351+
const auto& proto_input = request.input();
352+
input_params.prompt = proto_input.prompt();
353+
if (proto_input.has_prompt_2()) {
354+
input_params.prompt_2 = proto_input.prompt_2();
355+
}
356+
if (proto_input.has_negative_prompt()) {
357+
input_params.negative_prompt = proto_input.negative_prompt();
358+
}
359+
if (proto_input.has_negative_prompt_2()) {
360+
input_params.negative_prompt_2 = proto_input.negative_prompt_2();
361+
}
362+
if (proto_input.has_prompt_embeds()) {
363+
const auto& proto_tensor = proto_input.prompt_embeds();
364+
input_params.prompt_embeds = proto_tensor;
365+
}
366+
if (proto_input.has_pooled_prompt_embeds()) {
367+
input_params.pooled_prompt_embeds = proto_input.pooled_prompt_embeds();
368+
}
369+
if (proto_input.has_negative_prompt_embeds()) {
370+
input_params.negative_prompt_embeds = proto_input.negative_prompt_embeds();
371+
}
372+
if (proto_input.has_negative_pooled_prompt_embeds()) {
373+
input_params.negative_pooled_prompt_embeds =
374+
proto_input.negative_pooled_prompt_embeds();
375+
}
376+
if (proto_input.has_latents()) {
377+
const auto& proto_tensor = proto_input.latents();
378+
input_params.latents = proto_tensor;
379+
}
380+
const auto& proto_params = request.parameters();
381+
if (proto_params.has_size()) {
382+
generation_params.size = proto_params.size();
383+
}
384+
if (proto_params.has_num_inference_steps()) {
385+
generation_params.num_inference_steps = proto_params.num_inference_steps();
386+
}
387+
if (proto_params.has_true_cfg_scale()) {
388+
generation_params.true_cfg_scale = proto_params.true_cfg_scale();
389+
}
390+
if (proto_params.has_guidance_scale()) {
391+
generation_params.guidance_scale = proto_params.guidance_scale();
392+
}
393+
if (proto_params.has_num_images_per_prompt()) {
394+
generation_params.num_images_per_prompt =
395+
static_cast<uint32_t>(proto_params.num_images_per_prompt());
396+
} else {
397+
generation_params.num_images_per_prompt = 1;
398+
}
399+
if (proto_params.has_seed()) {
400+
generation_params.seed = proto_params.seed();
401+
}
402+
if (proto_params.has_max_sequence_length()) {
403+
generation_params.max_sequence_length = proto_params.max_sequence_length();
404+
}
405+
}
406+
bool ImageRequestParams::verify_params(
407+
std::function<bool(ImageRequestOutput)> callback) const {
408+
return true;
409+
}
335410
bool RequestParams::verify_params(OutputCallback callback) const {
336411
if (n == 0) {
337412
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,

xllm/core/framework/request/request_params.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ limitations under the License.
2727
#include "completion.pb.h"
2828
#include "core/common/macros.h"
2929
#include "core/common/types.h"
30+
#include "dit_request_params.h"
3031
#include "embedding.pb.h"
32+
#include "image_generation.pb.h"
3133
#include "multimodal.pb.h"
3234
#include "request.h"
3335
#include "request_output.h"
@@ -139,4 +141,32 @@ struct RequestParams {
139141
nlohmann::json chat_template_kwargs = nlohmann::json::object();
140142
};
141143

144+
struct ImageRequestParams {
145+
ImageRequestParams() = default;
146+
ImageRequestParams(const proto::ImageGenerationRequest& request,
147+
const std::string& x_rid,
148+
const std::string& x_rtime);
149+
150+
bool verify_params(ImageOutputCallback callback) const;
151+
152+
// request id
153+
std::string request_id;
154+
std::string service_request_id = "";
155+
std::string x_request_id;
156+
std::string x_request_time;
157+
158+
std::string model;
159+
160+
bool offline = false;
161+
162+
int32_t slo_ms = 0;
163+
164+
RequestPriority priority = RequestPriority::NORMAL;
165+
166+
InputParams input_params;
167+
// Mandatory: Generation control parameters (encapsulates all fields related
168+
// to "image generation process")
169+
GenerationParams generation_params;
170+
};
171+
142172
} // namespace xllm

xllm/core/framework/request/request_state.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <vector>
2525

2626
#include "core/framework/sampling/sampling_params.h"
27+
#include "dit_request_params.h"
2728
#include "mm_data.h"
2829
#include "request_output.h"
2930
#include "stopping_checker.h"
@@ -150,4 +151,18 @@ struct RequestState final {
150151
std::optional<Call*> call_;
151152
};
152153

154+
struct DITRequestState {
155+
public:
156+
DITRequestState(InputParams&& input_params,
157+
GenerationParams&& generation_params)
158+
: input_params_(std::move(input_params)),
159+
generation_params_(std::move(generation_params)) {}
160+
DITRequestState() {}
161+
InputParams& input_params() { return input_params_; }
162+
GenerationParams& generation_params() { return generation_params_; }
163+
164+
private:
165+
InputParams input_params_;
166+
GenerationParams generation_params_;
167+
};
153168
} // namespace xllm

xllm/core/runtime/forward_params.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ limitations under the License.
2424
#include "common/types.h"
2525
#include "framework/model/model_input_params.h"
2626
#include "framework/sampling/sampling_params.h"
27-
27+
#include "tensor.pb.h"
2828
namespace xllm {
2929

3030
class WorkerType {
@@ -33,6 +33,7 @@ class WorkerType {
3333
INVALID = 0,
3434
LLM, // LLM
3535
VLM, // VLM
36+
DIT, // DIT
3637
ELM, // Embedding LM
3738
EVLM, // Embedding VLM
3839
};
@@ -43,6 +44,8 @@ class WorkerType {
4344
value_ = LLM;
4445
} else if (str == "VLM") {
4546
value_ = VLM;
47+
} else if (str == "DIT") {
48+
value_ = DIT;
4649
} else if (str == "ELM") {
4750
value_ = ELM;
4851
} else if (str == "EVLM") {
@@ -67,6 +70,8 @@ class WorkerType {
6770
return "LLM";
6871
} else if (this->value_ == VLM) {
6972
return "VLM";
73+
} else if (this->value_ == DIT) {
74+
return "DIT";
7075
} else if (this->value_ == ELM) {
7176
return "ELM";
7277
} else if (this->value_ == EVLM) {
@@ -118,6 +123,9 @@ struct ForwardOutput {
118123
torch::Tensor expert_load_data;
119124

120125
int32_t prepared_layer_id;
126+
127+
// dit related output
128+
torch::Tensor image;
121129
};
122130

123131
// Model input with raw data, which will be

0 commit comments

Comments
 (0)