Skip to content

Commit bb8cf16

Browse files
grievejiameta-codesync[bot]
authored andcommitted
Fix chained comparison operator handling
Summary: For chained comparisons like `a < b < c`, Python evaluates as `(a < b) and (b < c)`. Previously Pyrefly always used the original left operand for each comparison in the chain, which was incorrect. This diffs updates current left operand as we iterate through the chain, using the previous comparator after each comparison. # AI-generated detail explanation: The Python AST for `x in foo in bar`: ``` Compare( left=Name('x'), ops=[In(), In()], comparators=[Name('foo'), Name('bar')] ) ``` The current code iterates over `[(In, foo), (In, bar)]` and **always uses the original `left` (`x`)** for each comparison: | Iteration | Left Used | Op | Right | Correct Left | |-----------|-----------|-----|-------|--------------| | 1 | x | in | foo | x | | 2 | x | in | bar | **foo** (should be `foo`, not `x`) | The fix requires updating the "left" operand as we iterate through comparisons: - First comparison: `left=x`, `right=foo` - Second comparison: `left=foo` (previous comparator), `right=bar` Reviewed By: migeed-z, ndmitchell Differential Revision: D88191668 fbshipit-source-id: 4452be9484463706f60353a5170fa364982a3fee
1 parent 8e8c263 commit bb8cf16

File tree

2 files changed

+108
-88
lines changed

2 files changed

+108
-88
lines changed

pyrefly/lib/alt/operators.rs

Lines changed: 76 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -341,100 +341,88 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
341341
}
342342

