2222import warnings
2323import weakref
2424from collections import defaultdict
25+ from collections .abc import Callable
2526from contextlib import contextmanager , nullcontext , suppress
2627from glob import glob
2728from itertools import count
5455from .diagnostics .plugin import WorkerPlugin
5556from .metrics import time
5657from .nanny import Nanny
58+ from .node import ServerNode
5759from .proctitle import enable_proctitle_on_children
5860from .security import Security
5961from .utils import (
@@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
770772 await asyncio .gather (* (disconnect (addr , timeout , rpc_kwargs ) for addr in addresses ))
771773
772774
773- def gen_test (timeout = _TEST_TIMEOUT ):
775+ def gen_test (timeout : float = _TEST_TIMEOUT ) -> Callable [[ Callable ], Callable ] :
774776 """Coroutine test
775777
776778 @gen_test(timeout=5)
@@ -797,14 +799,14 @@ def test_func():
797799
798800
799801async def start_cluster (
800- nthreads ,
801- scheduler_addr ,
802- loop ,
803- security = None ,
804- Worker = Worker ,
805- scheduler_kwargs = {},
806- worker_kwargs = {},
807- ):
802+ nthreads : list [ tuple [ str , int ] | tuple [ str , int , dict ]] ,
803+ scheduler_addr : str ,
804+ loop : IOLoop ,
805+ security : Security | dict [ str , Any ] | None = None ,
806+ Worker : type [ ServerNode ] = Worker ,
807+ scheduler_kwargs : dict [ str , Any ] = {},
808+ worker_kwargs : dict [ str , Any ] = {},
809+ ) -> tuple [ Scheduler , list [ ServerNode ]] :
808810 s = await Scheduler (
809811 loop = loop ,
810812 validate = True ,
@@ -813,6 +815,7 @@ async def start_cluster(
813815 host = scheduler_addr ,
814816 ** scheduler_kwargs ,
815817 )
818+
816819 workers = [
817820 Worker (
818821 s .address ,
@@ -822,7 +825,11 @@ async def start_cluster(
822825 loop = loop ,
823826 validate = True ,
824827 host = ncore [0 ],
825- ** (merge (worker_kwargs , ncore [2 ]) if len (ncore ) > 2 else worker_kwargs ),
828+ ** (
829+ merge (worker_kwargs , ncore [2 ]) # type: ignore
830+ if len (ncore ) > 2
831+ else worker_kwargs
832+ ),
826833 )
827834 for i , ncore in enumerate (nthreads )
828835 ]
@@ -854,21 +861,24 @@ async def end_worker(w):
854861
855862
856863def gen_cluster (
857- nthreads = [("127.0.0.1" , 1 ), ("127.0.0.1" , 2 )],
858- ncores = None ,
864+ nthreads : list [tuple [str , int ] | tuple [str , int , dict ]] = [
865+ ("127.0.0.1" , 1 ),
866+ ("127.0.0.1" , 2 ),
867+ ],
868+ ncores : None = None , # deprecated
859869 scheduler = "127.0.0.1" ,
860- timeout = _TEST_TIMEOUT ,
861- security = None ,
862- Worker = Worker ,
863- client = False ,
864- scheduler_kwargs = {},
865- worker_kwargs = {},
866- client_kwargs = {},
867- active_rpc_timeout = 1 ,
868- config = {},
869- clean_kwargs = {},
870- allow_unclosed = False ,
871- ):
870+ timeout : float = _TEST_TIMEOUT ,
871+ security : Security | dict [ str , Any ] | None = None ,
872+ Worker : type [ ServerNode ] = Worker ,
873+ client : bool = False ,
874+ scheduler_kwargs : dict [ str , Any ] = {},
875+ worker_kwargs : dict [ str , Any ] = {},
876+ client_kwargs : dict [ str , Any ] = {},
877+ active_rpc_timeout : float = 1 ,
878+ config : dict [ str , Any ] = {},
879+ clean_kwargs : dict [ str , Any ] = {},
880+ allow_unclosed : bool = False ,
881+ ) -> Callable [[ Callable ], Callable ] :
872882 from distributed import Client
873883
874884 """ Coroutine test with small cluster
0 commit comments