1- from dataclasses import dataclass , field , InitVar
2- from typing import Optional , List , Dict
1+ from dataclasses import dataclass , field
2+ from typing import Optional , Any
33
44
55@dataclass (frozen = True )
@@ -15,67 +15,54 @@ def __post_init__(self) -> None:
1515
1616
1717@dataclass (frozen = True )
18- class ModelDescription :
18+ class BaseModelDescription :
1919 model : str
2020 sources : ModelSource
2121 model_file : str
22- dim : Optional [int ]
22+ description : str = ""
23+ license : str = ""
24+ size_in_GB : Optional [float ] = None
25+ additional_files : list [str ] = field (default_factory = list )
2326
24- description : str
25- license : str
26- size_in_GB : Optional [float ]
27- additional_files : List [str ] = field (default_factory = list )
28- tasks : Dict [str , int ] = field (default_factory = dict )
27+ def validate_info (self ) -> None :
28+ if self .license == "" :
29+ raise ValueError ("license is required in builtin model description" )
30+
31+ if self .description == "" :
32+ raise ValueError ("description is required in builtin model description" )
33+
34+ if self .size_in_GB is None :
35+ raise ValueError ("size_in_GB is required in builtin model description" )
36+
37+ def __post_init__ (self ) -> None :
38+ self .validate_info ()
2939
3040
3141@dataclass (frozen = True )
32- class MultimodalModelDescription (ModelDescription ):
33- dim : int
42+ class DenseModelDescription (BaseModelDescription ):
43+ dim : Optional [int ] = None
44+ tasks : Optional [dict [str , Any ]] = None
45+
46+ def __post_init__ (self ) -> None :
47+ assert self .dim is not None , "dim is required for dense model description"
48+ self .validate_info ()
3449
3550
3651@dataclass (frozen = True )
37- class SparseModelDescription (ModelDescription ):
38- _vocab_size : InitVar [Optional [int ]] = None
39- _requires_idf : InitVar [Optional [bool ]] = None
40-
41- vocab_size : int = field (init = False )
42- requires_idf : Optional [bool ] = field (init = False , default = None )
43- dim : Optional [int ] = field (default = None , init = False )
44-
45- def __init__ (
46- self ,
47- * ,
48- model : str ,
49- sources : ModelSource ,
50- model_file : str ,
51- description : str ,
52- license : str ,
53- size_in_GB : Optional [float ],
54- dim : Optional [int ] = None ,
55- additional_files : Optional [List [str ]] = None ,
56- tasks : Optional [Dict [str , int ]] = None ,
57- vocab_size : int ,
58- requires_idf : Optional [bool ] = None ,
59- ):
60- # Call the parent initializer with the fields it needs.
61- object .__setattr__ (self , "model" , model )
62- object .__setattr__ (self , "sources" , sources )
63- object .__setattr__ (self , "model_file" , model_file )
64- object .__setattr__ (self , "dim" , dim if dim else None )
65- object .__setattr__ (self , "description" , description )
66- object .__setattr__ (self , "license" , license )
67- object .__setattr__ (self , "size_in_GB" , size_in_GB )
68- object .__setattr__ (
69- self , "additional_files" , additional_files if additional_files is not None else []
70- )
71- object .__setattr__ (self , "tasks" , tasks if tasks is not None else {})
72- # Set new fields.
73- object .__setattr__ (self , "vocab_size" , vocab_size )
74- object .__setattr__ (self , "requires_idf" , requires_idf )
52+ class SparseModelDescription (BaseModelDescription ):
53+ requires_idf : Optional [bool ] = None
54+ vocab_size : Optional [int ] = None
7555
7656
7757@dataclass (frozen = True )
78- class CustomModelDescription (ModelDescription ):
79- description : str = ""
80- license : str = ""
81- size_in_GB : Optional [float ] = None
58+ class CustomDenseModelDescription (DenseModelDescription ):
59+ def __post_init__ (self ) -> None :
60+ if self .dim is None :
61+ raise ValueError ("dim is required for custom dense model description" )
62+ # disable self.validate_info
63+
64+
65+ @dataclass (frozen = True )
66+ class CustomSparseModelDescription (SparseModelDescription ):
67+ def __post_init__ (self ) -> None :
68+ pass # disable self.validate_info
0 commit comments