Skip to content

Commit 303efd2

Browse files
gameofdimensionfelix01.yu
andauthored
Improve pos embed for Flux.1 inference on Ascend NPU (huggingface#12534)
improve pos embed for ascend npu Co-authored-by: felix01.yu <[email protected]>
1 parent 5afbcce commit 303efd2

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2828
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,7 +717,11 @@ def forward(
717717
img_ids = img_ids[0]
718718

719719
ids = torch.cat((txt_ids, img_ids), dim=0)
720-
image_rotary_emb = self.pos_embed(ids)
720+
if is_torch_npu_available():
721+
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
722+
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
723+
else:
724+
image_rotary_emb = self.pos_embed(ids)
721725

722726
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
723727
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

0 commit comments

Comments
 (0)