Skip to content

Commit 9d5fa77

Browse files
add tests
Signed-off-by: wzliu <wzliu@connect.hku.hk>
1 parent ae367dd commit 9d5fa77

1 file changed

Lines changed: 100 additions & 0 deletions

File tree

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import unittest
4+
from unittest.mock import patch
5+
6+
from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector
7+
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
8+
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
9+
from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer
10+
11+
12+
class TestOmniSerializer(unittest.TestCase):
13+
def test_pickle_serialization(self):
14+
"""Test basic pickle serialization."""
15+
data = {"key": "value", "list": [1, 2, 3]}
16+
serialized = OmniSerializer.serialize(data, method="cloudpickle")
17+
self.assertIsInstance(serialized, bytes)
18+
19+
deserialized = OmniSerializer.deserialize(serialized, method="cloudpickle")
20+
self.assertEqual(data, deserialized)
21+
22+
23+
class TestOmniConnectorFactory(unittest.TestCase):
24+
def test_create_shm_connector(self):
25+
"""Test creating SharedMemoryConnector via Factory."""
26+
spec = ConnectorSpec(name="SharedMemoryConnector", extra={"shm_threshold_bytes": 1024})
27+
connector = OmniConnectorFactory.create_connector(spec)
28+
self.assertIsInstance(connector, SharedMemoryConnector)
29+
self.assertEqual(connector.threshold, 1024)
30+
31+
def test_create_unknown_connector(self):
32+
"""Test error when creating unknown connector."""
33+
spec = ConnectorSpec(name="UnknownConnector")
34+
with self.assertRaises(ValueError):
35+
OmniConnectorFactory.create_connector(spec)
36+
37+
38+
class TestSharedMemoryConnector(unittest.TestCase):
39+
def setUp(self):
40+
self.config = {"shm_threshold_bytes": 100} # Small threshold for testing
41+
self.connector = SharedMemoryConnector(self.config)
42+
43+
def test_put_get_inline(self):
44+
"""Test inline transfer for small data."""
45+
data = {"small": "data"}
46+
# Ensure data is smaller than threshold (100 bytes)
47+
48+
success, size, metadata = self.connector.put("stage_0", "stage_1", "req_1", data)
49+
self.assertTrue(success)
50+
self.assertIn("inline_bytes", metadata)
51+
self.assertNotIn("shm", metadata)
52+
53+
# Retrieve
54+
retrieved_data, ret_size = self.connector.get("stage_0", "stage_1", "req_1", metadata)
55+
self.assertEqual(data, retrieved_data)
56+
self.assertEqual(size, ret_size)
57+
58+
@patch("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_write_bytes")
59+
@patch("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_read_bytes")
60+
def test_put_get_shm(self, mock_read, mock_write):
61+
"""Test SHM transfer logic for large data (Mocked)."""
62+
# Create data larger than 100 bytes
63+
data = {"large": "x" * 200}
64+
65+
# Mock SHM return values
66+
mock_handle = {"name": "test_shm", "size": 200}
67+
mock_write.return_value = mock_handle
68+
69+
# When reading, return the serialized bytes of the data
70+
serialized_data = self.connector.serialize_obj(data)
71+
mock_read.return_value = serialized_data
72+
73+
# Put
74+
success, size, metadata = self.connector.put("stage_0", "stage_1", "req_2", data)
75+
76+
self.assertTrue(success)
77+
# Should use SHM because data > threshold
78+
self.assertIn("shm", metadata)
79+
self.assertEqual(metadata["shm"], mock_handle)
80+
self.assertNotIn("inline_bytes", metadata)
81+
82+
mock_write.assert_called_once()
83+
84+
# Get
85+
retrieved_data, ret_size = self.connector.get("stage_0", "stage_1", "req_2", metadata)
86+
87+
self.assertEqual(data, retrieved_data)
88+
mock_read.assert_called_once_with(mock_handle)
89+
90+
def test_get_invalid_metadata(self):
91+
"""Test get with invalid metadata."""
92+
result = self.connector.get("stage_0", "stage_1", "req_3", {})
93+
self.assertIsNone(result)
94+
95+
result = self.connector.get("stage_0", "stage_1", "req_3", {"unknown": "format"})
96+
self.assertIsNone(result)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

0 commit comments

Comments
 (0)