Skip to content

Commit 04b7d02

Browse files
committed
Support for grid search algorithm in Optuna Suggestion Service
Signed-off-by: tenzen-y <[email protected]>
1 parent f941ec6 commit 04b7d02

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

manifests/v1beta1/components/controller/katib-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ data:
3131
"image": "docker.io/kubeflowkatib/suggestion-hyperopt:latest"
3232
},
3333
"grid": {
34-
"image": "docker.io/kubeflowkatib/suggestion-chocolate:latest"
34+
"image": "docker.io/kubeflowkatib/suggestion-optuna:latest"
3535
},
3636
"hyperband": {
3737
"image": "docker.io/kubeflowkatib/suggestion-hyperband:latest"

pkg/suggestion/v1beta1/optuna/base_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def _create_sampler(self):
4848
elif self.algorithm_name == "random":
4949
return optuna.samplers.RandomSampler(**self.algorithm_config)
5050

51+
elif self.algorithm_name == "grid":
52+
return optuna.samplers.GridSampler(**self.algorithm_config)
53+
5154
def get_suggestions(self, trials, current_request_number):
5255
if len(trials) != 0:
5356
self._tell(trials)

pkg/suggestion/v1beta1/optuna/service.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class OptimizerConfiguration(object):
8686
"random": {
8787
"seed": lambda x: int(x),
8888
},
89+
"grid": {
90+
"seed": lambda x: int(x),
91+
}
8992
}
9093

9194
@classmethod
@@ -120,6 +123,8 @@ def validate_algorithm_spec(cls, algorithm_spec):
120123
return cls._validate_cmaes_setting(algorithm_settings)
121124
elif algorithm_name == "random":
122125
return cls._validate_random_setting(algorithm_settings)
126+
elif algorithm_name == "grid":
127+
return cls._validate_grid_setting(algorithm_settings)
123128

124129
@classmethod
125130
def _validate_tpe_setting(cls, algorithm_spec):
@@ -178,3 +183,19 @@ def _validate_random_setting(cls, algorithm_settings):
178183
exception=e)
179184

180185
return True, ""
186+
187+
@classmethod
188+
def _validate_grid_setting(cls, algorithm_settings):
189+
for s in algorithm_settings:
190+
try:
191+
if s.name == "seed":
192+
if not int(s.value) >= 0:
193+
return False, ""
194+
else:
195+
return False, "unknown setting {} for algorithm grid".format(s.name)
196+
197+
except Exception as e:
198+
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
199+
exception=e)
200+
201+
return True, ""

0 commit comments

Comments
 (0)