Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/union_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,46 @@ def _(
reveal_type(i1) # revealed: P & Q
reveal_type(i2) # revealed: P & Q
```

## Unions of literals with `AlwaysTruthy` and `AlwaysFalsy`

```py
from typing import Literal
from knot_extensions import AlwaysTruthy, AlwaysFalsy

type strings = Literal["foo", ""]
type ints = Literal[0, 1]
type bytes = Literal[b"foo", b""]

def _(
strings_or_truthy: strings | AlwaysTruthy,
truthy_or_strings: AlwaysTruthy | strings,
strings_or_falsy: strings | AlwaysFalsy,
falsy_or_strings: AlwaysFalsy | strings,
ints_or_truthy: ints | AlwaysTruthy,
truthy_or_ints: AlwaysTruthy | ints,
ints_or_falsy: ints | AlwaysFalsy,
falsy_or_ints: AlwaysFalsy | ints,
bytes_or_truthy: bytes | AlwaysTruthy,
truthy_or_bytes: AlwaysTruthy | bytes,
bytes_or_falsy: bytes | AlwaysFalsy,
falsy_or_bytes: AlwaysFalsy | bytes,
):
reveal_type(strings_or_truthy) # revealed: Literal[""] | AlwaysTruthy
reveal_type(truthy_or_strings) # revealed: AlwaysTruthy | Literal[""]

reveal_type(strings_or_falsy) # revealed: Literal["foo"] | AlwaysFalsy
reveal_type(falsy_or_strings) # revealed: AlwaysFalsy | Literal["foo"]

reveal_type(ints_or_truthy) # revealed: Literal[0] | AlwaysTruthy
reveal_type(truthy_or_ints) # revealed: AlwaysTruthy | Literal[0]

reveal_type(ints_or_falsy) # revealed: Literal[1] | AlwaysFalsy
reveal_type(falsy_or_ints) # revealed: AlwaysFalsy | Literal[1]

reveal_type(bytes_or_truthy) # revealed: Literal[b""] | AlwaysTruthy
reveal_type(truthy_or_bytes) # revealed: AlwaysTruthy | Literal[b""]

reveal_type(bytes_or_falsy) # revealed: Literal[b"foo"] | AlwaysFalsy
reveal_type(falsy_or_bytes) # revealed: AlwaysFalsy | Literal[b"foo"]
```
99 changes: 75 additions & 24 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,67 @@ enum UnionElement<'db> {
Type(Type<'db>),
}

impl<'db> UnionElement<'db> {
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
///
/// If this `UnionElement` is a group of literals, filter the literals present if needed and
/// return `ReduceResult::KeepIf` with a boolean value indicating whether the remaining group
/// of literals should be kept in the union
///
/// If this `UnionElement` is some other type, return `ReduceResult::Type` so `UnionBuilder`
/// can perform more complex checks on it.
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
// `AlwaysTruthy` and `AlwaysFalsy` are the only types which can be a supertype of only
// _some_ literals of the same kind, so we need to walk the full set in this case.
let needs_filter = matches!(other_type, Type::AlwaysTruthy | Type::AlwaysFalsy);
match self {
UnionElement::IntLiterals(literals) => {
ReduceResult::KeepIf(if needs_filter {
literals.retain(|literal| {
!Type::IntLiteral(*literal).is_subtype_of(db, other_type)
});
!literals.is_empty()
} else {
// SAFETY: All `UnionElement` literal kinds must always be non-empty
!Type::IntLiteral(literals[0]).is_subtype_of(db, other_type)
})
}
UnionElement::StringLiterals(literals) => {
ReduceResult::KeepIf(if needs_filter {
literals.retain(|literal| {
!Type::StringLiteral(*literal).is_subtype_of(db, other_type)
});
!literals.is_empty()
} else {
// SAFETY: All `UnionElement` literal kinds must always be non-empty
!Type::StringLiteral(literals[0]).is_subtype_of(db, other_type)
})
}
UnionElement::BytesLiterals(literals) => {
ReduceResult::KeepIf(if needs_filter {
literals.retain(|literal| {
!Type::BytesLiteral(*literal).is_subtype_of(db, other_type)
});
!literals.is_empty()
} else {
// SAFETY: All `UnionElement` literal kinds must always be non-empty
!Type::BytesLiteral(literals[0]).is_subtype_of(db, other_type)
})
}
UnionElement::Type(existing) => ReduceResult::Type(*existing),
}
}
}

enum ReduceResult<'db> {
/// Reduction of this `UnionElement` is complete; keep it in the union if the nested
/// boolean is true, eliminate it from the union if false.
KeepIf(bool),
/// The given `Type` can stand-in for the entire `UnionElement` for further union
/// simplification checks.
Type(Type<'db>),
}

// TODO increase this once we extend `UnionElement` throughout all union/intersection
// representations, so that we can make large unions of literals fast in all operations.
const MAX_UNION_LITERALS: usize = 200;
Expand Down Expand Up @@ -197,27 +258,17 @@ impl<'db> UnionBuilder<'db> {
let mut to_remove = SmallVec::<[usize; 2]>::new();
let ty_negated = ty.negate(self.db);

for (index, element) in self
.elements
.iter()
.map(|element| {
// For literals, the first element in the set can stand in for all the rest,
// since they all have the same super-types. SAFETY: a `UnionElement` of
// literal kind must always have at least one element in it.
match element {
UnionElement::IntLiterals(literals) => Type::IntLiteral(literals[0]),
UnionElement::StringLiterals(literals) => {
Type::StringLiteral(literals[0])
for (index, element) in self.elements.iter_mut().enumerate() {
let element_type = match element.try_reduce(self.db, ty) {
ReduceResult::KeepIf(keep) => {
if !keep {
to_remove.push(index);
}
UnionElement::BytesLiterals(literals) => {
Type::BytesLiteral(literals[0])
}
UnionElement::Type(ty) => *ty,
continue;
}
})
.enumerate()
{
if Some(element) == bool_pair {
ReduceResult::Type(ty) => ty,
};
if Some(element_type) == bool_pair {
to_add = KnownClass::Bool.to_instance(self.db);
to_remove.push(index);
// The type we are adding is a BooleanLiteral, which doesn't have any
Expand All @@ -227,14 +278,14 @@ impl<'db> UnionBuilder<'db> {
break;
}

if ty.is_same_gradual_form(element)
|| ty.is_subtype_of(self.db, element)
|| element.is_object(self.db)
if ty.is_same_gradual_form(element_type)
|| ty.is_subtype_of(self.db, element_type)
|| element_type.is_object(self.db)
{
return;
} else if element.is_subtype_of(self.db, ty) {
} else if element_type.is_subtype_of(self.db, ty) {
to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, element) {
} else if ty_negated.is_subtype_of(self.db, element_type) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`.
// This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the
// first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is
Expand Down
Loading