Skip to content

Commit 61f906d

Browse files
authored
[ty] equality narrowing on enums that don't override __eq__ or __ne__ (#20285)
Add equality narrowing for enums, if they don't override `__eq__` or `__ne__` in an unsafe way. Follow-up to PR #20164 Fixes astral-sh/ty#939
1 parent 08a561f commit 61f906d

File tree

3 files changed

+50
-38
lines changed

3 files changed

+50
-38
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,9 @@ class Color(Enum):
168168

169169
def _(x: Color):
170170
if x in (Color.RED, Color.GREEN):
171-
# TODO should be `Literal[Color.RED, Color.GREEN]`
172-
reveal_type(x) # revealed: Color
171+
reveal_type(x) # revealed: Literal[Color.RED, Color.GREEN]
173172
else:
174-
# TODO should be `Literal[Color.BLUE]`
175-
reveal_type(x) # revealed: Color
173+
reveal_type(x) # revealed: Literal[Color.BLUE]
176174
```
177175

178176
## Union with enum and `int`
@@ -187,11 +185,9 @@ class Status(Enum):
187185

188186
def test(x: Status | int):
189187
if x in (Status.PENDING, Status.APPROVED):
190-
# TODO should be `Literal[Status.PENDING, Status.APPROVED] | int`
191188
# int is included because custom __eq__ methods could make
192189
# an int equal to Status.PENDING or Status.APPROVED, so we can't eliminate it
193-
reveal_type(x) # revealed: Status | int
190+
reveal_type(x) # revealed: Literal[Status.PENDING, Status.APPROVED] | int
194191
else:
195-
# TODO should be `Literal[Status.REJECTED] | int`
196-
reveal_type(x) # revealed: Status | int
192+
reveal_type(x) # revealed: Literal[Status.REJECTED] | int
197193
```

crates/ty_python_semantic/src/types.rs

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,35 @@ impl<'db> Type<'db> {
746746
.is_some_and(|instance| instance.class(db).is_known(db, KnownClass::Bool))
747747
}
748748

749+
fn is_enum(&self, db: &'db dyn Db) -> bool {
750+
self.into_nominal_instance().is_some_and(|instance| {
751+
crate::types::enums::enum_metadata(db, instance.class(db).class_literal(db).0).is_some()
752+
})
753+
}
754+
755+
/// Return true if this type overrides __eq__ or __ne__ methods
756+
fn overrides_equality(&self, db: &'db dyn Db) -> bool {
757+
let check_dunder = |dunder_name, allowed_return_value| {
758+
// Note that we do explicitly exclude dunder methods on `object`, `int` and `str` here.
759+
// The reason for this is that we know that these dunder methods behave in a predictable way.
760+
// Only custom dunder methods need to be examined here, as they might break single-valuedness
761+
// by always returning `False`, for example.
762+
let call_result = self.try_call_dunder_with_policy(
763+
db,
764+
dunder_name,
765+
&mut CallArguments::positional([Type::unknown()]),
766+
MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK
767+
| MemberLookupPolicy::MRO_NO_INT_OR_STR_LOOKUP,
768+
);
769+
let call_result = call_result.as_ref();
770+
call_result.is_ok_and(|bindings| {
771+
bindings.return_type(db) == Type::BooleanLiteral(allowed_return_value)
772+
}) || call_result.is_err_and(|err| matches!(err, CallDunderError::MethodNotAvailable))
773+
};
774+
775+
!(check_dunder("__eq__", true) && check_dunder("__ne__", false))
776+
}
777+
749778
pub(crate) fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
750779
self.into_nominal_instance().is_some_and(|instance| {
751780
instance
@@ -980,22 +1009,28 @@ impl<'db> Type<'db> {
9801009

9811010
pub(crate) fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool {
9821011
self.into_union().is_some_and(|union| {
983-
union
984-
.elements(db)
985-
.iter()
986-
.all(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
1012+
union.elements(db).iter().all(|ty| {
1013+
ty.is_single_valued(db)
1014+
|| ty.is_bool(db)
1015+
|| ty.is_literal_string()
1016+
|| (ty.is_enum(db) && !ty.overrides_equality(db))
1017+
})
9871018
}) || self.is_bool(db)
9881019
|| self.is_literal_string()
1020+
|| (self.is_enum(db) && !self.overrides_equality(db))
9891021
}
9901022

9911023
pub(crate) fn is_union_with_single_valued(&self, db: &'db dyn Db) -> bool {
9921024
self.into_union().is_some_and(|union| {
993-
union
994-
.elements(db)
995-
.iter()
996-
.any(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
1025+
union.elements(db).iter().any(|ty| {
1026+
ty.is_single_valued(db)
1027+
|| ty.is_bool(db)
1028+
|| ty.is_literal_string()
1029+
|| (ty.is_enum(db) && !ty.overrides_equality(db))
1030+
})
9971031
}) || self.is_bool(db)
9981032
|| self.is_literal_string()
1033+
|| (self.is_enum(db) && !self.overrides_equality(db))
9991034
}
10001035

10011036
pub(crate) fn into_string_literal(self) -> Option<StringLiteralType<'db>> {
@@ -2574,28 +2609,7 @@ impl<'db> Type<'db> {
25742609
| Type::SpecialForm(..)
25752610
| Type::KnownInstance(..) => true,
25762611

2577-
Type::EnumLiteral(_) => {
2578-
let check_dunder = |dunder_name, allowed_return_value| {
2579-
// Note that we do explicitly exclude dunder methods on `object`, `int` and `str` here.
2580-
// The reason for this is that we know that these dunder methods behave in a predictable way.
2581-
// Only custom dunder methods need to be examined here, as they might break single-valuedness
2582-
// by always returning `False`, for example.
2583-
let call_result = self.try_call_dunder_with_policy(
2584-
db,
2585-
dunder_name,
2586-
&mut CallArguments::positional([Type::unknown()]),
2587-
MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK
2588-
| MemberLookupPolicy::MRO_NO_INT_OR_STR_LOOKUP,
2589-
);
2590-
let call_result = call_result.as_ref();
2591-
call_result.is_ok_and(|bindings| {
2592-
bindings.return_type(db) == Type::BooleanLiteral(allowed_return_value)
2593-
}) || call_result
2594-
.is_err_and(|err| matches!(err, CallDunderError::MethodNotAvailable))
2595-
};
2596-
2597-
check_dunder("__eq__", true) && check_dunder("__ne__", false)
2598-
}
2612+
Type::EnumLiteral(_) => !self.overrides_equality(db),
25992613

26002614
Type::ProtocolInstance(..) => {
26012615
// See comment in the `Type::ProtocolInstance` branch for `Type::is_singleton`.

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
640640
if !element.is_single_valued(self.db)
641641
&& !element.is_literal_string()
642642
&& !element.is_bool(self.db)
643+
&& (!element.is_enum(self.db) || element.overrides_equality(self.db))
643644
{
644645
builder = builder.add(*element);
645646
}
@@ -675,6 +676,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
675676
if element.is_single_valued(self.db)
676677
|| element.is_literal_string()
677678
|| element.is_bool(self.db)
679+
|| (element.is_enum(self.db) && !element.overrides_equality(self.db))
678680
{
679681
single_builder = single_builder.add(*element);
680682
} else {

0 commit comments

Comments
 (0)