Skip to content

Commit ed45878

Browse files
committed
Only include pad into mask once
1 parent 5a92f4d commit ed45878

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

conditioner.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,21 @@ struct PixArtCLIPEmbedder : public Conditioner {
14411441
return {t5_tokens, t5_weights, t5_mask};
14421442
}
14431443

1444+
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
1445+
float* mask_data = (float*)mask->data;
1446+
int num_pad = 0;
1447+
for (int64_t i = 0; i < max_seq_length; i++) {
1448+
if (num_pad >= num_extra_padding) {
1449+
break;
1450+
}
1451+
if (std::isinf(mask_data[i])) {
1452+
mask_data[i] = 0;
1453+
++num_pad;
1454+
}
1455+
}
1456+
// LOG_DEBUG("PAD: %d", num_pad);
1457+
}
1458+
14441459
SDCondition get_learned_condition_common(ggml_context* work_ctx,
14451460
int n_threads,
14461461
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
@@ -1527,6 +1542,21 @@ struct PixArtCLIPEmbedder : public Conditioner {
15271542
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
15281543
ggml_set_f32(hidden_states, 0.f);
15291544
}
1545+
1546+
int mask_pad = 1;
1547+
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1548+
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1549+
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1550+
try {
1551+
mask_pad = std::stoi(mask_pad_str);
1552+
} catch (const std::invalid_argument&) {
1553+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1554+
} catch (const std::out_of_range&) {
1555+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1556+
}
1557+
}
1558+
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
1559+
15301560
return SDCondition(hidden_states, t5_attn_mask, NULL);
15311561
}
15321562

flux.hpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -709,20 +709,6 @@ namespace Flux {
709709
return ids;
710710
}
711711

712-
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
713-
float* mask_data = (float*)mask->data;
714-
int num_pad = 0;
715-
for (int64_t i = 0; i < max_seq_length; i++) {
716-
if (num_pad >= num_extra_padding) {
717-
break;
718-
}
719-
if (std::isinf(mask_data[i])) {
720-
mask_data[i] = 0;
721-
++num_pad;
722-
}
723-
}
724-
// LOG_DEBUG("PAD: %d", num_pad);
725-
}
726712

727713
// Generate positional embeddings
728714
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
@@ -1127,19 +1113,6 @@ namespace Flux {
11271113
guidance = ggml_set_f32(guidance, 0);
11281114
}
11291115

1130-
int mask_pad = 1;
1131-
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1132-
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1133-
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1134-
try {
1135-
mask_pad = std::stoi(mask_pad_str);
1136-
} catch (const std::invalid_argument&) {
1137-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1138-
} catch (const std::out_of_range&) {
1139-
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1140-
}
1141-
}
1142-
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
11431116

11441117
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
11451118
if (SD_CHROMA_USE_DIT_MASK != nullptr) {

0 commit comments

Comments
 (0)