|
1 | | -from dataclasses import dataclass, field |
| 1 | +from dataclasses import dataclass, field, InitVar |
2 | 2 | from typing import Optional, List, Dict |
3 | 3 |
|
4 | 4 |
|
@@ -28,12 +28,52 @@ class ModelDescription: |
28 | 28 | tasks: Dict[str, int] = field(default_factory=dict) |
29 | 29 |
|
30 | 30 |
|
31 | | -@dataclass(frozen=True, kw_only=True) |
| 31 | +# @dataclass(frozen=True, kw_only=True) |
| 32 | +# class SparseModelDescription(ModelDescription): |
| 33 | +# vocab_size: int |
| 34 | +# requires_idf: Optional[bool] = None |
| 35 | +# # For sparse models, override dim to always be None. |
| 36 | +# dim: Optional[int] = None |
| 37 | + |
| 38 | + |
| 39 | +@dataclass(frozen=True) |
32 | 40 | class SparseModelDescription(ModelDescription): |
33 | | - vocab_size: int |
34 | | - requires_idf: Optional[bool] = None |
35 | | - # For sparse models, override dim to always be None. |
36 | | - dim: Optional[int] = None |
| 41 | + _vocab_size: InitVar[Optional[int]] = None |
| 42 | + _requires_idf: InitVar[Optional[bool]] = None |
| 43 | + |
| 44 | + vocab_size: int = field(init=False) |
| 45 | + requires_idf: Optional[bool] = field(init=False, default=None) |
| 46 | + dim: Optional[int] = field(default=None, init=False) # Always None for sparse models. |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + *, |
| 51 | + model: str, |
| 52 | + sources: ModelSource, |
| 53 | + model_file: str, |
| 54 | + description: str, |
| 55 | + license: str, |
| 56 | + size_in_GB: float, |
| 57 | + additional_files: Optional[List[str]] = None, |
| 58 | + tasks: Optional[Dict[str, int]] = None, |
| 59 | + vocab_size: int, |
| 60 | + requires_idf: Optional[bool] = None, |
| 61 | + ): |
| 62 | + # Call the parent initializer with the fields it needs. |
| 63 | + object.__setattr__(self, "model", model) |
| 64 | + object.__setattr__(self, "sources", sources) |
| 65 | + object.__setattr__(self, "model_file", model_file) |
| 66 | + object.__setattr__(self, "dim", None) |
| 67 | + object.__setattr__(self, "description", description) |
| 68 | + object.__setattr__(self, "license", license) |
| 69 | + object.__setattr__(self, "size_in_GB", size_in_GB) |
| 70 | + object.__setattr__( |
| 71 | + self, "additional_files", additional_files if additional_files is not None else [] |
| 72 | + ) |
| 73 | + object.__setattr__(self, "tasks", tasks if tasks is not None else {}) |
| 74 | + # Set new fields. |
| 75 | + object.__setattr__(self, "vocab_size", vocab_size) |
| 76 | + object.__setattr__(self, "requires_idf", requires_idf) |
37 | 77 |
|
38 | 78 |
|
39 | 79 | @dataclass(frozen=True) |
|
0 commit comments