@@ -164,49 +164,49 @@ def from_params(cls, params):
164164@dataclass
165165class ModelArgs :
166166 model_type : ModelType
167- transformer_args : Dict [str , Union [Dict , TransformerArgs ]]
167+ transformer_args : Dict [str , Dict [str , Any ]]
168+ use_tiktoken : bool
168169
169170 def __init__ (
170171 self ,
171- transformer_args : Union [ TransformerArgs , Dict [str , TransformerArgs ]],
172+ transformer_args : Dict [ str , Dict [str , Any ]],
172173 model_type : ModelType = ModelType .TextOnly ,
174+ use_tiktoken : bool = False ,
173175 ) -> None :
174176 self ._sanity_check (transformer_args , model_type )
175177
176178 self .model_type = model_type
177- if isinstance (transformer_args , TransformerArgs ):
178- assert model_type == ModelType .TextOnly
179- self .transformer_args = {"text" : transformer_args }
180- else :
181- self .transformer_args = transformer_args
179+ self .transformer_args = transformer_args
180+
181+ # Model-level attributes
182+ self .use_tiktoken = use_tiktoken
182183
183184 def _sanity_check (
184185 self ,
185- transformer_args : Union [ TransformerArgs , Dict [str , TransformerArgs ]],
186+ transformer_args : Dict [ str , Dict [str , Any ]],
186187 model_type : ModelType ,
187188 ) -> None :
188- assert isinstance (model_type , ModelType )
189- assert isinstance (transformer_args , ( TransformerArgs , dict ) )
189+ assert isinstance (model_type , ModelType ), model_type
190+ assert isinstance (transformer_args , dict )
190191
191192 @classmethod
192193 def from_params (cls , params_path ):
193194 with open (params_path , "r" ) as f :
194195 loaded_params = json .loads (f .read ())
195-
196- try :
197- # try to interpret as a single transformer config
198- transformer_args : Dict [str , TransformerArgs ] = {}
199- transformer_args ["text" ] = TransformerArgs .from_params (loaded_params )
200- if (model_type := loaded_params .get ("model_type" , None )) is None :
201- model_type = ModelType .TextOnly
202-
203- except TypeError :
204- # try to interpret as a dict of transformer configs
205- model_type = ModelType (loaded_params ["model_type" ])
196+
197+ if (model_type_name := loaded_params .get ("model_type" , None )) is None :
198+ # The model params is in the transformer_args format
199+ # set the model_type to TextOnly and reformat the params
200+ model_type = ModelType .TextOnly
201+ transformer_args = {"text" : {"config" : loaded_params }}
202+ else :
203+ model_type = ModelType (model_type_name )
206204 transformer_args = {
207205 k : v for k , v in loaded_params .items () if k != "model_type"
208206 }
209- return cls (transformer_args , model_type )
207+
208+ use_tiktoken = loaded_params .get ("use_tiktoken" , False )
209+ return cls (transformer_args , model_type , use_tiktoken )
210210
211211 @classmethod
212212 def from_table (cls , name : str ):
@@ -304,10 +304,8 @@ def build_model(self) -> nn.Module:
304304 recipe = ModelRecipe .get_recipe (self .config .model_type )
305305 modules = {}
306306 for name , module_class in recipe .modules .items ():
307- if isinstance (config_args := self .config .transformer_args [name ], dict ):
308- modules [name ] = module_class (** config_args )
309- else :
310- modules [name ] = module_class (config_args )
307+ config_args = self .config .transformer_args [name ]
308+ modules [name ] = module_class (** config_args )
311309
312310 return recipe .fusion_class (** modules )
313311
@@ -399,8 +397,9 @@ def reset_caches(self):
399397
400398
401399class Transformer (nn .Module ):
402- def __init__ (self , config : TransformerArgs ) -> None :
400+ def __init__ (self , config : Dict [ str , Any ] ) -> None :
403401 super ().__init__ ()
402+ config = TransformerArgs .from_params (config )
404403 self .config = config
405404 layers_per_stage = config .n_layers // config .n_stages
406405
0 commit comments