Skip to content

Commit 906b3da

Browse files
author
Tomasz Trębski
authored
Add @parametrized handling (#27)
* Add parametrized handling * fixup! Add parametrized handling * Add parameters to out * Add an example to README
1 parent f01f8c3 commit 906b3da

File tree

4 files changed

+137
-49
lines changed

4 files changed

+137
-49
lines changed

README.md

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,40 @@
66
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
77
[![Gitter](https://badges.gitter.im/mypy-django/Lobby.svg)](https://gitter.im/mypy-django/Lobby)
88

9-
109
## Installation
1110

1211
```bash
1312
pip install pytest-mypy-plugins
1413
```
1514

16-
1715
## Usage
1816

1917
Examples of a test case:
2018

2119
```yaml
2220
# typesafety/test_request.yml
23-
- case: request_object_has_user_of_type_auth_user_model
24-
disable_cache: true
25-
main: |
26-
from django.http.request import HttpRequest
27-
reveal_type(HttpRequest().user) # N: Revealed type is 'myapp.models.MyUser'
28-
# check that other fields work ok
29-
reveal_type(HttpRequest().method) # N: Revealed type is 'Union[builtins.str, None]'
30-
files:
31-
- path: myapp/__init__.py
32-
- path: myapp/models.py
33-
content: |
34-
from django.db import models
35-
class MyUser(models.Model):
36-
pass
21+
- case: request_object_has_user_of_type_auth_user_model
22+
disable_cache: true
23+
main: |
24+
from django.http.request import HttpRequest
25+
reveal_type(HttpRequest().user) # N: Revealed type is 'myapp.models.MyUser'
26+
# check that other fields work ok
27+
reveal_type(HttpRequest().method) # N: Revealed type is 'Union[builtins.str, None]'
28+
files:
29+
- path: myapp/__init__.py
30+
- path: myapp/models.py
31+
content: |
32+
from django.db import models
33+
class MyUser(models.Model):
34+
pass
35+
- case: with_params
36+
parametrized:
37+
- val: 1
38+
rt: builtins.int
39+
- val: 1.0
40+
rt: builtins.float
41+
main: |
42+
reveal_type({[ val }}) # N: Reveal type is '{{ rt }}'
3743
```
3844
3945
Running:

pytest_mypy_plugins/collect.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import platform
33
import sys
44
import tempfile
5-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
5+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Mapping
66

77
import pytest
88
import yaml
99
from _pytest.config.argparsing import Parser
1010
from _pytest.nodes import Node
1111
from py._path.local import LocalPath
12+
import pystache
1213

1314
from pytest_mypy_plugins import utils
1415

@@ -42,6 +43,27 @@ def parse_environment_variables(env_vars: List[str]) -> Dict[str, str]:
4243
return parsed_vars
4344

4445

46+
def parse_parametrized(params: List[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
47+
if not params:
48+
return [{}]
49+
50+
parsed_params: List[Mapping[str, Any]] = []
51+
known_params = None
52+
for idx, param in enumerate(params):
53+
param_keys = set(sorted(param.keys()))
54+
if not known_params:
55+
known_params = param_keys
56+
elif known_params.intersection(param_keys) != known_params:
57+
raise ValueError(
58+
"All parametrized entries must have same keys."
59+
f'First entry is {", ".join(known_params)} but {", ".join(param_keys)} '
60+
"was spotted at {idx} position",
61+
)
62+
parsed_params.append({k: v for k, v in param.items() if not k.startswith("__")})
63+
64+
return parsed_params
65+
66+
4567
class SafeLineLoader(yaml.SafeLoader):
4668
def construct_mapping(self, node: yaml.Node, deep: bool = False) -> None:
4769
mapping = super().construct_mapping(node, deep=deep)
@@ -66,37 +88,47 @@ def collect(self) -> Iterator["YamlTestItem"]:
6688
raise ValueError(f"Test file has to be YAML list, got {type(parsed_file)!r}.")
6789

6890
for raw_test in parsed_file:
69-
test_name = raw_test["case"]
70-
if " " in test_name:
71-
raise ValueError(f"Invalid test name {test_name!r}, only '[a-zA-Z0-9_]' is allowed.")
72-
73-
test_files = [File(path="main.py", content=raw_test["main"])]
74-
test_files += parse_test_files(raw_test.get("files", []))
75-
76-
output_from_comments = []
77-
for test_file in test_files:
78-
output_lines = utils.extract_errors_from_comments(test_file.path, test_file.content.split("\n"))
79-
output_from_comments.extend(output_lines)
80-
81-
starting_lineno = raw_test["__line__"]
82-
extra_environment_variables = parse_environment_variables(raw_test.get("env", []))
83-
disable_cache = raw_test.get("disable_cache", False)
84-
expected_output_lines = raw_test.get("out", "").split("\n")
85-
additional_mypy_config = raw_test.get("mypy_config", "")
86-
87-
skip = self._eval_skip(str(raw_test.get("skip", "False")))
88-
if not skip:
89-
yield YamlTestItem.from_parent(
90-
self,
91-
name=test_name,
92-
files=test_files,
93-
starting_lineno=starting_lineno,
94-
environment_variables=extra_environment_variables,
95-
disable_cache=disable_cache,
96-
expected_output_lines=output_from_comments + expected_output_lines,
97-
parsed_test_data=raw_test,
98-
mypy_config=additional_mypy_config,
99-
)
91+
test_name_prefix = raw_test["case"]
92+
if " " in test_name_prefix:
93+
raise ValueError(f"Invalid test name {test_name_prefix!r}, only '[a-zA-Z0-9_]' is allowed.")
94+
else:
95+
parametrized = parse_parametrized(raw_test.get("parametrized", []))
96+
97+
for params in parametrized:
98+
if params:
99+
test_name_suffix = ",".join(f"{k}={v}" for k, v in params.items())
100+
test_name_suffix = f"[{test_name_suffix}]"
101+
else:
102+
test_name_suffix = ""
103+
104+
test_name = f"{test_name_prefix}{test_name_suffix}"
105+
main_file = File(path="main.py", content=pystache.render(raw_test["main"], params))
106+
test_files = [main_file] + parse_test_files(raw_test.get("files", []))
107+
108+
output_from_comments = []
109+
for test_file in test_files:
110+
output_lines = utils.extract_errors_from_comments(test_file.path, test_file.content.split("\n"))
111+
output_from_comments.extend(output_lines)
112+
113+
starting_lineno = raw_test["__line__"]
114+
extra_environment_variables = parse_environment_variables(raw_test.get("env", []))
115+
disable_cache = raw_test.get("disable_cache", False)
116+
expected_output_lines = pystache.render(raw_test.get("out", ""), params).split("\n")
117+
additional_mypy_config = raw_test.get("mypy_config", "")
118+
119+
skip = self._eval_skip(str(raw_test.get("skip", "False")))
120+
if not skip:
121+
yield YamlTestItem.from_parent(
122+
self,
123+
name=test_name,
124+
files=test_files,
125+
starting_lineno=starting_lineno,
126+
environment_variables=extra_environment_variables,
127+
disable_cache=disable_cache,
128+
expected_output_lines=output_from_comments + expected_output_lines,
129+
parsed_test_data=raw_test,
130+
mypy_config=additional_mypy_config,
131+
)
100132

101133
def _eval_skip(self, skip_if: str) -> bool:
102134
return eval(skip_if, {"sys": sys, "os": os, "pytest": pytest, "platform": platform})
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
---
2+
- case: only_main
3+
parametrized:
4+
- a: 1
5+
revealed_type: builtins.int
6+
- a: 1.0
7+
revealed_type: builtins.float
8+
main: |
9+
a = {{ a }}
10+
reveal_type(a) # N: Revealed type is '{{ revealed_type }}'
11+
- case: with_extra
12+
parametrized:
13+
- a: 2
14+
b: null
15+
rt: Any
16+
- a: 3
17+
b: 3
18+
rt: Any
19+
main: |
20+
import foo
21+
reveal_type(foo.test({{ a }}, {{ b }})) # N: Revealed type is '{{ rt }}'
22+
files:
23+
- path: foo.py
24+
content: |
25+
from typing import Any
26+
27+
def test(a: Any, b: Any) -> Any:
28+
...
29+
- case: with_out
30+
parametrized:
31+
- what: cat
32+
rt: builtins.str
33+
- what: dog
34+
rt: builtins.str
35+
main: |
36+
animal = '{{ what }}'
37+
reveal_type(animal)
38+
try:
39+
animal / 2
40+
except Exception:
41+
...
42+
out: |
43+
main:2: note: Revealed type is '{{ rt }}'
44+
main:4: error: Unsupported operand types for / ("str" and "int")

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
with open("README.md", "r") as f:
44
readme = f.read()
55

6-
dependencies = ["pytest>=5.4.0", "mypy>=0.730", "decorator", "pyyaml"]
6+
dependencies = [
7+
"pytest>=5.4.0",
8+
"mypy>=0.730",
9+
"decorator",
10+
"pyyaml",
11+
"pystache>=0.5.4",
12+
]
713

814
setup(
915
name="pytest-mypy-plugins",

0 commit comments

Comments
 (0)