Skip to content

Commit 9e589a6

Browse files
[SDK] Add UTs for wait_for_job_conditions (kubeflow/trainer#2196)
* test(sdk): add unit test for wait_for_job_conditions. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(sdk): fix lint error. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(sdk): add patch for load_kube_config. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(trial): fix lint error with black. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(trial): add package dependency. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(sdk): reuse exisiting fixture. Signed-off-by: Electronic-Waste <2690692950@qq.com> --------- Signed-off-by: Electronic-Waste <2690692950@qq.com>
1 parent 9004459 commit 9e589a6

1 file changed

Lines changed: 97 additions & 0 deletions

File tree

python/kubeflow/training/api/training_client_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import pytest
66
from kubeflow.training import (
7+
KubeflowOrgV1JobCondition,
8+
KubeflowOrgV1JobStatus,
79
KubeflowOrgV1PyTorchJob,
810
KubeflowOrgV1PyTorchJobSpec,
911
KubeflowOrgV1ReplicaSpec,
@@ -70,6 +72,13 @@ def get(self, timeout):
7072
return MockResponse()
7173

7274

75+
def get_job_response(*args, **kwargs):
76+
if kwargs.get("namespace") == RUNTIME:
77+
return generate_job_with_status(create_job(), constants.JOB_CONDITION_FAILED)
78+
else:
79+
return generate_job_with_status(create_job())
80+
81+
7382
def generate_container() -> V1Container:
7483
return V1Container(
7584
name="pytorch",
@@ -127,6 +136,21 @@ def create_job():
127136
return pytorchjob
128137

129138

139+
def generate_job_with_status(
140+
job: constants.JOB_MODELS_TYPE,
141+
condition_type: str = constants.JOB_CONDITION_SUCCEEDED,
142+
) -> constants.JOB_MODELS_TYPE:
143+
job.status = KubeflowOrgV1JobStatus(
144+
conditions=[
145+
KubeflowOrgV1JobCondition(
146+
type=condition_type,
147+
status=constants.CONDITION_STATUS_TRUE,
148+
)
149+
]
150+
)
151+
return job
152+
153+
130154
class DummyJobClass:
131155
def __init__(self, kind) -> None:
132156
self.kind = kind
@@ -279,6 +303,61 @@ def __init__(self, kind) -> None:
279303
),
280304
]
281305

306+
test_data_wait_for_job_conditions = [
307+
(
308+
"timeout waiting for succeeded condition",
309+
{
310+
"name": TEST_NAME,
311+
"namespace": TIMEOUT,
312+
"wait_timeout": 0,
313+
},
314+
TimeoutError,
315+
),
316+
(
317+
"invalid expected condition",
318+
{
319+
"name": TEST_NAME,
320+
"namespace": "value",
321+
"expected_conditions": {"invalid"},
322+
},
323+
ValueError,
324+
),
325+
(
326+
"invalid expected condition(lowercase)",
327+
{
328+
"name": TEST_NAME,
329+
"namespace": "value",
330+
"expected_conditions": {"succeeded"},
331+
},
332+
ValueError,
333+
),
334+
(
335+
"job failed unexpectedly",
336+
{
337+
"name": TEST_NAME,
338+
"namespace": RUNTIME,
339+
},
340+
RuntimeError,
341+
),
342+
(
343+
"valid case",
344+
{
345+
"name": TEST_NAME,
346+
"namespace": "test-namespace",
347+
},
348+
generate_job_with_status(create_job()),
349+
),
350+
(
351+
"valid case with specified callback",
352+
{
353+
"name": TEST_NAME,
354+
"namespace": "test-namespace",
355+
"callback": lambda job: "test train function",
356+
},
357+
generate_job_with_status(create_job()),
358+
),
359+
]
360+
282361

283362
test_data_get_job_pod_names = [
284363
(
@@ -354,6 +433,8 @@ def training_client():
354433
),
355434
), patch(
356435
"kubernetes.config.load_kube_config", return_value=Mock()
436+
), patch.object(
437+
TrainingClient, "get_job", side_effect=get_job_response
357438
):
358439
client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND)
359440
yield client
@@ -434,3 +515,19 @@ def test_update_job(training_client, test_name, kwargs, expected_output):
434515
except Exception as e:
435516
assert type(e) is expected_output
436517
print("test execution complete")
518+
519+
520+
@pytest.mark.parametrize(
521+
"test_name,kwargs,expected_output", test_data_wait_for_job_conditions
522+
)
523+
def test_wait_for_job_conditions(training_client, test_name, kwargs, expected_output):
524+
"""
525+
test wait_for_job_conditions function of training client
526+
"""
527+
print("Executing test:", test_name)
528+
try:
529+
out = training_client.wait_for_job_conditions(**kwargs)
530+
assert out == expected_output
531+
except Exception as e:
532+
assert type(e) is expected_output
533+
print("test execution complete")

0 commit comments

Comments
 (0)