Skip to content

Commit 485a852

Browse files
authored
Merge pull request regro#3935 from mgorny/v1-combine-not
Support combining trivial negations in v1 recipes
2 parents 3172e1b + 8944b32 commit 485a852

File tree

3 files changed

+134
-36
lines changed

3 files changed

+134
-36
lines changed

conda_forge_tick/migrators/recipe_v1.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from pathlib import Path
44
from typing import Any
55

6+
from jinja2 import Environment
7+
from jinja2.nodes import Compare, Node, Not
8+
from jinja2.parser import Parser
9+
610
from conda_forge_tick.migrators.core import MiniMigrator
711
from conda_forge_tick.recipe_parser._parser import _get_yaml_parser
812

@@ -12,28 +16,71 @@
1216
logger = logging.getLogger(__name__)
1317

1418

15-
def is_same_condition(a: Any, b: Any) -> bool:
16-
return (
17-
isinstance(a, dict)
18-
and isinstance(b, dict)
19-
and "if" in a
20-
and "if" in b
21-
and a["if"] == b["if"]
22-
)
19+
def get_condition(node: Any) -> Node | None:
20+
if isinstance(node, dict) and "if" in node:
21+
return Parser(
22+
Environment(), node["if"].strip(), state="variable"
23+
).parse_expression()
24+
return None
25+
26+
27+
def is_same_condition(a: Node, b: Node) -> bool:
28+
return a == b
29+
30+
31+
INVERSE_OPS = {
32+
"eq": "ne",
33+
"ne": "eq",
34+
"gt": "lteq",
35+
"gteq": "lt",
36+
"lt": "gteq",
37+
"lteq": "gt",
38+
"in": "notin",
39+
"notin": "in",
40+
}
41+
2342

43+
def is_negated_condition(a: Node, b: Node) -> bool:
44+
# X <-> not X
45+
if Not(a) == b or a == Not(b):
46+
return True
2447

25-
def fold_branch(source: Any, dest: Any, branch: str) -> None:
48+
# unwrap (not X) <-> (not Y)
49+
if isinstance(a, Not) and isinstance(b, Not):
50+
a = a.node
51+
b = b.node
52+
53+
# A == B <-> A != B
54+
if (
55+
isinstance(a, Compare)
56+
and isinstance(b, Compare)
57+
and len(a.ops) == len(b.ops) == 1
58+
and a.expr == b.expr
59+
and a.ops[0].expr == b.ops[0].expr
60+
and a.ops[0].op == INVERSE_OPS[b.ops[0].op]
61+
):
62+
return True
63+
64+
return False
65+
66+
67+
def fold_branch(source: Any, dest: Any, branch: str, dest_branch: str) -> None:
2668
if branch not in source:
2769
return
70+
2871
source_l = source[branch]
2972
if isinstance(source_l, str):
73+
if dest_branch not in dest:
74+
# special-case: do not expand a single string to list
75+
dest[dest_branch] = source_l
76+
return
3077
source_l = [source_l]
3178

32-
if branch not in dest:
33-
dest[branch] = []
34-
elif isinstance(dest[branch], str):
35-
dest[branch] = [dest[branch]]
36-
dest[branch].extend(source_l)
79+
if dest_branch not in dest:
80+
dest[dest_branch] = []
81+
elif isinstance(dest[dest_branch], str):
82+
dest[dest_branch] = [dest[dest_branch]]
83+
dest[dest_branch].extend(source_l)
3784

3885

3986
def combine_conditions(node: Any):
@@ -45,9 +92,18 @@ def combine_conditions(node: Any):
4592
# iterate in reverse order, so we can remove elements on the fly
4693
# start at index 1, since we can only fold to the previous node
4794
for i in reversed(range(1, len(node))):
48-
if is_same_condition(node[i], node[i - 1]):
49-
fold_branch(node[i], node[i - 1], "then")
50-
fold_branch(node[i], node[i - 1], "else")
95+
node_cond = get_condition(node[i])
96+
prev_cond = get_condition(node[i - 1])
97+
if node_cond is None or prev_cond is None:
98+
continue
99+
100+
if is_same_condition(node_cond, prev_cond):
101+
fold_branch(node[i], node[i - 1], "then", "then")
102+
fold_branch(node[i], node[i - 1], "else", "else")
103+
del node[i]
104+
elif is_negated_condition(node_cond, prev_cond):
105+
fold_branch(node[i], node[i - 1], "then", "else")
106+
fold_branch(node[i], node[i - 1], "else", "then")
51107
del node[i]
52108

53109
# then we descend down the tree

tests/test_recipe_v1.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from pathlib import Path
22

3+
import pytest
34
from flaky import flaky
45
from test_migrators import run_test_migration
56

67
from conda_forge_tick.migrators import (
78
CombineV1ConditionsMigrator,
89
Version,
910
)
11+
from conda_forge_tick.migrators.recipe_v1 import (
12+
get_condition,
13+
is_negated_condition,
14+
)
1015

1116
YAML_PATH = Path(__file__).parent / "test_v1_yaml"
1217

@@ -16,6 +21,53 @@
1621
)
1722

1823

