Skip to content
Closed
74 changes: 74 additions & 0 deletions stdlib/@tests/test_cases/sqlite3/check_aggregations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import sqlite3

class WindowSumInt:
def __init__(self) -> None:
self.count = 0

def step(self, param: int) -> None:
self.count += param

def value(self) -> int:
return self.count

def inverse(self, param: int) -> None:
self.count -= param

def finalize(self) -> int:
return self.count


con = sqlite3.connect(":memory:")
cur = con.execute("CREATE TABLE test(x, y)")
values = [
("a", 4),
("b", 5),
("c", 3),
("d", 8),
("e", 1),
]
cur.executemany("INSERT INTO test VALUES(?, ?)", values)
con.create_window_function("sumint", 1, WindowSumInt)
con.create_aggregate("sumint", 1, WindowSumInt)
cur.execute("""
SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM test ORDER BY x
""")
con.close()


def create_window_function() -> WindowSumInt:
return WindowSumInt()


# A callable should work as well.
con.create_window_function("sumint", 1, create_window_function)
con.create_aggregate("sumint", 1, create_window_function)

# With num_args set to 1, the callable should not be called with more than one.

class WindowSumIntMultiArgs:
def __init__(self) -> None:
self.count = 0

def step(self, arg_1: int, arg_2: int) -> None:
self.count += arg_1 + arg_2

def value(self) -> int:
return self.count

def inverse(self, arg_1: int, arg_2: int) -> None:
self.count -= arg_1 + arg_2

def finalize(self) -> int:
return self.count


# This should fail because the callable is called with more than one argument.
con.create_window_function("sumint", 1, WindowSumIntMultiArgs) # type: ignore
con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) # type: ignore

# With num_args set to -1, this should work.
con.create_window_function("sumint", 2, WindowSumIntMultiArgs)
con.create_aggregate("sumint", 2, WindowSumIntMultiArgs)
53 changes: 27 additions & 26 deletions stdlib/sqlite3/dbapi2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
from datetime import date, datetime, time
from types import TracebackType
from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overload
from typing import Any, Literal, Protocol, SupportsIndex, final, overload, TypeVar
from typing_extensions import Self, TypeAlias

_T = TypeVar("_T")
_ConnectionT = TypeVar("_ConnectionT", bound=Connection)
_CursorT = TypeVar("_CursorT", bound=Cursor)
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
_SQLType = TypeVar("_SQLType", bound=_SqliteData)
# Data that is passed through adapters can be of any type accepted by an adapter.
_AdaptedInputData: TypeAlias = _SqliteData | Any
# The Mapping must really be a dict, but making it invariant is too annoying.
Expand Down Expand Up @@ -312,27 +313,26 @@ else:
def register_adapter(type: type[_T], caster: _Adapter[_T], /) -> None: ...
def register_converter(name: str, converter: _Converter, /) -> None: ...

class _AggregateProtocol(Protocol):
def step(self, value: int, /) -> object: ...
def finalize(self) -> int: ...
class _SingleParamAggregateProtocol(Protocol[_SQLType]):
def step(self, param: _SQLType, /) -> object: ...
def finalize(self) -> _SQLType: ...

class _SingleParamWindowAggregateClass(Protocol):
def step(self, param: Any, /) -> object: ...
def inverse(self, param: Any, /) -> object: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...
class _AnyParamAggregateProtocol(Protocol[_SQLType]):
def step(self, *args: _SQLType) -> object: ...
def finalize(self) -> _SQLType: ...

class _AnyParamWindowAggregateClass(Protocol):
def step(self, *args: Any) -> object: ...
def inverse(self, *args: Any) -> object: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...
class _SingleParamWindowAggregateClass(Protocol[_SQLType]):
def step(self, param: _SQLType, /) -> object: ...
def inverse(self, param: _SQLType, /) -> object: ...
def value(self) -> _SQLType: ...
def finalize(self) -> _SQLType: ...

class _AnyParamWindowAggregateClass(Protocol[_SQLType]):
def step(self, *args: _SQLType) -> object: ...
def inverse(self, *args: _SQLType) -> object: ...
def value(self) -> _SQLType: ...
def finalize(self) -> _SQLType: ...

class _WindowAggregateClass(Protocol):
step: Callable[..., object]
Copy link
Contributor Author

@max-muoto max-muoto Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From testing things out, it doesn't seem this protocol really works as intended in either MyPy or Pyright. Unless you actually were annotating a lambda perhaps. Due to this, I went ahead and removed it.

Some examples of how it might not work as you would expect:

Pyright Playground

MyPy Playground

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been part of the initial commit in #7625, while the other protocols already used a function. Maybe @JelleZijlstra remembers why we used an attribute instead of a function here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really remember what I was thinking when I wrote that code, but the annotations proposed in this PR mean that protocol implementations must take *args. I am not familiar with how these things are used, but I'd expect concrete implementations to only accept a fixed number of parameters. Maybe that's why I chose to use Callable[....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried messing around with TypeVarTuples to do that, but had some issues there as well, there might not be a great way, but I'll see if I can figure something out.

inverse: Callable[..., object]
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...

class Connection:
@property
Expand Down Expand Up @@ -398,22 +398,23 @@ class Connection:
def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ...

def commit(self) -> None: ...
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ...

@overload
def create_aggregate(self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]) -> None: ...
@overload
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]]) -> None: ...

if sys.version_info >= (3, 11):
# num_params determines how many params will be passed to the aggregate class. We provide an overload
# for the case where num_params = 1, which is expected to be the common case.
@overload
def create_window_function(
self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, /
self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, /
) -> None: ...
# And for num_params = -1, which means the aggregate must accept any number of parameters.
@overload
def create_window_function(
self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
) -> None: ...
@overload
def create_window_function(
self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, /
self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass[_SQLType]] | None, /
) -> None: ...

def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...
Expand Down