Skip to content

Commit 0a8e501

Browse files
committed
Address review comments: refactor tests and fix linting
1 parent b5a4e7a commit 0a8e501

File tree

4 files changed

+103
-15
lines changed

4 files changed

+103
-15
lines changed

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,24 +2140,17 @@ def __init__(
21402140
max_batch_size: int = 10,
21412141
max_batch_weight: Optional[int] = None,
21422142
element_size_fn=None):
2143-
self._max_batch_size = max_batch_size
2144-
self._max_batch_weight = max_batch_weight
2145-
self._element_size_fn = element_size_fn
2143+
super().__init__(
2144+
max_batch_size=max_batch_size,
2145+
max_batch_weight=max_batch_weight,
2146+
element_size_fn=element_size_fn)
21462147

21472148
def load_model(self) -> FakeModel:
21482149
return FakeModel()
21492150

21502151
def run_inference(self, batch, model, inference_args=None):
21512152
return [model.predict(x) for x in batch]
21522153

2153-
def batch_elements_kwargs(self):
2154-
kwargs = {'max_batch_size': self._max_batch_size}
2155-
if self._max_batch_weight is not None:
2156-
kwargs['max_batch_weight'] = self._max_batch_weight
2157-
if self._element_size_fn:
2158-
kwargs['element_size_fn'] = self._element_size_fn
2159-
return kwargs
2160-
21612154

21622155
class RunInferenceSizeTest(unittest.TestCase):
21632156
"""Tests for ModelHandler.batch_elements_kwargs with element_size_fn."""
@@ -2191,5 +2184,100 @@ def test_sizing_with_edge_cases(self):
21912184
self.assertEqual(kwargs['element_size_fn'](1), 1000000)
21922185

21932186

