Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4336,6 +4336,67 @@ def my_pipeline():
pipeline_func=my_pipeline, package_path=output_yaml)


class TestPipelineSemaphoreMutex(unittest.TestCase):

def test_pipeline_with_semaphore(self):
"""Test that pipeline config correctly sets the semaphore key."""
config = PipelineConfig()
config.semaphore_key = 'semaphore'

@dsl.pipeline(pipeline_config=config)
def my_pipeline():
task = comp()

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_docs = list(yaml.safe_load_all(f))

platform_spec = None
for doc in pipeline_docs:
if 'platforms' in doc:
platform_spec = doc
break

self.assertIsNotNone(platform_spec,
'No platforms section found in compiled output')
kubernetes_spec = platform_spec['platforms']['kubernetes'][
'pipelineConfig']
self.assertEqual(kubernetes_spec['semaphoreKey'], 'semaphore')

def test_pipeline_with_mutex(self):
"""Test that pipeline config correctly sets the mutex name."""
config = PipelineConfig()
config.mutex_name = 'mutex'

@dsl.pipeline(pipeline_config=config)
def my_pipeline():
task = comp()

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_docs = list(yaml.safe_load_all(f))

platform_spec = None
for doc in pipeline_docs:
if 'platforms' in doc:
platform_spec = doc
break

self.assertIsNotNone(platform_spec,
'No platforms section found in compiled output')
kubernetes_spec = platform_spec['platforms']['kubernetes'][
'pipelineConfig']
self.assertEqual(kubernetes_spec['mutexName'], 'mutex')


class ExtractInputOutputDescription(unittest.TestCase):

def test_no_descriptions(self):
Expand Down
18 changes: 12 additions & 6 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,14 +2242,20 @@ def _write_kubernetes_manifest_to_file(

def _merge_pipeline_config(pipelineConfig: pipeline_config.PipelineConfig,
platformSpec: pipeline_spec_pb2.PlatformSpec):
config_dict = {}

workspace = pipelineConfig.workspace
if workspace is None:
return platformSpec
if workspace is not None:
config_dict['workspace'] = workspace.get_workspace()

json_format.ParseDict(
{'pipelineConfig': {
'workspace': workspace.get_workspace(),
}}, platformSpec.platforms['kubernetes'])
if pipelineConfig.semaphore_key is not None:
config_dict['semaphoreKey'] = pipelineConfig.semaphore_key
if pipelineConfig.mutex_name is not None:
config_dict['mutexName'] = pipelineConfig.mutex_name

if config_dict:
json_format.ParseDict({'pipelineConfig': config_dict},
platformSpec.platforms['kubernetes'])

return platformSpec

Expand Down
54 changes: 53 additions & 1 deletion sdk/python/kfp/dsl/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,57 @@ def set_kubernetes_config(self,
class PipelineConfig:
"""PipelineConfig contains pipeline-level config options."""

def __init__(self, workspace: Optional[WorkspaceConfig] = None):
def __init__(self,
workspace: Optional[WorkspaceConfig] = None,
semaphore_key: Optional[str] = None,
mutex_name: Optional[str] = None):
self.workspace = workspace
self._semaphore_key = semaphore_key
self._mutex_name = mutex_name

@property
def semaphore_key(self) -> Optional[str]:
"""Get the semaphore key for controlling pipeline concurrency.

Returns:
Optional[str]: The semaphore key, or None if not set.
"""
return self._semaphore_key

@semaphore_key.setter
def semaphore_key(self, value: str):
"""Set the semaphore key to control pipeline concurrency.

Pipelines with the same semaphore key will be limited to a configured maximum
number of concurrent executions. This allows you to control resource usage by
ensuring that only a specific number of pipelines can run simultaneously.

Note: A pipeline can use both semaphores and mutexes together. The pipeline
will wait until all required locks are available before starting.

Args:
value (str): The semaphore key name for controlling concurrent executions.
"""
self._semaphore_key = (value and value.strip()) or None

@property
def mutex_name(self) -> Optional[str]:
"""Get the mutex name for exclusive pipeline execution.

Returns:
Optional[str]: The mutex name, or None if not set.
"""
return self._mutex_name

@mutex_name.setter
def mutex_name(self, value: str):
"""Set the name of the mutex to ensure mutual exclusion.

Pipelines with the same mutex name will only run one at a time. This ensures
exclusive access to shared resources and prevents conflicts when multiple
pipelines would otherwise compete for the same resources.

Args:
value (str): Name of the mutex for exclusive pipeline execution.
"""
self._mutex_name = (value and value.strip()) or None
Loading