diff --git a/install/requirements.txt b/install/requirements.txt index bb9e35a6c..991768be5 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -27,6 +27,9 @@ cmake>=3.24, < 4.0.0 # 4.0 is BC breaking ninja zstd +# Test tools +pytest + # Browser mode streamlit diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..8fe1384b8 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + model_config: Tests related to model config diff --git a/torchchat/model_config/tests/test_model_config.py b/torchchat/model_config/tests/test_model_config.py new file mode 100644 index 000000000..5b4d86c6b --- /dev/null +++ b/torchchat/model_config/tests/test_model_config.py @@ -0,0 +1,24 @@ +import pytest +from torchchat.model_config.model_config import load_model_configs, resolve_model_config + + +TEST_CONFIG = "meta-llama/llama-3.2-11b-vision" +TEST_CONFIG_NAME = "meta-llama/Llama-3.2-11B-Vision" + + +@pytest.mark.model_config +def test_load_model_configs(): + configs = load_model_configs() + assert TEST_CONFIG in configs + assert configs[TEST_CONFIG].name == TEST_CONFIG_NAME + + +@pytest.mark.model_config +def test_resolve_model_config(): + config = resolve_model_config(TEST_CONFIG) + print(config) + assert config.name == TEST_CONFIG_NAME + assert config.checkpoint_file == "model.pth" + + with pytest.raises(ValueError): + resolve_model_config("UnknownModel")