Skip to content

Commit a002cfa

Browse files
authored
Merge pull request #1187 from weaviate/fix_optional_moduleconfig_for_custom
Make moduleconfig optional for custom reranker/generative
2 parents fc85585 + ffe9afc commit a002cfa

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
exclude: ^proto/
22
repos:
33
- repo: https://github.com/psf/black
4-
rev: 24.4.2
4+
rev: 23.9.1
55
hooks:
66
- id: black
77

requirements-devel.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ flake8
4646
flake8-bugbear==24.4.26
4747
flake8-comprehensions==3.15.0
4848
flake8-builtins==2.5.0
49+
black==23.9.1

weaviate/collections/classes/config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,11 @@ class _GenerativeAnyscale(_GenerativeConfigCreate):
378378

379379

380380
class _GenerativeCustom(_GenerativeConfigCreate):
381-
module_config: Dict[str, Any]
381+
module_config: Optional[Dict[str, Any]]
382382

383383
def _to_dict(self) -> Dict[str, Any]:
384+
if self.module_config is None:
385+
return {}
384386
return self.module_config
385387

386388

@@ -495,9 +497,11 @@ class _RerankerCohereConfig(_RerankerConfigCreate):
495497

496498

497499
class _RerankerCustomConfig(_RerankerConfigCreate):
498-
module_config: Dict[str, Any]
500+
module_config: Optional[Dict[str, Any]]
499501

500502
def _to_dict(self) -> Dict[str, Any]:
503+
if self.module_config is None:
504+
return {}
501505
return self.module_config
502506

503507

@@ -534,7 +538,7 @@ def anyscale(
534538
@staticmethod
535539
def custom(
536540
module_name: str,
537-
module_config: Dict[str, Any],
541+
module_config: Optional[Dict[str, Any]] = None,
538542
) -> _GenerativeConfigCreate:
539543
"""Create a `_GenerativeCustom` object for use when generating using a custom module.
540544
@@ -799,7 +803,9 @@ def transformers() -> _RerankerConfigCreate:
799803
return _RerankerTransformersConfig(reranker=Rerankers.TRANSFORMERS)
800804

801805
@staticmethod
802-
def custom(module_name: str, module_config: Dict[str, Any]) -> _RerankerConfigCreate:
806+
def custom(
807+
module_name: str, module_config: Optional[Dict[str, Any]] = None
808+
) -> _RerankerConfigCreate:
803809
"""Create a `_RerankerCustomConfig` object for use when reranking using a custom module.
804810
805811
Arguments:

0 commit comments

Comments
 (0)