@@ -2140,24 +2140,17 @@ def __init__(
21402140 max_batch_size : int = 10 ,
21412141 max_batch_weight : Optional [int ] = None ,
21422142 element_size_fn = None ):
2143- self ._max_batch_size = max_batch_size
2144- self ._max_batch_weight = max_batch_weight
2145- self ._element_size_fn = element_size_fn
2143+ super ().__init__ (
2144+ max_batch_size = max_batch_size ,
2145+ max_batch_weight = max_batch_weight ,
2146+ element_size_fn = element_size_fn )
21462147
21472148 def load_model (self ) -> FakeModel :
21482149 return FakeModel ()
21492150
21502151 def run_inference (self , batch , model , inference_args = None ):
21512152 return [model .predict (x ) for x in batch ]
21522153
2153- def batch_elements_kwargs (self ):
2154- kwargs = {'max_batch_size' : self ._max_batch_size }
2155- if self ._max_batch_weight is not None :
2156- kwargs ['max_batch_weight' ] = self ._max_batch_weight
2157- if self ._element_size_fn :
2158- kwargs ['element_size_fn' ] = self ._element_size_fn
2159- return kwargs
2160-
21612154
21622155class RunInferenceSizeTest (unittest .TestCase ):
21632156 """Tests for ModelHandler.batch_elements_kwargs with element_size_fn."""
@@ -2191,5 +2184,100 @@ def test_sizing_with_edge_cases(self):
21912184 self .assertEqual (kwargs ['element_size_fn' ](1 ), 1000000 )
21922185
21932186
2187+ class FakeModelHandlerForBatching (base .ModelHandler [int , int , FakeModel ]):
2188+ """A ModelHandler used to test batching behavior via base class __init__."""
2189+ def __init__ (self , ** kwargs ):
2190+ super ().__init__ (** kwargs )
2191+
2192+ def load_model (self ) -> FakeModel :
2193+ return FakeModel ()
2194+
2195+ def run_inference (self , batch , model , inference_args = None ):
2196+ return [model .predict (x ) for x in batch ]
2197+
2198+
2199+ class ModelHandlerBatchingArgsTest (unittest .TestCase ):
2200+ """Tests for ModelHandler.__init__ batching parameters."""
2201+ def test_batch_elements_kwargs_all_args (self ):
2202+ """All batching args passed to __init__ are in batch_elements_kwargs."""
2203+ def size_fn (x ):
2204+ return 10
2205+
2206+ handler = FakeModelHandlerForBatching (
2207+ min_batch_size = 5 ,
2208+ max_batch_size = 20 ,
2209+ max_batch_duration_secs = 30 ,
2210+ max_batch_weight = 100 ,
2211+ element_size_fn = size_fn )
2212+
2213+ kwargs = handler .batch_elements_kwargs ()
2214+
2215+ self .assertEqual (kwargs ['min_batch_size' ], 5 )
2216+ self .assertEqual (kwargs ['max_batch_size' ], 20 )
2217+ self .assertEqual (kwargs ['max_batch_duration_secs' ], 30 )
2218+ self .assertEqual (kwargs ['max_batch_weight' ], 100 )
2219+ self .assertIn ('element_size_fn' , kwargs )
2220+ self .assertEqual (kwargs ['element_size_fn' ](1 ), 10 )
2221+
2222+ def test_batch_elements_kwargs_partial_args (self ):
2223+ """Only provided batching args are included in kwargs."""
2224+ handler = FakeModelHandlerForBatching (max_batch_size = 50 )
2225+ kwargs = handler .batch_elements_kwargs ()
2226+
2227+ self .assertEqual (kwargs , {'max_batch_size' : 50 })
2228+
2229+ def test_batch_elements_kwargs_empty_when_no_args (self ):
2230+ """No batching kwargs when none are provided."""
2231+ handler = FakeModelHandlerForBatching ()
2232+ kwargs = handler .batch_elements_kwargs ()
2233+
2234+ self .assertEqual (kwargs , {})
2235+
2236+ def test_large_model_sets_share_across_processes (self ):
2237+ """Setting large_model=True enables share_model_across_processes."""
2238+ handler = FakeModelHandlerForBatching (large_model = True )
2239+
2240+ self .assertTrue (handler .share_model_across_processes ())
2241+
2242+ def test_model_copies_sets_share_across_processes (self ):
2243+ """Setting model_copies enables share_model_across_processes."""
2244+ handler = FakeModelHandlerForBatching (model_copies = 2 )
2245+
2246+ self .assertTrue (handler .share_model_across_processes ())
2247+ self .assertEqual (handler .model_copies (), 2 )
2248+
2249+ def test_default_share_across_processes_is_false (self ):
2250+ """Default share_model_across_processes is False."""
2251+ handler = FakeModelHandlerForBatching ()
2252+
2253+ self .assertFalse (handler .share_model_across_processes ())
2254+
2255+ def test_default_model_copies_is_one (self ):
2256+ """Default model_copies is 1."""
2257+ handler = FakeModelHandlerForBatching ()
2258+
2259+ self .assertEqual (handler .model_copies (), 1 )
2260+
2261+ def test_env_vars_from_kwargs (self ):
2262+ """Environment variables can be passed via kwargs."""
2263+ handler = FakeModelHandlerForBatching (env_vars = {'MY_VAR' : 'value' })
2264+
2265+ self .assertEqual (handler ._env_vars , {'MY_VAR' : 'value' })
2266+
2267+ def test_min_batch_size_only (self ):
2268+ """min_batch_size can be passed alone."""
2269+ handler = FakeModelHandlerForBatching (min_batch_size = 10 )
2270+ kwargs = handler .batch_elements_kwargs ()
2271+
2272+ self .assertEqual (kwargs , {'min_batch_size' : 10 })
2273+
2274+ def test_max_batch_duration_secs_only (self ):
2275+ """max_batch_duration_secs can be passed alone."""
2276+ handler = FakeModelHandlerForBatching (max_batch_duration_secs = 60 )
2277+ kwargs = handler .batch_elements_kwargs ()
2278+
2279+ self .assertEqual (kwargs , {'max_batch_duration_secs' : 60 })
2280+
2281+
21942282if __name__ == '__main__' :
21952283 unittest .main ()
0 commit comments