Skip to content

Commit 86b01b4

Browse files
committed
kw_only support
1 parent 1173971 commit 86b01b4

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

fastembed/common/model_description.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass, field, InitVar
22
from typing import Optional, List, Dict
33

44

@@ -28,12 +28,52 @@ class ModelDescription:
2828
tasks: Dict[str, int] = field(default_factory=dict)
2929

3030

31-
@dataclass(frozen=True, kw_only=True)
31+
# @dataclass(frozen=True, kw_only=True)
32+
# class SparseModelDescription(ModelDescription):
33+
# vocab_size: int
34+
# requires_idf: Optional[bool] = None
35+
# # For sparse models, override dim to always be None.
36+
# dim: Optional[int] = None
37+
38+
39+
@dataclass(frozen=True)
3240
class SparseModelDescription(ModelDescription):
33-
vocab_size: int
34-
requires_idf: Optional[bool] = None
35-
# For sparse models, override dim to always be None.
36-
dim: Optional[int] = None
41+
_vocab_size: InitVar[Optional[int]] = None
42+
_requires_idf: InitVar[Optional[bool]] = None
43+
44+
vocab_size: int = field(init=False)
45+
requires_idf: Optional[bool] = field(init=False, default=None)
46+
dim: Optional[int] = field(default=None, init=False) # Always None for sparse models.
47+
48+
def __init__(
49+
self,
50+
*,
51+
model: str,
52+
sources: ModelSource,
53+
model_file: str,
54+
description: str,
55+
license: str,
56+
size_in_GB: float,
57+
additional_files: Optional[List[str]] = None,
58+
tasks: Optional[Dict[str, int]] = None,
59+
vocab_size: int,
60+
requires_idf: Optional[bool] = None,
61+
):
62+
# Call the parent initializer with the fields it needs.
63+
object.__setattr__(self, "model", model)
64+
object.__setattr__(self, "sources", sources)
65+
object.__setattr__(self, "model_file", model_file)
66+
object.__setattr__(self, "dim", None)
67+
object.__setattr__(self, "description", description)
68+
object.__setattr__(self, "license", license)
69+
object.__setattr__(self, "size_in_GB", size_in_GB)
70+
object.__setattr__(
71+
self, "additional_files", additional_files if additional_files is not None else []
72+
)
73+
object.__setattr__(self, "tasks", tasks if tasks is not None else {})
74+
# Set new fields.
75+
object.__setattr__(self, "vocab_size", vocab_size)
76+
object.__setattr__(self, "requires_idf", requires_idf)
3777

3878

3979
@dataclass(frozen=True)

0 commit comments

Comments
 (0)