Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def add_field(self, field_name: str, datatype: DataType, **kwargs):
if "mmap_enabled" in kwargs:
struct_schema._type_params["mmap_enabled"] = kwargs["mmap_enabled"]

if "warmup" in kwargs:
struct_schema._type_params["warmup"] = kwargs["warmup"]

self._struct_fields.append(struct_schema)
return self

Expand Down Expand Up @@ -489,6 +492,9 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)
if "mmap_enabled" in kwargs:
self._type_params["mmap_enabled"] = kwargs["mmap_enabled"]

if "warmup" in kwargs:
self._type_params["warmup"] = kwargs["warmup"]

for key in ["analyzer_params", "multi_analyzer_params"]:
if key in self._kwargs and isinstance(self._kwargs[key], dict):
self._kwargs[key] = orjson.dumps(self._kwargs[key]).decode(Config.EncodeProtocol)
Expand Down
31 changes: 31 additions & 0 deletions tests/orm/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ def test_mmap_enabled_param(self):
field = FieldSchema("vec", DataType.FLOAT_VECTOR, dim=128, mmap_enabled=True)
assert field._type_params.get("mmap_enabled") is True

def test_warmup_param(self):
"""Test warmup type param."""
field = FieldSchema("vec", DataType.FLOAT_VECTOR, dim=128, warmup=True)
assert field._type_params.get("warmup") is True

def test_warmup_param_false(self):
"""Test warmup type param with False value."""
field = FieldSchema("vec", DataType.FLOAT_VECTOR, dim=128, warmup=False)
assert field._type_params.get("warmup") is False

def test_analyzer_params_dict(self):
"""Test analyzer_params as dict gets serialized."""
field = FieldSchema(
Expand Down Expand Up @@ -658,6 +668,27 @@ def test_add_struct_field_with_mmap(self):
assert len(schema.struct_fields) == 1
assert schema.struct_fields[0]._type_params.get("mmap_enabled") is True

def test_add_struct_field_with_warmup(self):
"""Test adding struct field with warmup."""
struct = StructFieldSchema()
struct.add_field("score", DataType.FLOAT)
schema = CollectionSchema(
[
FieldSchema("id", DataType.INT64, is_primary=True),
FieldSchema("vec", DataType.FLOAT_VECTOR, dim=128),
]
)
schema.add_field(
"struct",
DataType.ARRAY,
element_type=DataType.STRUCT,
struct_schema=struct,
max_capacity=10,
warmup=True,
)
assert len(schema.struct_fields) == 1
assert schema.struct_fields[0]._type_params.get("warmup") is True


class TestCollectionSchemaToDict:
"""Tests for CollectionSchema to_dict and construct_from_dict methods."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_client_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,28 @@ def test_field_schema_with_mmap_disabled(self):

assert field.params["mmap_enabled"] is False

def test_field_schema_with_warmup_enabled(self):
"""Test FieldSchema with warmup type parameter."""
warmup_param = MagicMock()
warmup_param.key = "warmup"
warmup_param.value = "true"

raw = self._create_mock_raw_field(type_params=[warmup_param])
field = FieldSchema(raw)

assert field.params["warmup"] == "true"

def test_field_schema_with_warmup_disabled(self):
"""Test FieldSchema with warmup set to false."""
warmup_param = MagicMock()
warmup_param.key = "warmup"
warmup_param.value = "false"

raw = self._create_mock_raw_field(type_params=[warmup_param])
field = FieldSchema(raw)

assert field.params["warmup"] == "false"

def test_field_schema_with_json_params(self):
"""Test FieldSchema with JSON type params."""
params_param = MagicMock()
Expand Down