22import platform
33import sys
44import tempfile
5- from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional
5+ from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Mapping
66
77import pytest
88import yaml
99from _pytest .config .argparsing import Parser
1010from _pytest .nodes import Node
1111from py ._path .local import LocalPath
12+ import pystache
1213
1314from pytest_mypy_plugins import utils
1415
@@ -42,6 +43,27 @@ def parse_environment_variables(env_vars: List[str]) -> Dict[str, str]:
4243 return parsed_vars
4344
4445
46+ def parse_parametrized (params : List [Mapping [str , Any ]]) -> List [Mapping [str , Any ]]:
47+ if not params :
48+ return [{}]
49+
50+ parsed_params : List [Mapping [str , Any ]] = []
51+ known_params = None
52+ for idx , param in enumerate (params ):
53+ param_keys = set (sorted (param .keys ()))
54+ if not known_params :
55+ known_params = param_keys
56+ elif known_params .intersection (param_keys ) != known_params :
57+ raise ValueError (
58+ "All parametrized entries must have same keys."
59+ f'First entry is { ", " .join (known_params )} but { ", " .join (param_keys )} '
60+ "was spotted at {idx} position" ,
61+ )
62+ parsed_params .append ({k : v for k , v in param .items () if not k .startswith ("__" )})
63+
64+ return parsed_params
65+
66+
4567class SafeLineLoader (yaml .SafeLoader ):
4668 def construct_mapping (self , node : yaml .Node , deep : bool = False ) -> None :
4769 mapping = super ().construct_mapping (node , deep = deep )
@@ -66,37 +88,47 @@ def collect(self) -> Iterator["YamlTestItem"]:
6688 raise ValueError (f"Test file has to be YAML list, got { type (parsed_file )!r} ." )
6789
6890 for raw_test in parsed_file :
69- test_name = raw_test ["case" ]
70- if " " in test_name :
71- raise ValueError (f"Invalid test name { test_name !r} , only '[a-zA-Z0-9_]' is allowed." )
72-
73- test_files = [File (path = "main.py" , content = raw_test ["main" ])]
74- test_files += parse_test_files (raw_test .get ("files" , []))
75-
76- output_from_comments = []
77- for test_file in test_files :
78- output_lines = utils .extract_errors_from_comments (test_file .path , test_file .content .split ("\n " ))
79- output_from_comments .extend (output_lines )
80-
81- starting_lineno = raw_test ["__line__" ]
82- extra_environment_variables = parse_environment_variables (raw_test .get ("env" , []))
83- disable_cache = raw_test .get ("disable_cache" , False )
84- expected_output_lines = raw_test .get ("out" , "" ).split ("\n " )
85- additional_mypy_config = raw_test .get ("mypy_config" , "" )
86-
87- skip = self ._eval_skip (str (raw_test .get ("skip" , "False" )))
88- if not skip :
89- yield YamlTestItem .from_parent (
90- self ,
91- name = test_name ,
92- files = test_files ,
93- starting_lineno = starting_lineno ,
94- environment_variables = extra_environment_variables ,
95- disable_cache = disable_cache ,
96- expected_output_lines = output_from_comments + expected_output_lines ,
97- parsed_test_data = raw_test ,
98- mypy_config = additional_mypy_config ,
99- )
91+ test_name_prefix = raw_test ["case" ]
92+ if " " in test_name_prefix :
93+ raise ValueError (f"Invalid test name { test_name_prefix !r} , only '[a-zA-Z0-9_]' is allowed." )
94+ else :
95+ parametrized = parse_parametrized (raw_test .get ("parametrized" , []))
96+
97+ for params in parametrized :
98+ if params :
99+ test_name_suffix = "," .join (f"{ k } ={ v } " for k , v in params .items ())
100+ test_name_suffix = f"[{ test_name_suffix } ]"
101+ else :
102+ test_name_suffix = ""
103+
104+ test_name = f"{ test_name_prefix } { test_name_suffix } "
105+ main_file = File (path = "main.py" , content = pystache .render (raw_test ["main" ], params ))
106+ test_files = [main_file ] + parse_test_files (raw_test .get ("files" , []))
107+
108+ output_from_comments = []
109+ for test_file in test_files :
110+ output_lines = utils .extract_errors_from_comments (test_file .path , test_file .content .split ("\n " ))
111+ output_from_comments .extend (output_lines )
112+
113+ starting_lineno = raw_test ["__line__" ]
114+ extra_environment_variables = parse_environment_variables (raw_test .get ("env" , []))
115+ disable_cache = raw_test .get ("disable_cache" , False )
116+ expected_output_lines = pystache .render (raw_test .get ("out" , "" ), params ).split ("\n " )
117+ additional_mypy_config = raw_test .get ("mypy_config" , "" )
118+
119+ skip = self ._eval_skip (str (raw_test .get ("skip" , "False" )))
120+ if not skip :
121+ yield YamlTestItem .from_parent (
122+ self ,
123+ name = test_name ,
124+ files = test_files ,
125+ starting_lineno = starting_lineno ,
126+ environment_variables = extra_environment_variables ,
127+ disable_cache = disable_cache ,
128+ expected_output_lines = output_from_comments + expected_output_lines ,
129+ parsed_test_data = raw_test ,
130+ mypy_config = additional_mypy_config ,
131+ )
100132
101133 def _eval_skip (self , skip_if : str ) -> bool :
102134 return eval (skip_if , {"sys" : sys , "os" : os , "pytest" : pytest , "platform" : platform })
0 commit comments