Skip to content

Commit 2090549

Browse files
committed
Add misc and component store
1 parent 7aaeab0 commit 2090549

File tree

4 files changed

+200
-0
lines changed

4 files changed

+200
-0
lines changed

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
# have to explicitly bring these in here to resolve circular import issues
1515
from .aliases import alias, resolve_name
16+
from .component_store import ComponentStore
1617
from .decorators import MethodReplacer, RestartGenerator
1718
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1819
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather

monai/utils/component_store.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections import namedtuple
15+
from keyword import iskeyword
16+
from textwrap import dedent, indent
17+
from typing import Any, Callable, Iterable, TypeVar
18+
19+
T = TypeVar("T")
20+
21+
22+
def is_variable(name):
23+
"""Returns True if `name` is a valid Python variable name and also not a keyword."""
24+
return name.isidentifier() and not iskeyword(name)
25+
26+
27+
class ComponentStore:
28+
"""
29+
Represents a storage object for other objects (specifically functions) keyed to a name with a description.
30+
31+
These objects act as global named places for storing components for objects parameterised by component names.
32+
Typically this is functions although other objects can be added. Printing a component store will produce a
33+
list of members along with their docstring information if present.
34+
35+
Example:
36+
37+
.. code-block:: python
38+
39+
TestStore = ComponentStore("Test Store", "A test store for demo purposes")
40+
41+
@TestStore.add_def("my_func_name", "Some description of your function")
42+
def _my_func(a, b):
43+
'''A description of your function here.'''
44+
return a * b
45+
46+
print(TestStore) # will print out name, description, and 'my_func_name' with the docstring
47+
48+
func = TestStore["my_func_name"]
49+
result = func(7, 6)
50+
51+
"""
52+
53+
_Component = namedtuple("Component", ("description", "value")) # internal value pair
54+
55+
def __init__(self, name: str, description: str) -> None:
56+
self.components: dict[str, self._Component] = {}
57+
self.name: str = name
58+
self.description: str = description
59+
60+
self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()
61+
62+
def add(self, name: str, desc: str, value: T) -> T:
63+
"""Store the object `value` under the name `name` with description `desc`."""
64+
if not is_variable(name):
65+
raise ValueError("Name of component must be valid Python identifier")
66+
67+
self.components[name] = self._Component(desc, value)
68+
return value
69+
70+
def add_def(self, name: str, desc: str) -> Callable:
71+
"""Returns a decorator which stores the decorated function under `name` with description `desc`."""
72+
73+
def deco(func):
74+
"""Decorator to add a function to a store."""
75+
return self.add(name, desc, func)
76+
77+
return deco
78+
79+
def __contains__(self, name: str) -> bool:
80+
"""Returns True if the given name is stored."""
81+
return name in self.components
82+
83+
def __len__(self) -> int:
84+
"""Returns the number of stored components."""
85+
return len(self.components)
86+
87+
def __iter__(self) -> Iterable:
88+
"""Yields name/component pairs."""
89+
for k, v in self.components.items():
90+
yield k, v.value
91+
92+
def __str__(self):
93+
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
94+
for k, v in self.components.items():
95+
result += f"\n* {k}:"
96+
97+
if hasattr(v.value, "__doc__"):
98+
doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ")
99+
result += f"\n{doc}\n"
100+
else:
101+
result += f" {v.description}"
102+
103+
return result
104+
105+
def __getattr__(self, name: str) -> Any:
106+
"""Returns the stored object under the given name."""
107+
if name in self.components:
108+
return self.components[name].value
109+
else:
110+
return self.__getattribute__(name)
111+
112+
def __getitem__(self, name: str) -> Any:
113+
"""Returns the stored object under the given name."""
114+
if name in self.components:
115+
return self.components[name].value
116+
else:
117+
raise ValueError(f"Component '{name}' not found")

monai/utils/misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
888888
sqrt_num = [int(math.sqrt(_num)) for _num in num]
889889
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
890890
return ensure_tuple(ret) == num
891+
892+
893+
def unsqueeze_right(arr: T, ndim: int) -> T:
894+
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
895+
return arr[(...,) + (None,) * (ndim - arr.ndim)]
896+
897+
898+
def unsqueeze_left(arr: T, ndim: int) -> T:
899+
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
900+
return arr[(None,) * (ndim - arr.ndim)]

tests/test_component_store.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
from monai.utils import ComponentStore
17+
18+
19+
class TestComponentStore(unittest.TestCase):
20+
def setUp(self):
21+
self.cs = ComponentStore("TestStore", "I am a test store, please ignore")
22+
23+
def test_empty(self):
24+
self.assertEqual(len(self.cs), 0)
25+
self.assertEqual(list(self.cs), [])
26+
27+
def test_add(self):
28+
test_obj = object()
29+
30+
self.assertFalse("test_obj" in self.cs)
31+
32+
self.cs.add("test_obj", "Test object", test_obj)
33+
34+
self.assertTrue("test_obj" in self.cs)
35+
36+
self.assertEqual(len(self.cs), 1)
37+
self.assertEqual(list(self.cs), [("test_obj", test_obj)])
38+
39+
self.assertEqual(self.cs.test_obj, test_obj)
40+
self.assertEqual(self.cs["test_obj"], test_obj)
41+
42+
def test_add2(self):
43+
test_obj1 = object()
44+
test_obj2 = object()
45+
46+
self.cs.add("test_obj1", "Test object", test_obj1)
47+
self.cs.add("test_obj2", "Test object", test_obj2)
48+
49+
self.assertEqual(len(self.cs), 2)
50+
self.assertTrue("test_obj1" in self.cs)
51+
self.assertTrue("test_obj2" in self.cs)
52+
53+
def test_add_def(self):
54+
self.assertFalse("test_func" in self.cs)
55+
56+
@self.cs.add_def("test_func", "Test function")
57+
def test_func():
58+
return 123
59+
60+
self.assertTrue("test_func" in self.cs)
61+
62+
self.assertEqual(len(self.cs), 1)
63+
self.assertEqual(list(self.cs), [("test_func", test_func)])
64+
65+
self.assertEqual(self.cs.test_func, test_func)
66+
self.assertEqual(self.cs["test_func"], test_func)
67+
68+
# try adding the same function again
69+
self.cs.add_def("test_func", "Test function but with new description")(test_func)
70+
71+
self.assertEqual(len(self.cs), 1)
72+
self.assertEqual(self.cs.test_func, test_func)

0 commit comments

Comments
 (0)