343343
pub fn compare_infer(&self, x: &ExprCompare, errors: &ErrorCollector) -> Type {
344-
let left = self.expr_infer(&x.left, errors);
345-
let comparisons = x.ops.iter().zip(x.comparators.iter());
346-
self.unions(
347-
comparisons
348-
.map(|(op, comparator)| {
349-
let right = self.expr_infer(comparator, errors);
350-
self.distribute_over_union(&left, |left| {
351-
self.distribute_over_union(&right, |right| {
352-
let context = || {
353-
ErrorContext::BinaryOp(
354-
op.as_str().to_owned(),
355-
self.for_display(left.clone()),
356-
self.for_display(right.clone()),
357-
)
358-
};
359-
match op {
360-
CmpOp::Is | CmpOp::IsNot => {
361-
// These comparisons never error.
362-
self.stdlib.bool().clone().to_type()
363-
}
364-
CmpOp::In | CmpOp::NotIn => {
365-
// See https://docs.python.org/3/reference/expressions.html#membership-test-operations.
366-
// `x in y` first tries `y.__contains__(x)`, then checks if `x` matches an element
367-
// obtained by iterating over `y`.
368-
if let Some(ret) = self.call_magic_dunder_method(
369-
right,
370-
&dunder::CONTAINS,
371-
x.range,
372-
&[CallArg::ty(left, x.left.range())],
373-
&[],
374-
errors,
375-
Some(&context),
376-
) {
377-
// Comparison method called.
378-
ret
379-
} else {
380-
let iteration_errors = self.error_collector();
381-
let iterables = self.iterate(
382-
right,
383-
x.range,
384-
&iteration_errors,
385-
Some(&context),
386-
);
387-
if iteration_errors.is_empty() {
388-
// Make sure `x` matches the produced type.
389-
self.check_type(
390-
left,
391-
&self.get_produced_type(iterables),
392-
x.range,
393-
errors,
394-
&|| TypeCheckContext {
395-
kind: TypeCheckKind::Container,
396-
context: Some(context()),
397-
},
398-
);
399-
} else {
400-
// Iterating `y` failed.
401-
errors.extend(iteration_errors);
402-
}
403-
self.stdlib.bool().clone().to_type()
404-
}
405-
}
406-
_ => {
407-
// We've handled the other cases above, so we know we have a rich comparison op.
408-
let calls_to_try = [
409-
(
410-
&dunder::rich_comparison_dunder(*op).unwrap(),
411-
left,
412-
right,
413-
),
414-
(
415-
&dunder::rich_comparison_fallback(*op).unwrap(),
416-
right,
417-
left,
418-
),
419-
];
420-
let ret = self.try_binop_calls(
421-
&calls_to_try,
344+
// For chained comparisons like `a < b < c`, Python evaluates as `(a < b) and (b < c)`.
345+
// We need to track the current left operand as we iterate through the chain.
346+
let mut current_left = self.expr_infer(&x.left, errors);
347+
let mut current_left_range = x.left.range();
348+
let mut results = Vec::new();
349+
for (op, comparator) in x.ops.iter().zip(x.comparators.iter()) {
350+
let right = self.expr_infer(comparator, errors);
351+
let result = self.distribute_over_union(&current_left, |left| {
352+
self.distribute_over_union(&right, |right| {
353+
let context = || {
354+
ErrorContext::BinaryOp(
355+
op.as_str().to_owned(),
356+
self.for_display(left.clone()),
357+
self.for_display(right.clone()),
358+
)
359+
};
360+
match op {
361+
CmpOp::Is | CmpOp::IsNot => {
362+
// These comparisons never error.
363+
self.stdlib.bool().clone().to_type()
364+
}
365+
CmpOp::In | CmpOp::NotIn => {
366+
// See https://docs.python.org/3/reference/expressions.html#membership-test-operations.
367+
// `x in y` first tries `y.__contains__(x)`, then checks if `x` matches an element
368+
// obtained by iterating over `y`.
369+
if let Some(ret) = self.call_magic_dunder_method(
370+
right,
371+
&dunder::CONTAINS,
372+
x.range,
373+
&[CallArg::ty(left, current_left_range)],
374+
&[],
375+
errors,
376+
Some(&context),
377+
) {
378+
// Comparison method called.
379+
ret
380+
} else {
381+
let iteration_errors = self.error_collector();
382+
let iterables =
383+
self.iterate(right, x.range, &iteration_errors, Some(&context));
384+
if iteration_errors.is_empty() {
385+
// Make sure `x` matches the produced type.
386+
self.check_type(
387+
left,
388+
&self.get_produced_type(iterables),
422389
x.range,
423390
errors,
424-
&context,
391+
&|| TypeCheckContext {
392+
kind: TypeCheckKind::Container,
393+
context: Some(context()),
394+
},
425395
);
426-
if ret.is_error() {
427-
self.stdlib.bool().clone().to_type()
428-
} else {
429-
ret
430-
}
396+
} else {
397+
// Iterating `y` failed.
398+
errors.extend(iteration_errors);
431399
}
400+
self.stdlib.bool().clone().to_type()
432401
}
433-
})
434-
})
402+
}
403+
_ => {
404+
// We've handled the other cases above, so we know we have a rich comparison op.
405+
let calls_to_try = [
406+
(&dunder::rich_comparison_dunder(*op).unwrap(), left, right),
407+
(&dunder::rich_comparison_fallback(*op).unwrap(), right, left),
408+
];
409+
let ret =
410+
self.try_binop_calls(&calls_to_try, x.range, errors, &context);
411+
if ret.is_error() {
412+
self.stdlib.bool().clone().to_type()
413+
} else {
414+
ret
415+
}
416+
}
417+
}
435418
})
436-
.collect(),
437-
)
419+
});
420+
results.push(result);
421+
// For next comparison, the current right becomes the new left
422+
current_left = right;
423+
current_left_range = comparator.range();
424+
}
425+
self.unions(results)
438426
}
439427

440428
pub fn unop_infer(&self, x: &ExprUnaryOp, errors: &ErrorCollector) -> Type {

pyrefly/lib/test/operators.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,35 @@ def f[S, T](x: type[S], y: type[T]):
681681
return x == y
682682
"#,
683683
);
684+
685+
testcase!(
686+
test_chained_in,
687+
r#"
688+
class Foo:
689+
def __contains__(self, x: int) -> bool:
690+
...
691+
692+
class Bar:
693+
def __contains__(self, x: Foo) -> bool:
694+
...
695+
696+
def test(x: int, foo: Foo, bar: Bar) -> None:
697+
x in foo in bar # Should be OK
698+
x in bar # E: `in` is not supported between `int` and `Bar`
699+
"#,
700+
);
701+
702+
testcase!(
703+
test_chained_lt,
704+
r#"
705+
class A:
706+
def __lt__(self, other: "B") -> bool: ...
707+
class B:
708+
def __lt__(self, other: "C") -> bool: ...
709+
class C: pass
710+
711+
def test(a: A, b: B, c: C) -> None:
712+
a < b < c # Should be OK: (a < b) and (b < c)
713+
a < c # E: `<` is not supported between `A` and `C`
714+
"#,
715+
);

0 commit comments

Comments
 (0)