11from dataclasses import field
22from datetime import datetime
3- from typing import Any , Generator , List , Optional , Sequence , Union , Dict
3+ from typing import Any , Generator , List , Optional , Sequence , Type , Union , Dict
44
55from runtype import dataclass
6+ from typing_extensions import Self
67
78from ..utils import join_iter , ArithString
89from ..abcs import Compilable
@@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
322323 return QB_When (self , whens [0 ])
323324 return QB_When (self , BinBoolOp ("AND" , whens ))
324325
325- def else_ (self , then : Expr ):
326+ def else_ (self , then : Expr ) -> Self :
326327 """Add an 'else' clause to the case expression.
327328
328329 Can only be called once!
@@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
422423 schema : Optional [Schema ] = field (default = None , repr = False )
423424
424425 @property
425- def source_table (self ):
426+ def source_table (self ) -> Self :
426427 return self
427428
428429 def compile (self , c : Compiler ) -> str :
@@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
524525 columns : Sequence [Expr ] = None
525526
526527 @property
527- def source_table (self ):
528+ def source_table (self ) -> Self :
528529 return self
529530
530531 @property
@@ -533,7 +534,7 @@ def schema(self):
533534 s = self .source_tables [0 ].schema # TODO validate types match between both tables
534535 return type (s )({c .name : c .type for c in self .columns })
535536
536- def on (self , * exprs ) -> "Join" :
537+ def on (self , * exprs ) -> Self :
537538 """Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
538539 if len (exprs ) == 1 :
539540 (e ,) = exprs
@@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":
546547
547548 return self .replace (on_exprs = (self .on_exprs or []) + exprs )
548549
549- def select (self , * exprs , ** named_exprs ) -> ITable :
550+ def select (self , * exprs , ** named_exprs ) -> Union [ Self , ITable ] :
550551 """Select fields to return from the JOIN operation
551552
552553 See Also: ``ITable.select()``
@@ -600,7 +601,7 @@ def source_table(self):
600601 def __post_init__ (self ):
601602 assert self .keys or self .values
602603
603- def having (self , * exprs ):
604+ def having (self , * exprs ) -> Self :
604605 """Add a 'HAVING' clause to the group-by"""
605606 exprs = args_as_tuple (exprs )
606607 exprs = _drop_skips (exprs )
@@ -610,7 +611,7 @@ def having(self, *exprs):
610611 resolve_names (self .table , exprs )
611612 return self .replace (having_exprs = (self .having_exprs or []) + exprs )
612613
613- def agg (self , * exprs ):
614+ def agg (self , * exprs ) -> Self :
614615 """Select aggregated fields for the group-by."""
615616 exprs = args_as_tuple (exprs )
616617 exprs = _drop_skips (exprs )
@@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:
991992
992993 return f"INSERT INTO { c .compile (self .path )} { columns } { expr } "
993994
994- def returning (self , * exprs ):
995+ def returning (self , * exprs ) -> Self :
995996 """Add a 'RETURNING' clause to the insert expression.
996997
997998 Note: Not all databases support this feature!
0 commit comments