From 8f6ba3177b55a40d22a1791b46bfe2564d22adfe Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 15 Oct 2025 15:24:33 +0000 Subject: [PATCH] add streaming feature --- examples/sd-server/main.cpp | 339 +++++++++++++++++++++++++++++++++++- 1 file changed, 337 insertions(+), 2 deletions(-) diff --git a/examples/sd-server/main.cpp b/examples/sd-server/main.cpp index 6bbd3a6b9..b5e8e1cdf 100644 --- a/examples/sd-server/main.cpp +++ b/examples/sd-server/main.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -509,6 +510,315 @@ struct ImageResultGuard { } }; +json logs_to_json(const LogCollector& collector); +json make_telemetry(const LogCollector& collector, + const GenerationRequest& request, + const CtxConfig& config, + int64_t elapsed_ms, + int64_t effective_seed); + +class StreamingImageResponder { + public: + StreamingImageResponder(ServerState& state, + std::unique_lock&& ctx_lock, + std::unique_ptr&& capture_scope, + std::shared_ptr collector, + GenerationRequest request, + CtxConfig ctx_config, + bool random_seed_requested, + int64_t effective_seed) + : state_(state), + ctx_lock_(std::move(ctx_lock)), + capture_scope_(std::move(capture_scope)), + collector_(std::move(collector)), + request_(std::move(request)), + ctx_config_(std::move(ctx_config)), + random_seed_requested_(random_seed_requested), + effective_seed_(effective_seed), + default_sample_method_(sd_get_default_sample_method(state.ctx)), + start_time_(std::chrono::steady_clock::now()) {} + + ~StreamingImageResponder() { + finalize_resources(); + } + + bool next(httplib::DataSink& sink) { + if (done_) { + return false; + } + + if (next_index_ < request_.batch_count) { + if (!emit_image_chunk(sink, next_index_)) { + done_ = true; + return false; + } + ++next_index_; + return true; + } + + emit_final_summary(sink); + done_ = true; + return false; + } + + void cancel() { + done_ = true; + finalize_resources(); + } + + private: + bool emit_image_chunk(httplib::DataSink& sink, int index) { + sd_img_gen_params_t params; + sd_img_gen_params_init(¶ms); + + params.prompt = request_.prompt.c_str(); + params.negative_prompt = request_.negative_prompt.c_str(); + params.clip_skip = request_.clip_skip; + params.width = request_.width; + params.height = request_.height; + params.batch_count = 1; + params.seed = effective_seed_ + index; + if (request_.has_vae_tiling_override) { + params.vae_tiling_params = request_.vae_tiling_params; + } + + sd_sample_params_t& sample_params = params.sample_params; + sample_params.sample_steps = request_.sample_steps; + sample_params.guidance.txt_cfg = request_.cfg_scale; + if (request_.has_img_cfg_scale) { + sample_params.guidance.img_cfg = request_.img_cfg_scale; + } + if (!std::isfinite(sample_params.guidance.img_cfg)) { + sample_params.guidance.img_cfg = sample_params.guidance.txt_cfg; + } + if (request_.override_sample_method) { + sample_params.sample_method = request_.sample_method; + } + if (sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { + sample_params.sample_method = default_sample_method_; + } + if (request_.override_scheduler) { + sample_params.scheduler = request_.scheduler; + } + if (request_.has_eta) { + sample_params.eta = request_.eta; + } + sample_params.shifted_timestep = request_.shifted_timestep; + + sd_image_t* results = generate_image(state_.ctx, ¶ms); + if (results == nullptr) { + emit_error(sink, "image generation failed", index); + return false; + } + + ImageResultGuard guard{results, params.batch_count}; + + sd_image_t& image = results[0]; + if (image.data == nullptr) { + emit_error(sink, "image data is empty", index); + return false; + } + + auto encode_start = std::chrono::steady_clock::now(); + + int png_size = 0; + unsigned char* png_data = stbi_write_png_to_mem(image.data, 0, image.width, image.height, image.channel, &png_size, nullptr); + if (png_data == nullptr) { + emit_error(sink, "failed to encode PNG", index); + return false; + } + std::string encoded = base64_encode(png_data, static_cast(png_size)); + STBIW_FREE(png_data); + + auto encode_end = std::chrono::steady_clock::now(); + const double encode_ms = std::chrono::duration_cast(encode_end - encode_start).count() / 1000.0; + const std::size_t encoded_size = encoded.size(); + + // Preserve the legacy -1 seed while still reporting the concrete seed that was used. + int64_t actual_seed = random_seed_requested_ ? (effective_seed_ + index) : (request_.seed + index); + int64_t reported_seed = random_seed_requested_ ? -1 : actual_seed; + + json image_chunk = json::object(); + image_chunk["type"] = "image"; + image_chunk["index"] = index; + image_chunk["seed"] = reported_seed; + image_chunk["actual_seed"] = actual_seed; + image_chunk["width"] = image.width; + image_chunk["height"] = image.height; + image_chunk["format"] = "png"; + image_chunk["mime_type"] = "image/png"; + image_chunk["payload_bytes"] = png_size; + image_chunk["encoded_bytes"] = static_cast(encoded_size); + image_chunk["encode_ms"] = encode_ms; + image_chunk["data"] = std::move(encoded); + + auto prepare_end = std::chrono::steady_clock::now(); + const double prepare_ms = std::chrono::duration_cast(prepare_end - encode_start).count() / 1000.0; + image_chunk["dispatch_prepare_ms"] = prepare_ms; + + std::size_t serialized_bytes = 0; + if (!write_json_array_item(sink, image_chunk, false, &serialized_bytes)) { + done_ = true; + finalize_resources(); + return false; + } + + auto dispatch_end = std::chrono::steady_clock::now(); + const double dispatch_total_ms = std::chrono::duration_cast(dispatch_end - encode_start).count() / 1000.0; + const double write_ms = std::chrono::duration_cast(dispatch_end - prepare_end).count() / 1000.0; + + json summary_entry = json::object(); + summary_entry["index"] = index; + summary_entry["seed"] = reported_seed; + summary_entry["actual_seed"] = actual_seed; + summary_entry["width"] = image.width; + summary_entry["height"] = image.height; + summary_entry["format"] = "png"; + summary_entry["mime_type"] = "image/png"; + summary_entry["streamed"] = true; + summary_entry["encode_ms"] = encode_ms; + summary_entry["dispatch_prepare_ms"] = prepare_ms; + summary_entry["dispatch_total_ms"] = dispatch_total_ms; + summary_entry["write_ms"] = write_ms; + summary_entry["payload_bytes"] = png_size; + summary_entry["encoded_bytes"] = static_cast(encoded_size); + summary_entry["serialized_bytes"] = static_cast(serialized_bytes); + image_summaries_.push_back(std::move(summary_entry)); + + return true; + } + + void emit_error(httplib::DataSink& sink, const std::string& message, int index) { + encountered_error_ = true; + done_ = true; + const int64_t elapsed = elapsed_ms(); + json error_chunk = json::object(); + error_chunk["type"] = "error"; + error_chunk["success"] = false; + error_chunk["error"] = message; + error_chunk["index"] = index; + error_chunk["requested_seed"] = request_.seed; + error_chunk["applied_seed"] = effective_seed_; + error_chunk["random_seed_requested"] = random_seed_requested_; + error_chunk["elapsed_ms"] = elapsed; + if (!ctx_config_.model_path.empty()) { + error_chunk["model_path"] = ctx_config_.model_path; + } + error_chunk["logs"] = logs_to_json(*collector_); + error_chunk["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_); + if (write_json_array_item(sink, error_chunk, true)) { + finalize_stream(sink); + } else { + finalize_resources(); + } + } + + void emit_final_summary(httplib::DataSink& sink) { + const int64_t elapsed = elapsed_ms(); + json summary = json::object(); + summary["type"] = "complete"; + summary["success"] = !encountered_error_; + summary["batch_count"] = request_.batch_count; + summary["requested_seed"] = request_.seed; + summary["applied_seed"] = effective_seed_; + summary["random_seed_requested"] = random_seed_requested_; + summary["elapsed_ms"] = elapsed; + if (!ctx_config_.model_path.empty()) { + summary["model_path"] = ctx_config_.model_path; + } + summary["images"] = image_summaries_; + summary["logs"] = logs_to_json(*collector_); + summary["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_); + done_ = true; + if (write_json_array_item(sink, summary, true)) { + finalize_stream(sink); + } else { + finalize_resources(); + } + } + + int64_t elapsed_ms() const { + auto end_time = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end_time - start_time_).count(); + } + + bool write_json_array_item(httplib::DataSink& sink, + const json& payload, + bool final_item, + std::size_t* serialized_size = nullptr) { + std::string serialized = payload.dump(); + if (serialized_size != nullptr) { + *serialized_size = serialized.size(); + } + + if (!array_opened_) { + const char prefix[] = "[\n"; + if (!sink.write(prefix, sizeof(prefix) - 1)) { + return false; + } + sink.os.flush(); + array_opened_ = true; + } + + if (!first_object_) { + const char separator[] = ",\n"; + if (!sink.write(separator, sizeof(separator) - 1)) { + return false; + } + sink.os.flush(); + } + + std::string chunk = std::move(serialized); + if (final_item) { + chunk.append("\n]"); + } + chunk.push_back('\n'); + bool ok = sink.write(chunk.data(), chunk.size()); + if (ok) { + sink.os.flush(); + } + first_object_ = false; + return ok; + } + + void finalize_stream(httplib::DataSink& sink) { + if (sink.done) { + sink.done(); + } + finalize_resources(); + } + + void finalize_resources() { + if (finalized_) { + return; + } + capture_scope_.reset(); + collector_.reset(); + if (ctx_lock_.owns_lock()) { + ctx_lock_.unlock(); + } + finalized_ = true; + } + + ServerState& state_; + std::unique_lock ctx_lock_; + std::unique_ptr capture_scope_; + std::shared_ptr collector_; + GenerationRequest request_; + CtxConfig ctx_config_; + bool random_seed_requested_ = false; + int64_t effective_seed_ = 0; + sample_method_t default_sample_method_ = SAMPLE_METHOD_DEFAULT; + std::chrono::steady_clock::time_point start_time_; + int next_index_ = 0; + bool done_ = false; + bool encountered_error_ = false; + bool finalized_ = false; + bool array_opened_ = false; + bool first_object_ = true; + std::vector image_summaries_; +}; + bool apply_context_overrides(const json& body, CtxConfig& config, std::string& error) { auto assign_string = [&](const char* key, std::string& target) -> bool { auto it = body.find(key); @@ -1830,7 +2140,8 @@ int main(int argc, char** argv) { }); server.Post("/generate", [&](const httplib::Request& req, httplib::Response& res) { - LogCollector collector; + auto collector_ptr = std::make_shared(); + LogCollector& collector = *collector_ptr; json body; try { @@ -1852,7 +2163,7 @@ int main(int argc, char** argv) { } std::unique_lock lock(state.mutex); - LogCaptureScope capture(state, collector); + auto capture_scope = std::make_unique(state, collector); CtxConfig desired_config = state.ctx_config; if (desired_config.model_path.empty()) { @@ -1886,6 +2197,30 @@ int main(int argc, char** argv) { effective_seed = generate_random_seed(); } + const bool enable_streaming = request_params.batch_count > 1; + if (enable_streaming) { + GenerationRequest streaming_request = std::move(request_params); + CtxConfig active_config = state.ctx_config; + auto streaming_responder = std::make_shared(state, + std::move(lock), + std::move(capture_scope), + collector_ptr, + std::move(streaming_request), + std::move(active_config), + random_seed_requested, + effective_seed); + res.status = 200; + res.set_chunked_content_provider( + "application/json", + [streaming_responder](size_t, httplib::DataSink& sink) { + return streaming_responder->next(sink); + }, + [streaming_responder](bool) { + streaming_responder->cancel(); + }); + return; + } + sd_img_gen_params_t img_params; sd_img_gen_params_init(&img_params);