55from io import StringIO
66from pathlib import Path
77from unittest import mock
8- from unittest .mock import Mock , PropertyMock , call , ANY
8+ from unittest .mock import Mock , call , ANY
99
1010import pytest
1111import torch
@@ -59,13 +59,12 @@ def test_main(tmp_path, monkeypatch):
5959 tokenizer_path = tmp_path / "tokenizer"
6060 tokenizer_path .touch ()
6161
62- class FabricMock (PropertyMock ):
62+ class FabricMock (Mock ):
6363 @property
6464 def device (self ):
6565 return torch .device ("cpu" )
6666
67- fabric_mock = FabricMock ()
68- monkeypatch .setattr (generate .L , "Fabric" , fabric_mock )
67+ monkeypatch .setattr (generate .L , "Fabric" , FabricMock )
6968 model_mock = Mock ()
7069 monkeypatch .setattr (generate .LLaMA , "from_name" , model_mock )
7170 load_mock = Mock ()
@@ -85,7 +84,6 @@ def device(self):
8584 checkpoint_path = checkpoint_path ,
8685 tokenizer_path = tokenizer_path ,
8786 model_size = "1T" ,
88- accelerator = "litpu" ,
8987 temperature = 2.0 ,
9088 top_k = 2 ,
9189 num_samples = num_samples ,
@@ -96,18 +94,7 @@ def device(self):
9694 tokenizer_mock .assert_called_once_with (tokenizer_path )
9795 assert len (tokenizer_mock .return_value .decode .mock_calls ) == num_samples
9896 assert torch .allclose (tokenizer_mock .return_value .decode .call_args [0 ][0 ], generate_mock .return_value )
99- model = model_mock .return_value
100- assert fabric_mock .mock_calls == [
101- call (accelerator = "litpu" , devices = 1 ),
102- call ().device .__enter__ (),
103- call ().device .__exit__ (None , None , None ),
104- call ().setup_module (model ),
105- ]
106- model = fabric_mock .return_value .setup_module .return_value
107- assert (
108- generate_mock .mock_calls
109- == [call (model , ANY , 50 , model .config .block_size , temperature = 2.0 , top_k = 2 )] * num_samples
110- )
97+ assert generate_mock .mock_calls == [call (ANY , ANY , 50 , ANY , temperature = 2.0 , top_k = 2 )] * num_samples
11198 # only the generated result is printed to stdout
11299 assert out .getvalue () == "foo bar baz\n " * num_samples
113100
0 commit comments