Skip to content

Commit 25c8ffc

Browse files
committed
improve bidirectional inference in infer_collection_literal
1 parent b27c1ed commit 25c8ffc

File tree

5 files changed

+100
-8
lines changed

5 files changed

+100
-8
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,81 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3}
150150
reveal_type(r) # revealed: dict[int | str, int | str]
151151
```
152152

153-
## Incorrect collection literal assignments are complained aobut
153+
## Optional collection literal annotations are understood
154+
155+
```toml
156+
[environment]
157+
python-version = "3.12"
158+
```
159+
160+
```py
161+
import typing
162+
163+
a: list[int] | None = [1, 2, 3]
164+
reveal_type(a) # revealed: list[int]
165+
166+
b: list[int | str] | None = [1, 2, 3]
167+
reveal_type(b) # revealed: list[int | str]
168+
169+
c: typing.List[int] | None = [1, 2, 3]
170+
reveal_type(c) # revealed: list[int]
171+
172+
d: list[typing.Any] | None = []
173+
reveal_type(d) # revealed: list[Any]
174+
175+
e: set[int] | None = {1, 2, 3}
176+
reveal_type(e) # revealed: set[int]
177+
178+
f: set[int | str] | None = {1, 2, 3}
179+
reveal_type(f) # revealed: set[int | str]
180+
181+
g: typing.Set[int] | None = {1, 2, 3}
182+
reveal_type(g) # revealed: set[int]
183+
184+
h: list[list[int]] | None = [[], [42]]
185+
reveal_type(h) # revealed: list[list[int]]
186+
187+
i: list[typing.Any] | None = [1, 2, "3", ([4],)]
188+
reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]]
189+
190+
j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()]
191+
reveal_type(j) # revealed: list[tuple[str | int, ...]]
192+
193+
k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])]
194+
reveal_type(k) # revealed: list[tuple[list[int], ...]]
195+
196+
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
197+
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
198+
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
199+
200+
type IntList = list[int]
201+
202+
m: IntList | None = [1, 2, 3]
203+
reveal_type(m) # revealed: list[int]
204+
205+
# TODO: this should type-check and avoid literal promotion
206+
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[Literal[1, 2, 3]] | None`"
207+
n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3]
208+
# TODO: this should be `list[Literal[1, 2, 3]]` at this scope
209+
reveal_type(n) # revealed: list[Literal[1, 2, 3]] | None
210+
211+
# TODO: this should type-check and avoid literal promotion
212+
# error: [invalid-assignment] "Object of type `list[Unknown | str]` is not assignable to `list[LiteralString] | None`"
213+
o: list[typing.LiteralString] | None = ["a", "b", "c"]
214+
# TODO: this should be `list[LiteralString]` at this scope
215+
reveal_type(o) # revealed: list[LiteralString] | None
216+
217+
p: dict[int, int] | None = {}
218+
reveal_type(p) # revealed: dict[int, int]
219+
220+
q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3}
221+
reveal_type(q) # revealed: dict[int | str, int]
222+
223+
r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3}
224+
reveal_type(r) # revealed: dict[int | str, int | str]
225+
```
226+
227+
## Incorrect collection literal assignments are complained about
154228

155229
```py
156230
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`"

crates/ty_python_semantic/src/types.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,15 @@ impl<'db> Type<'db> {
11401140
if yes { self.negate(db) } else { *self }
11411141
}
11421142

1143+
/// Remove the union elements that are not related to `target`.
1144+
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
1145+
if let Type::Union(union) = self {
1146+
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
1147+
} else {
1148+
self
1149+
}
1150+
}
1151+
11431152
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
11441153
/// is not a literal.
11451154
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,13 +1159,9 @@ impl<'db> SpecializationBuilder<'db> {
11591159
return Ok(());
11601160
}
11611161

1162-
if let Type::Union(union) = actual {
1163-
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1164-
// So, here we remove the union elements that are not related to `formal`.
1165-
actual = union.filter(self.db, |actual_elem| {
1166-
!actual_elem.is_disjoint_from(self.db, formal)
1167-
});
1168-
}
1162+
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1163+
// So, here we remove the union elements that are not related to `formal`.
1164+
actual = actual.filter_disjoint_elements(self.db, formal);
11691165

11701166
match (formal, actual) {
11711167
(Type::Union(_), Type::Union(_)) => {

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,12 @@ impl<'db> TypeContext<'db> {
390390
self.annotation
391391
.and_then(|ty| ty.known_specialization(known_class, db))
392392
}
393+
394+
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
395+
Self {
396+
annotation: self.annotation.map(f),
397+
}
398+
}
393399
}
394400

395401
/// Returns the statically-known truthiness of a given expression.

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5560,6 +5560,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
55605560
panic!("Typeshed should always have a `{name}` class in `builtins.pyi`")
55615561
});
55625562

5563+
let tcx = tcx.map_annotation(|annotation| {
5564+
// For example, if `collection_ty` is `list` and `annotation` is `list[int] | None`,
5565+
// remove any union elements of `annotation` that are not related to `collection_ty`.
5566+
let collection_ty = collection_class.to_instance(self.db());
5567+
annotation.filter_disjoint_elements(self.db(), collection_ty)
5568+
});
5569+
55635570
// Extract the annotated type of `T`, if provided.
55645571
let annotated_elt_tys = tcx
55655572
.known_specialization(collection_class, self.db())

0 commit comments

Comments
 (0)