2187+
class FakeModelHandlerForBatching(base.ModelHandler[int, int, FakeModel]):
2188+
"""A ModelHandler used to test batching behavior via base class __init__."""
2189+
def __init__(self, **kwargs):
2190+
super().__init__(**kwargs)
2191+
2192+
def load_model(self) -> FakeModel:
2193+
return FakeModel()
2194+
2195+
def run_inference(self, batch, model, inference_args=None):
2196+
return [model.predict(x) for x in batch]
2197+
2198+
2199+
class ModelHandlerBatchingArgsTest(unittest.TestCase):
2200+
"""Tests for ModelHandler.__init__ batching parameters."""
2201+
def test_batch_elements_kwargs_all_args(self):
2202+
"""All batching args passed to __init__ are in batch_elements_kwargs."""
2203+
def size_fn(x):
2204+
return 10
2205+
2206+
handler = FakeModelHandlerForBatching(
2207+
min_batch_size=5,
2208+
max_batch_size=20,
2209+
max_batch_duration_secs=30,
2210+
max_batch_weight=100,
2211+
element_size_fn=size_fn)
2212+
2213+
kwargs = handler.batch_elements_kwargs()
2214+
2215+
self.assertEqual(kwargs['min_batch_size'], 5)
2216+
self.assertEqual(kwargs['max_batch_size'], 20)
2217+
self.assertEqual(kwargs['max_batch_duration_secs'], 30)
2218+
self.assertEqual(kwargs['max_batch_weight'], 100)
2219+
self.assertIn('element_size_fn', kwargs)
2220+
self.assertEqual(kwargs['element_size_fn'](1), 10)
2221+
2222+
def test_batch_elements_kwargs_partial_args(self):
2223+
"""Only provided batching args are included in kwargs."""
2224+
handler = FakeModelHandlerForBatching(max_batch_size=50)
2225+
kwargs = handler.batch_elements_kwargs()
2226+
2227+
self.assertEqual(kwargs, {'max_batch_size': 50})
2228+
2229+
def test_batch_elements_kwargs_empty_when_no_args(self):
2230+
"""No batching kwargs when none are provided."""
2231+
handler = FakeModelHandlerForBatching()
2232+
kwargs = handler.batch_elements_kwargs()
2233+
2234+
self.assertEqual(kwargs, {})
2235+
2236+
def test_large_model_sets_share_across_processes(self):
2237+
"""Setting large_model=True enables share_model_across_processes."""
2238+
handler = FakeModelHandlerForBatching(large_model=True)
2239+
2240+
self.assertTrue(handler.share_model_across_processes())
2241+
2242+
def test_model_copies_sets_share_across_processes(self):
2243+
"""Setting model_copies enables share_model_across_processes."""
2244+
handler = FakeModelHandlerForBatching(model_copies=2)
2245+
2246+
self.assertTrue(handler.share_model_across_processes())
2247+
self.assertEqual(handler.model_copies(), 2)
2248+
2249+
def test_default_share_across_processes_is_false(self):
2250+
"""Default share_model_across_processes is False."""
2251+
handler = FakeModelHandlerForBatching()
2252+
2253+
self.assertFalse(handler.share_model_across_processes())
2254+
2255+
def test_default_model_copies_is_one(self):
2256+
"""Default model_copies is 1."""
2257+
handler = FakeModelHandlerForBatching()
2258+
2259+
self.assertEqual(handler.model_copies(), 1)
2260+
2261+
def test_env_vars_from_kwargs(self):
2262+
"""Environment variables can be passed via kwargs."""
2263+
handler = FakeModelHandlerForBatching(env_vars={'MY_VAR': 'value'})
2264+
2265+
self.assertEqual(handler._env_vars, {'MY_VAR': 'value'})
2266+
2267+
def test_min_batch_size_only(self):
2268+
"""min_batch_size can be passed alone."""
2269+
handler = FakeModelHandlerForBatching(min_batch_size=10)
2270+
kwargs = handler.batch_elements_kwargs()
2271+
2272+
self.assertEqual(kwargs, {'min_batch_size': 10})
2273+
2274+
def test_max_batch_duration_secs_only(self):
2275+
"""max_batch_duration_secs can be passed alone."""
2276+
handler = FakeModelHandlerForBatching(max_batch_duration_secs=60)
2277+
kwargs = handler.batch_elements_kwargs()
2278+
2279+
self.assertEqual(kwargs, {'max_batch_duration_secs': 60})
2280+
2281+
21942282
if __name__ == '__main__':
21952283
unittest.main()

sdks/python/apache_beam/ml/inference/onnx_inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from collections.abc import Callable
1919
from collections.abc import Iterable
20-
from collections.abc import Mapping
2120
from collections.abc import Sequence
2221
from typing import Any
2322
from typing import Optional

sdks/python/apache_beam/ml/inference/tensorflow_inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def __init__(
143143
onto your machine. This can be useful if you exactly know your CPU or
144144
GPU capacity and want to maximize resource utilization.
145145
max_batch_weight: the maximum total weight of a batch.
146-
element_size_fn: a function that returns the size (weight) of an element.
146+
element_size_fn: a function that returns the size (weight) of an
147+
element.
147148
kwargs: 'env_vars' can be used to set environment variables
148149
before loading the model.
149150
@@ -275,7 +276,8 @@ def __init__(
275276
onto your machine. This can be useful if you exactly know your CPU or
276277
GPU capacity and want to maximize resource utilization.
277278
max_batch_weight: the maximum total weight of a batch.
278-
element_size_fn: a function that returns the size (weight) of an element.
279+
element_size_fn: a function that returns the size (weight) of an
280+
element.
279281
kwargs: 'env_vars' can be used to set environment variables
280282
before loading the model.
281283

sdks/python/apache_beam/ml/inference/xgboost_inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from abc import ABC
2020
from collections.abc import Callable
2121
from collections.abc import Iterable
22-
from collections.abc import Mapping
2322
from collections.abc import Sequence
2423
from typing import Any
2524
from typing import Optional

0 commit comments

Comments
 (0)