Skip to content

Commit b942ddf

Browse files
committed
fix: fix split module feature not used
1 parent ec4c96a commit b942ddf

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

nodes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from transformers import AutoModelForImageSegmentation, AutoConfig
22
from transformers.dynamic_module_utils import get_class_from_dynamic_module
3+
from transformers.models.auto.auto_factory import add_generation_mixin_to_remote_model
34
import torch
45
from torchvision import transforms
56
import numpy as np
@@ -147,7 +148,7 @@ def get_device_by_name(device):
147148
if torch.cuda.is_available():
148149
device = "cuda"
149150
elif torch.backends.mps.is_available():
150-
device = "mps"
151+
device = "cpu"
151152
elif torch.xpu.is_available():
152153
device = "xpu"
153154
except:
@@ -352,7 +353,12 @@ def background_remove(self,
352353
except Exception as e:
353354
print('No need to delete:', e)
354355

355-
birefnet = AutoModelForImageSegmentation.from_pretrained(local_model_path,trust_remote_code=True, **spare_params)
356+
AutoModelForImageSegmentation.register(config.__class__, model_class, exist_ok=True)
357+
model_class = add_generation_mixin_to_remote_model(model_class)
358+
birefnet = model_class.from_pretrained(
359+
local_model_path, config=config, **spare_params
360+
)
361+
# birefnet = AutoModelForImageSegmentation.from_pretrained(local_model_path,trust_remote_code=True, **spare_params)
356362
if cached:
357363
_birefnet_model = birefnet
358364
_birefnet_model_name = local_model_path

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-birefnet-super"
33
description = "This repository packages the latest BiRefNet model as a ComfyUI node for use, supporting chunked loading on both CPU and GPU, as well as model caching features."
4-
version = "1.0.6"
4+
version = "1.0.7"
55
license = {file = "LICENSE"}
66

77
[project.urls]

0 commit comments

Comments
 (0)