24+
@pytest.mark.parametrize(
25+
"a,b",
26+
[
27+
("unix", "not unix"),
28+
('cuda_compiler_version == "None"', 'not cuda_compiler_version == "None"'),
29+
('cuda_compiler_version == "None"', 'cuda_compiler_version != "None"'),
30+
('not cuda_compiler_version == "None"', 'not cuda_compiler_version != "None"'),
31+
(
32+
'cuda_compiler_version != "None" and linux',
33+
'not (cuda_compiler_version != "None" and linux)',
34+
),
35+
("linux or osx", "not (linux or osx)"),
36+
("a >= 14", "a < 14"),
37+
("a >= 14", "not (a >= 14)"),
38+
("a in [1, 2, 3]", "a not in [1, 2, 3]"),
39+
("a in [1, 2, 3]", "not a in [1, 2, 3]"),
40+
("a + b < 10", "a + b >= 10"),
41+
("a == b == c", "not (a == b == c)"),
42+
],
43+
)
44+
def test_is_negated_condition(a, b):
45+
a_cond = get_condition({"if": a})
46+
b_cond = get_condition({"if": b})
47+
assert is_negated_condition(a_cond, b_cond)
48+
assert is_negated_condition(b_cond, a_cond)
49+
50+
51+
@pytest.mark.parametrize(
52+
"a,b",
53+
[
54+
("not unix", "not unix"),
55+
('cuda_compiler_version == "None"', 'not cuda_compiler_version != "None"'),
56+
('cuda_compiler_version != "None"', 'not cuda_compiler_version == "None"'),
57+
("a or b", "not a or b"),
58+
("a and b", "not a and b"),
59+
("a == b == c", "a != b != c"),
60+
("a > 4", "a < 4"),
61+
("a == b == c", "not (a == b) == c"),
62+
],
63+
)
64+
def test_not_is_negated_condition(a, b):
65+
a_cond = get_condition({"if": a})
66+
b_cond = get_condition({"if": b})
67+
assert not is_negated_condition(a_cond, b_cond)
68+
assert not is_negated_condition(b_cond, a_cond)
69+
70+
1971
@flaky
2072
def test_combine_v1_conditions(tmp_path):
2173
run_test_migration(

tests/test_v1_yaml/version_pytorch_correct.yaml

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ outputs:
3939
then:
4040
- python 3.12.*
4141
- numpy *
42-
- if: not megabuild
43-
then:
42+
else:
4443
- python
4544
- numpy
4645
- cross-python_${{ target_platform }}
@@ -51,8 +50,7 @@ outputs:
5150
then: ${{ compiler('cuda') }}
5251
- if: not win
5352
then: llvm-openmp
54-
- if: win
55-
then:
53+
else:
5654
- intel-openmp ${{ mkl }}
5755
- libuv
5856
- cmake
@@ -123,8 +121,7 @@ outputs:
123121
- liblapack
124122
- if: not win
125123
then: llvm-openmp
126-
- if: win
127-
then: intel-openmp ${{ mkl }}
124+
else: intel-openmp ${{ mkl }}
128125
- libabseil
129126
- libprotobuf
130127
- sleef
@@ -152,8 +149,7 @@ outputs:
152149
else:
153150
- pytorch-gpu ==${{ version }}
154151
- pytorch-cpu ==99999999
155-
- if: "cuda_compiler_version != \"None\""
156-
then: pytorch =${{ version }} cuda${{ cuda_compiler_version | replace('.', '') }}_${{ blas_impl }}_*_${{ build }}
152+
- pytorch =${{ version }} cuda${{ cuda_compiler_version | replace('.', '') }}_${{ blas_impl }}_*_${{ build }}
157153
- if: "unix and blas_impl != \"mkl\""
158154
then: openblas * openmp_*
159155

@@ -230,12 +226,10 @@ outputs:
230226
- libcblas * *_mkl
231227
else:
232228
- libcblas
233-
- if: "blas_impl != \"mkl\""
234-
then: liblapack
229+
- liblapack
235230
- if: not win
236231
then: llvm-openmp
237-
- if: win
238-
then: intel-openmp ${{ mkl }}
232+
else: intel-openmp ${{ mkl }}
239233
- libabseil
240234
- libprotobuf
241235
- pybind11
@@ -251,17 +245,13 @@ outputs:
251245
then: ${{ pin_subpackage('libtorch', exact=True) }}
252246
# for non-megabuild, allow libtorch from any python version;
253247
# pinning build number would be nice but breaks conda
254-
- if: not megabuild
255-
then: libtorch ${{ version }}.*
248+
else: libtorch ${{ version }}.*
256249
- if: not win
257250
then: llvm-openmp
258-
- if: win
259-
then: intel-openmp ${{ mkl }}
251+
else: intel-openmp ${{ mkl }}
260252
- if: "blas_impl == \"mkl\""
261253
then: libblas * *${{ blas_impl }}
262-
- if: "blas_impl != \"mkl\""
263-
then: nomkl
264-
# GPU requirements without run_exports
254+
else: nomkl
265255
- if: "cuda_compiler_version != \"None\""
266256
then: ${{ pin_compatible('cudnn') }}
267257
- if: "cuda_compiler_version != \"None\" and not win"

0 commit comments

Comments
 (0)