@@ -25,22 +25,31 @@ def test_get_untrained_model_with_inputs_tiny_llm(self):
2525 data = get_untrained_model_with_inputs (mid , verbose = 1 )
2626 model , inputs = data ["model" ], data ["inputs" ]
2727 model (** inputs )
28- self .assertEqual ((data ["size" ], data ["n_weights" ]), ( 1858125824 , 464531456 ))
28+ self .assertEqual ((1858125824 , 464531456 ), ( data ["size" ], data ["n_weights" ]))
2929
3030 @hide_stdout ()
3131 def test_get_untrained_model_with_inputs_tiny_xlm_roberta (self ):
3232 mid = "hf-internal-testing/tiny-xlm-roberta" # XLMRobertaConfig
3333 data = get_untrained_model_with_inputs (mid , verbose = 1 )
3434 model , inputs = data ["model" ], data ["inputs" ]
3535 model (** inputs )
36- self .assertEqual ((data ["size" ], data ["n_weights" ]), ( 126190824 , 31547706 ))
36+ self .assertEqual ((126190824 , 31547706 ), ( data ["size" ], data ["n_weights" ]))
3737
38+ @hide_stdout ()
3839 def test_get_untrained_model_with_inputs_tiny_gpt_neo (self ):
3940 mid = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
4041 data = get_untrained_model_with_inputs (mid , verbose = 1 )
4142 model , inputs = data ["model" ], data ["inputs" ]
4243 model (** inputs )
43- self .assertEqual ((data ["size" ], data ["n_weights" ]), (4291141632 , 1072785408 ))
44+ self .assertEqual ((4291141632 , 1072785408 ), (data ["size" ], data ["n_weights" ]))
45+
46+ @hide_stdout ()
47+ def test_get_untrained_model_with_inputs_phi_2 (self ):
48+ mid = "microsoft/phi-2"
49+ data = get_untrained_model_with_inputs (mid , verbose = 1 )
50+ model , inputs = data ["model" ], data ["inputs" ]
51+ model (** inputs )
52+ self .assertEqual ((1040498688 , 260124672 ), (data ["size" ], data ["n_weights" ]))
4453
4554
4655if __name__ == "__main__" :
0 commit comments