Skip to content

Commit d2f015e

Browse files
authored
[0.16.0] remove cuda hard-code for Hunyuan Image3 (#1402)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 0b577a7 commit d2f015e

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464

6565
from vllm_omni.diffusion.attention.layer import Attention
6666
from vllm_omni.diffusion.distributed.parallel_state import get_pp_group
67+
from vllm_omni.diffusion.distributed.utils import get_local_device
6768
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
6869

6970
logger = logging.getLogger(__name__)
@@ -1737,6 +1738,7 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""):
17371738
lora_config = None
17381739
self.num_redundant_experts = 0
17391740
self.config = config
1741+
self.device = get_local_device()
17401742
self.quant_config = quant_config
17411743
self.padding_idx = config.pad_token_id
17421744
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
@@ -2430,7 +2432,7 @@ def __call__(
24302432
**model_kwargs,
24312433
)
24322434

2433-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
2435+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True):
24342436
model_output = self.model.forward_call(**model_inputs, first_step=(i == 0))
24352437
pred = model_output["diffusion_prediction"]
24362438
pred = pred.to(dtype=torch.float32)
@@ -2477,7 +2479,7 @@ def __call__(
24772479
if hasattr(self.vae, "ffactor_temporal"):
24782480
latents = latents.unsqueeze(2)
24792481

2480-
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
2482+
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=True):
24812483
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
24822484

24832485
if hasattr(self.vae, "ffactor_temporal"):

vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def pre_load(self):
140140
if hasattr(self, prefix.split(".")[0]):
141141
module = dict(self.named_modules()).get(prefix)
142142
if module:
143-
module.to(f"cuda:{tp_rank}")
143+
module.to(f"{self.model.device.type}:{tp_rank}")
144144

145145
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
146146
self.pre_load()
@@ -369,7 +369,7 @@ def build_batch_rope_image_info(output, sections):
369369
def vae_encode(self, image, cfg_factor=1):
370370
config = self.vae.config
371371

372-
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
372+
with torch.autocast(device_type=self.model.device.type, dtype=torch.float16, enabled=True):
373373
vae_encode_result = self.vae.encode(image)
374374
if isinstance(vae_encode_result, torch.Tensor):
375375
latents = vae_encode_result

vllm_omni/platforms/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def xpu_omni_platform_plugin() -> str | None:
8484
is_xpu = False
8585
logger.debug("Checking if XPU OmniPlatform is available.")
8686
try:
87-
# installed IPEX if the machine has XPUs.
88-
import intel_extension_for_pytorch # noqa: F401
8987
import torch
9088

9189
if supports_xccl():

0 commit comments

Comments
 (0)