Skip to content

Commit 79a8369

Browse files
committed
fix: fixed bug in loading kbit model
There was a bug in loading GenericLoraKbitModel, fixed that
1 parent 60d4a97 commit 79a8369

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/xturing/models/causal.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,19 @@ def __init__(
399399

400400

401401
class CausalLoraKbitModel(CausalLoraModel):
402-
def __init__(self, engine: str, weights_path: Optional[str] = None):
402+
def __init__(
403+
self,
404+
engine: str,
405+
weights_path: Optional[str] = None,
406+
model_name: Optional[str] = None,
407+
target_modules: Optional[List[str]] = None,
408+
**kwargs,
409+
):
403410
assert_not_cpu_int8()
404-
super().__init__(engine, weights_path)
411+
super().__init__(
412+
engine,
413+
weights_path=weights_path,
414+
model_name=model_name,
415+
target_modules=target_modules,
416+
**kwargs,
417+
)

src/xturing/models/generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@ def __init__(
9494
model_name: str,
9595
target_modules: List[str] = ["c_attn"],
9696
weights_path: Optional[str] = None,
97+
**kwargs,
9798
):
9899
super().__init__(
99100
GenericLoraKbitEngine.config_name,
100101
weights_path,
101102
model_name=model_name,
102103
target_modules=target_modules,
104+
**kwargs,
103105
)

0 commit comments

Comments
 (0)