Skip to content

Commit 109ebcf

Browse files
committed
Fix context for multiplication, fixes #310
1 parent 666be19 commit 109ebcf

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

crates/zuban_python/src/file/inference.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3035,15 +3035,29 @@ impl<'db, 'file> Inference<'db, 'file, '_> {
30353035
}
30363036

30373037
fn infer_operation(&self, op: Operation, result_context: &mut ResultContext) -> Inferred {
3038-
let context = if result_context.has_explicit_type() && op.infos.operand != "%" {
3039-
// Pass on the context to each side. I'm not sure that's correct, but it's necessary at
3040-
// least for list additions. However it's wrong for `"%s" % ...`.
3041-
&mut *result_context
3042-
} else {
3043-
&mut ResultContext::ValueExpected
3038+
let mut check = |part: ExpressionPart| {
3039+
let context = if result_context.has_explicit_type()
3040+
&& matches!(
3041+
part.maybe_unpacked_atom(),
3042+
Some(
3043+
AtomContent::List(_)
3044+
| AtomContent::ListComprehension(_)
3045+
| AtomContent::Set(_)
3046+
| AtomContent::SetComprehension(_)
3047+
| AtomContent::Dict(_)
3048+
| AtomContent::DictComprehension(_)
3049+
)
3050+
) {
3051+
// Pass on the context to each side. It's nessary at least for list additions, but
3052+
// can potentially be wrong even here. Not sure this is right.
3053+
&mut *result_context
3054+
} else {
3055+
&mut ResultContext::ValueExpected
3056+
};
3057+
self.infer_expression_part_with_context(part, context)
30443058
};
3045-
let left = self.infer_expression_part_with_context(op.left, context);
3046-
let right = self.infer_expression_part_with_context(op.right, context);
3059+
let left = check(op.left);
3060+
let right = check(op.right);
30473061
self.infer_detailed_operation(op.index, op.infos, left, &right, result_context)
30483062
}
30493063

crates/zuban_python/tests/mypylike/tests/generics.test

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,7 @@ class X(Generic[T]):
19641964
def __init__(self, x: Callable[[], T]): ...
19651965

19661966
X[SimpleGeneric[int]](SimpleGeneric)
1967+
1968+
[case generic_context_avoid_multiplication_context_inference]
1969+
# From GH #310
1970+
x: str = "x" * max(1, 10)

0 commit comments

Comments
 (0)