Skip to content

Commit caf2b85

Browse files
committed
Apply Flux fixes to SD3
1 parent 5b1ae45 commit caf2b85

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

conditioner.hpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -798,21 +798,21 @@ struct SD3CLIPEmbedder : public Conditioner {
798798
}
799799

800800
if (chunk_idx == 0) {
801-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
802-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
803-
// clip_l->compute(n_threads,
804-
// input_ids,
805-
// 0,
806-
// NULL,
807-
// max_token_idx,
808-
// true,
809-
// &pooled_l,
810-
// work_ctx);
801+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
802+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
803+
clip_l->compute(n_threads,
804+
input_ids,
805+
0,
806+
NULL,
807+
max_token_idx,
808+
true,
809+
&pooled_l,
810+
work_ctx);
811811

812812
// clip_l.transformer.text_model.text_projection no in file, ignore
813813
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
814-
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
815-
ggml_set_f32(pooled_l, 0.f);
814+
// pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
815+
// ggml_set_f32(pooled_l, 0.f);
816816
}
817817
}
818818

@@ -852,21 +852,21 @@ struct SD3CLIPEmbedder : public Conditioner {
852852
}
853853

854854
if (chunk_idx == 0) {
855-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
856-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
857-
// clip_g->compute(n_threads,
858-
// input_ids,
859-
// 0,
860-
// NULL,
861-
// max_token_idx,
862-
// true,
863-
// &pooled_g,
864-
// work_ctx);
855+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
856+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
857+
clip_g->compute(n_threads,
858+
input_ids,
859+
0,
860+
NULL,
861+
max_token_idx,
862+
true,
863+
&pooled_g,
864+
work_ctx);
865865
// clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
866866

867867
// TODO: fix pooled_g
868-
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
869-
ggml_set_f32(pooled_g, 0.f);
868+
// pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
869+
// ggml_set_f32(pooled_g, 0.f);
870870
}
871871
}
872872

0 commit comments

Comments
 (0)