@@ -39,6 +39,11 @@ std::string generate_chat_request_id() {
3939 short_uuid.random ();
4040}
4141
42+ std::string generate_image_generation_request_id () {
43+ return " imggen-" + InstanceName::name ()->get_name_hash () + " -" +
44+ short_uuid.random ();
45+ }
46+
4247} // namespace
4348
4449RequestParams::RequestParams (const proto::CompletionRequest& request,
@@ -332,6 +337,76 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request,
332337 streaming = false ;
333338}
334339
340+ ImageRequestParams::ImageRequestParams (
341+ const proto::ImageGenerationRequest& request,
342+ const std::string& x_rid,
343+ const std::string& x_rtime) {
344+ request_id = generate_image_generation_request_id ();
345+ x_request_id = x_rid;
346+ x_request_time = x_rtime;
347+ model = request.model ();
348+ if (request.has_service_request_id ()) {
349+ service_request_id = request.service_request_id ();
350+ }
351+ const auto & proto_input = request.input ();
352+ input_params.prompt = proto_input.prompt ();
353+ if (proto_input.has_prompt_2 ()) {
354+ input_params.prompt_2 = proto_input.prompt_2 ();
355+ }
356+ if (proto_input.has_negative_prompt ()) {
357+ input_params.negative_prompt = proto_input.negative_prompt ();
358+ }
359+ if (proto_input.has_negative_prompt_2 ()) {
360+ input_params.negative_prompt_2 = proto_input.negative_prompt_2 ();
361+ }
362+ if (proto_input.has_prompt_embeds ()) {
363+ const auto & proto_tensor = proto_input.prompt_embeds ();
364+ input_params.prompt_embeds = proto_tensor;
365+ }
366+ if (proto_input.has_pooled_prompt_embeds ()) {
367+ input_params.pooled_prompt_embeds = proto_input.pooled_prompt_embeds ();
368+ }
369+ if (proto_input.has_negative_prompt_embeds ()) {
370+ input_params.negative_prompt_embeds = proto_input.negative_prompt_embeds ();
371+ }
372+ if (proto_input.has_negative_pooled_prompt_embeds ()) {
373+ input_params.negative_pooled_prompt_embeds =
374+ proto_input.negative_pooled_prompt_embeds ();
375+ }
376+ if (proto_input.has_latents ()) {
377+ const auto & proto_tensor = proto_input.latents ();
378+ input_params.latents = proto_tensor;
379+ }
380+ const auto & proto_params = request.parameters ();
381+ if (proto_params.has_size ()) {
382+ generation_params.size = proto_params.size ();
383+ }
384+ if (proto_params.has_num_inference_steps ()) {
385+ generation_params.num_inference_steps = proto_params.num_inference_steps ();
386+ }
387+ if (proto_params.has_true_cfg_scale ()) {
388+ generation_params.true_cfg_scale = proto_params.true_cfg_scale ();
389+ }
390+ if (proto_params.has_guidance_scale ()) {
391+ generation_params.guidance_scale = proto_params.guidance_scale ();
392+ }
393+ if (proto_params.has_num_images_per_prompt ()) {
394+ generation_params.num_images_per_prompt =
395+ static_cast <uint32_t >(proto_params.num_images_per_prompt ());
396+ } else {
397+ generation_params.num_images_per_prompt = 1 ;
398+ }
399+ if (proto_params.has_seed ()) {
400+ generation_params.seed = proto_params.seed ();
401+ }
402+ if (proto_params.has_max_sequence_length ()) {
403+ generation_params.max_sequence_length = proto_params.max_sequence_length ();
404+ }
405+ }
406+ bool ImageRequestParams::verify_params (
407+ std::function<bool (ImageRequestOutput)> callback) const {
408+ return true ;
409+ }
335410bool RequestParams::verify_params (OutputCallback callback) const {
336411 if (n == 0 ) {
337412 CALLBACK_WITH_ERROR (StatusCode::INVALID_ARGUMENT,
0 commit comments