Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 337 additions & 2 deletions examples/sd-server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <optional>
#include <sstream>
#include <iostream>
#include <memory>
#include <limits>
#include <mutex>
#include <random>
Expand Down Expand Up @@ -509,6 +510,315 @@ struct ImageResultGuard {
}
};

json logs_to_json(const LogCollector& collector);
json make_telemetry(const LogCollector& collector,
const GenerationRequest& request,
const CtxConfig& config,
int64_t elapsed_ms,
int64_t effective_seed);

class StreamingImageResponder {
public:
StreamingImageResponder(ServerState& state,
std::unique_lock<std::mutex>&& ctx_lock,
std::unique_ptr<LogCaptureScope>&& capture_scope,
std::shared_ptr<LogCollector> collector,
GenerationRequest request,
CtxConfig ctx_config,
bool random_seed_requested,
int64_t effective_seed)
: state_(state),
ctx_lock_(std::move(ctx_lock)),
capture_scope_(std::move(capture_scope)),
collector_(std::move(collector)),
request_(std::move(request)),
ctx_config_(std::move(ctx_config)),
random_seed_requested_(random_seed_requested),
effective_seed_(effective_seed),
default_sample_method_(sd_get_default_sample_method(state.ctx)),
start_time_(std::chrono::steady_clock::now()) {}

~StreamingImageResponder() {
finalize_resources();
}

bool next(httplib::DataSink& sink) {
if (done_) {
return false;
}

if (next_index_ < request_.batch_count) {
if (!emit_image_chunk(sink, next_index_)) {
done_ = true;
return false;
}
++next_index_;
return true;
}

emit_final_summary(sink);
done_ = true;
return false;
}

void cancel() {
done_ = true;
finalize_resources();
}

private:
bool emit_image_chunk(httplib::DataSink& sink, int index) {
sd_img_gen_params_t params;
sd_img_gen_params_init(&params);

params.prompt = request_.prompt.c_str();
params.negative_prompt = request_.negative_prompt.c_str();
params.clip_skip = request_.clip_skip;
params.width = request_.width;
params.height = request_.height;
params.batch_count = 1;
params.seed = effective_seed_ + index;
if (request_.has_vae_tiling_override) {
params.vae_tiling_params = request_.vae_tiling_params;
}

sd_sample_params_t& sample_params = params.sample_params;
sample_params.sample_steps = request_.sample_steps;
sample_params.guidance.txt_cfg = request_.cfg_scale;
if (request_.has_img_cfg_scale) {
sample_params.guidance.img_cfg = request_.img_cfg_scale;
}
if (!std::isfinite(sample_params.guidance.img_cfg)) {
sample_params.guidance.img_cfg = sample_params.guidance.txt_cfg;
}
if (request_.override_sample_method) {
sample_params.sample_method = request_.sample_method;
}
if (sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
sample_params.sample_method = default_sample_method_;
}
if (request_.override_scheduler) {
sample_params.scheduler = request_.scheduler;
}
if (request_.has_eta) {
sample_params.eta = request_.eta;
}
sample_params.shifted_timestep = request_.shifted_timestep;

sd_image_t* results = generate_image(state_.ctx, &params);
if (results == nullptr) {
emit_error(sink, "image generation failed", index);
return false;
}

ImageResultGuard guard{results, params.batch_count};

sd_image_t& image = results[0];
if (image.data == nullptr) {
emit_error(sink, "image data is empty", index);
return false;
}

auto encode_start = std::chrono::steady_clock::now();

int png_size = 0;
unsigned char* png_data = stbi_write_png_to_mem(image.data, 0, image.width, image.height, image.channel, &png_size, nullptr);
if (png_data == nullptr) {
emit_error(sink, "failed to encode PNG", index);
return false;
}
std::string encoded = base64_encode(png_data, static_cast<size_t>(png_size));
STBIW_FREE(png_data);

auto encode_end = std::chrono::steady_clock::now();
const double encode_ms = std::chrono::duration_cast<std::chrono::microseconds>(encode_end - encode_start).count() / 1000.0;
const std::size_t encoded_size = encoded.size();

// Preserve the legacy -1 seed while still reporting the concrete seed that was used.
int64_t actual_seed = random_seed_requested_ ? (effective_seed_ + index) : (request_.seed + index);
int64_t reported_seed = random_seed_requested_ ? -1 : actual_seed;

json image_chunk = json::object();
image_chunk["type"] = "image";
image_chunk["index"] = index;
image_chunk["seed"] = reported_seed;
image_chunk["actual_seed"] = actual_seed;
image_chunk["width"] = image.width;
image_chunk["height"] = image.height;
image_chunk["format"] = "png";
image_chunk["mime_type"] = "image/png";
image_chunk["payload_bytes"] = png_size;
image_chunk["encoded_bytes"] = static_cast<int64_t>(encoded_size);
image_chunk["encode_ms"] = encode_ms;
image_chunk["data"] = std::move(encoded);

auto prepare_end = std::chrono::steady_clock::now();
const double prepare_ms = std::chrono::duration_cast<std::chrono::microseconds>(prepare_end - encode_start).count() / 1000.0;
image_chunk["dispatch_prepare_ms"] = prepare_ms;

std::size_t serialized_bytes = 0;
if (!write_json_array_item(sink, image_chunk, false, &serialized_bytes)) {
done_ = true;
finalize_resources();
return false;
}

auto dispatch_end = std::chrono::steady_clock::now();
const double dispatch_total_ms = std::chrono::duration_cast<std::chrono::microseconds>(dispatch_end - encode_start).count() / 1000.0;
const double write_ms = std::chrono::duration_cast<std::chrono::microseconds>(dispatch_end - prepare_end).count() / 1000.0;

json summary_entry = json::object();
summary_entry["index"] = index;
summary_entry["seed"] = reported_seed;
summary_entry["actual_seed"] = actual_seed;
summary_entry["width"] = image.width;
summary_entry["height"] = image.height;
summary_entry["format"] = "png";
summary_entry["mime_type"] = "image/png";
summary_entry["streamed"] = true;
summary_entry["encode_ms"] = encode_ms;
summary_entry["dispatch_prepare_ms"] = prepare_ms;
summary_entry["dispatch_total_ms"] = dispatch_total_ms;
summary_entry["write_ms"] = write_ms;
summary_entry["payload_bytes"] = png_size;
summary_entry["encoded_bytes"] = static_cast<int64_t>(encoded_size);
summary_entry["serialized_bytes"] = static_cast<int64_t>(serialized_bytes);
image_summaries_.push_back(std::move(summary_entry));

return true;
}

void emit_error(httplib::DataSink& sink, const std::string& message, int index) {
encountered_error_ = true;
done_ = true;
const int64_t elapsed = elapsed_ms();
json error_chunk = json::object();
error_chunk["type"] = "error";
error_chunk["success"] = false;
error_chunk["error"] = message;
error_chunk["index"] = index;
error_chunk["requested_seed"] = request_.seed;
error_chunk["applied_seed"] = effective_seed_;
error_chunk["random_seed_requested"] = random_seed_requested_;
error_chunk["elapsed_ms"] = elapsed;
if (!ctx_config_.model_path.empty()) {
error_chunk["model_path"] = ctx_config_.model_path;
}
error_chunk["logs"] = logs_to_json(*collector_);
error_chunk["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_);
if (write_json_array_item(sink, error_chunk, true)) {
finalize_stream(sink);
} else {
finalize_resources();
}
}

void emit_final_summary(httplib::DataSink& sink) {
const int64_t elapsed = elapsed_ms();
json summary = json::object();
summary["type"] = "complete";
summary["success"] = !encountered_error_;
summary["batch_count"] = request_.batch_count;
summary["requested_seed"] = request_.seed;
summary["applied_seed"] = effective_seed_;
summary["random_seed_requested"] = random_seed_requested_;
summary["elapsed_ms"] = elapsed;
if (!ctx_config_.model_path.empty()) {
summary["model_path"] = ctx_config_.model_path;
}
summary["images"] = image_summaries_;
summary["logs"] = logs_to_json(*collector_);
summary["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_);
done_ = true;
if (write_json_array_item(sink, summary, true)) {
finalize_stream(sink);
} else {
finalize_resources();
}
}

int64_t elapsed_ms() const {
auto end_time = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time_).count();
}

bool write_json_array_item(httplib::DataSink& sink,
const json& payload,
bool final_item,
std::size_t* serialized_size = nullptr) {
std::string serialized = payload.dump();
if (serialized_size != nullptr) {
*serialized_size = serialized.size();
}

if (!array_opened_) {
const char prefix[] = "[\n";
if (!sink.write(prefix, sizeof(prefix) - 1)) {
return false;
}
sink.os.flush();
array_opened_ = true;
}

if (!first_object_) {
const char separator[] = ",\n";
if (!sink.write(separator, sizeof(separator) - 1)) {
return false;
}
sink.os.flush();
}

std::string chunk = std::move(serialized);
if (final_item) {
chunk.append("\n]");
}
chunk.push_back('\n');
bool ok = sink.write(chunk.data(), chunk.size());
if (ok) {
sink.os.flush();
}
first_object_ = false;
return ok;
}

void finalize_stream(httplib::DataSink& sink) {
if (sink.done) {
sink.done();
}
finalize_resources();
}

void finalize_resources() {
if (finalized_) {
return;
}
capture_scope_.reset();
collector_.reset();
if (ctx_lock_.owns_lock()) {
ctx_lock_.unlock();
}
finalized_ = true;
}

ServerState& state_;
std::unique_lock<std::mutex> ctx_lock_;
std::unique_ptr<LogCaptureScope> capture_scope_;
std::shared_ptr<LogCollector> collector_;
GenerationRequest request_;
CtxConfig ctx_config_;
bool random_seed_requested_ = false;
int64_t effective_seed_ = 0;
sample_method_t default_sample_method_ = SAMPLE_METHOD_DEFAULT;
std::chrono::steady_clock::time_point start_time_;
int next_index_ = 0;
bool done_ = false;
bool encountered_error_ = false;
bool finalized_ = false;
bool array_opened_ = false;
bool first_object_ = true;
std::vector<json> image_summaries_;
};

bool apply_context_overrides(const json& body, CtxConfig& config, std::string& error) {
auto assign_string = [&](const char* key, std::string& target) -> bool {
auto it = body.find(key);
Expand Down Expand Up @@ -1830,7 +2140,8 @@ int main(int argc, char** argv) {
});

server.Post("/generate", [&](const httplib::Request& req, httplib::Response& res) {
LogCollector collector;
auto collector_ptr = std::make_shared<LogCollector>();
LogCollector& collector = *collector_ptr;

json body;
try {
Expand All @@ -1852,7 +2163,7 @@ int main(int argc, char** argv) {
}

std::unique_lock<std::mutex> lock(state.mutex);
LogCaptureScope capture(state, collector);
auto capture_scope = std::make_unique<LogCaptureScope>(state, collector);

CtxConfig desired_config = state.ctx_config;
if (desired_config.model_path.empty()) {
Expand Down Expand Up @@ -1886,6 +2197,30 @@ int main(int argc, char** argv) {
effective_seed = generate_random_seed();
}

const bool enable_streaming = request_params.batch_count > 1;
if (enable_streaming) {
GenerationRequest streaming_request = std::move(request_params);
CtxConfig active_config = state.ctx_config;
auto streaming_responder = std::make_shared<StreamingImageResponder>(state,
std::move(lock),
std::move(capture_scope),
collector_ptr,
std::move(streaming_request),
std::move(active_config),
random_seed_requested,
effective_seed);
res.status = 200;
res.set_chunked_content_provider(
"application/json",
[streaming_responder](size_t, httplib::DataSink& sink) {
return streaming_responder->next(sink);
},
[streaming_responder](bool) {
streaming_responder->cancel();
});
return;
}

sd_img_gen_params_t img_params;
sd_img_gen_params_init(&img_params);

Expand Down
Loading