From ef6d6e96be91366234e907d6ae25c9bf855e48f0 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 22:32:39 +0100 Subject: [PATCH 01/22] Add new black-check and black-format Make targets which verify code comforms to black formatting rules and run it on CI. --- Makefile | 59 ++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 13 ++++++++++ test-requirements.txt | 1 + 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 pyproject.toml diff --git a/Makefile b/Makefile index 65ca6204ae..abf89f450b 100644 --- a/Makefile +++ b/Makefile @@ -326,6 +326,63 @@ schemasgen: requirements .schemasgen . $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc --load-plugins=pylint_plugins.api_models tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc pylint_plugins/*.py || exit 1; +# Black task which checks if the code comforts to black code style +.PHONY: black-check +black: requirements .black-check + +.PHONY: .black-check +.black: + @echo + @echo "================== black-check ====================" + @echo + # st2 components + @for component in $(COMPONENTS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ + done + # runner modules and packages + @for component in $(COMPONENTS_RUNNERS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ + done + # Python pack management actions + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/* || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml scripts/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml tools/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml pylint_plugins/*.py || exit 1; + +# Black task which reformats the code using black +.PHONY: black-format +black: requirements .black-format + +.PHONY: .black-format +.black-format: + @echo + @echo "================== black ====================" + @echo + # st2 components + @for component in $(COMPONENTS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \ + done + # runner modules and packages + @for component in $(COMPONENTS_RUNNERS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \ + done + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml contrib/ || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml scripts/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml tools/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml pylint_plugins/*.py || exit 1; + .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec @@ -979,7 +1036,7 @@ debs: ci: ci-checks ci-unit ci-integration ci-packs-tests .PHONY: ci-checks -ci-checks: .generated-files-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages +ci-checks: .generated-files-check .black-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages .PHONY: .rst-check .rst-check: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..4d03482994 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.black] +max-line-length = 100 +target_version = ['py36'] +include = '\.pyi?$' +exclude = ''' +( + /( + | \.git + | \.virtualenv + | __pycache__ + )/ +) +''' diff --git a/test-requirements.txt b/test-requirements.txt index 6ca0e9608d..b1909e4535 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,6 +5,7 @@ st2flake8==0.1.0 astroid==2.4.2 pylint==2.6.0 pylint-plugin-utils>=0.4 +black==20.8b1 bandit==1.5.1 ipython<6.0.0 isort>=4.2.5 From 8496bb2407b969f0937431992172b98b545f6756 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Wed, 17 Feb 2021 22:34:26 +0100 Subject: [PATCH 02/22] Reformat all the code using black tool. --- .../actions/format_execution_result.py | 43 +- contrib/chatops/actions/match.py | 17 +- contrib/chatops/actions/match_and_execute.py | 27 +- contrib/chatops/tests/test_format_result.py | 28 +- contrib/core/actions/generate_uuid.py | 8 +- contrib/core/actions/inject_trigger.py | 13 +- contrib/core/actions/pause.py | 4 +- .../core/tests/test_action_inject_trigger.py | 26 +- contrib/core/tests/test_action_sendmail.py | 265 +-- contrib/core/tests/test_action_uuid.py | 6 +- contrib/examples/actions/noop.py | 4 +- contrib/examples/actions/print_config.py | 4 +- .../actions/print_to_stdout_and_stderr.py | 6 +- .../actions/python-mock-core-remote.py | 13 +- .../examples/actions/python-mock-create-vm.py | 9 +- .../actions/pythonactions/fibonacci.py | 5 +- ...loop_increase_index_and_check_condition.py | 6 +- .../forloop_parse_github_repos.py | 4 +- .../examples/actions/pythonactions/isprime.py | 13 +- .../pythonactions/json_string_to_object.py | 1 - .../actions/pythonactions/object_return.py | 3 +- .../pythonactions/print_python_environment.py | 11 +- .../pythonactions/print_python_version.py | 5 +- .../pythonactions/yaml_string_to_object.py | 1 - .../ubuntu_pkg_info/lib/datatransformer.py | 8 +- .../ubuntu_pkg_info/ubuntu_pkg_info.py | 11 +- contrib/examples/sensors/echo_flask_app.py | 21 +- contrib/examples/sensors/fibonacci_sensor.py | 15 +- contrib/hello_st2/sensors/sensor1.py | 10 +- contrib/linux/actions/checks/check_loadavg.py | 22 +- .../linux/actions/checks/check_processes.py | 29 +- contrib/linux/actions/dig.py | 28 +- contrib/linux/actions/service.py | 23 +- contrib/linux/actions/wait_for_ssh.py | 40 +- contrib/linux/sensors/file_watch_sensor.py | 24 +- contrib/linux/tests/test_action_dig.py | 18 +- contrib/packs/actions/get_config.py | 4 +- contrib/packs/actions/pack_mgmt/delete.py | 19 +- contrib/packs/actions/pack_mgmt/download.py | 93 +- .../packs/actions/pack_mgmt/get_installed.py | 33 +- .../pack_mgmt/get_pack_dependencies.py | 51 +- .../actions/pack_mgmt/get_pack_warnings.py | 6 +- contrib/packs/actions/pack_mgmt/register.py | 77 +- contrib/packs/actions/pack_mgmt/search.py | 48 +- .../actions/pack_mgmt/setup_virtualenv.py | 71 +- .../packs/actions/pack_mgmt/show_remote.py | 5 +- contrib/packs/actions/pack_mgmt/unload.py | 73 +- .../pack_mgmt/virtualenv_setup_prerun.py | 2 +- contrib/packs/tests/test_action_aliases.py | 60 +- contrib/packs/tests/test_action_download.py | 535 +++--- contrib/packs/tests/test_action_unload.py | 19 +- .../packs/tests/test_get_pack_dependencies.py | 103 +- contrib/packs/tests/test_get_pack_warnings.py | 45 +- .../tests/test_virtualenv_setup_prerun.py | 23 +- .../action_chain_runner/__init__.py | 2 +- .../action_chain_runner.py | 570 +++--- .../runners/action_chain_runner/dist_utils.py | 65 +- contrib/runners/action_chain_runner/setup.py | 30 +- .../tests/unit/test_actionchain.py | 835 +++++---- .../tests/unit/test_actionchain_cancel.py | 164 +- .../unit/test_actionchain_notifications.py | 56 +- .../unit/test_actionchain_params_rendering.py | 100 +- .../unit/test_actionchain_pause_resume.py | 596 +++--- .../announcement_runner/__init__.py | 2 +- .../announcement_runner.py | 35 +- .../runners/announcement_runner/dist_utils.py | 65 +- contrib/runners/announcement_runner/setup.py | 28 +- .../tests/unit/test_announcementrunner.py | 64 +- contrib/runners/http_runner/dist_utils.py | 65 +- .../http_runner/http_runner/__init__.py | 2 +- .../http_runner/http_runner/http_runner.py | 199 +- contrib/runners/http_runner/setup.py | 28 +- .../tests/unit/test_http_runner.py | 337 ++-- contrib/runners/inquirer_runner/dist_utils.py | 65 +- .../inquirer_runner/__init__.py | 2 +- .../inquirer_runner/inquirer_runner.py | 47 +- contrib/runners/inquirer_runner/setup.py | 28 +- .../tests/unit/test_inquirer_runner.py | 96 +- contrib/runners/local_runner/dist_utils.py | 65 +- .../local_runner/local_runner/__init__.py | 2 +- .../runners/local_runner/local_runner/base.py | 169 +- .../local_shell_command_runner.py | 36 +- .../local_runner/local_shell_script_runner.py | 42 +- contrib/runners/local_runner/setup.py | 32 +- .../tests/integration/test_localrunner.py | 453 ++--- contrib/runners/noop_runner/dist_utils.py | 65 +- .../noop_runner/noop_runner/__init__.py | 2 +- .../noop_runner/noop_runner/noop_runner.py | 25 +- contrib/runners/noop_runner/setup.py | 26 +- .../noop_runner/tests/unit/test_nooprunner.py | 14 +- contrib/runners/orquesta_runner/dist_utils.py | 65 +- .../orquesta_functions/runtime.py | 44 +- .../orquesta_functions/st2kv.py | 20 +- .../orquesta_runner/__init__.py | 2 +- .../orquesta_runner/orquesta_runner.py | 132 +- contrib/runners/orquesta_runner/setup.py | 90 +- .../test_wiring_functions_st2kv.py | 63 +- .../orquesta_runner/tests/unit/base.py | 12 +- .../orquesta_runner/tests/unit/test_basic.py | 339 ++-- .../orquesta_runner/tests/unit/test_cancel.py | 182 +- .../tests/unit/test_context.py | 285 +-- .../tests/unit/test_data_flow.py | 92 +- .../orquesta_runner/tests/unit/test_delay.py | 114 +- .../tests/unit/test_error_handling.py | 779 ++++---- .../tests/unit/test_functions_common.py | 214 +-- .../tests/unit/test_functions_st2kv.py | 118 +- .../tests/unit/test_functions_task.py | 208 ++- .../tests/unit/test_inquiries.py | 365 ++-- .../orquesta_runner/tests/unit/test_notify.py | 228 +-- .../tests/unit/test_output_schema.py | 112 +- .../tests/unit/test_pause_and_resume.py | 660 ++++--- .../tests/unit/test_policies.py | 107 +- .../orquesta_runner/tests/unit/test_rerun.py | 398 ++-- .../tests/unit/test_with_items.py | 293 +-- contrib/runners/python_runner/dist_utils.py | 65 +- .../python_runner/python_runner/__init__.py | 2 +- .../python_runner/python_action_wrapper.py | 196 +- .../python_runner/python_runner.py | 226 ++- contrib/runners/python_runner/setup.py | 26 +- .../test_python_action_process_wrapper.py | 91 +- .../integration/test_pythonrunner_behavior.py | 43 +- .../tests/unit/test_output_schema.py | 35 +- .../tests/unit/test_pythonrunner.py | 841 +++++---- contrib/runners/remote_runner/dist_utils.py | 65 +- .../remote_runner/remote_runner/__init__.py | 2 +- .../remote_runner/remote_command_runner.py | 68 +- .../remote_runner/remote_script_runner.py | 154 +- contrib/runners/remote_runner/setup.py | 32 +- contrib/runners/winrm_runner/dist_utils.py | 65 +- contrib/runners/winrm_runner/setup.py | 34 +- .../tests/unit/test_winrm_base.py | 1180 ++++++------ .../tests/unit/test_winrm_command_runner.py | 15 +- .../unit/test_winrm_ps_command_runner.py | 15 +- .../tests/unit/test_winrm_ps_script_runner.py | 29 +- .../winrm_runner/winrm_runner/__init__.py | 2 +- .../winrm_runner/winrm_runner/winrm_base.py | 218 ++- .../winrm_runner/winrm_command_runner.py | 18 +- .../winrm_runner/winrm_ps_command_runner.py | 18 +- .../winrm_runner/winrm_ps_script_runner.py | 20 +- lint-configs/python/.flake8 | 5 +- pylint_plugins/api_models.py | 40 +- pylint_plugins/db_models.py | 13 +- scripts/dist_utils.py | 65 +- scripts/dist_utils_old.py | 47 +- scripts/fixate-requirements.py | 122 +- st2actions/dist_utils.py | 65 +- st2actions/setup.py | 28 +- st2actions/st2actions/__init__.py | 2 +- st2actions/st2actions/cmd/actionrunner.py | 32 +- st2actions/st2actions/cmd/scheduler.py | 41 +- st2actions/st2actions/cmd/st2notifier.py | 27 +- st2actions/st2actions/cmd/workflow_engine.py | 23 +- st2actions/st2actions/config.py | 7 +- st2actions/st2actions/container/base.py | 230 ++- st2actions/st2actions/notifier/config.py | 15 +- st2actions/st2actions/notifier/notifier.py | 220 ++- st2actions/st2actions/policies/concurrency.py | 59 +- .../policies/concurrency_by_attr.py | 84 +- st2actions/st2actions/policies/retry.py | 118 +- st2actions/st2actions/runners/pythonrunner.py | 4 +- st2actions/st2actions/scheduler/config.py | 57 +- st2actions/st2actions/scheduler/entrypoint.py | 37 +- st2actions/st2actions/scheduler/handler.py | 199 +- st2actions/st2actions/worker.py | 160 +- st2actions/st2actions/workflows/config.py | 15 +- st2actions/st2actions/workflows/workflows.py | 62 +- st2actions/tests/unit/policies/test_base.py | 62 +- .../tests/unit/policies/test_concurrency.py | 240 ++- .../unit/policies/test_concurrency_by_attr.py | 249 ++- .../tests/unit/policies/test_retry_policy.py | 120 +- .../tests/unit/test_action_runner_worker.py | 13 +- .../tests/unit/test_actions_registrar.py | 178 +- st2actions/tests/unit/test_async_runner.py | 16 +- .../tests/unit/test_execution_cancellation.py | 190 +- st2actions/tests/unit/test_executions.py | 125 +- st2actions/tests/unit/test_notifier.py | 429 +++-- st2actions/tests/unit/test_parallel_ssh.py | 404 +++-- .../test_paramiko_remote_script_runner.py | 293 +-- st2actions/tests/unit/test_paramiko_ssh.py | 975 ++++++---- .../tests/unit/test_paramiko_ssh_runner.py | 212 ++- st2actions/tests/unit/test_policies.py | 98 +- .../tests/unit/test_polling_async_runner.py | 16 +- st2actions/tests/unit/test_queue_consumers.py | 65 +- st2actions/tests/unit/test_remote_runners.py | 25 +- .../tests/unit/test_runner_container.py | 278 +-- st2actions/tests/unit/test_scheduler.py | 113 +- .../tests/unit/test_scheduler_entrypoint.py | 41 +- st2actions/tests/unit/test_scheduler_retry.py | 103 +- st2actions/tests/unit/test_worker.py | 79 +- st2actions/tests/unit/test_workflow_engine.py | 140 +- st2api/dist_utils.py | 65 +- st2api/setup.py | 22 +- st2api/st2api/__init__.py | 2 +- st2api/st2api/app.py | 53 +- st2api/st2api/cmd/__init__.py | 2 +- st2api/st2api/cmd/api.py | 35 +- st2api/st2api/config.py | 49 +- st2api/st2api/controllers/base.py | 10 +- .../controllers/controller_transforms.py | 8 +- st2api/st2api/controllers/resource.py | 473 +++-- st2api/st2api/controllers/root.py | 16 +- st2api/st2api/controllers/v1/action_views.py | 155 +- st2api/st2api/controllers/v1/actionalias.py | 202 ++- .../st2api/controllers/v1/actionexecutions.py | 694 ++++--- st2api/st2api/controllers/v1/actions.py | 212 ++- .../st2api/controllers/v1/aliasexecution.py | 198 +- st2api/st2api/controllers/v1/auth.py | 116 +- .../st2api/controllers/v1/execution_views.py | 46 +- st2api/st2api/controllers/v1/inquiries.py | 90 +- st2api/st2api/controllers/v1/keyvalue.py | 222 ++- .../controllers/v1/pack_config_schemas.py | 24 +- st2api/st2api/controllers/v1/pack_configs.py | 61 +- st2api/st2api/controllers/v1/pack_views.py | 98 +- st2api/st2api/controllers/v1/packs.py | 252 +-- st2api/st2api/controllers/v1/policies.py | 215 ++- st2api/st2api/controllers/v1/rbac.py | 90 +- .../controllers/v1/rule_enforcement_views.py | 93 +- .../controllers/v1/rule_enforcements.py | 69 +- st2api/st2api/controllers/v1/rule_views.py | 122 +- st2api/st2api/controllers/v1/rules.py | 198 +- st2api/st2api/controllers/v1/ruletypes.py | 25 +- st2api/st2api/controllers/v1/runnertypes.py | 85 +- st2api/st2api/controllers/v1/sensors.py | 74 +- .../st2api/controllers/v1/service_registry.py | 30 +- st2api/st2api/controllers/v1/timers.py | 52 +- st2api/st2api/controllers/v1/traces.py | 60 +- st2api/st2api/controllers/v1/triggers.py | 374 ++-- st2api/st2api/controllers/v1/user.py | 28 +- st2api/st2api/controllers/v1/webhooks.py | 99 +- .../controllers/v1/workflow_inspection.py | 11 +- st2api/st2api/validation.py | 25 +- st2api/st2api/wsgi.py | 8 +- .../integration/test_gunicorn_configs.py | 24 +- st2api/tests/unit/controllers/test_root.py | 10 +- .../unit/controllers/v1/test_action_alias.py | 199 +- .../unit/controllers/v1/test_action_views.py | 290 +-- .../tests/unit/controllers/v1/test_actions.py | 689 +++---- .../controllers/v1/test_alias_execution.py | 374 ++-- st2api/tests/unit/controllers/v1/test_auth.py | 186 +- .../unit/controllers/v1/test_auth_api_keys.py | 271 +-- st2api/tests/unit/controllers/v1/test_base.py | 90 +- .../unit/controllers/v1/test_executions.py | 1470 ++++++++------- .../controllers/v1/test_executions_auth.py | 232 +-- .../v1/test_executions_descendants.py | 65 +- .../controllers/v1/test_executions_filters.py | 330 ++-- .../unit/controllers/v1/test_inquiries.py | 235 ++- st2api/tests/unit/controllers/v1/test_kvps.py | 639 +++---- .../controllers/v1/test_pack_config_schema.py | 39 +- .../unit/controllers/v1/test_pack_configs.py | 92 +- .../tests/unit/controllers/v1/test_packs.py | 697 +++---- .../unit/controllers/v1/test_packs_views.py | 94 +- .../unit/controllers/v1/test_policies.py | 161 +- .../v1/test_rule_enforcement_views.py | 126 +- .../controllers/v1/test_rule_enforcements.py | 70 +- .../unit/controllers/v1/test_rule_views.py | 65 +- .../tests/unit/controllers/v1/test_rules.py | 442 +++-- .../unit/controllers/v1/test_ruletypes.py | 24 +- .../unit/controllers/v1/test_runnertypes.py | 67 +- .../unit/controllers/v1/test_sensortypes.py | 103 +- .../controllers/v1/test_service_registry.py | 56 +- .../tests/unit/controllers/v1/test_timers.py | 61 +- .../tests/unit/controllers/v1/test_traces.py | 180 +- .../controllers/v1/test_triggerinstances.py | 195 +- .../unit/controllers/v1/test_triggers.py | 99 +- .../unit/controllers/v1/test_triggertypes.py | 67 +- .../unit/controllers/v1/test_webhooks.py | 441 +++-- .../v1/test_workflow_inspection.py | 77 +- st2api/tests/unit/test_validation_utils.py | 48 +- st2auth/dist_utils.py | 65 +- st2auth/setup.py | 28 +- st2auth/st2auth/__init__.py | 2 +- st2auth/st2auth/app.py | 46 +- st2auth/st2auth/backends/__init__.py | 21 +- st2auth/st2auth/backends/base.py | 12 +- st2auth/st2auth/backends/constants.py | 10 +- st2auth/st2auth/cmd/api.py | 48 +- st2auth/st2auth/config.py | 85 +- st2auth/st2auth/controllers/v1/auth.py | 42 +- st2auth/st2auth/controllers/v1/sso.py | 43 +- st2auth/st2auth/handlers.py | 182 +- st2auth/st2auth/sso/__init__.py | 24 +- st2auth/st2auth/sso/base.py | 8 +- st2auth/st2auth/sso/noop.py | 6 +- st2auth/st2auth/validation.py | 20 +- st2auth/st2auth/wsgi.py | 8 +- st2auth/tests/base.py | 1 - st2auth/tests/unit/controllers/v1/test_sso.py | 104 +- .../tests/unit/controllers/v1/test_token.py | 206 ++- st2auth/tests/unit/test_auth_backends.py | 2 +- st2auth/tests/unit/test_handlers.py | 166 +- st2auth/tests/unit/test_validation_utils.py | 47 +- st2client/dist_utils.py | 65 +- st2client/setup.py | 64 +- st2client/st2client/__init__.py | 2 +- st2client/st2client/base.py | 228 ++- st2client/st2client/client.py | 310 ++-- st2client/st2client/commands/__init__.py | 27 +- st2client/st2client/commands/action.py | 1604 ++++++++++------- st2client/st2client/commands/action_alias.py | 164 +- st2client/st2client/commands/auth.py | 417 +++-- st2client/st2client/commands/inquiry.py | 134 +- st2client/st2client/commands/keyvalue.py | 410 +++-- st2client/st2client/commands/pack.py | 508 ++++-- st2client/st2client/commands/policy.py | 84 +- st2client/st2client/commands/rbac.py | 172 +- st2client/st2client/commands/resource.py | 480 +++-- st2client/st2client/commands/rule.py | 166 +- .../st2client/commands/rule_enforcement.py | 194 +- st2client/st2client/commands/sensor.py | 66 +- .../st2client/commands/service_registry.py | 67 +- st2client/st2client/commands/timer.py | 35 +- st2client/st2client/commands/trace.py | 349 ++-- st2client/st2client/commands/trigger.py | 90 +- .../st2client/commands/triggerinstance.py | 194 +- st2client/st2client/commands/webhook.py | 40 +- st2client/st2client/commands/workflow.py | 45 +- st2client/st2client/config.py | 5 +- st2client/st2client/config_parser.py | 130 +- st2client/st2client/exceptions/base.py | 5 +- st2client/st2client/formatters/__init__.py | 4 +- st2client/st2client/formatters/doc.py | 25 +- st2client/st2client/formatters/execution.py | 74 +- st2client/st2client/formatters/table.py | 118 +- st2client/st2client/models/__init__.py | 24 +- st2client/st2client/models/action.py | 36 +- st2client/st2client/models/action_alias.py | 29 +- st2client/st2client/models/aliasexecution.py | 25 +- st2client/st2client/models/auth.py | 16 +- st2client/st2client/models/config.py | 16 +- st2client/st2client/models/core.py | 318 ++-- st2client/st2client/models/inquiry.py | 17 +- st2client/st2client/models/keyvalue.py | 10 +- st2client/st2client/models/pack.py | 10 +- st2client/st2client/models/policy.py | 14 +- st2client/st2client/models/rbac.py | 29 +- st2client/st2client/models/reactor.py | 56 +- .../st2client/models/service_registry.py | 35 +- st2client/st2client/models/timer.py | 8 +- st2client/st2client/models/trace.py | 10 +- st2client/st2client/models/webhook.py | 10 +- st2client/st2client/shell.py | 348 ++-- st2client/st2client/utils/color.py | 62 +- st2client/st2client/utils/date.py | 11 +- st2client/st2client/utils/httpclient.py | 50 +- st2client/st2client/utils/interactive.py | 218 +-- st2client/st2client/utils/jsutil.py | 13 +- st2client/st2client/utils/logging.py | 6 +- st2client/st2client/utils/misc.py | 4 +- st2client/st2client/utils/schema.py | 28 +- st2client/st2client/utils/strutil.py | 12 +- st2client/st2client/utils/terminal.py | 33 +- st2client/st2client/utils/types.py | 15 +- st2client/tests/base.py | 36 +- st2client/tests/fixtures/loader.py | 15 +- st2client/tests/unit/test_action.py | 716 ++++---- st2client/tests/unit/test_action_alias.py | 28 +- st2client/tests/unit/test_app.py | 18 +- st2client/tests/unit/test_auth.py | 472 ++--- st2client/tests/unit/test_client.py | 133 +- st2client/tests/unit/test_client_actions.py | 55 +- .../tests/unit/test_client_executions.py | 236 ++- .../tests/unit/test_command_actionrun.py | 207 ++- st2client/tests/unit/test_commands.py | 354 ++-- st2client/tests/unit/test_config_parser.py | 114 +- .../tests/unit/test_execution_tail_command.py | 437 ++--- st2client/tests/unit/test_formatters.py | 287 +-- st2client/tests/unit/test_inquiry.py | 274 +-- st2client/tests/unit/test_interactive.py | 372 ++-- st2client/tests/unit/test_keyvalue.py | 326 ++-- st2client/tests/unit/test_models.py | 275 ++- st2client/tests/unit/test_shell.py | 659 ++++--- st2client/tests/unit/test_ssl.py | 99 +- st2client/tests/unit/test_trace_commands.py | 237 ++- st2client/tests/unit/test_util_date.py | 24 +- st2client/tests/unit/test_util_json.py | 151 +- st2client/tests/unit/test_util_misc.py | 30 +- st2client/tests/unit/test_util_strutil.py | 8 +- st2client/tests/unit/test_util_terminal.py | 28 +- st2client/tests/unit/test_workflow.py | 87 +- ...grate-datastore-to-include-scope-secret.py | 26 +- .../v2.1/st2-migrate-datastore-scopes.py | 26 +- .../v3.1/st2-cleanup-policy-delayed.py | 10 +- st2common/bin/paramiko_ssh_evenlets_tester.py | 72 +- st2common/dist_utils.py | 65 +- st2common/setup.py | 62 +- st2common/st2common/__init__.py | 2 +- .../st2common/bootstrap/actionsregistrar.py | 120 +- .../st2common/bootstrap/aliasesregistrar.py | 103 +- st2common/st2common/bootstrap/base.py | 79 +- .../st2common/bootstrap/configsregistrar.py | 74 +- .../st2common/bootstrap/policiesregistrar.py | 91 +- .../st2common/bootstrap/rulesregistrar.py | 112 +- .../st2common/bootstrap/ruletypesregistrar.py | 37 +- .../st2common/bootstrap/runnersregistrar.py | 38 +- .../st2common/bootstrap/sensorsregistrar.py | 103 +- .../st2common/bootstrap/triggersregistrar.py | 90 +- st2common/st2common/callback/base.py | 3 +- st2common/st2common/cmd/download_pack.py | 50 +- st2common/st2common/cmd/generate_api_spec.py | 8 +- st2common/st2common/cmd/install_pack.py | 60 +- st2common/st2common/cmd/purge_executions.py | 43 +- .../st2common/cmd/purge_trigger_instances.py | 19 +- .../st2common/cmd/setup_pack_virtualenv.py | 48 +- st2common/st2common/cmd/validate_api_spec.py | 41 +- st2common/st2common/cmd/validate_config.py | 38 +- st2common/st2common/config.py | 797 ++++---- st2common/st2common/constants/action.py | 132 +- st2common/st2common/constants/api.py | 8 +- st2common/st2common/constants/auth.py | 30 +- .../st2common/constants/error_messages.py | 23 +- st2common/st2common/constants/exit_codes.py | 8 +- .../st2common/constants/garbage_collection.py | 8 +- st2common/st2common/constants/keyvalue.py | 57 +- st2common/st2common/constants/logging.py | 6 +- st2common/st2common/constants/meta.py | 9 +- st2common/st2common/constants/pack.py | 69 +- st2common/st2common/constants/policy.py | 9 +- .../st2common/constants/rule_enforcement.py | 13 +- st2common/st2common/constants/rules.py | 10 +- st2common/st2common/constants/runners.py | 54 +- st2common/st2common/constants/scheduler.py | 9 +- st2common/st2common/constants/secrets.py | 19 +- st2common/st2common/constants/sensors.py | 8 +- st2common/st2common/constants/system.py | 17 +- st2common/st2common/constants/timer.py | 9 +- st2common/st2common/constants/trace.py | 6 +- st2common/st2common/constants/triggers.py | 479 ++--- st2common/st2common/constants/types.py | 50 +- st2common/st2common/content/bootstrap.py | 220 +-- st2common/st2common/content/loader.py | 69 +- st2common/st2common/content/utils.py | 115 +- st2common/st2common/content/validators.py | 15 +- st2common/st2common/database_setup.py | 40 +- st2common/st2common/exceptions/__init__.py | 26 +- st2common/st2common/exceptions/action.py | 6 +- st2common/st2common/exceptions/actionalias.py | 4 +- st2common/st2common/exceptions/api.py | 3 +- st2common/st2common/exceptions/auth.py | 26 +- st2common/st2common/exceptions/connection.py | 3 + st2common/st2common/exceptions/db.py | 7 +- st2common/st2common/exceptions/inquiry.py | 17 +- st2common/st2common/exceptions/keyvalue.py | 6 +- st2common/st2common/exceptions/rbac.py | 53 +- st2common/st2common/exceptions/ssh.py | 4 +- st2common/st2common/exceptions/workflow.py | 38 +- .../st2common/expressions/functions/data.py | 32 +- .../expressions/functions/datastore.py | 12 +- .../st2common/expressions/functions/path.py | 5 +- .../st2common/expressions/functions/regex.py | 7 +- .../st2common/expressions/functions/time.py | 29 +- .../expressions/functions/version.py | 14 +- st2common/st2common/fields.py | 15 +- .../garbage_collection/executions.py | 147 +- .../st2common/garbage_collection/inquiries.py | 25 +- .../garbage_collection/trigger_instances.py | 34 +- st2common/st2common/log.py | 93 +- st2common/st2common/logging/filters.py | 15 +- st2common/st2common/logging/formatters.py | 71 +- st2common/st2common/logging/handlers.py | 33 +- st2common/st2common/logging/misc.py | 46 +- st2common/st2common/metrics/base.py | 32 +- .../st2common/metrics/drivers/echo_driver.py | 16 +- .../st2common/metrics/drivers/noop_driver.py | 4 +- .../metrics/drivers/statsd_driver.py | 51 +- st2common/st2common/metrics/utils.py | 9 +- st2common/st2common/middleware/cors.py | 45 +- .../st2common/middleware/error_handling.py | 34 +- .../st2common/middleware/instrumentation.py | 47 +- st2common/st2common/middleware/logging.py | 52 +- st2common/st2common/middleware/streaming.py | 8 +- st2common/st2common/models/api/action.py | 505 +++--- .../st2common/models/api/actionrunner.py | 13 +- st2common/st2common/models/api/auth.py | 134 +- st2common/st2common/models/api/base.py | 31 +- st2common/st2common/models/api/execution.py | 132 +- st2common/st2common/models/api/inquiry.py | 131 +- st2common/st2common/models/api/keyvalue.py | 188 +- .../st2common/models/api/notification.py | 81 +- st2common/st2common/models/api/pack.py | 355 ++-- st2common/st2common/models/api/policy.py | 153 +- st2common/st2common/models/api/rbac.py | 377 ++-- st2common/st2common/models/api/rule.py | 258 ++- .../st2common/models/api/rule_enforcement.py | 113 +- st2common/st2common/models/api/sensor.py | 61 +- st2common/st2common/models/api/tag.py | 12 +- st2common/st2common/models/api/trace.py | 188 +- st2common/st2common/models/api/trigger.py | 196 +- st2common/st2common/models/api/webhook.py | 15 +- st2common/st2common/models/base.py | 4 +- st2common/st2common/models/db/__init__.py | 386 ++-- st2common/st2common/models/db/action.py | 80 +- st2common/st2common/models/db/actionalias.py | 65 +- st2common/st2common/models/db/auth.py | 55 +- st2common/st2common/models/db/execution.py | 121 +- .../st2common/models/db/execution_queue.py | 48 +- .../st2common/models/db/executionstate.py | 23 +- st2common/st2common/models/db/keyvalue.py | 20 +- st2common/st2common/models/db/liveaction.py | 67 +- st2common/st2common/models/db/marker.py | 14 +- st2common/st2common/models/db/notification.py | 30 +- st2common/st2common/models/db/pack.py | 36 +- st2common/st2common/models/db/policy.py | 93 +- st2common/st2common/models/db/rbac.py | 52 +- st2common/st2common/models/db/reactor.py | 19 +- st2common/st2common/models/db/rule.py | 78 +- .../st2common/models/db/rule_enforcement.py | 56 +- st2common/st2common/models/db/runner.py | 38 +- st2common/st2common/models/db/sensor.py | 33 +- st2common/st2common/models/db/stormbase.py | 90 +- st2common/st2common/models/db/timer.py | 4 +- st2common/st2common/models/db/trace.py | 70 +- st2common/st2common/models/db/trigger.py | 65 +- st2common/st2common/models/db/webhook.py | 4 +- st2common/st2common/models/db/workflow.py | 40 +- st2common/st2common/models/system/action.py | 387 ++-- .../st2common/models/system/actionchain.py | 93 +- st2common/st2common/models/system/common.py | 26 +- st2common/st2common/models/system/keyvalue.py | 6 +- .../models/system/paramiko_command_action.py | 26 +- .../models/system/paramiko_script_action.py | 52 +- .../models/utils/action_alias_utils.py | 137 +- .../models/utils/action_param_utils.py | 65 +- st2common/st2common/models/utils/profiling.py | 75 +- .../models/utils/sensor_type_utils.py | 100 +- st2common/st2common/operators.py | 162 +- st2common/st2common/persistence/action.py | 12 +- st2common/st2common/persistence/auth.py | 34 +- st2common/st2common/persistence/base.py | 104 +- st2common/st2common/persistence/cleanup.py | 57 +- st2common/st2common/persistence/db_init.py | 58 +- st2common/st2common/persistence/execution.py | 4 +- .../st2common/persistence/execution_queue.py | 4 +- .../st2common/persistence/executionstate.py | 8 +- st2common/st2common/persistence/keyvalue.py | 70 +- st2common/st2common/persistence/liveaction.py | 4 +- st2common/st2common/persistence/marker.py | 4 +- st2common/st2common/persistence/pack.py | 6 +- st2common/st2common/persistence/policy.py | 12 +- st2common/st2common/persistence/rbac.py | 7 +- st2common/st2common/persistence/reactor.py | 10 +- st2common/st2common/persistence/rule.py | 2 +- st2common/st2common/persistence/runner.py | 2 +- st2common/st2common/persistence/sensor.py | 4 +- st2common/st2common/persistence/trace.py | 10 +- st2common/st2common/persistence/trigger.py | 20 +- st2common/st2common/persistence/workflow.py | 5 +- st2common/st2common/policies/__init__.py | 5 +- st2common/st2common/policies/base.py | 13 +- st2common/st2common/policies/concurrency.py | 15 +- st2common/st2common/rbac/backends/__init__.py | 12 +- st2common/st2common/rbac/backends/base.py | 20 +- st2common/st2common/rbac/backends/noop.py | 16 +- st2common/st2common/rbac/migrations.py | 8 +- st2common/st2common/rbac/types.py | 598 +++--- st2common/st2common/router.py | 481 +++-- st2common/st2common/runners/__init__.py | 9 +- st2common/st2common/runners/base.py | 226 +-- st2common/st2common/runners/base_action.py | 19 +- st2common/st2common/runners/parallel_ssh.py | 229 ++- st2common/st2common/runners/paramiko_ssh.py | 326 ++-- .../st2common/runners/paramiko_ssh_runner.py | 158 +- st2common/st2common/runners/utils.py | 75 +- st2common/st2common/script_setup.py | 22 +- st2common/st2common/service_setup.py | 85 +- st2common/st2common/services/access.py | 32 +- st2common/st2common/services/action.py | 213 ++- st2common/st2common/services/config.py | 6 +- st2common/st2common/services/coordination.py | 57 +- st2common/st2common/services/datastore.py | 61 +- st2common/st2common/services/executions.py | 141 +- st2common/st2common/services/inquiry.py | 36 +- st2common/st2common/services/keyvalues.py | 69 +- st2common/st2common/services/packs.py | 98 +- st2common/st2common/services/policies.py | 43 +- st2common/st2common/services/queries.py | 8 +- st2common/st2common/services/rules.py | 15 +- .../st2common/services/sensor_watcher.py | 57 +- st2common/st2common/services/trace.py | 161 +- .../st2common/services/trigger_dispatcher.py | 54 +- st2common/st2common/services/triggers.py | 293 +-- .../st2common/services/triggerwatcher.py | 73 +- st2common/st2common/services/workflows.py | 748 ++++---- st2common/st2common/signal_handlers.py | 2 +- st2common/st2common/stream/listener.py | 124 +- st2common/st2common/transport/__init__.py | 16 +- .../transport/actionexecutionstate.py | 12 +- st2common/st2common/transport/announcement.py | 30 +- st2common/st2common/transport/bootstrap.py | 7 +- .../st2common/transport/bootstrap_utils.py | 93 +- .../transport/connection_retry_wrapper.py | 41 +- st2common/st2common/transport/consumers.py | 76 +- st2common/st2common/transport/execution.py | 35 +- st2common/st2common/transport/liveaction.py | 16 +- st2common/st2common/transport/publishers.py | 57 +- st2common/st2common/transport/queues.py | 109 +- st2common/st2common/transport/reactor.py | 32 +- st2common/st2common/transport/utils.py | 68 +- st2common/st2common/transport/workflow.py | 26 +- st2common/st2common/triggers.py | 68 +- st2common/st2common/util/action_db.py | 216 ++- .../st2common/util/actionalias_helpstring.py | 24 +- .../st2common/util/actionalias_matching.py | 113 +- st2common/st2common/util/api.py | 8 +- st2common/st2common/util/argument_parser.py | 40 +- st2common/st2common/util/auth.py | 33 +- st2common/st2common/util/casts.py | 12 +- st2common/st2common/util/compat.py | 13 +- st2common/st2common/util/concurrency.py | 102 +- st2common/st2common/util/config_loader.py | 83 +- st2common/st2common/util/config_parser.py | 14 +- st2common/st2common/util/crypto.py | 156 +- st2common/st2common/util/date.py | 11 +- st2common/st2common/util/debugging.py | 6 +- st2common/st2common/util/deprecation.py | 8 +- st2common/st2common/util/driver_loader.py | 17 +- st2common/st2common/util/enum.py | 11 +- st2common/st2common/util/file_system.py | 11 +- st2common/st2common/util/green/shell.py | 103 +- st2common/st2common/util/greenpooldispatch.py | 39 +- st2common/st2common/util/gunicorn_workers.py | 6 +- st2common/st2common/util/hash.py | 6 +- st2common/st2common/util/http.py | 27 +- st2common/st2common/util/ip_utils.py | 30 +- st2common/st2common/util/isotime.py | 30 +- st2common/st2common/util/jinja.py | 93 +- st2common/st2common/util/jsonify.py | 26 +- st2common/st2common/util/keyvalue.py | 43 +- st2common/st2common/util/loader.py | 71 +- st2common/st2common/util/misc.py | 48 +- st2common/st2common/util/mongoescape.py | 19 +- st2common/st2common/util/monkey_patch.py | 24 +- st2common/st2common/util/output_schema.py | 33 +- st2common/st2common/util/pack.py | 128 +- st2common/st2common/util/pack_management.py | 267 +-- st2common/st2common/util/param.py | 193 +- st2common/st2common/util/payload.py | 5 +- st2common/st2common/util/queues.py | 8 +- st2common/st2common/util/reference.py | 25 +- st2common/st2common/util/sandboxing.py | 86 +- st2common/st2common/util/schema/__init__.py | 327 ++-- st2common/st2common/util/secrets.py | 38 +- st2common/st2common/util/service.py | 4 +- st2common/st2common/util/shell.py | 44 +- st2common/st2common/util/spec_loader.py | 39 +- st2common/st2common/util/system_info.py | 14 +- st2common/st2common/util/templating.py | 12 +- st2common/st2common/util/types.py | 15 +- st2common/st2common/util/uid.py | 8 +- st2common/st2common/util/ujson.py | 4 +- st2common/st2common/util/url.py | 6 +- st2common/st2common/util/versioning.py | 17 +- st2common/st2common/util/virtualenvs.py | 188 +- st2common/st2common/util/wsgi.py | 8 +- st2common/st2common/validators/api/action.py | 75 +- st2common/st2common/validators/api/misc.py | 8 +- st2common/st2common/validators/api/reactor.py | 93 +- .../st2common/validators/workflow/base.py | 1 - .../tests/fixtures/mock_runner/mock_runner.py | 11 +- st2common/tests/fixtures/version_file.py | 2 +- .../integration/test_rabbitmq_ssl_listener.py | 183 +- .../test_register_content_script.py | 112 +- .../test_service_setup_log_level_filtering.py | 83 +- st2common/tests/unit/base.py | 39 +- st2common/tests/unit/services/test_access.py | 26 +- st2common/tests/unit/services/test_action.py | 384 ++-- .../tests/unit/services/test_keyvalue.py | 23 +- st2common/tests/unit/services/test_policy.py | 51 +- .../unit/services/test_synchronization.py | 8 +- st2common/tests/unit/services/test_trace.py | 635 ++++--- .../test_trace_injection_action_services.py | 46 +- .../tests/unit/services/test_workflow.py | 263 +-- .../services/test_workflow_cancellation.py | 52 +- .../test_workflow_identify_orphans.py | 189 +- .../unit/services/test_workflow_rerun.py | 214 ++- .../services/test_workflow_service_retries.py | 205 ++- .../tests/unit/test_action_alias_utils.py | 267 +-- .../tests/unit/test_action_api_validator.py | 108 +- st2common/tests/unit/test_action_db_utils.py | 506 +++--- .../tests/unit/test_action_param_utils.py | 106 +- .../tests/unit/test_action_system_models.py | 81 +- .../tests/unit/test_actionchain_schema.py | 44 +- st2common/tests/unit/test_aliasesregistrar.py | 14 +- .../tests/unit/test_api_model_validation.py | 245 +-- st2common/tests/unit/test_casts.py | 16 +- st2common/tests/unit/test_config_loader.py | 503 +++--- st2common/tests/unit/test_config_parser.py | 14 +- .../tests/unit/test_configs_registrar.py | 186 +- .../unit/test_connection_retry_wrapper.py | 13 +- st2common/tests/unit/test_content_loader.py | 67 +- st2common/tests/unit/test_content_utils.py | 277 +-- st2common/tests/unit/test_crypto_utils.py | 166 +- st2common/tests/unit/test_datastore.py | 103 +- st2common/tests/unit/test_date_utils.py | 30 +- st2common/tests/unit/test_db.py | 682 ++++--- st2common/tests/unit/test_db_action_state.py | 6 +- st2common/tests/unit/test_db_auth.py | 35 +- st2common/tests/unit/test_db_base.py | 59 +- st2common/tests/unit/test_db_execution.py | 144 +- st2common/tests/unit/test_db_fields.py | 10 +- st2common/tests/unit/test_db_liveaction.py | 89 +- st2common/tests/unit/test_db_marker.py | 11 +- st2common/tests/unit/test_db_model_uids.py | 78 +- st2common/tests/unit/test_db_pack.py | 24 +- st2common/tests/unit/test_db_policy.py | 203 ++- st2common/tests/unit/test_db_rbac.py | 46 +- .../tests/unit/test_db_rule_enforcement.py | 85 +- st2common/tests/unit/test_db_task.py | 55 +- st2common/tests/unit/test_db_trace.py | 149 +- st2common/tests/unit/test_db_uid_mixin.py | 41 +- st2common/tests/unit/test_db_workflow.py | 35 +- st2common/tests/unit/test_dist_utils.py | 94 +- .../tests/unit/test_exceptions_workflow.py | 5 +- st2common/tests/unit/test_executions.py | 269 +-- st2common/tests/unit/test_executions_util.py | 254 +-- .../tests/unit/test_greenpooldispatch.py | 17 +- st2common/tests/unit/test_hash.py | 5 +- st2common/tests/unit/test_ip_utils.py | 43 +- st2common/tests/unit/test_isotime_utils.py | 102 +- .../unit/test_jinja_render_crypto_filters.py | 103 +- .../unit/test_jinja_render_data_filters.py | 59 +- .../test_jinja_render_json_escape_filters.py | 41 +- ...est_jinja_render_jsonpath_query_filters.py | 59 +- .../unit/test_jinja_render_path_filters.py | 29 +- .../unit/test_jinja_render_regex_filters.py | 47 +- .../unit/test_jinja_render_time_filters.py | 16 +- .../unit/test_jinja_render_version_filters.py | 97 +- st2common/tests/unit/test_json_schema.py | 478 ++--- st2common/tests/unit/test_jsonify.py | 19 +- st2common/tests/unit/test_keyvalue_lookup.py | 151 +- .../tests/unit/test_keyvalue_system_model.py | 22 +- st2common/tests/unit/test_logger.py | 382 ++-- st2common/tests/unit/test_logging.py | 22 +- .../tests/unit/test_logging_middleware.py | 56 +- st2common/tests/unit/test_metrics.py | 214 ++- st2common/tests/unit/test_misc_utils.py | 118 +- .../tests/unit/test_model_utils_profiling.py | 22 +- st2common/tests/unit/test_mongoescape.py | 90 +- .../tests/unit/test_notification_helper.py | 133 +- st2common/tests/unit/test_operators.py | 1329 +++++++------- ...st_pack_action_alias_unit_testing_utils.py | 152 +- st2common/tests/unit/test_pack_management.py | 24 +- st2common/tests/unit/test_param_utils.py | 1100 ++++++----- .../test_paramiko_command_action_model.py | 107 +- .../unit/test_paramiko_script_action_model.py | 166 +- st2common/tests/unit/test_persistence.py | 112 +- .../unit/test_persistence_change_revision.py | 17 +- st2common/tests/unit/test_plugin_loader.py | 65 +- st2common/tests/unit/test_policies.py | 46 +- .../tests/unit/test_policies_registrar.py | 109 +- st2common/tests/unit/test_purge_executions.py | 287 +-- .../unit/test_purge_trigger_instances.py | 43 +- st2common/tests/unit/test_queue_consumer.py | 31 +- st2common/tests/unit/test_queue_utils.py | 53 +- st2common/tests/unit/test_rbac_types.py | 400 ++-- st2common/tests/unit/test_reference.py | 33 +- .../unit/test_register_internal_trigger.py | 5 +- .../tests/unit/test_resource_reference.py | 71 +- .../tests/unit/test_resource_registrar.py | 182 +- st2common/tests/unit/test_runners_base.py | 11 +- st2common/tests/unit/test_runners_utils.py | 28 +- .../tests/unit/test_sensor_type_utils.py | 60 +- st2common/tests/unit/test_sensor_watcher.py | 37 +- st2common/tests/unit/test_service_setup.py | 118 +- .../unit/test_shell_action_system_model.py | 494 ++--- st2common/tests/unit/test_state_publisher.py | 17 +- st2common/tests/unit/test_stream_generator.py | 21 +- st2common/tests/unit/test_system_info.py | 5 +- st2common/tests/unit/test_tags.py | 54 +- .../tests/unit/test_time_jinja_filters.py | 22 +- st2common/tests/unit/test_transport.py | 66 +- st2common/tests/unit/test_trigger_services.py | 247 +-- .../tests/unit/test_triggers_registrar.py | 22 +- .../tests/unit/test_unit_testing_mocks.py | 71 +- .../unit/test_util_actionalias_helpstrings.py | 161 +- .../unit/test_util_actionalias_matching.py | 142 +- st2common/tests/unit/test_util_api.py | 41 +- st2common/tests/unit/test_util_compat.py | 12 +- st2common/tests/unit/test_util_db.py | 81 +- st2common/tests/unit/test_util_file_system.py | 30 +- st2common/tests/unit/test_util_http.py | 20 +- st2common/tests/unit/test_util_jinja.py | 106 +- st2common/tests/unit/test_util_keyvalue.py | 108 +- .../tests/unit/test_util_output_schema.py | 60 +- st2common/tests/unit/test_util_pack.py | 44 +- st2common/tests/unit/test_util_payload.py | 28 +- st2common/tests/unit/test_util_sandboxing.py | 213 ++- st2common/tests/unit/test_util_secrets.py | 1140 +++++------- st2common/tests/unit/test_util_shell.py | 32 +- st2common/tests/unit/test_util_templating.py | 36 +- st2common/tests/unit/test_util_types.py | 4 +- st2common/tests/unit/test_util_url.py | 16 +- st2common/tests/unit/test_versioning_utils.py | 54 +- st2common/tests/unit/test_virtualenvs.py | 341 ++-- st2exporter/dist_utils.py | 65 +- st2exporter/setup.py | 22 +- .../st2exporter/cmd/st2exporter_starter.py | 20 +- st2exporter/st2exporter/config.py | 23 +- st2exporter/st2exporter/exporter/dumper.py | 67 +- .../st2exporter/exporter/file_writer.py | 12 +- .../st2exporter/exporter/json_converter.py | 7 +- st2exporter/st2exporter/worker.py | 52 +- .../integration/test_dumper_integration.py | 63 +- .../tests/integration/test_export_worker.py | 55 +- st2exporter/tests/unit/test_dumper.py | 133 +- st2exporter/tests/unit/test_json_converter.py | 29 +- st2reactor/dist_utils.py | 65 +- st2reactor/setup.py | 32 +- st2reactor/st2reactor/__init__.py | 2 +- st2reactor/st2reactor/cmd/garbagecollector.py | 31 +- st2reactor/st2reactor/cmd/rule_tester.py | 40 +- st2reactor/st2reactor/cmd/rulesengine.py | 28 +- st2reactor/st2reactor/cmd/sensormanager.py | 42 +- st2reactor/st2reactor/cmd/timersengine.py | 29 +- st2reactor/st2reactor/cmd/trigger_re_fire.py | 58 +- .../st2reactor/container/hash_partitioner.py | 31 +- st2reactor/st2reactor/container/manager.py | 90 +- .../container/partitioner_lookup.py | 35 +- .../st2reactor/container/partitioners.py | 49 +- .../st2reactor/container/process_container.py | 193 +- .../st2reactor/container/sensor_wrapper.py | 237 ++- st2reactor/st2reactor/container/utils.py | 10 +- .../st2reactor/garbage_collector/base.py | 143 +- .../st2reactor/garbage_collector/config.py | 65 +- st2reactor/st2reactor/rules/config.py | 15 +- st2reactor/st2reactor/rules/enforcer.py | 146 +- st2reactor/st2reactor/rules/engine.py | 42 +- st2reactor/st2reactor/rules/filter.py | 157 +- st2reactor/st2reactor/rules/matcher.py | 57 +- st2reactor/st2reactor/rules/tester.py | 103 +- st2reactor/st2reactor/rules/worker.py | 46 +- st2reactor/st2reactor/sensor/base.py | 9 +- st2reactor/st2reactor/sensor/config.py | 63 +- st2reactor/st2reactor/timer/base.py | 111 +- st2reactor/st2reactor/timer/config.py | 7 +- .../integration/test_garbage_collector.py | 218 ++- .../tests/integration/test_rules_engine.py | 56 +- .../integration/test_sensor_container.py | 90 +- .../tests/integration/test_sensor_watcher.py | 27 +- st2reactor/tests/unit/test_container_utils.py | 63 +- st2reactor/tests/unit/test_enforce.py | 535 +++--- st2reactor/tests/unit/test_filter.py | 325 ++-- .../tests/unit/test_garbage_collector.py | 58 +- .../tests/unit/test_hash_partitioner.py | 53 +- st2reactor/tests/unit/test_partitioners.py | 82 +- .../tests/unit/test_process_container.py | 170 +- st2reactor/tests/unit/test_rule_engine.py | 208 +-- st2reactor/tests/unit/test_rule_matcher.py | 310 ++-- .../unit/test_sensor_and_rule_registration.py | 81 +- st2reactor/tests/unit/test_sensor_service.py | 203 ++- st2reactor/tests/unit/test_sensor_wrapper.py | 172 +- st2reactor/tests/unit/test_tester.py | 90 +- st2reactor/tests/unit/test_timer.py | 26 +- st2stream/dist_utils.py | 65 +- st2stream/setup.py | 22 +- st2stream/st2stream/__init__.py | 2 +- st2stream/st2stream/app.py | 47 +- st2stream/st2stream/cmd/__init__.py | 2 +- st2stream/st2stream/cmd/api.py | 49 +- st2stream/st2stream/config.py | 27 +- .../st2stream/controllers/v1/executions.py | 78 +- st2stream/st2stream/controllers/v1/root.py | 4 +- st2stream/st2stream/controllers/v1/stream.py | 70 +- st2stream/st2stream/signal_handlers.py | 4 +- st2stream/st2stream/wsgi.py | 7 +- st2stream/tests/unit/controllers/v1/base.py | 4 +- .../tests/unit/controllers/v1/test_stream.py | 202 +-- .../v1/test_stream_execution_output.py | 138 +- st2tests/dist_utils.py | 65 +- st2tests/integration/orquesta/base.py | 55 +- .../integration/orquesta/test_performance.py | 23 +- st2tests/integration/orquesta/test_wiring.py | 112 +- .../orquesta/test_wiring_cancel.py | 56 +- .../orquesta/test_wiring_data_flow.py | 42 +- .../integration/orquesta/test_wiring_delay.py | 25 +- .../orquesta/test_wiring_error_handling.py | 349 ++-- .../orquesta/test_wiring_functions.py | 211 +-- .../orquesta/test_wiring_functions_st2kv.py | 150 +- .../orquesta/test_wiring_functions_task.py | 81 +- .../orquesta/test_wiring_inquiry.py | 57 +- .../orquesta/test_wiring_pause_and_resume.py | 122 +- .../integration/orquesta/test_wiring_rerun.py | 90 +- .../orquesta/test_wiring_task_retry.py | 22 +- .../orquesta/test_wiring_with_items.py | 89 +- st2tests/setup.py | 20 +- st2tests/st2tests/__init__.py | 12 +- st2tests/st2tests/action_aliases.py | 45 +- st2tests/st2tests/actions.py | 14 +- st2tests/st2tests/api.py | 173 +- st2tests/st2tests/base.py | 277 +-- st2tests/st2tests/config.py | 408 +++-- .../fixtures/history_views/__init__.py | 4 +- .../localrunner_pack/actions/text_gen.py | 8 +- .../actions/render_config_context.py | 1 - .../dummy_pack_9/actions/invalid_syntax.py | 4 +- .../fixtures/packs/executions/__init__.py | 8 +- .../test_async_runner/test_async_runner.py | 16 +- .../test_polling_async_runner.py | 16 +- .../actions/get_library_path.py | 4 +- st2tests/st2tests/fixturesloader.py | 234 ++- st2tests/st2tests/http.py | 1 - st2tests/st2tests/mocks/action.py | 13 +- st2tests/st2tests/mocks/auth.py | 18 +- st2tests/st2tests/mocks/datastore.py | 29 +- st2tests/st2tests/mocks/execution.py | 6 +- st2tests/st2tests/mocks/liveaction.py | 7 +- .../st2tests/mocks/runners/async_runner.py | 16 +- .../mocks/runners/polling_async_runner.py | 16 +- st2tests/st2tests/mocks/runners/runner.py | 30 +- st2tests/st2tests/mocks/sensor.py | 24 +- st2tests/st2tests/mocks/workflow.py | 5 +- st2tests/st2tests/pack_resource.py | 10 +- st2tests/st2tests/policies/concurrency.py | 12 +- st2tests/st2tests/policies/mock_exception.py | 3 +- .../packs/pythonactions/actions/echoer.py | 2 +- .../pythonactions/actions/non_simple_type.py | 8 +- .../packs/pythonactions/actions/pascal_row.py | 44 +- .../actions/print_config_item_doesnt_exist.py | 4 +- .../actions/print_to_stdout_and_stderr.py | 4 +- .../pythonactions/actions/python_paths.py | 4 +- .../packs/pythonactions/actions/test.py | 2 +- st2tests/st2tests/sensors.py | 20 +- .../checks/actions/checks/check_loadavg.py | 28 +- tools/config_gen.py | 130 +- tools/diff-db-disk.py | 200 +- tools/direct_queue_publisher.py | 25 +- tools/enumerate-runners.py | 9 +- tools/json2yaml.py | 39 +- tools/list_group_members.py | 26 +- tools/log_watcher.py | 90 +- tools/migrate_messaging_setup.py | 14 +- tools/migrate_rules_to_include_pack.py | 49 +- .../migrate_triggers_to_include_ref_count.py | 7 +- tools/queue_consumer.py | 39 +- tools/queue_producer.py | 20 +- tools/st2-analyze-links.py | 83 +- tools/st2-inject-trigger-instances.py | 90 +- tools/visualize_action_chain.py | 101 +- 937 files changed, 54139 insertions(+), 42097 deletions(-) diff --git a/contrib/chatops/actions/format_execution_result.py b/contrib/chatops/actions/format_execution_result.py index 8790ae4ae7..d6830df004 100755 --- a/contrib/chatops/actions/format_execution_result.py +++ b/contrib/chatops/actions/format_execution_result.py @@ -23,51 +23,50 @@ class FormatResultAction(Action): def __init__(self, config=None, action_service=None): - super(FormatResultAction, self).__init__(config=config, action_service=action_service) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + super(FormatResultAction, self).__init__( + config=config, action_service=action_service + ) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) self.jinja = jinja_utils.get_jinja_environment(allow_undefined=True) - self.jinja.tests['in'] = lambda item, list: item in list + self.jinja.tests["in"] = lambda item, list: item in list path = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(path, 'templates/default.j2'), 'r') as f: + with open(os.path.join(path, "templates/default.j2"), "r") as f: self.default_template = f.read() def run(self, execution_id): execution = self._get_execution(execution_id) - context = { - 'six': six, - 'execution': execution - } + context = {"six": six, "execution": execution} template = self.default_template result = {"enabled": True} - alias_id = execution['context'].get('action_alias_ref', {}).get('id', None) + alias_id = execution["context"].get("action_alias_ref", {}).get("id", None) if alias_id: - alias = self.client.managers['ActionAlias'].get_by_id(alias_id) + alias = self.client.managers["ActionAlias"].get_by_id(alias_id) - context.update({ - 'alias': alias - }) + context.update({"alias": alias}) - result_params = getattr(alias, 'result', None) + result_params = getattr(alias, "result", None) if result_params: - if not result_params.get('enabled', True): + if not result_params.get("enabled", True): result["enabled"] = False else: - if 'format' in alias.result: - template = alias.result['format'] - if 'extra' in alias.result: - result['extra'] = jinja_utils.render_values(alias.result['extra'], context) + if "format" in alias.result: + template = alias.result["format"] + if "extra" in alias.result: + result["extra"] = jinja_utils.render_values( + alias.result["extra"], context + ) - result['message'] = self.jinja.from_string(template).render(context) + result["message"] = self.jinja.from_string(template).render(context) return result def _get_execution(self, execution_id): if not execution_id: - raise ValueError('Invalid execution_id provided.') + raise ValueError("Invalid execution_id provided.") execution = self.client.liveactions.get_by_id(id=execution_id) if not execution: return None diff --git a/contrib/chatops/actions/match.py b/contrib/chatops/actions/match.py index 46dac1ff64..7ee2154b42 100644 --- a/contrib/chatops/actions/match.py +++ b/contrib/chatops/actions/match.py @@ -23,23 +23,16 @@ class MatchAction(Action): def __init__(self, config=None): super(MatchAction, self).__init__(config=config) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) def run(self, text): alias_match = ActionAliasMatch() alias_match.command = text - matches = self.client.managers['ActionAlias'].match(alias_match) - return { - 'alias': _format_match(matches[0]), - 'representation': matches[1] - } + matches = self.client.managers["ActionAlias"].match(alias_match) + return {"alias": _format_match(matches[0]), "representation": matches[1]} def _format_match(match): - return { - 'name': match.name, - 'pack': match.pack, - 'action_ref': match.action_ref - } + return {"name": match.name, "pack": match.pack, "action_ref": match.action_ref} diff --git a/contrib/chatops/actions/match_and_execute.py b/contrib/chatops/actions/match_and_execute.py index 11388e599b..5e90080f03 100644 --- a/contrib/chatops/actions/match_and_execute.py +++ b/contrib/chatops/actions/match_and_execute.py @@ -19,25 +19,26 @@ from st2common.runners.base_action import Action from st2client.models.action_alias import ActionAliasMatch from st2client.models.aliasexecution import ActionAliasExecution -from st2client.commands.action import (LIVEACTION_STATUS_REQUESTED, - LIVEACTION_STATUS_SCHEDULED, - LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING) +from st2client.commands.action import ( + LIVEACTION_STATUS_REQUESTED, + LIVEACTION_STATUS_SCHEDULED, + LIVEACTION_STATUS_RUNNING, + LIVEACTION_STATUS_CANCELING, +) from st2client.client import Client class ExecuteActionAliasAction(Action): def __init__(self, config=None): super(ExecuteActionAliasAction, self).__init__(config=config) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) def run(self, text, source_channel=None, user=None): alias_match = ActionAliasMatch() alias_match.command = text - alias, representation = self.client.managers['ActionAlias'].match( - alias_match) + alias, representation = self.client.managers["ActionAlias"].match(alias_match) execution = ActionAliasExecution() execution.name = alias.name @@ -48,20 +49,20 @@ def run(self, text, source_channel=None, user=None): execution.notification_route = None execution.user = user - action_exec_mgr = self.client.managers['ActionAliasExecution'] + action_exec_mgr = self.client.managers["ActionAliasExecution"] execution = action_exec_mgr.create(execution) - self._wait_execution_to_finish(execution.execution['id']) - return execution.execution['id'] + self._wait_execution_to_finish(execution.execution["id"]) + return execution.execution["id"] def _wait_execution_to_finish(self, execution_id): pending_statuses = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING + LIVEACTION_STATUS_CANCELING, ] - action_exec_mgr = self.client.managers['LiveAction'] + action_exec_mgr = self.client.managers["LiveAction"] execution = action_exec_mgr.get_by_id(execution_id) while execution.status in pending_statuses: time.sleep(1) diff --git a/contrib/chatops/tests/test_format_result.py b/contrib/chatops/tests/test_format_result.py index e700af7454..05114cb361 100644 --- a/contrib/chatops/tests/test_format_result.py +++ b/contrib/chatops/tests/test_format_result.py @@ -20,9 +20,7 @@ from format_execution_result import FormatResultAction -__all__ = [ - 'FormatResultActionTestCase' -] +__all__ = ["FormatResultActionTestCase"] class FormatResultActionTestCase(BaseActionTestCase): @@ -30,47 +28,45 @@ class FormatResultActionTestCase(BaseActionTestCase): def test_rendering_works_remote_shell_cmd(self): remote_shell_cmd_execution_model = json.loads( - self.get_fixture_content('remote_cmd_execution.json') + self.get_fixture_content("remote_cmd_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=remote_shell_cmd_execution_model ) - result = action.run(execution_id='57967f9355fc8c19a96d9e4f') + result = action.run(execution_id="57967f9355fc8c19a96d9e4f") self.assertTrue(result) - self.assertIn('web_url', result['message']) - self.assertIn('Took 2s to complete', result['message']) + self.assertIn("web_url", result["message"]) + self.assertIn("Took 2s to complete", result["message"]) def test_rendering_local_shell_cmd(self): local_shell_cmd_execution_model = json.loads( - self.get_fixture_content('local_cmd_execution.json') + self.get_fixture_content("local_cmd_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=local_shell_cmd_execution_model ) - self.assertTrue(action.run(execution_id='5799522f55fc8c2d33ac03e0')) + self.assertTrue(action.run(execution_id="5799522f55fc8c2d33ac03e0")) def test_rendering_http_request(self): http_execution_model = json.loads( - self.get_fixture_content('http_execution.json') + self.get_fixture_content("http_execution.json") ) action = self.get_action_instance() - action._get_execution = mock.MagicMock( - return_value=http_execution_model - ) - self.assertTrue(action.run(execution_id='579955f055fc8c2d33ac03e3')) + action._get_execution = mock.MagicMock(return_value=http_execution_model) + self.assertTrue(action.run(execution_id="579955f055fc8c2d33ac03e3")) def test_rendering_python_action(self): python_action_execution_model = json.loads( - self.get_fixture_content('python_action_execution.json') + self.get_fixture_content("python_action_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=python_action_execution_model ) - self.assertTrue(action.run(execution_id='5799572a55fc8c2d33ac03ec')) + self.assertTrue(action.run(execution_id="5799572a55fc8c2d33ac03ec")) diff --git a/contrib/core/actions/generate_uuid.py b/contrib/core/actions/generate_uuid.py index 972b7cb552..88d8125549 100644 --- a/contrib/core/actions/generate_uuid.py +++ b/contrib/core/actions/generate_uuid.py @@ -18,16 +18,14 @@ from st2common.runners.base_action import Action -__all__ = [ - 'GenerateUUID' -] +__all__ = ["GenerateUUID"] class GenerateUUID(Action): def run(self, uuid_type): - if uuid_type == 'uuid1': + if uuid_type == "uuid1": return str(uuid.uuid1()) - elif uuid_type == 'uuid4': + elif uuid_type == "uuid4": return str(uuid.uuid4()) else: raise ValueError("Unknown uuid_type. Only uuid1 and uuid4 are supported") diff --git a/contrib/core/actions/inject_trigger.py b/contrib/core/actions/inject_trigger.py index 706e2165db..a6b2e68317 100644 --- a/contrib/core/actions/inject_trigger.py +++ b/contrib/core/actions/inject_trigger.py @@ -17,9 +17,7 @@ from st2common.runners.base_action import Action -__all__ = [ - 'InjectTriggerAction' -] +__all__ = ["InjectTriggerAction"] class InjectTriggerAction(Action): @@ -34,8 +32,11 @@ def run(self, trigger, payload=None, trace_tag=None): # results in a TriggerInstanceDB database object creation or not. The object is created # inside rulesengine service and could fail due to the user providing an invalid trigger # reference or similar. - self.logger.debug('Injecting trigger "%s" with payload="%s"' % (trigger, str(payload))) - result = client.webhooks.post_generic_webhook(trigger=trigger, payload=payload, - trace_tag=trace_tag) + self.logger.debug( + 'Injecting trigger "%s" with payload="%s"' % (trigger, str(payload)) + ) + result = client.webhooks.post_generic_webhook( + trigger=trigger, payload=payload, trace_tag=trace_tag + ) return result diff --git a/contrib/core/actions/pause.py b/contrib/core/actions/pause.py index 99b9ed9e9b..7ef8b4eccb 100755 --- a/contrib/core/actions/pause.py +++ b/contrib/core/actions/pause.py @@ -19,9 +19,7 @@ from st2common.runners.base_action import Action -__all__ = [ - 'PauseAction' -] +__all__ = ["PauseAction"] class PauseAction(Action): diff --git a/contrib/core/tests/test_action_inject_trigger.py b/contrib/core/tests/test_action_inject_trigger.py index 4e0c3b1a29..7c8e44ac98 100644 --- a/contrib/core/tests/test_action_inject_trigger.py +++ b/contrib/core/tests/test_action_inject_trigger.py @@ -27,50 +27,46 @@ class InjectTriggerActionTestCase(BaseActionTestCase): action_cls = InjectTriggerAction - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_only_trigger_no_payload(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger1') + action.run(trigger="dummy_pack.trigger1") mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger1', - payload={}, - trace_tag=None + trigger="dummy_pack.trigger1", payload={}, trace_tag=None ) mock_api_client.webhooks.post_generic_webhook.reset() - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_trigger_and_payload(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger2', payload={'foo': 'bar'}) + action.run(trigger="dummy_pack.trigger2", payload={"foo": "bar"}) mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger2', - payload={'foo': 'bar'}, - trace_tag=None + trigger="dummy_pack.trigger2", payload={"foo": "bar"}, trace_tag=None ) mock_api_client.webhooks.post_generic_webhook.reset() - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_trigger_payload_trace_tag(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger3', payload={'foo': 'bar'}, trace_tag='Tag1') + action.run( + trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1" + ) mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger3', - payload={'foo': 'bar'}, - trace_tag='Tag1' + trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1" ) diff --git a/contrib/core/tests/test_action_sendmail.py b/contrib/core/tests/test_action_sendmail.py index 241fd35d68..b821ca5f12 100644 --- a/contrib/core/tests/test_action_sendmail.py +++ b/contrib/core/tests/test_action_sendmail.py @@ -33,12 +33,10 @@ from local_runner.local_shell_script_runner import LocalShellScriptRunner -__all__ = [ - 'SendmailActionTestCase' -] +__all__ = ["SendmailActionTestCase"] MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" HOSTNAME = socket.gethostname() @@ -47,134 +45,151 @@ class SendmailActionTestCase(RunnerTestCase, CleanDbTestCase, CleanFilesTestCase NOTE: Those tests rely on stanley user being available on the system and having passwordless sudo access. """ + fixtures_loader = FixturesLoader() def test_sendmail_default_text_html_content_type(self): action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is subject 1', - 'send_empty_body': False, - 'content_type': 'text/html', - 'body': 'Hello there html.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is subject 1", + "send_empty_body": False, + "content_type": "text/html", + "body": "Hello there html.", + "attachments": "", } - expected_body = ('Hello there html.\n' - '

\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there html.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/html; charset=UTF-8') + self.assertEqual(message.content_type, "text/html; charset=UTF-8") def test_sendmail_text_plain_content_type(self): action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is subject 2', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': 'Hello there plain.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is subject 2", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there plain.", + "attachments": "", } - expected_body = ('Hello there plain.\n\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there plain.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/plain; charset=UTF-8') + self.assertEqual(message.content_type, "text/plain; charset=UTF-8") def test_sendmail_utf8_subject_and_body(self): # 1. tex/html action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': u'Å unicode subject 😃😃', - 'send_empty_body': False, - 'content_type': 'text/html', - 'body': u'Hello there 😃😃.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "Å unicode subject 😃😃", + "send_empty_body": False, + "content_type": "text/html", + "body": "Hello there 😃😃.", + "attachments": "", } if six.PY2: - expected_body = (u'Hello there 😃😃.\n' - u'

\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there 😃😃.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) else: - expected_body = (u'Hello there \\U0001f603\\U0001f603.\n' - u'

\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) - - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + expected_body = ( + "Hello there \\U0001f603\\U0001f603.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) + + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/html; charset=UTF-8') + self.assertEqual(message.content_type, "text/html; charset=UTF-8") # 2. text/plain action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': u'Å unicode subject 😃😃', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': u'Hello there 😃😃.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "Å unicode subject 😃😃", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there 😃😃.", + "attachments": "", } if six.PY2: - expected_body = (u'Hello there 😃😃.\n\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there 😃😃.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) else: - expected_body = (u'Hello there \\U0001f603\\U0001f603.\n\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) - - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + expected_body = ( + "Hello there \\U0001f603\\U0001f603.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) + + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/plain; charset=UTF-8') + self.assertEqual(message.content_type, "text/plain; charset=UTF-8") def test_sendmail_with_attachments(self): _, path_1 = tempfile.mkstemp() @@ -185,48 +200,52 @@ def test_sendmail_with_attachments(self): self.to_delete_files.append(path_1) self.to_delete_files.append(path_2) - with open(path_1, 'w') as fp: - fp.write('content 1') + with open(path_1, "w") as fp: + fp.write("content 1") - with open(path_2, 'w') as fp: - fp.write('content 2') + with open(path_2, "w") as fp: + fp.write("content 2") action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is email with attachments', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': 'Hello there plain.', - 'attachments': '%s,%s' % (path_1, path_2) + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is email with attachments", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there plain.", + "attachments": "%s,%s" % (path_1, path_2), } - expected_body = ('Hello there plain.\n\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there plain.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, - 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"') + self.assertEqual( + message.content_type, 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"' + ) # There should be 3 message parts - 2 for attachments, one for body - self.assertEqual(email_data.count('--ZZ_/afg6432dfgkl.94531q'), 3) + self.assertEqual(email_data.count("--ZZ_/afg6432dfgkl.94531q"), 3) # There should be 2 attachments - self.assertEqual(email_data.count('Content-Transfer-Encoding: base64'), 2) - self.assertIn(base64.b64encode(b'content 1').decode('utf-8'), email_data) - self.assertIn(base64.b64encode(b'content 2').decode('utf-8'), email_data) + self.assertEqual(email_data.count("Content-Transfer-Encoding: base64"), 2) + self.assertIn(base64.b64encode(b"content 1").decode("utf-8"), email_data) + self.assertIn(base64.b64encode(b"content 2").decode("utf-8"), email_data) def _run_action(self, action_parameters): """ @@ -234,10 +253,12 @@ def _run_action(self, action_parameters): parse the output email data. """ models = self.fixtures_loader.load_models( - fixtures_pack='packs/core', fixtures_dict={'actions': ['sendmail.yaml']}) - action_db = models['actions']['sendmail.yaml'] + fixtures_pack="packs/core", fixtures_dict={"actions": ["sendmail.yaml"]} + ) + action_db = models["actions"]["sendmail.yaml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'packs/core', 'actions', 'send_mail/send_mail') + "packs/core", "actions", "send_mail/send_mail" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() @@ -246,13 +267,13 @@ def _run_action(self, action_parameters): # Remove footer added by the action which is not part of raw email data and parse # the message - if 'stdout' in result: - email_data = result['stdout'] - email_data = email_data.split('\n')[:-2] - email_data = '\n'.join(email_data) + if "stdout" in result: + email_data = result["stdout"] + email_data = email_data.split("\n")[:-2] + email_data = "\n".join(email_data) if six.PY2 and isinstance(email_data, six.text_type): - email_data = email_data.encode('utf-8') + email_data = email_data.encode("utf-8") message = mailparser.parse_from_string(email_data) else: @@ -273,5 +294,5 @@ def _get_runner(self, action_db, entry_point): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/core/tests/test_action_uuid.py b/contrib/core/tests/test_action_uuid.py index 4e946f3062..e13cfd4c18 100644 --- a/contrib/core/tests/test_action_uuid.py +++ b/contrib/core/tests/test_action_uuid.py @@ -28,13 +28,13 @@ def test_run(self): action = self.get_action_instance() # accepts uuid1 as a type - result = action.run(uuid_type='uuid1') + result = action.run(uuid_type="uuid1") self.assertTrue(result) # accepts uuid4 as a type - result = action.run(uuid_type='uuid4') + result = action.run(uuid_type="uuid4") self.assertTrue(result) # fails on incorrect type with self.assertRaises(ValueError): - result = action.run(uuid_type='foobar') + result = action.run(uuid_type="foobar") diff --git a/contrib/examples/actions/noop.py b/contrib/examples/actions/noop.py index 0283499ce1..bbdf5e67e6 100644 --- a/contrib/examples/actions/noop.py +++ b/contrib/examples/actions/noop.py @@ -5,6 +5,6 @@ class PrintParametersAction(Action): def run(self, **parameters): - print('=========') + print("=========") pprint(parameters) - print('=========') + print("=========") diff --git a/contrib/examples/actions/print_config.py b/contrib/examples/actions/print_config.py index 68bdf1e2d6..15b3103b61 100644 --- a/contrib/examples/actions/print_config.py +++ b/contrib/examples/actions/print_config.py @@ -5,6 +5,6 @@ class PrintConfigAction(Action): def run(self): - print('=========') + print("=========") pprint(self.config) - print('=========') + print("=========") diff --git a/contrib/examples/actions/print_to_stdout_and_stderr.py b/contrib/examples/actions/print_to_stdout_and_stderr.py index da31dc14b4..124a32a67c 100644 --- a/contrib/examples/actions/print_to_stdout_and_stderr.py +++ b/contrib/examples/actions/print_to_stdout_and_stderr.py @@ -23,12 +23,12 @@ class PrintToStdoutAndStderrAction(Action): def run(self, count=100, sleep_delay=0.5): for i in range(0, count): if i % 2 == 0: - text = 'stderr' + text = "stderr" stream = sys.stderr else: - text = 'stdout' + text = "stdout" stream = sys.stdout - stream.write('%s -> Line: %s\n' % (text, (i + 1))) + stream.write("%s -> Line: %s\n" % (text, (i + 1))) stream.flush() time.sleep(sleep_delay) diff --git a/contrib/examples/actions/python-mock-core-remote.py b/contrib/examples/actions/python-mock-core-remote.py index cd4d44500e..52c13d804e 100644 --- a/contrib/examples/actions/python-mock-core-remote.py +++ b/contrib/examples/actions/python-mock-core-remote.py @@ -2,7 +2,6 @@ class MockCoreRemoteAction(Action): - def run(self, cmd, hosts, hosts_dict): if hosts_dict: return hosts_dict @@ -10,14 +9,14 @@ def run(self, cmd, hosts, hosts_dict): if not hosts: return None - host_list = hosts.split(',') + host_list = hosts.split(",") results = {} for h in hosts: results[h] = { - 'failed': False, - 'return_code': 0, - 'stderr': '', - 'succeeded': True, - 'stdout': cmd, + "failed": False, + "return_code": 0, + "stderr": "", + "succeeded": True, + "stdout": cmd, } return results diff --git a/contrib/examples/actions/python-mock-create-vm.py b/contrib/examples/actions/python-mock-create-vm.py index 60a88b7967..62fdaa36c1 100644 --- a/contrib/examples/actions/python-mock-create-vm.py +++ b/contrib/examples/actions/python-mock-create-vm.py @@ -5,17 +5,12 @@ class MockCreateVMAction(Action): - def run(self, cpu_cores, memory_mb, vm_name, ip): eventlet.sleep(5) data = { - 'vm_id': 'vm' + str(random.randint(0, 10000)), - ip: { - 'cpu_cores': cpu_cores, - 'memory_mb': memory_mb, - 'vm_name': vm_name - } + "vm_id": "vm" + str(random.randint(0, 10000)), + ip: {"cpu_cores": cpu_cores, "memory_mb": memory_mb, "vm_name": vm_name}, } return data diff --git a/contrib/examples/actions/pythonactions/fibonacci.py b/contrib/examples/actions/pythonactions/fibonacci.py index afab612161..bd9a479f35 100755 --- a/contrib/examples/actions/pythonactions/fibonacci.py +++ b/contrib/examples/actions/pythonactions/fibonacci.py @@ -12,12 +12,13 @@ def fib(n): return n return fib(n - 2) + fib(n - 1) -if __name__ == '__main__': + +if __name__ == "__main__": try: startNumber = int(float(sys.argv[1])) endNumber = int(float(sys.argv[2])) results = map(str, map(fib, list(range(startNumber, endNumber)))) - results = ' '.join(results) + results = " ".join(results) print(results) except Exception as e: traceback.print_exc(file=sys.stderr) diff --git a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py index 989467570c..8cb3c42f4b 100644 --- a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py +++ b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py @@ -3,13 +3,13 @@ class IncreaseIndexAndCheckCondition(Action): def run(self, index, pagesize, input): - if pagesize and pagesize != '': + if pagesize and pagesize != "": if len(input) < int(pagesize): return (False, "Breaking out of the loop") else: pagesize = 0 - if not index or index == '': + if not index or index == "": index = 1 - return(True, int(index) + 1) + return (True, int(index) + 1) diff --git a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py index a2cdfd1063..dbefc1b07e 100644 --- a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py +++ b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py @@ -6,12 +6,12 @@ class ParseGithubRepos(Action): def run(self, content): try: - soup = BeautifulSoup(content, 'html.parser') + soup = BeautifulSoup(content, "html.parser") repo_list = soup.find_all("h3") output = {} for each_item in repo_list: - repo_half_url = each_item.find("a")['href'] + repo_half_url = each_item.find("a")["href"] repo_name = repo_half_url.split("/")[-1] repo_url = "https://github.com" + repo_half_url output[repo_name] = repo_url diff --git a/contrib/examples/actions/pythonactions/isprime.py b/contrib/examples/actions/pythonactions/isprime.py index 911594a01e..e55d202922 100644 --- a/contrib/examples/actions/pythonactions/isprime.py +++ b/contrib/examples/actions/pythonactions/isprime.py @@ -6,18 +6,19 @@ class PrimeCheckerAction(Action): def run(self, value=0): - self.logger.debug('PYTHONPATH: %s', get_environ('PYTHONPATH')) - self.logger.debug('value=%s' % (value)) + self.logger.debug("PYTHONPATH: %s", get_environ("PYTHONPATH")) + self.logger.debug("value=%s" % (value)) if math.floor(value) != value: - raise ValueError('%s should be an integer.' % value) + raise ValueError("%s should be an integer." % value) if value < 2: return False - for test in range(2, int(math.floor(math.sqrt(value)))+1): + for test in range(2, int(math.floor(math.sqrt(value))) + 1): if value % test == 0: return False return True -if __name__ == '__main__': + +if __name__ == "__main__": checker = PrimeCheckerAction() for i in range(0, 10): - print('%s : %s' % (i, checker.run(value=1))) + print("%s : %s" % (i, checker.run(value=1))) diff --git a/contrib/examples/actions/pythonactions/json_string_to_object.py b/contrib/examples/actions/pythonactions/json_string_to_object.py index 1072c4554e..e3c492d7a2 100644 --- a/contrib/examples/actions/pythonactions/json_string_to_object.py +++ b/contrib/examples/actions/pythonactions/json_string_to_object.py @@ -4,6 +4,5 @@ class JsonStringToObject(Action): - def run(self, json_str): return json.loads(json_str) diff --git a/contrib/examples/actions/pythonactions/object_return.py b/contrib/examples/actions/pythonactions/object_return.py index ecaaf57391..f8a008b73d 100644 --- a/contrib/examples/actions/pythonactions/object_return.py +++ b/contrib/examples/actions/pythonactions/object_return.py @@ -2,6 +2,5 @@ class ObjectReturnAction(Action): - def run(self): - return {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + return {"a": "b", "c": {"d": "e", "f": 1, "g": True}} diff --git a/contrib/examples/actions/pythonactions/print_python_environment.py b/contrib/examples/actions/pythonactions/print_python_environment.py index 9c070cc1c0..dd92bfc202 100644 --- a/contrib/examples/actions/pythonactions/print_python_environment.py +++ b/contrib/examples/actions/pythonactions/print_python_environment.py @@ -6,10 +6,9 @@ class PrintPythonEnvironmentAction(Action): - def run(self): - print('Using Python executable: %s' % (sys.executable)) - print('Using Python version: %s' % (sys.version)) - print('Platform: %s' % (platform.platform())) - print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH'))) - print('sys.path: %s' % (sys.path)) + print("Using Python executable: %s" % (sys.executable)) + print("Using Python version: %s" % (sys.version)) + print("Platform: %s" % (platform.platform())) + print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH"))) + print("sys.path: %s" % (sys.path)) diff --git a/contrib/examples/actions/pythonactions/print_python_version.py b/contrib/examples/actions/pythonactions/print_python_version.py index 0ae2a27b18..201c68dd5f 100644 --- a/contrib/examples/actions/pythonactions/print_python_version.py +++ b/contrib/examples/actions/pythonactions/print_python_version.py @@ -4,7 +4,6 @@ class PrintPythonVersionAction(Action): - def run(self): - print('Using Python executable: %s' % (sys.executable)) - print('Using Python version: %s' % (sys.version)) + print("Using Python executable: %s" % (sys.executable)) + print("Using Python version: %s" % (sys.version)) diff --git a/contrib/examples/actions/pythonactions/yaml_string_to_object.py b/contrib/examples/actions/pythonactions/yaml_string_to_object.py index 297451cdad..aa888ce408 100644 --- a/contrib/examples/actions/pythonactions/yaml_string_to_object.py +++ b/contrib/examples/actions/pythonactions/yaml_string_to_object.py @@ -4,6 +4,5 @@ class YamlStringToObject(Action): - def run(self, yaml_str): return yaml.safe_load(yaml_str) diff --git a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py index c2c0198a42..14f19582fd 100644 --- a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py +++ b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py @@ -5,11 +5,11 @@ def to_json(out, err, code): payload = {} if err: - payload['err'] = err - payload['exit_code'] = code + payload["err"] = err + payload["exit_code"] = code return json.dumps(payload) - payload['pkg_info'] = out - payload['exit_code'] = code + payload["pkg_info"] = out + payload["exit_code"] = code return json.dumps(payload) diff --git a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py index d8213f4342..ec5e5f7ace 100755 --- a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py +++ b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py @@ -7,17 +7,20 @@ def main(args): - command_list = shlex.split('apt-cache policy ' + ' '.join(args[1:])) - process = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command_list = shlex.split("apt-cache policy " + " ".join(args[1:])) + process = subprocess.Popen( + command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) command_stdout, command_stderr = process.communicate() command_exitcode = process.returncode try: payload = transformer.to_json(command_stdout, command_stderr, command_exitcode) except Exception as e: - sys.stderr.write('JSON conversion failed. %s' % six.text_type(e)) + sys.stderr.write("JSON conversion failed. %s" % six.text_type(e)) sys.exit(1) sys.stdout.write(payload) -if __name__ == '__main__': + +if __name__ == "__main__": main(sys.argv) diff --git a/contrib/examples/sensors/echo_flask_app.py b/contrib/examples/sensors/echo_flask_app.py index c4025ae441..5177df2306 100644 --- a/contrib/examples/sensors/echo_flask_app.py +++ b/contrib/examples/sensors/echo_flask_app.py @@ -6,13 +6,12 @@ class EchoFlaskSensor(Sensor): def __init__(self, sensor_service, config): super(EchoFlaskSensor, self).__init__( - sensor_service=sensor_service, - config=config + sensor_service=sensor_service, config=config ) - self._host = '127.0.0.1' + self._host = "127.0.0.1" self._port = 5000 - self._path = '/echo' + self._path = "/echo" self._log = self._sensor_service.get_logger(__name__) self._app = Flask(__name__) @@ -21,15 +20,19 @@ def setup(self): pass def run(self): - @self._app.route(self._path, methods=['POST']) + @self._app.route(self._path, methods=["POST"]) def echo(): payload = request.get_json(force=True) - self._sensor_service.dispatch(trigger="examples.echoflasksensor", - payload=payload) + self._sensor_service.dispatch( + trigger="examples.echoflasksensor", payload=payload + ) return request.data - self._log.info('Listening for payload on http://{}:{}{}'.format( - self._host, self._port, self._path)) + self._log.info( + "Listening for payload on http://{}:{}{}".format( + self._host, self._port, self._path + ) + ) self._app.run(host=self._host, port=self._port, threaded=False) def cleanup(self): diff --git a/contrib/examples/sensors/fibonacci_sensor.py b/contrib/examples/sensors/fibonacci_sensor.py index 266e81aba3..2df956335b 100644 --- a/contrib/examples/sensors/fibonacci_sensor.py +++ b/contrib/examples/sensors/fibonacci_sensor.py @@ -4,12 +4,9 @@ class FibonacciSensor(PollingSensor): - def __init__(self, sensor_service, config, poll_interval=20): super(FibonacciSensor, self).__init__( - sensor_service=sensor_service, - config=config, - poll_interval=poll_interval + sensor_service=sensor_service, config=config, poll_interval=poll_interval ) self.a = None self.b = None @@ -26,19 +23,21 @@ def setup(self): def poll(self): # Reset a and b if there are large enough to avoid integer overflow problems if self.a > 10000 or self.b > 10000: - self.logger.debug('Reseting values to avoid integer overflow issues') + self.logger.debug("Reseting values to avoid integer overflow issues") self.a = 0 self.b = 1 self.count = 2 - fib = (self.a + self.b) - self.logger.debug('Count: %d, a: %d, b: %d, fib: %s', self.count, self.a, self.b, fib) + fib = self.a + self.b + self.logger.debug( + "Count: %d, a: %d, b: %d, fib: %s", self.count, self.a, self.b, fib + ) payload = { "count": self.count, "fibonacci": fib, - "pythonpath": os.environ.get("PYTHONPATH", None) + "pythonpath": os.environ.get("PYTHONPATH", None), } self.sensor_service.dispatch(trigger="examples.fibonacci", payload=payload) diff --git a/contrib/hello_st2/sensors/sensor1.py b/contrib/hello_st2/sensors/sensor1.py index 501de54a98..a4914cdf8b 100644 --- a/contrib/hello_st2/sensors/sensor1.py +++ b/contrib/hello_st2/sensors/sensor1.py @@ -14,11 +14,11 @@ def setup(self): def run(self): while not self._stop: - self._logger.debug('HelloSensor dispatching trigger...') - count = self.sensor_service.get_value('hello_st2.count') or 0 - payload = {'greeting': 'Yo, StackStorm!', 'count': int(count) + 1} - self.sensor_service.dispatch(trigger='hello_st2.event1', payload=payload) - self.sensor_service.set_value('hello_st2.count', payload['count']) + self._logger.debug("HelloSensor dispatching trigger...") + count = self.sensor_service.get_value("hello_st2.count") or 0 + payload = {"greeting": "Yo, StackStorm!", "count": int(count) + 1} + self.sensor_service.dispatch(trigger="hello_st2.event1", payload=payload) + self.sensor_service.set_value("hello_st2.count", payload["count"]) eventlet.sleep(60) def cleanup(self): diff --git a/contrib/linux/actions/checks/check_loadavg.py b/contrib/linux/actions/checks/check_loadavg.py index fb7d3938cc..04036924e8 100755 --- a/contrib/linux/actions/checks/check_loadavg.py +++ b/contrib/linux/actions/checks/check_loadavg.py @@ -29,7 +29,7 @@ output = {} try: - fh = open(loadAvgFile, 'r') + fh = open(loadAvgFile, "r") load = fh.readline().split()[0:3] except: print("Error opening %s" % loadAvgFile) @@ -38,7 +38,7 @@ fh.close() try: - fh = open(cpuInfoFile, 'r') + fh = open(cpuInfoFile, "r") for line in fh: if "processor" in line: cpus += 1 @@ -48,16 +48,16 @@ finally: fh.close() -output['1'] = str(float(load[0]) / cpus) -output['5'] = str(float(load[1]) / cpus) -output['15'] = str(float(load[2]) / cpus) +output["1"] = str(float(load[0]) / cpus) +output["5"] = str(float(load[1]) / cpus) +output["15"] = str(float(load[2]) / cpus) -if time == '1' or time == 'one': - print(output['1']) -elif time == '5' or time == 'five': - print(output['5']) -elif time == '15' or time == 'fifteen': - print(output['15']) +if time == "1" or time == "one": + print(output["1"]) +elif time == "5" or time == "five": + print(output["5"]) +elif time == "15" or time == "fifteen": + print(output["15"]) else: print(json.dumps(output)) diff --git a/contrib/linux/actions/checks/check_processes.py b/contrib/linux/actions/checks/check_processes.py index b1ff1af0ae..d2a7db195f 100755 --- a/contrib/linux/actions/checks/check_processes.py +++ b/contrib/linux/actions/checks/check_processes.py @@ -41,8 +41,11 @@ def setup(self, debug=False, pidlist=False): if debug is True: print("Debug is on") - self.allProcs = [procs for procs in os.listdir(self.procDir) if procs.isdigit() and - int(procs) != int(self.myPid)] + self.allProcs = [ + procs + for procs in os.listdir(self.procDir) + if procs.isdigit() and int(procs) != int(self.myPid) + ] def process(self, criteria): for p in self.allProcs: @@ -58,37 +61,37 @@ def process(self, criteria): cmdfh.close() fh.close() - if criteria == 'state': + if criteria == "state": if pInfo[2] == self.state: self.interestingProcs.append(pInfo) - elif criteria == 'name': + elif criteria == "name": if re.search(self.name, pInfo[1]): self.interestingProcs.append(pInfo) - elif criteria == 'pid': + elif criteria == "pid": if pInfo[0] == self.pid: self.interestingProcs.append(pInfo) def byState(self, state): self.state = state - self.process(criteria='state') + self.process(criteria="state") self.show() def byPid(self, pid): self.pid = pid - self.process(criteria='pid') + self.process(criteria="pid") self.show() def byName(self, name): self.name = name - self.process(criteria='name') + self.process(criteria="name") self.show() def run(self, foo, criteria): - if foo == 'state': + if foo == "state": self.byState(criteria) - elif foo == 'name': + elif foo == "name": self.byName(criteria) - elif foo == 'pid': + elif foo == "pid": self.byPid(criteria) def show(self): @@ -99,13 +102,13 @@ def show(self): prettyOut[proc[0]] = proc[1] if self.pidlist is True: - pidlist = ' '.join(prettyOut.keys()) + pidlist = " ".join(prettyOut.keys()) sys.stderr.write(pidlist) print(json.dumps(prettyOut)) -if __name__ == '__main__': +if __name__ == "__main__": if "pidlist" in sys.argv: pidlist = True else: diff --git a/contrib/linux/actions/dig.py b/contrib/linux/actions/dig.py index 9a3b58a5cd..7eb8518a2a 100644 --- a/contrib/linux/actions/dig.py +++ b/contrib/linux/actions/dig.py @@ -25,29 +25,28 @@ class DigAction(Action): - def run(self, rand, count, nameserver, hostname, queryopts): opt_list = [] output = [] - cmd_args = ['dig'] + cmd_args = ["dig"] if nameserver: - nameserver = '@' + nameserver + nameserver = "@" + nameserver cmd_args.append(nameserver) - if isinstance(queryopts, str) and ',' in queryopts: - opt_list = queryopts.split(',') + if isinstance(queryopts, str) and "," in queryopts: + opt_list = queryopts.split(",") else: opt_list.append(queryopts) - cmd_args.extend(['+' + option for option in opt_list]) + cmd_args.extend(["+" + option for option in opt_list]) cmd_args.append(hostname) try: - raw_result = subprocess.Popen(cmd_args, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE).communicate()[0] + raw_result = subprocess.Popen( + cmd_args, stderr=subprocess.PIPE, stdout=subprocess.PIPE + ).communicate()[0] if sys.version_info >= (3,): # This function might call getpreferred encoding unless we pass @@ -57,16 +56,19 @@ def run(self, rand, count, nameserver, hostname, queryopts): else: result_list_str = str(raw_result) - result_list = list(filter(None, result_list_str.split('\n'))) + result_list = list(filter(None, result_list_str.split("\n"))) # NOTE: Python3 supports the FileNotFoundError, the errono.ENOENT is for py2 compat # for Python3: # except FileNotFoundError as e: except OSError as e: if e.errno == errno.ENOENT: - return False, "Can't find dig installed in the path (usually /usr/bin/dig). If " \ - "dig isn't installed, you can install it with 'sudo yum install " \ - "bind-utils' or 'sudo apt install dnsutils'" + return ( + False, + "Can't find dig installed in the path (usually /usr/bin/dig). If " + "dig isn't installed, you can install it with 'sudo yum install " + "bind-utils' or 'sudo apt install dnsutils'", + ) else: raise e diff --git a/contrib/linux/actions/service.py b/contrib/linux/actions/service.py index 3961438431..335e5038f6 100644 --- a/contrib/linux/actions/service.py +++ b/contrib/linux/actions/service.py @@ -26,20 +26,23 @@ distro = platform.linux_distribution()[0] if len(sys.argv) < 3: - raise ValueError('Usage: service.py ') + raise ValueError("Usage: service.py ") -args = {'act': quote_unix(sys.argv[1]), 'service': quote_unix(sys.argv[2])} +args = {"act": quote_unix(sys.argv[1]), "service": quote_unix(sys.argv[2])} -if re.search(distro, 'Ubuntu'): - if os.path.isfile("/etc/init/%s.conf" % args['service']): - cmd_args = ['service', args['service'], args['act']] - elif os.path.isfile("/etc/init.d/%s" % args['service']): - cmd_args = ['/etc/init.d/%s' % (args['service']), args['act']] +if re.search(distro, "Ubuntu"): + if os.path.isfile("/etc/init/%s.conf" % args["service"]): + cmd_args = ["service", args["service"], args["act"]] + elif os.path.isfile("/etc/init.d/%s" % args["service"]): + cmd_args = ["/etc/init.d/%s" % (args["service"]), args["act"]] else: print("Unknown service") sys.exit(2) -elif re.search(distro, 'Redhat') or re.search(distro, 'Fedora') or \ - re.search(distro, 'CentOS Linux'): - cmd_args = ['systemctl', args['act'], args['service']] +elif ( + re.search(distro, "Redhat") + or re.search(distro, "Fedora") + or re.search(distro, "CentOS Linux") +): + cmd_args = ["systemctl", args["act"], args["service"]] subprocess.call(cmd_args, shell=False) diff --git a/contrib/linux/actions/wait_for_ssh.py b/contrib/linux/actions/wait_for_ssh.py index 4ad4a66050..c29e91ba03 100644 --- a/contrib/linux/actions/wait_for_ssh.py +++ b/contrib/linux/actions/wait_for_ssh.py @@ -25,29 +25,47 @@ class BaseAction(Action): - def run(self, hostname, port, username, password=None, keyfile=None, ssh_timeout=5, - sleep_delay=20, retries=10): + def run( + self, + hostname, + port, + username, + password=None, + keyfile=None, + ssh_timeout=5, + sleep_delay=20, + retries=10, + ): # Note: If neither password nor key file is provided, we try to use system user # key file if not password and not keyfile: keyfile = cfg.CONF.system_user.ssh_key_file - self.logger.info('Neither "password" nor "keyfile" parameter provided, ' - 'defaulting to using "%s" key file' % (keyfile)) + self.logger.info( + 'Neither "password" nor "keyfile" parameter provided, ' + 'defaulting to using "%s" key file' % (keyfile) + ) - client = ParamikoSSHClient(hostname=hostname, port=port, username=username, - password=password, key_files=keyfile, - timeout=ssh_timeout) + client = ParamikoSSHClient( + hostname=hostname, + port=port, + username=username, + password=password, + key_files=keyfile, + timeout=ssh_timeout, + ) for index in range(retries): attempt = index + 1 try: - self.logger.debug('SSH connection attempt: %s' % (attempt)) + self.logger.debug("SSH connection attempt: %s" % (attempt)) client.connect() return True except Exception as e: - self.logger.info('Attempt %s failed (%s), sleeping for %s seconds...' % - (attempt, six.text_type(e), sleep_delay)) + self.logger.info( + "Attempt %s failed (%s), sleeping for %s seconds..." + % (attempt, six.text_type(e), sleep_delay) + ) time.sleep(sleep_delay) - raise Exception('Exceeded max retries (%s)' % (retries)) + raise Exception("Exceeded max retries (%s)" % (retries)) diff --git a/contrib/linux/sensors/file_watch_sensor.py b/contrib/linux/sensors/file_watch_sensor.py index 2597d63926..52e2943116 100644 --- a/contrib/linux/sensors/file_watch_sensor.py +++ b/contrib/linux/sensors/file_watch_sensor.py @@ -24,8 +24,9 @@ class FileWatchSensor(Sensor): def __init__(self, sensor_service, config=None): - super(FileWatchSensor, self).__init__(sensor_service=sensor_service, - config=config) + super(FileWatchSensor, self).__init__( + sensor_service=sensor_service, config=config + ) self._trigger = None self._logger = self._sensor_service.get_logger(__name__) self._tail = None @@ -48,16 +49,16 @@ def cleanup(self): pass def add_trigger(self, trigger): - file_path = trigger['parameters'].get('file_path', None) + file_path = trigger["parameters"].get("file_path", None) if not file_path: self._logger.error('Received trigger type without "file_path" field.') return - self._trigger = trigger.get('ref', None) + self._trigger = trigger.get("ref", None) if not self._trigger: - raise Exception('Trigger %s did not contain a ref.' % trigger) + raise Exception("Trigger %s did not contain a ref." % trigger) # Wait a bit to avoid initialization race in logshipper library eventlet.sleep(1.0) @@ -69,7 +70,7 @@ def update_trigger(self, trigger): pass def remove_trigger(self, trigger): - file_path = trigger['parameters'].get('file_path', None) + file_path = trigger["parameters"].get("file_path", None) if not file_path: self._logger.error('Received trigger type without "file_path" field.') @@ -83,10 +84,11 @@ def remove_trigger(self, trigger): def _handle_line(self, file_path, line): trigger = self._trigger payload = { - 'file_path': file_path, - 'file_name': os.path.basename(file_path), - 'line': line + "file_path": file_path, + "file_name": os.path.basename(file_path), + "line": line, } - self._logger.debug('Sending payload %s for trigger %s to sensor_service.', - payload, trigger) + self._logger.debug( + "Sending payload %s for trigger %s to sensor_service.", payload, trigger + ) self.sensor_service.dispatch(trigger=trigger, payload=payload) diff --git a/contrib/linux/tests/test_action_dig.py b/contrib/linux/tests/test_action_dig.py index 4f363521d9..008cf16e76 100644 --- a/contrib/linux/tests/test_action_dig.py +++ b/contrib/linux/tests/test_action_dig.py @@ -27,15 +27,18 @@ def test_run_with_empty_hostname(self): action = self.get_action_instance() # Use the defaults from dig.yaml - result = action.run(rand=False, count=0, nameserver=None, hostname='', queryopts='short') + result = action.run( + rand=False, count=0, nameserver=None, hostname="", queryopts="short" + ) self.assertIsInstance(result, list) self.assertEqual(len(result), 0) def test_run_with_empty_queryopts(self): action = self.get_action_instance() - results = action.run(rand=False, count=0, nameserver=None, hostname='google.com', - queryopts='') + results = action.run( + rand=False, count=0, nameserver=None, hostname="google.com", queryopts="" + ) self.assertIsInstance(results, list) for result in results: @@ -45,8 +48,13 @@ def test_run_with_empty_queryopts(self): def test_run(self): action = self.get_action_instance() - results = action.run(rand=False, count=0, nameserver=None, hostname='google.com', - queryopts='short') + results = action.run( + rand=False, + count=0, + nameserver=None, + hostname="google.com", + queryopts="short", + ) self.assertIsInstance(results, list) self.assertGreater(len(results), 0) diff --git a/contrib/packs/actions/get_config.py b/contrib/packs/actions/get_config.py index 505ef683c4..07e4654cef 100755 --- a/contrib/packs/actions/get_config.py +++ b/contrib/packs/actions/get_config.py @@ -22,8 +22,8 @@ class RenderTemplateAction(Action): def run(self): result = { - 'pack_group': utils.get_pack_group(), - 'pack_path': utils.get_system_packs_base_path() + "pack_group": utils.get_pack_group(), + "pack_path": utils.get_system_packs_base_path(), } return result diff --git a/contrib/packs/actions/pack_mgmt/delete.py b/contrib/packs/actions/pack_mgmt/delete.py index 93bcc46044..ca0436e834 100644 --- a/contrib/packs/actions/pack_mgmt/delete.py +++ b/contrib/packs/actions/pack_mgmt/delete.py @@ -27,15 +27,18 @@ class UninstallPackAction(Action): def __init__(self, config=None, action_service=None): - super(UninstallPackAction, self).__init__(config=config, action_service=action_service) - self._base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, - 'virtualenvs/') + super(UninstallPackAction, self).__init__( + config=config, action_service=action_service + ) + self._base_virtualenvs_path = os.path.join( + cfg.CONF.system.base_path, "virtualenvs/" + ) def run(self, packs, abs_repo_base, delete_env=True): intersection = BLOCKED_PACKS & frozenset(packs) if len(intersection) > 0: - names = ', '.join(list(intersection)) - raise ValueError('Uninstall includes an uninstallable pack - %s.' % (names)) + names = ", ".join(list(intersection)) + raise ValueError("Uninstall includes an uninstallable pack - %s." % (names)) # 1. Delete pack content for fp in os.listdir(abs_repo_base): @@ -51,6 +54,8 @@ def run(self, packs, abs_repo_base, delete_env=True): virtualenv_path = os.path.join(self._base_virtualenvs_path, pack_name) if os.path.isdir(virtualenv_path): - self.logger.debug('Deleting virtualenv "%s" for pack "%s"' % - (virtualenv_path, pack_name)) + self.logger.debug( + 'Deleting virtualenv "%s" for pack "%s"' + % (virtualenv_path, pack_name) + ) shutil.rmtree(virtualenv_path) diff --git a/contrib/packs/actions/pack_mgmt/download.py b/contrib/packs/actions/pack_mgmt/download.py index b4d888630b..cc0f7cd8fb 100644 --- a/contrib/packs/actions/pack_mgmt/download.py +++ b/contrib/packs/actions/pack_mgmt/download.py @@ -21,68 +21,85 @@ from st2common.runners.base_action import Action from st2common.util.pack_management import download_pack -__all__ = [ - 'DownloadGitRepoAction' -] +__all__ = ["DownloadGitRepoAction"] class DownloadGitRepoAction(Action): def __init__(self, config=None, action_service=None): - super(DownloadGitRepoAction, self).__init__(config=config, action_service=action_service) + super(DownloadGitRepoAction, self).__init__( + config=config, action_service=action_service + ) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } # This is needed for git binary to work with a proxy - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy - def run(self, packs, abs_repo_base, verifyssl=True, force=False, - dependency_list=None): + def run( + self, packs, abs_repo_base, verifyssl=True, force=False, dependency_list=None + ): result = {} pack_url = None if dependency_list: for pack_dependency in dependency_list: - pack_result = download_pack(pack=pack_dependency, abs_repo_base=abs_repo_base, - verify_ssl=verifyssl, force=force, - proxy_config=self.proxy_config, force_permissions=True, - logger=self.logger) + pack_result = download_pack( + pack=pack_dependency, + abs_repo_base=abs_repo_base, + verify_ssl=verifyssl, + force=force, + proxy_config=self.proxy_config, + force_permissions=True, + logger=self.logger, + ) pack_url, pack_ref, pack_result = pack_result result[pack_ref] = pack_result else: for pack in packs: - pack_result = download_pack(pack=pack, abs_repo_base=abs_repo_base, - verify_ssl=verifyssl, force=force, - proxy_config=self.proxy_config, - force_permissions=True, - logger=self.logger) + pack_result = download_pack( + pack=pack, + abs_repo_base=abs_repo_base, + verify_ssl=verifyssl, + force=force, + proxy_config=self.proxy_config, + force_permissions=True, + logger=self.logger, + ) pack_url, pack_ref, pack_result = pack_result result[pack_ref] = pack_result @@ -99,14 +116,16 @@ def _validate_result(result, repo_url): if not atleast_one_success: message_list = [] - message_list.append('The pack has not been downloaded from "%s".\n' % (repo_url)) - message_list.append('Errors:') + message_list.append( + 'The pack has not been downloaded from "%s".\n' % (repo_url) + ) + message_list.append("Errors:") for pack, value in result.items(): success, error = value message_list.append(error) - message = '\n'.join(message_list) + message = "\n".join(message_list) raise Exception(message) return sanitized_result diff --git a/contrib/packs/actions/pack_mgmt/get_installed.py b/contrib/packs/actions/pack_mgmt/get_installed.py index eaa88b6319..36f2504b85 100644 --- a/contrib/packs/actions/pack_mgmt/get_installed.py +++ b/contrib/packs/actions/pack_mgmt/get_installed.py @@ -28,6 +28,7 @@ class GetInstalled(Action): """"Get information about installed pack.""" + def run(self, pack): """ :param pack: Installed Pack Name to get info about @@ -47,46 +48,42 @@ def run(self, pack): # Pack doesn't exist, finish execution normally with empty metadata if not os.path.isdir(pack_path): - return { - 'pack': None, - 'git_status': None - } + return {"pack": None, "git_status": None} if not metadata_file: - error = ('Pack "%s" doesn\'t contain pack.yaml file.' % (pack)) + error = 'Pack "%s" doesn\'t contain pack.yaml file.' % (pack) raise Exception(error) try: details = self._parse_yaml_file(metadata_file) except Exception as e: - error = ('Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % (pack, - six.text_type(e))) + error = 'Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % ( + pack, + six.text_type(e), + ) raise Exception(error) try: repo = Repo(pack_path) git_status = "Status:\n%s\n\nRemotes:\n%s" % ( - repo.git.status().split('\n')[0], - "\n".join([remote.url for remote in repo.remotes]) + repo.git.status().split("\n")[0], + "\n".join([remote.url for remote in repo.remotes]), ) ahead_behind = repo.git.rev_list( - '--left-right', '--count', 'HEAD...origin/master' + "--left-right", "--count", "HEAD...origin/master" ).split() # Dear god. - if ahead_behind != [u'0', u'0']: + if ahead_behind != ["0", "0"]: git_status += "\n\n" - git_status += "%s commits ahead " if ahead_behind[0] != u'0' else "" - git_status += "and " if u'0' not in ahead_behind else "" - git_status += "%s commits behind " if ahead_behind[1] != u'0' else "" + git_status += "%s commits ahead " if ahead_behind[0] != "0" else "" + git_status += "and " if "0" not in ahead_behind else "" + git_status += "%s commits behind " if ahead_behind[1] != "0" else "" git_status += "origin/master." except InvalidGitRepositoryError: git_status = None - return { - 'pack': details, - 'git_status': git_status - } + return {"pack": details, "git_status": git_status} def _parse_yaml_file(self, file_path): with open(file_path) as data_file: diff --git a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py index 60ab2c9503..b9168526a2 100644 --- a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py +++ b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py @@ -40,7 +40,7 @@ def run(self, packs_status, nested): return result for pack, status in six.iteritems(packs_status): - if 'success' not in status.lower(): + if "success" not in status.lower(): continue dependency_packs = get_dependency_list(pack) @@ -50,40 +50,51 @@ def run(self, packs_status, nested): for dep_pack in dependency_packs: name_or_url, pack_version = self.get_name_and_version(dep_pack) - if len(name_or_url.split('/')) == 1: + if len(name_or_url.split("/")) == 1: pack_name = name_or_url else: name_or_git = name_or_url.split("/")[-1] - pack_name = name_or_git if '.git' not in name_or_git else \ - name_or_git.split('.')[0] + pack_name = ( + name_or_git + if ".git" not in name_or_git + else name_or_git.split(".")[0] + ) # Check existing pack by pack name existing_pack_version = get_pack_version(pack_name) # Try one more time to get existing pack version by name if 'stackstorm-' is in # pack name - if not existing_pack_version and 'stackstorm-' in pack_name.lower(): - existing_pack_version = get_pack_version(pack_name.split('stackstorm-')[-1]) + if not existing_pack_version and "stackstorm-" in pack_name.lower(): + existing_pack_version = get_pack_version( + pack_name.split("stackstorm-")[-1] + ) if existing_pack_version: - if existing_pack_version and not existing_pack_version.startswith('v'): - existing_pack_version = 'v' + existing_pack_version - if pack_version and not pack_version.startswith('v'): - pack_version = 'v' + pack_version - if pack_version and existing_pack_version != pack_version \ - and dep_pack not in conflict_list: + if existing_pack_version and not existing_pack_version.startswith( + "v" + ): + existing_pack_version = "v" + existing_pack_version + if pack_version and not pack_version.startswith("v"): + pack_version = "v" + pack_version + if ( + pack_version + and existing_pack_version != pack_version + and dep_pack not in conflict_list + ): conflict_list.append(dep_pack) else: - conflict = self.check_dependency_list_for_conflict(name_or_url, pack_version, - dependency_list) + conflict = self.check_dependency_list_for_conflict( + name_or_url, pack_version, dependency_list + ) if conflict: conflict_list.append(dep_pack) elif dep_pack not in dependency_list: dependency_list.append(dep_pack) - result['dependency_list'] = dependency_list - result['conflict_list'] = conflict_list - result['nested'] = nested - 1 + result["dependency_list"] = dependency_list + result["conflict_list"] = conflict_list + result["nested"] = nested - 1 return result @@ -112,7 +123,7 @@ def get_pack_version(pack=None): pack_path = get_pack_base_path(pack) try: pack_metadata = get_pack_metadata(pack_dir=pack_path) - result = pack_metadata.get('version', None) + result = pack_metadata.get("version", None) except Exception: result = None finally: @@ -124,9 +135,9 @@ def get_dependency_list(pack=None): try: pack_metadata = get_pack_metadata(pack_dir=pack_path) - result = pack_metadata.get('dependencies', None) + result = pack_metadata.get("dependencies", None) except Exception: - print('Could not open pack.yaml at location %s' % pack_path) + print("Could not open pack.yaml at location %s" % pack_path) result = None finally: return result diff --git a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py index 445a5df0c2..e8f42dcbb6 100755 --- a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py +++ b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py @@ -34,7 +34,7 @@ def run(self, packs_status): return result for pack, status in six.iteritems(packs_status): - if 'success' not in status.lower(): + if "success" not in status.lower(): continue warning = get_warnings(pack) @@ -42,7 +42,7 @@ def run(self, packs_status): if warning: warning_list.append(warning) - result['warning_list'] = warning_list + result["warning_list"] = warning_list return result @@ -54,6 +54,6 @@ def get_warnings(pack=None): pack_metadata = get_pack_metadata(pack_dir=pack_path) result = get_pack_warnings(pack_metadata) except Exception: - print('Could not open pack.yaml at location %s' % pack_path) + print("Could not open pack.yaml at location %s" % pack_path) finally: return result diff --git a/contrib/packs/actions/pack_mgmt/register.py b/contrib/packs/actions/pack_mgmt/register.py index 220962f0f4..1587333d5b 100644 --- a/contrib/packs/actions/pack_mgmt/register.py +++ b/contrib/packs/actions/pack_mgmt/register.py @@ -19,21 +19,19 @@ from st2client.models.keyvalue import KeyValuePair # pylint: disable=no-name-in-module from st2common.runners.base_action import Action -__all__ = [ - 'St2RegisterAction' -] +__all__ = ["St2RegisterAction"] COMPATIBILITY_TRANSFORMATIONS = { - 'runners': 'runner', - 'triggers': 'trigger', - 'sensors': 'sensor', - 'actions': 'action', - 'rules': 'rule', - 'rule_types': 'rule_type', - 'aliases': 'alias', - 'policiy_types': 'policy_type', - 'policies': 'policy', - 'configs': 'config', + "runners": "runner", + "triggers": "trigger", + "sensors": "sensor", + "actions": "action", + "rules": "rule", + "rule_types": "rule_type", + "aliases": "alias", + "policiy_types": "policy_type", + "policies": "policy", + "configs": "config", } @@ -63,23 +61,23 @@ def __init__(self, config): def run(self, register, packs=None): types = [] - for type in register.split(','): + for type in register.split(","): if type in COMPATIBILITY_TRANSFORMATIONS: types.append(COMPATIBILITY_TRANSFORMATIONS[type]) else: types.append(type) - method_kwargs = { - 'types': types - } + method_kwargs = {"types": types} packs.reverse() if packs: - method_kwargs['packs'] = packs + method_kwargs["packs"] = packs - result = self._run_client_method(method=self.client.packs.register, - method_kwargs=method_kwargs, - format_func=format_result) + result = self._run_client_method( + method=self.client.packs.register, + method_kwargs=method_kwargs, + format_func=format_result, + ) # TODO: make sure to return proper model return result @@ -90,42 +88,48 @@ def _get_client(self): client_kwargs = {} if cacert: - client_kwargs['cacert'] = cacert + client_kwargs["cacert"] = cacert - return self._client(base_url=base_url, api_url=api_url, - auth_url=auth_url, token=token, - **client_kwargs) + return self._client( + base_url=base_url, + api_url=api_url, + auth_url=auth_url, + token=token, + **client_kwargs, + ) def _get_st2_urls(self): # First try to use base_url from config. - base_url = self.config.get('base_url', None) - api_url = self.config.get('api_url', None) - auth_url = self.config.get('auth_url', None) + base_url = self.config.get("base_url", None) + api_url = self.config.get("api_url", None) + auth_url = self.config.get("auth_url", None) # not found look up from env vars. Assuming the pack is # configuered to work with current StackStorm instance. if not base_url: - api_url = os.environ.get('ST2_ACTION_API_URL', None) - auth_url = os.environ.get('ST2_ACTION_AUTH_URL', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + auth_url = os.environ.get("ST2_ACTION_AUTH_URL", None) return base_url, api_url, auth_url def _get_auth_token(self): # First try to use auth_token from config. - token = self.config.get('auth_token', None) + token = self.config.get("auth_token", None) # not found look up from env vars. Assuming the pack is # configuered to work with current StackStorm instance. if not token: - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) return token def _get_cacert(self): - cacert = self.config.get('cacert', None) + cacert = self.config.get("cacert", None) return cacert - def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=None): + def _run_client_method( + self, method, method_kwargs, format_func, format_kwargs=None + ): """ Run the provided client method and format the result. @@ -144,8 +148,9 @@ def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=N # This is a work around since the default values can only be strings method_kwargs = filter_none_values(method_kwargs) method_name = method.__name__ - self.logger.debug('Calling client method "%s" with kwargs "%s"' % (method_name, - method_kwargs)) + self.logger.debug( + 'Calling client method "%s" with kwargs "%s"' % (method_name, method_kwargs) + ) result = method(**method_kwargs) result = format_func(result, **format_kwargs or {}) diff --git a/contrib/packs/actions/pack_mgmt/search.py b/contrib/packs/actions/pack_mgmt/search.py index b7cb07f7fc..dd732c1b29 100644 --- a/contrib/packs/actions/pack_mgmt/search.py +++ b/contrib/packs/actions/pack_mgmt/search.py @@ -22,43 +22,51 @@ class PackSearch(Action): def __init__(self, config=None, action_service=None): super(PackSearch, self).__init__(config=config, action_service=action_service) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy """"Search for packs in StackStorm Exchange and other directories.""" + def run(self, query): """ :param query: A word or a phrase to search for :type query: ``str`` """ - self.logger.debug('Proxy config: %s', self.proxy_config) + self.logger.debug("Proxy config: %s", self.proxy_config) return search_pack_index(query, proxy_config=self.proxy_config) diff --git a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py index 23f8a75ef7..bf7a32ed7e 100644 --- a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py +++ b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py @@ -18,9 +18,7 @@ from st2common.runners.base_action import Action from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'SetupVirtualEnvironmentAction' -] +__all__ = ["SetupVirtualEnvironmentAction"] class SetupVirtualEnvironmentAction(Action): @@ -37,42 +35,50 @@ class SetupVirtualEnvironmentAction(Action): creation of the virtual environment and performs an update of the current dependencies as well as an installation of new dependencies """ + def __init__(self, config=None, action_service=None): super(SetupVirtualEnvironmentAction, self).__init__( - config=config, - action_service=action_service) + config=config, action_service=action_service + ) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy def run(self, packs, update=False, no_download=True): """ @@ -84,10 +90,15 @@ def run(self, packs, update=False, no_download=True): """ for pack_name in packs: - setup_pack_virtualenv(pack_name=pack_name, update=update, logger=self.logger, - proxy_config=self.proxy_config, - no_download=no_download) - - message = ('Successfully set up virtualenv for the following packs: %s' % - (', '.join(packs))) + setup_pack_virtualenv( + pack_name=pack_name, + update=update, + logger=self.logger, + proxy_config=self.proxy_config, + no_download=no_download, + ) + + message = "Successfully set up virtualenv for the following packs: %s" % ( + ", ".join(packs) + ) return message diff --git a/contrib/packs/actions/pack_mgmt/show_remote.py b/contrib/packs/actions/pack_mgmt/show_remote.py index ba5bff8141..6b2f655594 100644 --- a/contrib/packs/actions/pack_mgmt/show_remote.py +++ b/contrib/packs/actions/pack_mgmt/show_remote.py @@ -19,11 +19,10 @@ class ShowRemote(Action): """Get detailed information about an available pack from the StackStorm Exchange index""" + def run(self, pack): """ :param pack: Pack Name to get info about :type pack: ``str`` """ - return { - 'pack': get_pack_from_index(pack) - } + return {"pack": get_pack_from_index(pack)} diff --git a/contrib/packs/actions/pack_mgmt/unload.py b/contrib/packs/actions/pack_mgmt/unload.py index c72cdf9ce1..46caf9cc7a 100644 --- a/contrib/packs/actions/pack_mgmt/unload.py +++ b/contrib/packs/actions/pack_mgmt/unload.py @@ -36,31 +36,48 @@ class UnregisterPackAction(BaseAction): def __init__(self, config=None, action_service=None): - super(UnregisterPackAction, self).__init__(config=config, action_service=action_service) + super(UnregisterPackAction, self).__init__( + config=config, action_service=action_service + ) self.initialize() def initialize(self): # 1. Setup db connection - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None - db_setup(cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, - ssl=cfg.CONF.database.ssl, - ssl_keyfile=cfg.CONF.database.ssl_keyfile, - ssl_certfile=cfg.CONF.database.ssl_certfile, - ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, - authentication_mechanism=cfg.CONF.database.authentication_mechanism, - ssl_match_hostname=cfg.CONF.database.ssl_match_hostname) + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) + db_setup( + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ssl=cfg.CONF.database.ssl, + ssl_keyfile=cfg.CONF.database.ssl_keyfile, + ssl_certfile=cfg.CONF.database.ssl_certfile, + ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, + authentication_mechanism=cfg.CONF.database.authentication_mechanism, + ssl_match_hostname=cfg.CONF.database.ssl_match_hostname, + ) def run(self, packs): intersection = BLOCKED_PACKS & frozenset(packs) if len(intersection) > 0: - names = ', '.join(list(intersection)) - raise ValueError('Unregister includes an unregisterable pack - %s.' % (names)) + names = ", ".join(list(intersection)) + raise ValueError( + "Unregister includes an unregisterable pack - %s." % (names) + ) for pack in packs: - self.logger.debug('Removing pack %s.', pack) + self.logger.debug("Removing pack %s.", pack) self._unregister_sensors(pack=pack) self._unregister_trigger_types(pack=pack) self._unregister_triggers(pack=pack) @@ -69,21 +86,27 @@ def run(self, packs): self._unregister_aliases(pack=pack) self._unregister_policies(pack=pack) self._unregister_pack(pack=pack) - self.logger.info('Removed pack %s.', pack) + self.logger.info("Removed pack %s.", pack) def _unregister_sensors(self, pack): return self._delete_pack_db_objects(pack=pack, access_cls=SensorType) def _unregister_trigger_types(self, pack): - deleted_trigger_types_dbs = self._delete_pack_db_objects(pack=pack, access_cls=TriggerType) + deleted_trigger_types_dbs = self._delete_pack_db_objects( + pack=pack, access_cls=TriggerType + ) # 2. Check if deleted trigger is used by any other rules outside this pack for trigger_type_db in deleted_trigger_types_dbs: - rule_dbs = Rule.query(trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack) + rule_dbs = Rule.query( + trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack + ) for rule_db in rule_dbs: - self.logger.warning('Rule "%s" references deleted trigger "%s"' % - (rule_db.name, trigger_type_db.ref)) + self.logger.warning( + 'Rule "%s" references deleted trigger "%s"' + % (rule_db.name, trigger_type_db.ref) + ) return deleted_trigger_types_dbs @@ -136,25 +159,25 @@ def _delete_pack_db_object(self, pack): pack_db = None if not pack_db: - self.logger.exception('Pack DB object not found') + self.logger.exception("Pack DB object not found") return try: Pack.delete(pack_db) except: - self.logger.exception('Failed to remove DB object %s.', pack_db) + self.logger.exception("Failed to remove DB object %s.", pack_db) def _delete_config_schema_db_object(self, pack): try: config_schema_db = ConfigSchema.get_by_pack(value=pack) except StackStormDBObjectNotFoundError: - self.logger.exception('ConfigSchemaDB object not found') + self.logger.exception("ConfigSchemaDB object not found") return try: ConfigSchema.delete(config_schema_db) except: - self.logger.exception('Failed to remove DB object %s.', config_schema_db) + self.logger.exception("Failed to remove DB object %s.", config_schema_db) def _delete_pack_db_objects(self, pack, access_cls): db_objs = access_cls.get_all(pack=pack) @@ -166,6 +189,6 @@ def _delete_pack_db_objects(self, pack, access_cls): access_cls.delete(db_obj) deleted_objs.append(db_obj) except: - self.logger.exception('Failed to remove DB object %s.', db_obj) + self.logger.exception("Failed to remove DB object %s.", db_obj) return deleted_objs diff --git a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py index aedc993f6b..abde082ed3 100644 --- a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py +++ b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py @@ -32,7 +32,7 @@ def run(self, packs_status, packs_list=None): packs = [] for pack_name, status in six.iteritems(packs_status): - if 'success' in status.lower(): + if "success" in status.lower(): packs.append(pack_name) packs_list.extend(packs) diff --git a/contrib/packs/tests/test_action_aliases.py b/contrib/packs/tests/test_action_aliases.py index 858a167751..ecfebe8b68 100644 --- a/contrib/packs/tests/test_action_aliases.py +++ b/contrib/packs/tests/test_action_aliases.py @@ -19,73 +19,65 @@ class PackGet(BaseActionAliasTestCase): action_alias_name = "pack_get" def test_alias_pack_get(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack get st2" - expected_parameters = { - 'pack': "st2" - } + expected_parameters = {"pack": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) class PackInstall(BaseActionAliasTestCase): action_alias_name = "pack_install" def test_alias_pack_install(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] command = "pack install st2" - expected_parameters = { - 'packs': "st2" - } + expected_parameters = {"packs": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) class PackSearch(BaseActionAliasTestCase): action_alias_name = "pack_search" def test_alias_pack_search(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack search st2" - expected_parameters = { - 'query': "st2" - } + expected_parameters = {"query": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) class PackShow(BaseActionAliasTestCase): action_alias_name = "pack_show" def test_alias_pack_show(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack show st2" - expected_parameters = { - 'pack': "st2" - } + expected_parameters = {"pack": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) diff --git a/contrib/packs/tests/test_action_download.py b/contrib/packs/tests/test_action_download.py index 3eeda00886..c29e95fccc 100644 --- a/contrib/packs/tests/test_action_download.py +++ b/contrib/packs/tests/test_action_download.py @@ -22,6 +22,7 @@ import hashlib from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from lockfile import LockFile @@ -46,7 +47,7 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", }, "test2": { "version": "0.5.0", @@ -55,7 +56,7 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", }, "test3": { "version": "0.5.0", @@ -65,16 +66,17 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", }, "test4": { "version": "0.5.0", "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" - } + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", + "description": "another st2 pack to test package management pipeline", + }, } @@ -85,7 +87,7 @@ def mock_is_dir_func(path): """ Mock function which returns True if path ends with .git """ - if path.endswith('.git'): + if path.endswith(".git"): return True return original_is_dir_func(path) @@ -95,9 +97,9 @@ def mock_get_gitref(repo, ref): Mock get_gitref function which return mocked object if ref passed is PACK_INDEX['test']['version'] """ - if PACK_INDEX['test']['version'] in ref: - if ref[0] == 'v': - return mock.MagicMock(hexsha=PACK_INDEX['test']['version']) + if PACK_INDEX["test"]["version"] in ref: + if ref[0] == "v": + return mock.MagicMock(hexsha=PACK_INDEX["test"]["version"]) else: return None elif ref: @@ -106,21 +108,24 @@ def mock_get_gitref(repo, ref): return None -@mock.patch.object(pack_service, 'fetch_pack_index', mock.MagicMock(return_value=(PACK_INDEX, {}))) +@mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) +) class DownloadGitRepoActionTestCase(BaseActionTestCase): action_cls = DownloadGitRepoAction def setUp(self): super(DownloadGitRepoActionTestCase, self).setUp() - clone_from = mock.patch.object(Repo, 'clone_from') + clone_from = mock.patch.object(Repo, "clone_from") self.addCleanup(clone_from.stop) self.clone_from = clone_from.start() self.expand_user_path = tempfile.mkdtemp() - expand_user = mock.patch.object(os.path, 'expanduser', - mock.MagicMock(return_value=self.expand_user_path)) + expand_user = mock.patch.object( + os.path, "expanduser", mock.MagicMock(return_value=self.expand_user_path) + ) self.addCleanup(expand_user.stop) self.expand_user = expand_user.start() @@ -132,8 +137,10 @@ def setUp(self): def side_effect(url, to_path, **kwargs): # Since we have no way to pass pack name here, we would have to derive it from repo url - fixture_name = url.split('/')[-1] - fixture_path = os.path.join(self._get_base_pack_path(), 'tests/fixtures', fixture_name) + fixture_name = url.split("/")[-1] + fixture_path = os.path.join( + self._get_base_pack_path(), "tests/fixtures", fixture_name + ) shutil.copytree(fixture_path, to_path) return self.repo_instance @@ -145,13 +152,15 @@ def tearDown(self): def test_run_pack_download(self): action = self.get_action_instance() - result = action.run(packs=['test'], abs_repo_base=self.repo_base) - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + result = action.run(packs=["test"], abs_repo_base=self.repo_base) + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() - self.assertEqual(result, {'test': 'Success.'}) - self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dir)) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + self.assertEqual(result, {"test": "Success."}) + self.clone_from.assert_called_once_with( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dir), + ) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) self.repo_instance.git.checkout.assert_called() self.repo_instance.git.branch.assert_called() @@ -159,65 +168,81 @@ def test_run_pack_download(self): def test_run_pack_download_dependencies(self): action = self.get_action_instance() - result = action.run(packs=['test'], dependency_list=['test2', 'test4'], - abs_repo_base=self.repo_base) + result = action.run( + packs=["test"], + dependency_list=["test2", "test4"], + abs_repo_base=self.repo_base, + ) temp_dirs = [ - hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest(), - hashlib.md5(PACK_INDEX['test4']['repo_url'].encode()).hexdigest() + hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(), + hashlib.md5(PACK_INDEX["test4"]["repo_url"].encode()).hexdigest(), ] - self.assertEqual(result, {'test2': 'Success.', 'test4': 'Success.'}) - self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[0])) - self.clone_from.assert_any_call(PACK_INDEX['test4']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[1])) + self.assertEqual(result, {"test2": "Success.", "test4": "Success."}) + self.clone_from.assert_any_call( + PACK_INDEX["test2"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[0]), + ) + self.clone_from.assert_any_call( + PACK_INDEX["test4"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[1]), + ) self.assertEqual(self.clone_from.call_count, 2) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml'))) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test4/pack.yaml'))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml"))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test4/pack.yaml"))) def test_run_pack_download_existing_pack(self): action = self.get_action_instance() - action.run(packs=['test'], abs_repo_base=self.repo_base) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + action.run(packs=["test"], abs_repo_base=self.repo_base) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) - result = action.run(packs=['test'], abs_repo_base=self.repo_base) + result = action.run(packs=["test"], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) def test_run_pack_download_multiple_packs(self): action = self.get_action_instance() - result = action.run(packs=['test', 'test2'], abs_repo_base=self.repo_base) + result = action.run(packs=["test", "test2"], abs_repo_base=self.repo_base) temp_dirs = [ - hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest(), - hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest() + hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest(), + hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(), ] - self.assertEqual(result, {'test': 'Success.', 'test2': 'Success.'}) - self.clone_from.assert_any_call(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[0])) - self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[1])) + self.assertEqual(result, {"test": "Success.", "test2": "Success."}) + self.clone_from.assert_any_call( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[0]), + ) + self.clone_from.assert_any_call( + PACK_INDEX["test2"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[1]), + ) self.assertEqual(self.clone_from.call_count, 2) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml'))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml"))) - @mock.patch.object(Repo, 'clone_from') + @mock.patch.object(Repo, "clone_from") def test_run_pack_download_error(self, clone_from): - clone_from.side_effect = Exception('Something went terribly wrong during the clone') + clone_from.side_effect = Exception( + "Something went terribly wrong during the clone" + ) action = self.get_action_instance() - self.assertRaises(Exception, action.run, packs=['test'], abs_repo_base=self.repo_base) + self.assertRaises( + Exception, action.run, packs=["test"], abs_repo_base=self.repo_base + ) def test_run_pack_download_no_tag(self): self.repo_instance.commit.side_effect = BadName action = self.get_action_instance() - self.assertRaises(ValueError, action.run, packs=['test=1.2.3'], - abs_repo_base=self.repo_base) + self.assertRaises( + ValueError, action.run, packs=["test=1.2.3"], abs_repo_base=self.repo_base + ) def test_run_pack_lock_is_already_acquired(self): action = self.get_action_instance() - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() original_acquire = LockFile.acquire @@ -227,15 +252,20 @@ def mock_acquire(self, timeout=None): LockFile.acquire = mock_acquire try: - lock_file = LockFile('/tmp/%s' % (temp_dir)) + lock_file = LockFile("/tmp/%s" % (temp_dir)) # Acquire a lock (file) so acquire inside download will fail - with open(lock_file.lock_file, 'w') as fp: - fp.write('') - - expected_msg = 'Timeout waiting to acquire lock for' - self.assertRaisesRegexp(LockTimeout, expected_msg, action.run, packs=['test'], - abs_repo_base=self.repo_base) + with open(lock_file.lock_file, "w") as fp: + fp.write("") + + expected_msg = "Timeout waiting to acquire lock for" + self.assertRaisesRegexp( + LockTimeout, + expected_msg, + action.run, + packs=["test"], + abs_repo_base=self.repo_base, + ) finally: os.unlink(lock_file.lock_file) LockFile.acquire = original_acquire @@ -243,7 +273,7 @@ def mock_acquire(self, timeout=None): def test_run_pack_lock_is_already_acquired_force_flag(self): # Lock is already acquired but force is true so it should be deleted and released action = self.get_action_instance() - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() original_acquire = LockFile.acquire @@ -253,194 +283,266 @@ def mock_acquire(self, timeout=None): LockFile.acquire = mock_acquire try: - lock_file = LockFile('/tmp/%s' % (temp_dir)) + lock_file = LockFile("/tmp/%s" % (temp_dir)) # Acquire a lock (file) so acquire inside download will fail - with open(lock_file.lock_file, 'w') as fp: - fp.write('') + with open(lock_file.lock_file, "w") as fp: + fp.write("") - result = action.run(packs=['test'], abs_repo_base=self.repo_base, force=True) + result = action.run( + packs=["test"], abs_repo_base=self.repo_base, force=True + ) finally: LockFile.acquire = original_acquire - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) def test_run_pack_download_v_tag(self): def side_effect(ref): - if ref[0] != 'v': + if ref[0] != "v": raise BadName() - return mock.MagicMock(hexsha='abcdef') + return mock.MagicMock(hexsha="abcdef") self.repo_instance.commit.side_effect = side_effect self.repo_instance.git = mock.MagicMock( - branch=(lambda *args: 'master'), - checkout=(lambda *args: True) + branch=(lambda *args: "master"), checkout=(lambda *args: True) ) action = self.get_action_instance() - result = action.run(packs=['test=1.2.3'], abs_repo_base=self.repo_base) + result = action.run(packs=["test=1.2.3"], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) - @mock.patch.object(st2common.util.pack_management, 'get_valid_versions_for_repo', - mock.Mock(return_value=['1.0.0', '2.0.0'])) + @mock.patch.object( + st2common.util.pack_management, + "get_valid_versions_for_repo", + mock.Mock(return_value=["1.0.0", "2.0.0"]), + ) def test_run_pack_download_invalid_version(self): self.repo_instance.commit.side_effect = lambda ref: None action = self.get_action_instance() - expected_msg = ('is not a valid version, hash, tag or branch.*?' - 'Available versions are: 1.0.0, 2.0.0.') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test=2.2.3'], abs_repo_base=self.repo_base) + expected_msg = ( + "is not a valid version, hash, tag or branch.*?" + "Available versions are: 1.0.0, 2.0.0." + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test=2.2.3"], + abs_repo_base=self.repo_base, + ) def test_download_pack_stackstorm_version_identifier_check(self): action = self.get_action_instance() # Version is satisfied - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.0.0' + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.0.0" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base) - self.assertEqual(result['test3'], 'Success.') + result = action.run(packs=["test3"], abs_repo_base=self.repo_base) + self.assertEqual(result["test3"], "Success.") # Pack requires a version which is not satisfied by current StackStorm version - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.2.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "2.2.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.3.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "2.3.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.9' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "1.5.9"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "1.5.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.2.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "2.2.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.3.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "2.3.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.9" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "1.5.9"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "1.5.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) # Version is not met, but force=true parameter is provided - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0' - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=True) - self.assertEqual(result['test3'], 'Success.') + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0" + result = action.run(packs=["test3"], abs_repo_base=self.repo_base, force=True) + self.assertEqual(result["test3"], "Success.") def test_download_pack_python_version_check(self): action = self.get_action_instance() # No python_versions attribute specified in the metadata file - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': [] + "name": "test3", + "stackstorm_version": "", + "python_versions": [], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.11' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.11" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") # Pack works with Python 2.x installation is running 2.7 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.12' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.12" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") # Pack works with Python 2.x installation is running 3.5 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2"], } st2common.util.pack_management.six.PY2 = False st2common.util.pack_management.six.PY3 = True - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.5.2' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.5.2" - expected_msg = (r'Pack "test3" requires Python 2.x, but current Python version is ' - '"3.5.2"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test3'], abs_repo_base=self.repo_base, force=False) + expected_msg = ( + r'Pack "test3" requires Python 2.x, but current Python version is ' + '"3.5.2"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + force=False, + ) # Pack works with Python 3.x installation is running 2.7 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['3'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["3"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.2' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.2" - expected_msg = (r'Pack "test3" requires Python 3.x, but current Python version is ' - '"2.7.2"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test3'], abs_repo_base=self.repo_base, force=False) + expected_msg = ( + r'Pack "test3" requires Python 3.x, but current Python version is ' + '"2.7.2"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + force=False, + ) # Pack works with Python 2.x and 3.x installation is running 2.7 and 3.6.1 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2', '3'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2", "3"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") st2common.util.pack_management.six.PY2 = False st2common.util.pack_management.six.PY3 = True - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.6.1' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.6.1" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") def test_resolve_urls(self): - url = eval_repo_url( - "https://github.com/StackStorm-Exchange/stackstorm-test") + url = eval_repo_url("https://github.com/StackStorm-Exchange/stackstorm-test") self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test") url = eval_repo_url( - "https://github.com/StackStorm-Exchange/stackstorm-test.git") - self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test.git") + "https://github.com/StackStorm-Exchange/stackstorm-test.git" + ) + self.assertEqual( + url, "https://github.com/StackStorm-Exchange/stackstorm-test.git" + ) url = eval_repo_url("StackStorm-Exchange/stackstorm-test") self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test") @@ -460,11 +562,11 @@ def test_resolve_urls(self): url = eval_repo_url("file://localhost/home/vagrant/stackstorm-test") self.assertEqual(url, "file://localhost/home/vagrant/stackstorm-test") - url = eval_repo_url('ssh:///AutomationStackStorm') - self.assertEqual(url, 'ssh:///AutomationStackStorm') + url = eval_repo_url("ssh:///AutomationStackStorm") + self.assertEqual(url, "ssh:///AutomationStackStorm") - url = eval_repo_url('ssh://joe@local/AutomationStackStorm') - self.assertEqual(url, 'ssh://joe@local/AutomationStackStorm') + url = eval_repo_url("ssh://joe@local/AutomationStackStorm") + self.assertEqual(url, "ssh://joe@local/AutomationStackStorm") def test_run_pack_download_edge_cases(self): """ @@ -479,36 +581,35 @@ def test_run_pack_download_edge_cases(self): """ def side_effect(ref): - if ref[0] != 'v': + if ref[0] != "v": raise BadName() - return mock.MagicMock(hexsha='abcdeF') + return mock.MagicMock(hexsha="abcdeF") self.repo_instance.commit.side_effect = side_effect edge_cases = [ - ('master', '1.2.3'), - ('master', 'some-branch'), - ('master', 'default-branch'), - ('master', None), - ('default-branch', '1.2.3'), - ('default-branch', 'some-branch'), - ('default-branch', 'default-branch'), - ('default-branch', None) + ("master", "1.2.3"), + ("master", "some-branch"), + ("master", "default-branch"), + ("master", None), + ("default-branch", "1.2.3"), + ("default-branch", "some-branch"), + ("default-branch", "default-branch"), + ("default-branch", None), ] for default_branch, ref in edge_cases: self.repo_instance.git = mock.MagicMock( - branch=(lambda *args: default_branch), - checkout=(lambda *args: True) + branch=(lambda *args: default_branch), checkout=(lambda *args: True) ) # Set default branch self.repo_instance.active_branch.name = default_branch - self.repo_instance.active_branch.object = 'aBcdef' - self.repo_instance.head.commit = 'aBcdef' + self.repo_instance.active_branch.object = "aBcdef" + self.repo_instance.head.commit = "aBcdef" # Fake gitref object - gitref = mock.MagicMock(hexsha='abcDef') + gitref = mock.MagicMock(hexsha="abcDef") # Fool _get_gitref into working when its ref == our ref def fake_commit(arg_ref): @@ -516,30 +617,34 @@ def fake_commit(arg_ref): return gitref else: raise BadName() + self.repo_instance.commit = fake_commit self.repo_instance.active_branch.object = gitref action = self.get_action_instance() if ref: - packs = ['test=%s' % (ref)] + packs = ["test=%s" % (ref)] else: - packs = ['test'] + packs = ["test"] result = action.run(packs=packs, abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) - @mock.patch('os.path.isdir', mock_is_dir_func) + @mock.patch("os.path.isdir", mock_is_dir_func) def test_run_pack_dowload_local_git_repo_detached_head_state(self): action = self.get_action_instance() - type(self.repo_instance).active_branch = \ - mock.PropertyMock(side_effect=TypeError('detached head')) + type(self.repo_instance).active_branch = mock.PropertyMock( + side_effect=TypeError("detached head") + ) - pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test') + pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test") - result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + result = action.run( + packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base + ) + self.assertEqual(result, {"test": "Success."}) # Verify function has bailed out early self.repo_instance.git.checkout.assert_not_called() @@ -551,41 +656,55 @@ def test_run_pack_download_local_directory(self): # 1. Local directory doesn't exist expected_msg = r'Local pack directory ".*" doesn\'t exist' - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['file://doesnt_exist'], abs_repo_base=self.repo_base) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["file://doesnt_exist"], + abs_repo_base=self.repo_base, + ) # 2. Local pack which is not a git repository - pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test4') + pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test4") - result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test4': 'Success.'}) + result = action.run( + packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base + ) + self.assertEqual(result, {"test4": "Success."}) # Verify pack contents have been copied over - destination_path = os.path.join(self.repo_base, 'test4') + destination_path = os.path.join(self.repo_base, "test4") self.assertTrue(os.path.exists(destination_path)) - self.assertTrue(os.path.exists(os.path.join(destination_path, 'pack.yaml'))) + self.assertTrue(os.path.exists(os.path.join(destination_path, "pack.yaml"))) - @mock.patch('st2common.util.pack_management.get_gitref', mock_get_gitref) + @mock.patch("st2common.util.pack_management.get_gitref", mock_get_gitref) def test_run_pack_download_with_tag(self): action = self.get_action_instance() - result = action.run(packs=['test'], abs_repo_base=self.repo_base) - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + result = action.run(packs=["test"], abs_repo_base=self.repo_base) + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() - self.assertEqual(result, {'test': 'Success.'}) - self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dir)) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + self.assertEqual(result, {"test": "Success."}) + self.clone_from.assert_called_once_with( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dir), + ) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) # Check repo.git.checkout is called three times self.assertEqual(self.repo_instance.git.checkout.call_count, 3) # Check repo.git.checkout called with latest tag or branch - self.assertEqual(PACK_INDEX['test']['version'], - self.repo_instance.git.checkout.call_args_list[1][0][0]) + self.assertEqual( + PACK_INDEX["test"]["version"], + self.repo_instance.git.checkout.call_args_list[1][0][0], + ) # Check repo.git.checkout called with head - self.assertEqual(self.repo_instance.head.reference, - self.repo_instance.git.checkout.call_args_list[2][0][0]) + self.assertEqual( + self.repo_instance.head.reference, + self.repo_instance.git.checkout.call_args_list[2][0][0], + ) self.repo_instance.git.branch.assert_called_with( - '-f', self.repo_instance.head.reference, PACK_INDEX['test']['version']) + "-f", self.repo_instance.head.reference, PACK_INDEX["test"]["version"] + ) diff --git a/contrib/packs/tests/test_action_unload.py b/contrib/packs/tests/test_action_unload.py index 5e642483d4..fc07ff87c3 100644 --- a/contrib/packs/tests/test_action_unload.py +++ b/contrib/packs/tests/test_action_unload.py @@ -20,6 +20,7 @@ from oslo_config import cfg from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from st2common.content.bootstrap import register_content @@ -39,11 +40,11 @@ from pack_mgmt.unload import UnregisterPackAction -__all__ = [ - 'UnloadActionTestCase' -] +__all__ = ["UnloadActionTestCase"] -PACK_PATH_1 = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1') +PACK_PATH_1 = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) class UnloadActionTestCase(BaseActionTestCase, CleanDbTestCase): @@ -64,13 +65,15 @@ def setUp(self): # Register the pack with all the content # TODO: Don't use pack cache - cfg.CONF.set_override(name='all', override=True, group='register') - cfg.CONF.set_override(name='pack', override=PACK_PATH_1, group='register') - cfg.CONF.set_override(name='no_fail_on_failure', override=True, group='register') + cfg.CONF.set_override(name="all", override=True, group="register") + cfg.CONF.set_override(name="pack", override=PACK_PATH_1, group="register") + cfg.CONF.set_override( + name="no_fail_on_failure", override=True, group="register" + ) register_content() def test_run(self): - pack = 'dummy_pack_1' + pack = "dummy_pack_1" # Verify all the resources are there pack_dbs = Pack.query(ref=pack) diff --git a/contrib/packs/tests/test_get_pack_dependencies.py b/contrib/packs/tests/test_get_pack_dependencies.py index e047d7fca4..a90f940638 100644 --- a/contrib/packs/tests/test_get_pack_dependencies.py +++ b/contrib/packs/tests/test_get_pack_dependencies.py @@ -21,21 +21,20 @@ from pack_mgmt.get_pack_dependencies import GetPackDependencies -UNINSTALLED_PACK = 'uninstalled_pack' +UNINSTALLED_PACK = "uninstalled_pack" UNINSTALLED_PACKS = [ UNINSTALLED_PACK, - 'https://github.com/StackStorm-Exchange/stackstorm-pack1', - 'https://github.com/StackStorm-Exchange/stackstorm-pack2.git', - 'https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1', - 'StackStorm-Exchange/stackstorm-pack4', - 'git://StackStorm-Exchange/stackstorm-pack5=v2.1.1', - 'git://StackStorm-Exchange/stackstorm-pack6.git', - 'git@github.com:foo/pack7.git' - 'git@github.com:foo/pack8.git=v3.2.1', - 'file:///home/vagrant/stackstorm-pack9', - 'file://localhost/home/vagrant/stackstorm-pack10', - 'ssh:///AutomationStackStorm11', - 'ssh://joe@local/AutomationStackStorm12' + "https://github.com/StackStorm-Exchange/stackstorm-pack1", + "https://github.com/StackStorm-Exchange/stackstorm-pack2.git", + "https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1", + "StackStorm-Exchange/stackstorm-pack4", + "git://StackStorm-Exchange/stackstorm-pack5=v2.1.1", + "git://StackStorm-Exchange/stackstorm-pack6.git", + "git@github.com:foo/pack7.git" "git@github.com:foo/pack8.git=v3.2.1", + "file:///home/vagrant/stackstorm-pack9", + "file://localhost/home/vagrant/stackstorm-pack10", + "ssh:///AutomationStackStorm11", + "ssh://joe@local/AutomationStackStorm12", ] DOWNLOADED_OR_INSTALLED_PACK_METAdATA = { @@ -58,7 +57,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ['uninstalled_pack', 'no_dependencies'] + "dependencies": ["uninstalled_pack", "no_dependencies"], }, # List of uninstalled dependency packs. "test3": { @@ -70,7 +69,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": UNINSTALLED_PACKS + "dependencies": UNINSTALLED_PACKS, }, # One conflict pack with existing pack. "test4": { @@ -82,9 +81,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": [ - "test2=v0.4.0" - ] + "dependencies": ["test2=v0.4.0"], }, # One uninstalled conflict pack. "test5": { @@ -93,9 +90,10 @@ "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ["uninstalled_pack=v0.4.0"] + "dependencies": ["uninstalled_pack=v0.4.0"], }, # One dependency pack without version. It is not checked against conflict. "test6": { @@ -104,10 +102,11 @@ "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ["test2"] - } + "dependencies": ["test2"], + }, } @@ -119,7 +118,7 @@ def mock_get_dependency_list(pack): if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA: metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack] - dependencies = metadata.get('dependencies', None) + dependencies = metadata.get("dependencies", None) return dependencies @@ -132,13 +131,15 @@ def mock_get_pack_version(pack): if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA: metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack] - version = metadata.get('version', None) + version = metadata.get("version", None) return version -@mock.patch('pack_mgmt.get_pack_dependencies.get_dependency_list', mock_get_dependency_list) -@mock.patch('pack_mgmt.get_pack_dependencies.get_pack_version', mock_get_pack_version) +@mock.patch( + "pack_mgmt.get_pack_dependencies.get_dependency_list", mock_get_dependency_list +) +@mock.patch("pack_mgmt.get_pack_dependencies.get_pack_version", mock_get_pack_version) class GetPackDependenciesTestCase(BaseActionTestCase): action_cls = GetPackDependencies @@ -167,9 +168,9 @@ def test_run_get_pack_dependencies_with_failed_packs_status(self): nested = 2 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], []) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], []) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self): action = self.get_action_instance() @@ -177,9 +178,9 @@ def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self): nested = 2 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_no_dependency(self): action = self.get_action_instance() @@ -187,9 +188,9 @@ def test_run_get_pack_dependencies_with_no_dependency(self): nested = 3 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], []) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], []) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependency(self): action = self.get_action_instance() @@ -197,9 +198,9 @@ def test_run_get_pack_dependencies_with_dependency(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependencies(self): action = self.get_action_instance() @@ -207,9 +208,9 @@ def test_run_get_pack_dependencies_with_dependencies(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], UNINSTALLED_PACKS) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], UNINSTALLED_PACKS) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_existing_pack_conflict(self): action = self.get_action_instance() @@ -217,9 +218,9 @@ def test_run_get_pack_dependencies_with_existing_pack_conflict(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], ['test2=v0.4.0']) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], ["test2=v0.4.0"]) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependency_conflict(self): action = self.get_action_instance() @@ -227,9 +228,9 @@ def test_run_get_pack_dependencies_with_dependency_conflict(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], ['uninstalled_pack']) - self.assertEqual(result['conflict_list'], ['uninstalled_pack=v0.4.0']) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], ["uninstalled_pack"]) + self.assertEqual(result["conflict_list"], ["uninstalled_pack=v0.4.0"]) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_no_version(self): action = self.get_action_instance() @@ -237,6 +238,6 @@ def test_run_get_pack_dependencies_with_no_version(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) diff --git a/contrib/packs/tests/test_get_pack_warnings.py b/contrib/packs/tests/test_get_pack_warnings.py index 49e2d920a8..3eac7ba356 100644 --- a/contrib/packs/tests/test_get_pack_warnings.py +++ b/contrib/packs/tests/test_get_pack_warnings.py @@ -29,7 +29,7 @@ "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", "description": "st2 pack to test package management pipeline", - "python_versions": ["2","3"], + "python_versions": ["2", "3"], }, # Python 3 "py3": { @@ -72,10 +72,11 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "python_versions": ["2"] - } + "python_versions": ["2"], + }, } + def mock_get_pack_basepath(pack): """ Mock get_pack_basepath function which just returns pack n ame @@ -94,8 +95,8 @@ def mock_get_pack_metadata(pack_dir): return metadata -@mock.patch('pack_mgmt.get_pack_warnings.get_pack_base_path', mock_get_pack_basepath) -@mock.patch('pack_mgmt.get_pack_warnings.get_pack_metadata', mock_get_pack_metadata) +@mock.patch("pack_mgmt.get_pack_warnings.get_pack_base_path", mock_get_pack_basepath) +@mock.patch("pack_mgmt.get_pack_warnings.get_pack_metadata", mock_get_pack_metadata) class GetPackWarningsTestCase(BaseActionTestCase): action_cls = GetPackWarnings @@ -107,15 +108,15 @@ def test_run_get_pack_warnings_py3_pack(self): packs_status = {"py3": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_py2_pack(self): action = self.get_action_instance() packs_status = {"py2": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(len(result['warning_list']), 1) - warning = result['warning_list'][0] + self.assertEqual(len(result["warning_list"]), 1) + warning = result["warning_list"][0] self.assertTrue("DEPRECATION WARNING" in warning) self.assertTrue("Pack py2 only supports Python 2" in warning) @@ -124,28 +125,32 @@ def test_run_get_pack_warnings_py23_pack(self): packs_status = {"py23": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_pynone_pack(self): action = self.get_action_instance() packs_status = {"pynone": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_multiple_pack(self): action = self.get_action_instance() - packs_status = {"py2": "Success.", - "py23": "Success.", - "py22": "Success."} + packs_status = {"py2": "Success.", "py23": "Success.", "py22": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(len(result['warning_list']), 2) - warning0 = result['warning_list'][0] - warning1 = result['warning_list'][1] + self.assertEqual(len(result["warning_list"]), 2) + warning0 = result["warning_list"][0] + warning1 = result["warning_list"][1] self.assertTrue("DEPRECATION WARNING" in warning0) self.assertTrue("DEPRECATION WARNING" in warning1) - self.assertTrue(("Pack py2 only supports Python 2" in warning0 and - "Pack py22 only supports Python 2" in warning1) or - ("Pack py22 only supports Python 2" in warning0 and - "Pack py2 only supports Python 2" in warning1)) + self.assertTrue( + ( + "Pack py2 only supports Python 2" in warning0 + and "Pack py22 only supports Python 2" in warning1 + ) + or ( + "Pack py22 only supports Python 2" in warning0 + and "Pack py2 only supports Python 2" in warning1 + ) + ) diff --git a/contrib/packs/tests/test_virtualenv_setup_prerun.py b/contrib/packs/tests/test_virtualenv_setup_prerun.py index 63b27410f6..0097ecd8fe 100644 --- a/contrib/packs/tests/test_virtualenv_setup_prerun.py +++ b/contrib/packs/tests/test_virtualenv_setup_prerun.py @@ -28,21 +28,26 @@ def setUp(self): def test_run_with_pack_list(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'}, - packs_list=['test3', 'test4']) + result = action.run( + packs_status={"test1": "Success.", "test2": "Success."}, + packs_list=["test3", "test4"], + ) - self.assertEqual(result, ['test3', 'test4', 'test1', 'test2']) + self.assertEqual(result, ["test3", "test4", "test1", "test2"]) def test_run_with_none_pack_list(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'}, - packs_list=None) + result = action.run( + packs_status={"test1": "Success.", "test2": "Success."}, packs_list=None + ) - self.assertEqual(result, ['test1', 'test2']) + self.assertEqual(result, ["test1", "test2"]) def test_run_with_failed_status(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Failed.', 'test2': 'Success.'}, - packs_list=['test3', 'test4']) + result = action.run( + packs_status={"test1": "Failed.", "test2": "Success."}, + packs_list=["test3", "test4"], + ) - self.assertEqual(result, ['test3', 'test4', 'test2']) + self.assertEqual(result, ["test3", "test4", "test2"]) diff --git a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py +++ b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py index 39cb873136..e71c12d004 100644 --- a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py +++ b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py @@ -50,26 +50,16 @@ from st2common.util.config_loader import get_config from st2common.util.ujson import fast_deepcopy -__all__ = [ - 'ActionChainRunner', - 'ChainHolder', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ActionChainRunner", "ChainHolder", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RESULTS_KEY = '__results' -JINJA_START_MARKERS = [ - '{{', - '{%' -] -PUBLISHED_VARS_KEY = 'published' +RESULTS_KEY = "__results" +JINJA_START_MARKERS = ["{{", "{%"] +PUBLISHED_VARS_KEY = "published" class ChainHolder(object): - def __init__(self, chainspec, chainname): self.actionchain = actionchain.ActionChain(**chainspec) self.chainname = chainname @@ -78,17 +68,21 @@ def __init__(self, chainspec, chainname): default = self._get_default(self.actionchain) self.actionchain.default = default - LOG.debug('Using %s as default for %s.', self.actionchain.default, self.chainname) + LOG.debug( + "Using %s as default for %s.", self.actionchain.default, self.chainname + ) if not self.actionchain.default: - raise Exception('Failed to find default node in %s.' % (self.chainname)) + raise Exception("Failed to find default node in %s." % (self.chainname)) self.vars = {} def init_vars(self, action_parameters, action_context=None): if self.actionchain.vars: - self.vars = self._get_rendered_vars(self.actionchain.vars, - action_parameters=action_parameters, - action_context=action_context) + self.vars = self._get_rendered_vars( + self.actionchain.vars, + action_parameters=action_parameters, + action_context=action_context, + ) def restore_vars(self, ctx_vars): self.vars.update(fast_deepcopy(ctx_vars)) @@ -107,28 +101,37 @@ def validate(self): on_failure_node_name = node.on_failure # Check "on-success" path - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=on_success_node_name) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=on_success_node_name + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "on-success" in ' - 'task "%s".' % (on_success_node_name, node.name)) + msg = ( + 'Unable to find node with name "%s" referenced in "on-success" in ' + 'task "%s".' % (on_success_node_name, node.name) + ) raise ValueError(msg) # Check "on-failure" path - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=on_failure_node_name) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=on_failure_node_name + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "on-failure" in ' - 'task "%s".' % (on_failure_node_name, node.name)) + msg = ( + 'Unable to find node with name "%s" referenced in "on-failure" in ' + 'task "%s".' % (on_failure_node_name, node.name) + ) raise ValueError(msg) # check if node specified in default is valid. if self.actionchain.default: - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=self.actionchain.default) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=self.actionchain.default + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "default".' % - self.actionchain.default) + msg = ( + 'Unable to find node with name "%s" referenced in "default".' + % self.actionchain.default + ) raise ValueError(msg) return True @@ -147,8 +150,12 @@ def _get_default(action_chain): # 2. There are no fragments in the chain. all_nodes = ChainHolder._get_all_nodes(action_chain=action_chain) node_names = set(all_nodes) - on_success_nodes = ChainHolder._get_all_on_success_nodes(action_chain=action_chain) - on_failure_nodes = ChainHolder._get_all_on_failure_nodes(action_chain=action_chain) + on_success_nodes = ChainHolder._get_all_on_success_nodes( + action_chain=action_chain + ) + on_failure_nodes = ChainHolder._get_all_on_failure_nodes( + action_chain=action_chain + ) referenced_nodes = on_success_nodes | on_failure_nodes possible_default_nodes = node_names - referenced_nodes if possible_default_nodes: @@ -210,19 +217,25 @@ def _get_rendered_vars(vars, action_parameters, action_context): return {} action_context = action_context or {} - user = action_context.get('user', cfg.CONF.system_user.user) + user = action_context.get("user", cfg.CONF.system_user.user) context = {} - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE), - kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( - scope=kv_constants.FULL_USER_SCOPE, user=user) + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ), + kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( + scope=kv_constants.FULL_USER_SCOPE, user=user + ), + } } - }) + ) context.update(action_parameters) - LOG.info('Rendering action chain vars. Mapping = %s; Context = %s', vars, context) + LOG.info( + "Rendering action chain vars. Mapping = %s; Context = %s", vars, context + ) return jinja_utils.render_values(mapping=vars, context=context) def get_node(self, node_name=None, raise_on_failure=False): @@ -233,22 +246,22 @@ def get_node(self, node_name=None, raise_on_failure=False): return node if raise_on_failure: raise runner_exc.ActionRunnerException( - 'Unable to find node with name "%s".' % (node_name)) + 'Unable to find node with name "%s".' % (node_name) + ) return None - def get_next_node(self, curr_node_name=None, condition='on-success'): + def get_next_node(self, curr_node_name=None, condition="on-success"): if not curr_node_name: return self.get_node(self.actionchain.default) current_node = self.get_node(curr_node_name) - if condition == 'on-success': + if condition == "on-success": return self.get_node(current_node.on_success, raise_on_failure=True) - elif condition == 'on-failure': + elif condition == "on-failure": return self.get_node(current_node.on_failure, raise_on_failure=True) - raise runner_exc.ActionRunnerException('Unknown condition %s.' % condition) + raise runner_exc.ActionRunnerException("Unknown condition %s." % condition) class ActionChainRunner(ActionRunner): - def __init__(self, runner_id): super(ActionChainRunner, self).__init__(runner_id=runner_id) self.chain_holder = None @@ -261,16 +274,20 @@ def pre_run(self): super(ActionChainRunner, self).pre_run() chainspec_file = self.entry_point - LOG.debug('Reading action chain from %s for action %s.', chainspec_file, - self.action) + LOG.debug( + "Reading action chain from %s for action %s.", chainspec_file, self.action + ) try: - chainspec = self._meta_loader.load(file_path=chainspec_file, - expected_type=dict) + chainspec = self._meta_loader.load( + file_path=chainspec_file, expected_type=dict + ) except Exception as e: - message = ('Failed to parse action chain definition from "%s": %s' % - (chainspec_file, six.text_type(e))) - LOG.exception('Failed to load action chain definition.') + message = 'Failed to parse action chain definition from "%s": %s' % ( + chainspec_file, + six.text_type(e), + ) + LOG.exception("Failed to load action chain definition.") raise runner_exc.ActionRunnerPreRunError(message) try: @@ -279,20 +296,22 @@ def pre_run(self): # preserve the whole nasty jsonschema message as that is better to get to the # root cause message = six.text_type(e) - LOG.exception('Failed to instantiate ActionChain.') + LOG.exception("Failed to instantiate ActionChain.") raise runner_exc.ActionRunnerPreRunError(message) except Exception as e: message = six.text_type(e) - LOG.exception('Failed to instantiate ActionChain.') + LOG.exception("Failed to instantiate ActionChain.") raise runner_exc.ActionRunnerPreRunError(message) # Runner attributes are set lazily. So these steps # should happen outside the constructor. - if getattr(self, 'liveaction', None): - self._chain_notify = getattr(self.liveaction, 'notify', None) + if getattr(self, "liveaction", None): + self._chain_notify = getattr(self.liveaction, "notify", None) if self.runner_parameters: - self._skip_notify_tasks = self.runner_parameters.get('skip_notify', []) - self._display_published = self.runner_parameters.get('display_published', True) + self._skip_notify_tasks = self.runner_parameters.get("skip_notify", []) + self._display_published = self.runner_parameters.get( + "display_published", True + ) # Perform some pre-run chain validation try: @@ -308,34 +327,38 @@ def cancel(self): # Identify the list of action executions that are workflows and cascade pause. for child_exec_id in self.execution.children: child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True) - if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and - child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES): + if ( + child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES + and child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES + ): action_service.request_cancellation( - LiveAction.get(id=child_exec.liveaction['id']), - self.context.get('user', None) + LiveAction.get(id=child_exec.liveaction["id"]), + self.context.get("user", None), ) return ( action_constants.LIVEACTION_STATUS_CANCELING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def pause(self): # Identify the list of action executions that are workflows and cascade pause. for child_exec_id in self.execution.children: child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True) - if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and - child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING): + if ( + child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES + and child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING + ): action_service.request_pause( - LiveAction.get(id=child_exec.liveaction['id']), - self.context.get('user', None) + LiveAction.get(id=child_exec.liveaction["id"]), + self.context.get("user", None), ) return ( action_constants.LIVEACTION_STATUS_PAUSING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def resume(self): @@ -344,7 +367,7 @@ def resume(self): self.runner_type.runner_parameters, self.action.parameters, self.liveaction.parameters, - self.liveaction.context + self.liveaction.context, ) # Assign runner parameters needed for pre-run. @@ -357,9 +380,7 @@ def resume(self): # Change the status of the liveaction from resuming to running. self.liveaction = action_service.update_status( - self.liveaction, - action_constants.LIVEACTION_STATUS_RUNNING, - publish=False + self.liveaction, action_constants.LIVEACTION_STATUS_RUNNING, publish=False ) # Run the action chain. @@ -370,13 +391,15 @@ def _run_chain(self, action_parameters, resuming=False): chain_status = action_constants.LIVEACTION_STATUS_FAILED # Result holds the final result that the chain store in the database. - result = {'tasks': []} + result = {"tasks": []} # Save published variables into the result if specified. if self._display_published: result[PUBLISHED_VARS_KEY] = {} - context_result = {} # Holds result which is used for the template context purposes + context_result = ( + {} + ) # Holds result which is used for the template context purposes top_level_error = None # Stores a reference to a top level error action_node = None last_task = None @@ -384,11 +407,12 @@ def _run_chain(self, action_parameters, resuming=False): try: # Initialize vars with the action parameters. # This allows action parameers to be referenced from vars. - self.chain_holder.init_vars(action_parameters=action_parameters, - action_context=self.context) + self.chain_holder.init_vars( + action_parameters=action_parameters, action_context=self.context + ) except Exception as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed initializing ``vars`` in chain.' + m = "Failed initializing ``vars`` in chain." LOG.exception(m) top_level_error = self._format_error(e, m) result.update(top_level_error) @@ -397,28 +421,32 @@ def _run_chain(self, action_parameters, resuming=False): # Restore state on resuming an existing chain execution. if resuming: # Restore vars is any from the liveaction. - ctx_vars = self.liveaction.context.pop('vars', {}) + ctx_vars = self.liveaction.context.pop("vars", {}) self.chain_holder.restore_vars(ctx_vars) # Restore result if any from the liveaction. - if self.liveaction and hasattr(self.liveaction, 'result') and self.liveaction.result: + if ( + self.liveaction + and hasattr(self.liveaction, "result") + and self.liveaction.result + ): result = self.liveaction.result # Initialize or rebuild existing context_result from liveaction # which holds the result used for resolving context in Jinja template. - for task in result.get('tasks', []): - context_result[task['name']] = task['result'] + for task in result.get("tasks", []): + context_result[task["name"]] = task["result"] # Restore or initialize the top_level_error # that stores a reference to a top level error. - if 'error' in result or 'traceback' in result: + if "error" in result or "traceback" in result: top_level_error = { - 'error': result.get('error'), - 'traceback': result.get('traceback') + "error": result.get("error"), + "traceback": result.get("traceback"), } # If there are no executed tasks in the chain, then get the first node. - if len(result['tasks']) <= 0: + if len(result["tasks"]) <= 0: try: action_node = self.chain_holder.get_next_node() except Exception as e: @@ -433,21 +461,24 @@ def _run_chain(self, action_parameters, resuming=False): # Otherwise, figure out the last task executed and # its state to determine where to begin executing. else: - last_task = result['tasks'][-1] - action_node = self.chain_holder.get_node(last_task['name']) - liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id']) + last_task = result["tasks"][-1] + action_node = self.chain_holder.get_node(last_task["name"]) + liveaction = action_db_util.get_liveaction_by_id(last_task["liveaction_id"]) # If the liveaction of the last task has changed, update the result entry. - if liveaction.status != last_task['state']: + if liveaction.status != last_task["state"]: updated_task_result = self._get_updated_action_exec_result( - action_node, liveaction, last_task) - del result['tasks'][-1] - result['tasks'].append(updated_task_result) + action_node, liveaction, last_task + ) + del result["tasks"][-1] + result["tasks"].append(updated_task_result) # Also need to update context_result so the updated result # is available to Jinja expressions - updated_task_name = updated_task_result['name'] - context_result[updated_task_name]['result'] = updated_task_result['result'] + updated_task_name = updated_task_result["name"] + context_result[updated_task_name]["result"] = updated_task_result[ + "result" + ] # If the last task was canceled, then canceled the chain altogether. if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: @@ -463,42 +494,52 @@ def _run_chain(self, action_parameters, resuming=False): if liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED: chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_node = self.chain_holder.get_next_node( - last_task['name'], condition='on-success') + last_task["name"], condition="on-success" + ) # If the last task failed, then get the next on-failure action node. if liveaction.status in action_constants.LIVEACTION_FAILED_STATES: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - last_task['name'], condition='on-failure') + last_task["name"], condition="on-failure" + ) # Setup parent context. - parent_context = { - 'execution_id': self.execution_id - } + parent_context = {"execution_id": self.execution_id} - if getattr(self.liveaction, 'context', None): + if getattr(self.liveaction, "context", None): parent_context.update(self.liveaction.context) # Run the action chain until there are no more tasks. while action_node: error = None liveaction = None - last_task = result['tasks'][-1] if len(result['tasks']) > 0 else None + last_task = result["tasks"][-1] if len(result["tasks"]) > 0 else None created_at = date_utils.get_datetime_utc_now() try: # If last task was paused, then fetch the liveaction and resume it first. - if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED: - liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id']) - del result['tasks'][-1] + if ( + last_task + and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED + ): + liveaction = action_db_util.get_liveaction_by_id( + last_task["liveaction_id"] + ) + del result["tasks"][-1] else: liveaction = self._get_next_action( - action_node=action_node, parent_context=parent_context, - action_params=action_parameters, context_result=context_result) + action_node=action_node, + parent_context=parent_context, + action_params=action_parameters, + context_result=context_result, + ) except action_exc.InvalidActionReferencedException as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = ('Failed to run task "%s". Action with reference "%s" doesn\'t exist.' % - (action_node.name, action_node.ref)) + m = ( + 'Failed to run task "%s". Action with reference "%s" doesn\'t exist.' + % (action_node.name, action_node.ref) + ) LOG.exception(m) top_level_error = self._format_error(e, m) break @@ -506,24 +547,41 @@ def _run_chain(self, action_parameters, resuming=False): # Rendering parameters failed before we even got to running this action, # abort and fail the whole action chain chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed to run task "%s". Parameter rendering failed.' % action_node.name + m = ( + 'Failed to run task "%s". Parameter rendering failed.' + % action_node.name + ) LOG.exception(m) top_level_error = self._format_error(e, m) break except db_exc.StackStormDBObjectNotFoundError as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed to resume task "%s". Unable to find liveaction.' % action_node.name + m = ( + 'Failed to resume task "%s". Unable to find liveaction.' + % action_node.name + ) LOG.exception(m) top_level_error = self._format_error(e, m) break try: # If last task was paused, then fetch the liveaction and resume it first. - if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED: - LOG.info('Resume task %s for chain %s.', action_node.name, self.liveaction.id) + if ( + last_task + and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED + ): + LOG.info( + "Resume task %s for chain %s.", + action_node.name, + self.liveaction.id, + ) liveaction = self._resume_action(liveaction) else: - LOG.info('Run task %s for chain %s.', action_node.name, self.liveaction.id) + LOG.info( + "Run task %s for chain %s.", + action_node.name, + self.liveaction.id, + ) liveaction = self._run_action(liveaction) except Exception as e: # Save the traceback and error message @@ -537,9 +595,12 @@ def _run_chain(self, action_parameters, resuming=False): # Render and publish variables rendered_publish_vars = ActionChainRunner._render_publish_vars( - action_node=action_node, action_parameters=action_parameters, - execution_result=liveaction.result, previous_execution_results=context_result, - chain_vars=self.chain_holder.vars) + action_node=action_node, + action_parameters=action_parameters, + execution_result=liveaction.result, + previous_execution_results=context_result, + chain_vars=self.chain_holder.vars, + ) if rendered_publish_vars: self.chain_holder.vars.update(rendered_publish_vars) @@ -550,49 +611,68 @@ def _run_chain(self, action_parameters, resuming=False): updated_at = date_utils.get_datetime_utc_now() task_result = self._format_action_exec_result( - action_node, - liveaction, - created_at, - updated_at, - error=error + action_node, liveaction, created_at, updated_at, error=error ) - result['tasks'].append(task_result) + result["tasks"].append(task_result) try: if not liveaction: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_TIMED_OUT: + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status + == action_constants.LIVEACTION_STATUS_TIMED_OUT + ): chain_status = action_constants.LIVEACTION_STATUS_TIMED_OUT action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: - LOG.info('Chain execution (%s) canceled because task "%s" is canceled.', - self.liveaction_id, action_node.name) + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED + ): + LOG.info( + 'Chain execution (%s) canceled because task "%s" is canceled.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_CANCELED action_node = None elif liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED: - LOG.info('Chain execution (%s) paused because task "%s" is paused.', - self.liveaction_id, action_node.name) + LOG.info( + 'Chain execution (%s) paused because task "%s" is paused.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() action_node = None - elif liveaction.status == action_constants.LIVEACTION_STATUS_PENDING: - LOG.info('Chain execution (%s) paused because task "%s" is pending.', - self.liveaction_id, action_node.name) + elif ( + liveaction.status == action_constants.LIVEACTION_STATUS_PENDING + ): + LOG.info( + 'Chain execution (%s) paused because task "%s" is pending.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() action_node = None elif liveaction.status in action_constants.LIVEACTION_FAILED_STATES: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED: + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status + == action_constants.LIVEACTION_STATUS_SUCCEEDED + ): chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-success') + action_node.name, condition="on-success" + ) else: action_node = None except Exception as e: @@ -604,12 +684,12 @@ def _run_chain(self, action_parameters, resuming=False): break if action_service.is_action_canceled_or_canceling(self.liveaction.id): - LOG.info('Chain execution (%s) canceled by user.', self.liveaction.id) + LOG.info("Chain execution (%s) canceled by user.", self.liveaction.id) chain_status = action_constants.LIVEACTION_STATUS_CANCELED return (chain_status, result, None) if action_service.is_action_paused_or_pausing(self.liveaction.id): - LOG.info('Chain execution (%s) paused by user.', self.liveaction.id) + LOG.info("Chain execution (%s) paused by user.", self.liveaction.id) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() return (chain_status, result, self.liveaction.context) @@ -621,17 +701,22 @@ def _run_chain(self, action_parameters, resuming=False): def _format_error(self, e, msg): return { - 'error': '%s. %s' % (msg, six.text_type(e)), - 'traceback': traceback.format_exc(10) + "error": "%s. %s" % (msg, six.text_type(e)), + "traceback": traceback.format_exc(10), } def _save_vars(self): # Save the context vars in the liveaction context. - self.liveaction.context['vars'] = self.chain_holder.vars + self.liveaction.context["vars"] = self.chain_holder.vars @staticmethod - def _render_publish_vars(action_node, action_parameters, execution_result, - previous_execution_results, chain_vars): + def _render_publish_vars( + action_node, + action_parameters, + execution_result, + previous_execution_results, + chain_vars, + ): """ If no output is specified on the action_node the output is the entire execution_result. If any output is specified then only those variables are published as output of an @@ -649,36 +734,48 @@ def _render_publish_vars(action_node, action_parameters, execution_result, context.update(chain_vars) context.update({RESULTS_KEY: previous_execution_results}) - context.update({ - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.SYSTEM_SCOPE) - }) - - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { + context.update( + { kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE) + scope=kv_constants.SYSTEM_SCOPE + ) } - }) + ) + + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ) + } + } + ) try: - rendered_result = jinja_utils.render_values(mapping=action_node.publish, - context=context) + rendered_result = jinja_utils.render_values( + mapping=action_node.publish, context=context + ) except Exception as e: - key = getattr(e, 'key', None) - value = getattr(e, 'value', None) - msg = ('Failed rendering value for publish parameter "%s" in task "%s" ' - '(template string=%s): %s' % (key, action_node.name, value, six.text_type(e))) + key = getattr(e, "key", None) + value = getattr(e, "value", None) + msg = ( + 'Failed rendering value for publish parameter "%s" in task "%s" ' + "(template string=%s): %s" + % (key, action_node.name, value, six.text_type(e)) + ) raise action_exc.ParameterRenderingFailedException(msg) return rendered_result @staticmethod - def _resolve_params(action_node, original_parameters, results, chain_vars, chain_context): + def _resolve_params( + action_node, original_parameters, results, chain_vars, chain_context + ): # setup context with original parameters and the intermediate results. - chain_parent = chain_context.get('parent', {}) - pack = chain_parent.get('pack') - user = chain_parent.get('user') + chain_parent = chain_context.get("parent", {}) + pack = chain_parent.get("pack") + user = chain_parent.get("user") config = get_config(pack, user) @@ -688,34 +785,47 @@ def _resolve_params(action_node, original_parameters, results, chain_vars, chain context.update(chain_vars) context.update({RESULTS_KEY: results}) - context.update({ - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.SYSTEM_SCOPE) - }) - - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { + context.update( + { kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE) + scope=kv_constants.SYSTEM_SCOPE + ) } - }) + ) + + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ) + } + } + ) context.update({action_constants.ACTION_CONTEXT_KV_PREFIX: chain_context}) context.update({pack_constants.PACK_CONFIG_CONTEXT_KV_PREFIX: config}) try: - rendered_params = jinja_utils.render_values(mapping=action_node.get_parameters(), - context=context) + rendered_params = jinja_utils.render_values( + mapping=action_node.get_parameters(), context=context + ) except Exception as e: LOG.exception('Jinja rendering for parameter "%s" failed.' % (e.key)) - key = getattr(e, 'key', None) - value = getattr(e, 'value', None) - msg = ('Failed rendering value for action parameter "%s" in task "%s" ' - '(template string=%s): %s') % (key, action_node.name, value, six.text_type(e)) + key = getattr(e, "key", None) + value = getattr(e, "value", None) + msg = ( + 'Failed rendering value for action parameter "%s" in task "%s" ' + "(template string=%s): %s" + ) % (key, action_node.name, value, six.text_type(e)) raise action_exc.ParameterRenderingFailedException(msg) - LOG.debug('Rendered params: %s: Type: %s', rendered_params, type(rendered_params)) + LOG.debug( + "Rendered params: %s: Type: %s", rendered_params, type(rendered_params) + ) return rendered_params - def _get_next_action(self, action_node, parent_context, action_params, context_result): + def _get_next_action( + self, action_node, parent_context, action_params, context_result + ): # Verify that the referenced action exists # TODO: We do another lookup in cast_param, refactor to reduce number of lookups task_name = action_node.name @@ -723,18 +833,25 @@ def _get_next_action(self, action_node, parent_context, action_params, context_r action_db = action_db_util.get_action_by_ref(ref=action_ref) if not action_db: - error = 'Task :: %s - Action with ref %s not registered.' % (task_name, action_ref) + error = "Task :: %s - Action with ref %s not registered." % ( + task_name, + action_ref, + ) raise action_exc.InvalidActionReferencedException(error) resolved_params = ActionChainRunner._resolve_params( - action_node=action_node, original_parameters=action_params, - results=context_result, chain_vars=self.chain_holder.vars, - chain_context={'parent': parent_context}) + action_node=action_node, + original_parameters=action_params, + results=context_result, + chain_vars=self.chain_holder.vars, + chain_context={"parent": parent_context}, + ) liveaction = self._build_liveaction_object( action_node=action_node, resolved_params=resolved_params, - parent_context=parent_context) + parent_context=parent_context, + ) return liveaction @@ -747,13 +864,16 @@ def _run_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0): liveaction, _ = action_service.request(liveaction) except Exception as e: liveaction.status = action_constants.LIVEACTION_STATUS_FAILED - LOG.exception('Failed to schedule liveaction.') + LOG.exception("Failed to schedule liveaction.") raise e - while (wait_for_completion and liveaction.status not in ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED, - action_constants.LIVEACTION_STATUS_PENDING])): + while wait_for_completion and liveaction.status not in ( + action_constants.LIVEACTION_COMPLETED_STATES + + [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PENDING, + ] + ): eventlet.sleep(sleep_delay) liveaction = action_db_util.get_liveaction_by_id(liveaction.id) @@ -765,16 +885,17 @@ def _resume_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0): :type sleep_delay: ``float`` """ try: - user = self.context.get('user', None) + user = self.context.get("user", None) liveaction, _ = action_service.request_resume(liveaction, user) except Exception as e: liveaction.status = action_constants.LIVEACTION_STATUS_FAILED - LOG.exception('Failed to schedule liveaction.') + LOG.exception("Failed to schedule liveaction.") raise e - while (wait_for_completion and liveaction.status not in ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED])): + while wait_for_completion and liveaction.status not in ( + action_constants.LIVEACTION_COMPLETED_STATES + + [action_constants.LIVEACTION_STATUS_PAUSED] + ): eventlet.sleep(sleep_delay) liveaction = action_db_util.get_liveaction_by_id(liveaction.id) @@ -787,14 +908,12 @@ def _build_liveaction_object(self, action_node, resolved_params, parent_context) notify = self._get_notify(action_node) if notify: liveaction.notify = notify - LOG.debug('%s: Task notify set to: %s', action_node.name, liveaction.notify) + LOG.debug("%s: Task notify set to: %s", action_node.name, liveaction.notify) - liveaction.context = { - 'parent': parent_context, - 'chain': vars(action_node) - } - liveaction.parameters = action_param_utils.cast_params(action_ref=action_node.ref, - params=resolved_params) + liveaction.context = {"parent": parent_context, "chain": vars(action_node)} + liveaction.parameters = action_param_utils.cast_params( + action_ref=action_node.ref, params=resolved_params + ) return liveaction def _get_notify(self, action_node): @@ -807,18 +926,23 @@ def _get_notify(self, action_node): return None - def _get_updated_action_exec_result(self, action_node, liveaction, prev_task_result): + def _get_updated_action_exec_result( + self, action_node, liveaction, prev_task_result + ): if liveaction.status in action_constants.LIVEACTION_COMPLETED_STATES: - created_at = isotime.parse(prev_task_result['created_at']) + created_at = isotime.parse(prev_task_result["created_at"]) updated_at = liveaction.end_timestamp else: - created_at = isotime.parse(prev_task_result['created_at']) - updated_at = isotime.parse(prev_task_result['updated_at']) + created_at = isotime.parse(prev_task_result["created_at"]) + updated_at = isotime.parse(prev_task_result["updated_at"]) - return self._format_action_exec_result(action_node, liveaction, created_at, updated_at) + return self._format_action_exec_result( + action_node, liveaction, created_at, updated_at + ) - def _format_action_exec_result(self, action_node, liveaction_db, created_at, updated_at, - error=None): + def _format_action_exec_result( + self, action_node, liveaction_db, created_at, updated_at, error=None + ): """ Format ActionExecution result so it can be used in the final action result output. @@ -833,24 +957,24 @@ def _format_action_exec_result(self, action_node, liveaction_db, created_at, upd if liveaction_db: execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - result['id'] = action_node.name - result['name'] = action_node.name - result['execution_id'] = str(execution_db.id) if execution_db else None - result['liveaction_id'] = str(liveaction_db.id) if liveaction_db else None - result['workflow'] = None + result["id"] = action_node.name + result["name"] = action_node.name + result["execution_id"] = str(execution_db.id) if execution_db else None + result["liveaction_id"] = str(liveaction_db.id) if liveaction_db else None + result["workflow"] = None - result['created_at'] = isotime.format(dt=created_at) - result['updated_at'] = isotime.format(dt=updated_at) + result["created_at"] = isotime.format(dt=created_at) + result["updated_at"] = isotime.format(dt=updated_at) if error or not liveaction_db: - result['state'] = action_constants.LIVEACTION_STATUS_FAILED + result["state"] = action_constants.LIVEACTION_STATUS_FAILED else: - result['state'] = liveaction_db.status + result["state"] = liveaction_db.status if error: - result['result'] = error + result["result"] = error else: - result['result'] = liveaction_db.result + result["result"] = liveaction_db.result return result @@ -860,4 +984,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('action_chain_runner')[0] + return get_runner_metadata("action_chain_runner")[0] diff --git a/contrib/runners/action_chain_runner/dist_utils.py b/contrib/runners/action_chain_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/action_chain_runner/dist_utils.py +++ b/contrib/runners/action_chain_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/action_chain_runner/setup.py b/contrib/runners/action_chain_runner/setup.py index 6c2043505c..7c96e1e1d1 100644 --- a/contrib/runners/action_chain_runner/setup.py +++ b/contrib/runners/action_chain_runner/setup.py @@ -26,31 +26,33 @@ from action_chain_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-action-chain', + name="stackstorm-runner-action-chain", version=__version__, - description=('Action-Chain workflow action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Action-Chain workflow action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'action_chain_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"action_chain_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'action-chain = action_chain_runner.action_chain_runner', + "st2common.runners.runner": [ + "action-chain = action_chain_runner.action_chain_runner", ], - } + }, ) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py index 32bb5c9249..9daed4fa90 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py @@ -39,99 +39,135 @@ class DummyActionExecution(object): - def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''): + def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""): self.id = None self.status = status self.result = result -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'actions': ['a1.yaml', 'a2.yaml', 'action_4_action_context_param.yaml'], - 'runners': ['testrunner1.yaml'] + "actions": ["a1.yaml", "a2.yaml", "action_4_action_context_param.yaml"], + "runners": ["testrunner1.yaml"], } -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] -ACTION_2 = MODELS['actions']['a2.yaml'] -ACTION_3 = MODELS['actions']['action_4_action_context_param.yaml'] -RUNNER = MODELS['runners']['testrunner1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] +ACTION_2 = MODELS["actions"]["a2.yaml"] +ACTION_3 = MODELS["actions"]["action_4_action_context_param.yaml"] +RUNNER = MODELS["runners"]["testrunner1.yaml"] CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain1.yaml') + FIXTURES_PACK, "actionchains", "chain1.yaml" +) CHAIN_2_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain2.yaml') + FIXTURES_PACK, "actionchains", "chain2.yaml" +) CHAIN_ACTION_CALL_NO_PARAMS_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_call_no_params.yaml') + FIXTURES_PACK, "actionchains", "chain_action_call_no_params.yaml" +) CHAIN_NO_DEFAULT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'no_default_chain.yaml') + FIXTURES_PACK, "actionchains", "no_default_chain.yaml" +) CHAIN_NO_DEFAULT_2 = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'no_default_chain_2.yaml') + FIXTURES_PACK, "actionchains", "no_default_chain_2.yaml" +) CHAIN_BAD_DEFAULT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'bad_default_chain.yaml') -CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_broken_on_success_path_static_task_name.yaml') -CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_broken_on_failure_path_static_task_name.yaml') + FIXTURES_PACK, "actionchains", "bad_default_chain.yaml" +) +CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, + "actionchains", + "chain_broken_on_success_path_static_task_name.yaml", + ) +) +CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, + "actionchains", + "chain_broken_on_failure_path_static_task_name.yaml", + ) +) CHAIN_FIRST_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_first_task_parameter_render_fail.yaml') + FIXTURES_PACK, "actionchains", "chain_first_task_parameter_render_fail.yaml" +) CHAIN_SECOND_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_second_task_parameter_render_fail.yaml') + FIXTURES_PACK, "actionchains", "chain_second_task_parameter_render_fail.yaml" +) CHAIN_LIST_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_list_template.yaml') + FIXTURES_PACK, "actionchains", "chain_list_template.yaml" +) CHAIN_DICT_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dict_template.yaml') + FIXTURES_PACK, "actionchains", "chain_dict_template.yaml" +) CHAIN_DEP_INPUT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dependent_input.yaml') + FIXTURES_PACK, "actionchains", "chain_dependent_input.yaml" +) CHAIN_DEP_RESULTS_INPUT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dep_result_input.yaml') + FIXTURES_PACK, "actionchains", "chain_dep_result_input.yaml" +) MALFORMED_CHAIN_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'malformedchain.yaml') + FIXTURES_PACK, "actionchains", "malformedchain.yaml" +) CHAIN_TYPED_PARAMS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_typed_params.yaml') + FIXTURES_PACK, "actionchains", "chain_typed_params.yaml" +) CHAIN_SYSTEM_PARAMS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_typed_system_params.yaml') + FIXTURES_PACK, "actionchains", "chain_typed_system_params.yaml" +) CHAIN_WITH_ACTIONPARAM_VARS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_actionparam_vars.yaml') + FIXTURES_PACK, "actionchains", "chain_with_actionparam_vars.yaml" +) CHAIN_WITH_SYSTEM_VARS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_system_vars.yaml') + FIXTURES_PACK, "actionchains", "chain_with_system_vars.yaml" +) CHAIN_WITH_PUBLISH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_publish.yaml') + FIXTURES_PACK, "actionchains", "chain_with_publish.yaml" +) CHAIN_WITH_PUBLISH_2 = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_publish_2.yaml') + FIXTURES_PACK, "actionchains", "chain_with_publish_2.yaml" +) CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_publish_params_rendering_failure.yaml') + FIXTURES_PACK, "actionchains", "chain_publish_params_rendering_failure.yaml" +) CHAIN_WITH_INVALID_ACTION = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_invalid_action.yaml') -CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_params_and_parameters.yaml') + FIXTURES_PACK, "actionchains", "chain_with_invalid_action.yaml" +) +CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, "actionchains", "chain_action_params_and_parameters.yaml" + ) +) CHAIN_ACTION_PARAMS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_params_attribute.yaml') + FIXTURES_PACK, "actionchains", "chain_action_params_attribute.yaml" +) CHAIN_ACTION_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_parameters_attribute.yaml') + FIXTURES_PACK, "actionchains", "chain_action_parameters_attribute.yaml" +) CHAIN_ACTION_INVALID_PARAMETER_TYPE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_invalid_parameter_type_passed_to_action.yaml') + FIXTURES_PACK, "actionchains", "chain_invalid_parameter_type_passed_to_action.yaml" +) -CHAIN_NOTIFY_API = {'notify': {'on-complete': {'message': 'foo happened.'}}} +CHAIN_NOTIFY_API = {"notify": {"on-complete": {"message": "foo happened."}}} CHAIN_NOTIFY_DB = NotificationsHelper.to_model(CHAIN_NOTIFY_API) @mock.patch.object( - action_db_util, - 'get_runnertype_by_name', - mock.MagicMock(return_value=RUNNER)) + action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER) +) @mock.patch.object( action_service, - 'is_action_canceled_or_canceling', - mock.MagicMock(return_value=False)) + "is_action_canceled_or_canceling", + mock.MagicMock(return_value=False), +) @mock.patch.object( - action_service, - 'is_action_paused_or_pausing', - mock.MagicMock(return_value=False)) + action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False) +) class TestActionChainRunner(ExecutionDbTestCase): - def test_runner_creation(self): runner = acr.get_runner() self.assertTrue(runner) @@ -143,18 +179,23 @@ def test_malformed_chain(self): chain_runner.entry_point = MALFORMED_CHAIN_PATH chain_runner.action = ACTION_1 chain_runner.pre_run() - self.assertTrue(False, 'Expected pre_run to fail.') + self.assertTrue(False, "Expected pre_run to fail.") except runnerexceptions.ActionRunnerPreRunError: self.assertTrue(True) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.liveaction.notify = CHAIN_NOTIFY_DB chain_runner.pre_run() @@ -163,9 +204,12 @@ def test_chain_runner_success_path(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_chain_second_task_times_out(self, request): # Second task in the chain times out so the action chain status should be timeout chain_runner = acr.get_runner() @@ -177,13 +221,15 @@ def test_chain_runner_chain_second_task_times_out(self, request): def mock_run_action(*args, **kwargs): original_live_action = args[0] liveaction = original_run_action(*args, **kwargs) - if original_live_action.action == 'wolfpack.a2': + if original_live_action.action == "wolfpack.a2": # Mock a timeout for second task liveaction.status = LIVEACTION_STATUS_TIMED_OUT return liveaction chain_runner._run_action = mock_run_action - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -193,9 +239,12 @@ def mock_run_action(*args, **kwargs): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_task_is_canceled_while_running(self, request): # Second task in the action is CANCELED, make sure runner doesn't get stuck in an infinite # loop @@ -207,7 +256,7 @@ def test_chain_runner_task_is_canceled_while_running(self, request): def mock_run_action(*args, **kwargs): original_live_action = args[0] - if original_live_action.action == 'wolfpack.a2': + if original_live_action.action == "wolfpack.a2": status = LIVEACTION_STATUS_CANCELED else: status = LIVEACTION_STATUS_SUCCEEDED @@ -216,7 +265,9 @@ def mock_run_action(*args, **kwargs): return liveaction chain_runner._run_action = mock_run_action - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -227,16 +278,21 @@ def mock_run_action(*args, **kwargs): # canceled self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_task_action_call_with_no_params(self, request): # Make sure that the runner doesn't explode if task definition contains # no "params" section chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_CALL_NO_PARAMS_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.liveaction.notify = CHAIN_NOTIFY_DB chain_runner.pre_run() @@ -245,14 +301,19 @@ def test_chain_runner_success_task_action_call_with_no_params(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_no_default(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_NO_DEFAULT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -264,9 +325,12 @@ def test_chain_runner_no_default(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_no_default_multiple_options(self, request): # subtle difference is that when there are multiple possible default nodes # the order per chain definition may not be preseved. This is really a @@ -274,7 +338,9 @@ def test_chain_runner_no_default_multiple_options(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_NO_DEFAULT_2 chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -286,29 +352,44 @@ def test_chain_runner_no_default_multiple_options(self, request): # based on the chain the callcount is known to be 2. self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_bad_default(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BAD_DEFAULT chain_runner.action = ACTION_1 - expected_msg = 'Unable to find node with name "bad_default" referenced in "default".' - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch('eventlet.sleep', mock.MagicMock()) - @mock.patch.object(action_db_util, 'get_liveaction_by_id', mock.MagicMock( - return_value=DummyActionExecution())) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None)) + expected_msg = ( + 'Unable to find node with name "bad_default" referenced in "default".' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch("eventlet.sleep", mock.MagicMock()) + @mock.patch.object( + action_db_util, + "get_liveaction_by_id", + mock.MagicMock(return_value=DummyActionExecution()), + ) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None), + ) def test_chain_runner_success_path_with_wait(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -316,15 +397,21 @@ def test_chain_runner_success_path_with_wait(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None), + ) def test_chain_runner_failure_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -333,42 +420,57 @@ def test_chain_runner_failure_path(self, request): # based on the chain the callcount is known to be 2. Not great but works. self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_broken_on_success_path_static_task_name(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME chain_runner.action = ACTION_1 - expected_msg = ('Unable to find node with name "c5" referenced in "on-success" ' - 'in task "c2"') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(), None)) + expected_msg = ( + 'Unable to find node with name "c5" referenced in "on-success" ' + 'in task "c2"' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_broken_on_failure_path_static_task_name(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME chain_runner.action = ACTION_1 - expected_msg = ('Unable to find node with name "c6" referenced in "on-failure" ' - 'in task "c2"') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', side_effect=RuntimeError('Test Failure.')) + expected_msg = ( + 'Unable to find node with name "c6" referenced in "on-failure" ' + 'in task "c2"' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", side_effect=RuntimeError("Test Failure.") + ) def test_chain_runner_action_exception(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, results, _ = chain_runner.run({}) @@ -379,102 +481,131 @@ def test_chain_runner_action_exception(self, request): self.assertEqual(request.call_count, 2) error_count = 0 - for task_result in results['tasks']: - if task_result['result'].get('error', None): + for task_result in results["tasks"]: + if task_result["result"].get("error", None): error_count += 1 self.assertEqual(error_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_str_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, {"p1": "1"}) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_list_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_LIST_TEMP_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, {"p1": "[2, 3, 4]"}) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_dict_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DICT_TEMP_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) expected_value = {"p1": {"p1.3": "[3, 4]", "p1.2": "2", "p1.1": "1"}} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'o1': '1'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"o1": "1"}), None), + ) def test_chain_runner_dependent_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DEP_INPUT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_values = [{u'p1': u'1'}, - {u'p1': u'1'}, - {u'p2': u'1', u'p3': u'1', u'p1': u'1'}] + expected_values = [{"p1": "1"}, {"p1": "1"}, {"p2": "1", "p3": "1", "p1": "1"}] # Each of the call_args must be one of for call_args in request.call_args_list: self.assertIn(call_args[0][0].parameters, expected_values) expected_values.remove(call_args[0][0].parameters) - self.assertEqual(len(expected_values), 0, 'Not all expected values received.') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'o1': '1'}), None)) + self.assertEqual(len(expected_values), 0, "Not all expected values received.") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"o1": "1"}), None), + ) def test_chain_runner_dependent_results_param(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DEP_RESULTS_INPUT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1}) + chain_runner.run({"s1": 1}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) if six.PY2: - expected_values = [{u'p1': u'1'}, - {u'p1': u'1'}, - {u'out': u"{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"}] + expected_values = [ + {"p1": "1"}, + {"p1": "1"}, + {"out": "{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"}, + ] else: - expected_values = [{'p1': '1'}, - {'p1': '1'}, - {'out': "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"}] + expected_values = [ + {"p1": "1"}, + {"p1": "1"}, + {"out": "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"}, + ] # Each of the call_args must be one of self.assertEqual(request.call_count, 3) @@ -482,104 +613,137 @@ def test_chain_runner_dependent_results_param(self, request): self.assertIn(call_args[0][0].parameters, expected_values) expected_values.remove(call_args[0][0].parameters) - self.assertEqual(len(expected_values), 0, 'Not all expected values received.') + self.assertEqual(len(expected_values), 0, "Not all expected values received.") - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(RunnerType, 'get_by_name', - mock.MagicMock(return_value=RUNNER)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object(RunnerType, "get_by_name", mock.MagicMock(return_value=RUNNER)) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_missing_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) - self.assertEqual(request.call_count, 0, 'No call expected.') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(request.call_count, 0, "No call expected.") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_failure_during_param_rendering_single_task(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, result, _ = chain_runner.run({}) # No tasks ran because rendering of parameters for the first task failed self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(result['tasks'], []) - self.assertIn('error', result) - self.assertIn('traceback', result) - self.assertIn('Failed to run task "c1". Parameter rendering failed', result['error']) - self.assertIn('Traceback', result['traceback']) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(result["tasks"], []) + self.assertIn("error", result) + self.assertIn("traceback", result) + self.assertIn( + 'Failed to run task "c1". Parameter rendering failed', result["error"] + ) + self.assertIn("Traceback", result["traceback"]) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_failure_during_param_rendering_multiple_tasks(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_SECOND_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, result, _ = chain_runner.run({}) # Verify that only first task has ran self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(len(result['tasks']), 1) - self.assertEqual(result['tasks'][0]['name'], 'c1') - - expected_error = ('Failed rendering value for action parameter "p1" in ' - 'task "c2" (template string={{s1}}):') - - self.assertIn('error', result) - self.assertIn('traceback', result) - self.assertIn('Failed to run task "c2". Parameter rendering failed', result['error']) - self.assertIn(expected_error, result['error']) - self.assertIn('Traceback', result['traceback']) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(len(result["tasks"]), 1) + self.assertEqual(result["tasks"][0]["name"], "c1") + + expected_error = ( + 'Failed rendering value for action parameter "p1" in ' + 'task "c2" (template string={{s1}}):' + ) + + self.assertIn("error", result) + self.assertIn("traceback", result) + self.assertIn( + 'Failed to run task "c2". Parameter rendering failed', result["error"] + ) + self.assertIn(expected_error, result["error"]) + self.assertIn("Traceback", result["traceback"]) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_typed_params(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_TYPED_PARAMS chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 'two', 's3': 3.14}) + chain_runner.run({"s1": 1, "s2": "two", "s3": 3.14}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'booltype': True, - 'inttype': 1, - 'numbertype': 3.14, - 'strtype': 'two', - 'arrtype': ['1', 'two'], - 'objtype': {'s2': 'two', - 'k1': '1'}} + expected_value = { + "booltype": True, + "inttype": 1, + "numbertype": 3.14, + "strtype": "two", + "arrtype": ["1", "two"], + "objtype": {"s2": "two", "k1": "1"}, + } mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_typed_system_params(self, request): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) kvps = [] try: - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='1'))) - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='two'))) + kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="1"))) + kvps.append( + KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="two")) + ) chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_SYSTEM_PARAMS chain_runner.action = ACTION_2 @@ -587,22 +751,28 @@ def test_chain_runner_typed_system_params(self, request): chain_runner.pre_run() chain_runner.run({}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two'} + expected_value = {"inttype": 1, "strtype": "two"} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) finally: for kvp in kvps: KeyValuePair.delete(kvp) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_vars_system_params(self, request): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) kvps = [] try: - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='two'))) + kvps.append( + KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="two")) + ) chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_SYSTEM_VARS chain_runner.action = ACTION_2 @@ -610,72 +780,88 @@ def test_chain_runner_vars_system_params(self, request): chain_runner.pre_run() chain_runner.run({}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two', - 'booltype': True} + expected_value = {"inttype": 1, "strtype": "two", "booltype": True} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) finally: for kvp in kvps: KeyValuePair.delete(kvp) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_vars_action_params(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_ACTIONPARAM_VARS chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'input_a': 'two'}) + chain_runner.run({"input_a": "two"}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two', - 'booltype': True} + expected_value = {"inttype": 1, "strtype": "two", "booltype": True} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_chain_runner_publish(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': True} + chain_runner.runner_parameters = {"display_published": True} chain_runner.pre_run() - action_parameters = {'action_param_1': 'test value 1'} + action_parameters = {"action_param_1": "test value 1"} _, result, _ = chain_runner.run(action_parameters=action_parameters) # We also assert that the action parameters are available in the # "publish" scope self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'published', - 'booltype': True, - 'published_action_param': action_parameters['action_param_1']} + expected_value = { + "inttype": 1, + "strtype": "published", + "booltype": True, + "published_action_param": action_parameters["action_param_1"], + } mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) # Assert that the variables are correctly published - self.assertEqual(result['published'], - {'published_action_param': u'test value 1', 'o1': u'published'}) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual( + result["published"], + {"published_action_param": "test value 1", "o1": "published"}, + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_publish_param_rendering_failure(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() @@ -685,16 +871,21 @@ def test_chain_runner_publish_param_rendering_failure(self, request): # TODO: Should we treat this as task error? Right now it bubbles all # the way up and it's not really consistent with action param # rendering failure - expected_error = ('Failed rendering value for publish parameter "p1" in ' - 'task "c2" (template string={{ not_defined }}):') + expected_error = ( + 'Failed rendering value for publish parameter "p1" in ' + 'task "c2" (template string={{ not_defined }}):' + ) self.assertIn(expected_error, six.text_type(e)) pass else: - self.fail('Exception was not thrown') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.fail("Exception was not thrown") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_INVALID_PARAMETER_TYPE @@ -702,48 +893,72 @@ def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request): chain_runner.pre_run() action_parameters = {} - expected_msg = (r'Failed to cast value "stringnotanarray" \(type: str\) for parameter ' - r'"arrtype" of type "array"') - self.assertRaisesRegexp(ValueError, expected_msg, chain_runner.run, - action_parameters=action_parameters) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=None)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + expected_msg = ( + r'Failed to cast value "stringnotanarray" \(type: str\) for parameter ' + r'"arrtype" of type "array"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + chain_runner.run, + action_parameters=action_parameters, + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=None) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_action_chain_runner_referenced_action_doesnt_exist(self, mock_request): # Action referenced by a task doesn't exist, should result in a top level error chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_INVALID_ACTION chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() action_parameters = {} status, output, _ = chain_runner.run(action_parameters=action_parameters) - expected_error = ('Failed to run task "c1". Action with reference "wolfpack.a2" ' - 'doesn\'t exist.') + expected_error = ( + 'Failed to run task "c1". Action with reference "wolfpack.a2" ' + "doesn't exist." + ) self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertIn(expected_error, output['error']) - self.assertIn('Traceback', output['traceback']) + self.assertIn(expected_error, output["error"]) + self.assertIn("Traceback", output["traceback"]) - def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided(self): + def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided( + self, + ): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE chain_runner.action = ACTION_2 - expected_msg = ('Either "params" or "parameters" attribute needs to be provided, but ' - 'not both') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, expected_msg, - chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + expected_msg = ( + 'Either "params" or "parameters" attribute needs to be provided, but ' + "not both" + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_params_and_parameters_attributes_both_work(self, _): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) # "params" attribute used chain_runner = acr.get_runner() @@ -756,10 +971,12 @@ def test_params_and_parameters_attributes_both_work(self, _): def mock_build_liveaction_object(action_node, resolved_params, parent_context): # Verify parameters are correctly passed to the action - self.assertEqual(resolved_params, {'pparams': 'v1'}) - original_build_liveaction_object(action_node=action_node, - resolved_params=resolved_params, - parent_context=parent_context) + self.assertEqual(resolved_params, {"pparams": "v1"}) + original_build_liveaction_object( + action_node=action_node, + resolved_params=resolved_params, + parent_context=parent_context, + ) chain_runner._build_liveaction_object = mock_build_liveaction_object @@ -776,10 +993,12 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context): def mock_build_liveaction_object(action_node, resolved_params, parent_context): # Verify parameters are correctly passed to the action - self.assertEqual(resolved_params, {'pparameters': 'v1'}) - original_build_liveaction_object(action_node=action_node, - resolved_params=resolved_params, - parent_context=parent_context) + self.assertEqual(resolved_params, {"pparameters": "v1"}) + original_build_liveaction_object( + action_node=action_node, + resolved_params=resolved_params, + parent_context=parent_context, + ) chain_runner._build_liveaction_object = mock_build_liveaction_object @@ -787,21 +1006,27 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context): status, output, _ = chain_runner.run(action_parameters=action_parameters) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_display_published_is_true_by_default(self, _): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) expected_published_values = { - 't1_publish_param_1': 'foo1', - 't1_publish_param_2': 'foo2', - 't1_publish_param_3': 'foo3', - 't2_publish_param_1': 'foo4', - 't2_publish_param_2': 'foo5', - 't2_publish_param_3': 'foo6', - 'publish_last_wins': 'bar_last', + "t1_publish_param_1": "foo1", + "t1_publish_param_2": "foo2", + "t1_publish_param_3": "foo3", + "t2_publish_param_1": "foo4", + "t2_publish_param_2": "foo5", + "t2_publish_param_3": "foo6", + "publish_last_wins": "bar_last", } # 1. display_published is True by default @@ -816,35 +1041,35 @@ def test_display_published_is_true_by_default(self, _): _, result, _ = chain_runner.run(action_parameters=action_parameters) # Assert that the variables are correctly published - self.assertEqual(result['published'], expected_published_values) + self.assertEqual(result["published"], expected_published_values) # 2. display_published is True by default so end result should be the same chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_2 chain_runner.action = ACTION_2 chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': True} + chain_runner.runner_parameters = {"display_published": True} chain_runner.pre_run() action_parameters = {} _, result, _ = chain_runner.run(action_parameters=action_parameters) # Assert that the variables are correctly published - self.assertEqual(result['published'], expected_published_values) + self.assertEqual(result["published"], expected_published_values) # 3. display_published is disabled chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_2 chain_runner.action = ACTION_2 chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': False} + chain_runner.runner_parameters = {"display_published": False} chain_runner.pre_run() action_parameters = {} _, result, _ = chain_runner.run(action_parameters=action_parameters) - self.assertNotIn('published', result) - self.assertEqual(result.get('published', {}), {}) + self.assertNotIn("published", result) + self.assertEqual(result.get("published", {}), {}) @classmethod def tearDownClass(cls): diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py index 7bba3606d8..dca88cf803 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py @@ -20,6 +20,7 @@ import tempfile from st2tests import config as test_config + test_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -40,39 +41,25 @@ TEST_FIXTURES = { - 'chains': [ - 'test_cancel.yaml', - 'test_cancel_with_subworkflow.yaml' - ], - 'actions': [ - 'test_cancel.yaml', - 'test_cancel_with_subworkflow.yaml' - ] + "chains": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"], + "actions": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"], } -TEST_PACK = 'action_chain_tests' -TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "action_chain_tests" +TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK -PACKS = [ - TEST_PACK_PATH, - fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"] -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object( - CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) -@mock.patch.object( - CUDPublisher, - 'publish_create', - mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), +) class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase): temp_file_path = None @@ -86,8 +73,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -98,7 +84,7 @@ def setUp(self): # Create temporary directory used by the tests _, self.temp_file_path = tempfile.mkstemp() - os.chmod(self.temp_file_path, 0o755) # nosec + os.chmod(self.temp_file_path, 0o755) # nosec def tearDown(self): if self.temp_file_path and os.path.exists(self.temp_file_path): @@ -110,7 +96,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100): # Wait until the execution has children. for i in range(0, retries): execution = ActionExecution.get_by_id(str(execution.id)) - if len(getattr(execution, 'children', [])) <= 0: + if len(getattr(execution, "children", [])) <= 0: eventlet.sleep(interval) continue @@ -123,34 +109,42 @@ def test_chain_cancel(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request action chain to cancel. - liveaction, execution = action_service.request_cancellation(liveaction, USERNAME) + liveaction, execution = action_service.request_cancellation( + liveaction, USERNAME + ) # Wait until the liveaction is canceling. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) def test_chain_cancel_cascade_to_subworkflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -159,14 +153,16 @@ def test_chain_cancel_cascade_to_subworkflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Wait for subworkflow to register. execution = self._wait_for_children(execution) @@ -174,44 +170,58 @@ def test_chain_cancel_cascade_to_subworkflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request action chain to cancel. - liveaction, execution = action_service.request_cancellation(liveaction, USERNAME) + liveaction, execution = action_service.request_cancellation( + liveaction, USERNAME + ) # Wait until the liveaction is canceling. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELING + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is canceling. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is canceled. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED + ) def test_chain_cancel_cascade_to_parent_workflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -220,14 +230,16 @@ def test_chain_cancel_cascade_to_parent_workflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Wait for subworkflow to register. execution = self._wait_for_children(execution) @@ -235,16 +247,22 @@ def test_chain_cancel_cascade_to_parent_workflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request subworkflow to cancel. - task1_live, task1_exec = action_service.request_cancellation(task1_live, USERNAME) + task1_live, task1_exec = action_service.request_cancellation( + task1_live, USERNAME + ) # Wait until the subworkflow is canceling. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) @@ -252,20 +270,26 @@ def test_chain_cancel_cascade_to_parent_workflow(self): # Wait until the subworkflow is canceled. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait until the parent liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) self.assertEqual(len(execution.children), 1) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED + ) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py index 193d6064a1..7997869b13 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py @@ -27,51 +27,53 @@ class DummyActionExecution(object): - def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''): + def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""): self.id = None self.status = status self.result = result -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -TEST_MODELS = { - 'actions': ['a1.yaml', 'a2.yaml'], - 'runners': ['testrunner1.yaml'] -} +TEST_MODELS = {"actions": ["a1.yaml", "a2.yaml"], "runners": ["testrunner1.yaml"]} -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] -ACTION_2 = MODELS['actions']['a2.yaml'] -RUNNER = MODELS['runners']['testrunner1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] +ACTION_2 = MODELS["actions"]["a2.yaml"] +RUNNER = MODELS["runners"]["testrunner1.yaml"] CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_notifications.yaml') + FIXTURES_PACK, "actionchains", "chain_with_notifications.yaml" +) @mock.patch.object( - action_db_util, - 'get_runnertype_by_name', - mock.MagicMock(return_value=RUNNER)) + action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER) +) @mock.patch.object( action_service, - 'is_action_canceled_or_canceling', - mock.MagicMock(return_value=False)) + "is_action_canceled_or_canceling", + mock.MagicMock(return_value=False), +) @mock.patch.object( - action_service, - 'is_action_paused_or_pausing', - mock.MagicMock(return_value=False)) + action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False) +) class TestActionChainNotifications(ExecutionDbTestCase): - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -79,8 +81,8 @@ def test_chain_runner_success_path(self, request): self.assertEqual(request.call_count, 2) first_call_args = request.call_args_list[0][0] liveaction_db = first_call_args[0] - self.assertTrue(liveaction_db.notify, 'Notify property expected.') + self.assertTrue(liveaction_db.notify, "Notify property expected.") second_call_args = request.call_args_list[1][0] liveaction_db = second_call_args[0] - self.assertFalse(liveaction_db.notify, 'Notify property not expected.') + self.assertFalse(liveaction_db.notify, "Notify property not expected.") diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py index 6fa4c6b456..d6278ca61a 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py @@ -25,96 +25,96 @@ class ActionChainRunnerResolveParamsTests(unittest2.TestCase): - def test_render_params_action_context(self): runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad' - }, - 'user': 'son', - 'k1': 'v1' + "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"}, + "user": "son", + "k1": "v1", } task_params = { - 'exec_id': {'default': '{{action_context.parent.execution_id}}'}, - 'k2': {}, - 'foo': {'default': 1} + "exec_id": {"default": "{{action_context.parent.execution_id}}"}, + "k2": {}, + "foo": {"default": 1}, } - action_node = Node(name='test_action_context_params', ref='core.local', params=task_params) + action_node = Node( + name="test_action_context_params", ref="core.local", params=task_params + ) rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.assertEqual(rendered_params['exec_id']['default'], 'some_awesome_exec_id') + self.assertEqual(rendered_params["exec_id"]["default"], "some_awesome_exec_id") def test_render_params_action_context_non_existent_member(self): runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad' - }, - 'user': 'son', - 'k1': 'v1' + "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"}, + "user": "son", + "k1": "v1", } task_params = { - 'exec_id': {'default': '{{action_context.parent.yo_gimme_tha_key}}'}, - 'k2': {}, - 'foo': {'default': 1} + "exec_id": {"default": "{{action_context.parent.yo_gimme_tha_key}}"}, + "k2": {}, + "foo": {"default": 1}, } - action_node = Node(name='test_action_context_params', ref='core.local', params=task_params) + action_node = Node( + name="test_action_context_params", ref="core.local", params=task_params + ) try: runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.fail('Should have thrown an instance of %s' % ParameterRenderingFailedException) + self.fail( + "Should have thrown an instance of %s" + % ParameterRenderingFailedException + ) except ParameterRenderingFailedException: pass def test_render_params_with_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: config_loader().get_config.return_value = { - 'amazing_config_value_fo_lyfe': 'no' + "amazing_config_value_fo_lyfe": "no" } runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad', - 'pack': 'mom' + "parent": { + "execution_id": "some_awesome_exec_id", + "user": "dad", + "pack": "mom", }, - 'user': 'son', + "user": "son", } task_params = { - 'config_val': '{{config_context.amazing_config_value_fo_lyfe}}' + "config_val": "{{config_context.amazing_config_value_fo_lyfe}}" } action_node = Node( - name='test_action_context_params', - ref='core.local', - params=task_params + name="test_action_context_params", ref="core.local", params=task_params + ) + rendered_params = runner._resolve_params( + action_node, {}, {}, {}, chain_context ) - rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.assertEqual(rendered_params['config_val'], 'no') + self.assertEqual(rendered_params["config_val"], "no") def test_init_params_vars_with_unicode_value(self): chain_spec = { - 'vars': { - 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'unicode_var_param': u'{{ param }}' + "vars": { + "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "unicode_var_param": "{{ param }}", }, - 'chain': [ + "chain": [ { - 'name': 'c1', - 'ref': 'core.local', - 'parameters': { - 'cmd': 'echo {{ unicode_var }}' - } + "name": "c1", + "ref": "core.local", + "parameters": {"cmd": "echo {{ unicode_var }}"}, } - ] + ], } - chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname='foo') - chain_holder.init_vars(action_parameters={'param': u'٩(̾●̮̮̃̾•̃̾)۶'}) + chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname="foo") + chain_holder.init_vars(action_parameters={"param": "٩(̾●̮̮̃̾•̃̾)۶"}) expected = { - 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'unicode_var_param': u'٩(̾●̮̮̃̾•̃̾)۶' + "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "unicode_var_param": "٩(̾●̮̮̃̾•̃̾)۶", } self.assertEqual(chain_holder.vars, expected) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py index c093c2061c..46f948d73a 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py @@ -20,6 +20,7 @@ import tempfile from st2tests import config as test_config + test_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -42,53 +43,45 @@ TEST_FIXTURES = { - 'chains': [ - 'test_pause_resume.yaml', - 'test_pause_resume_context_result', - 'test_pause_resume_with_published_vars.yaml', - 'test_pause_resume_with_error.yaml', - 'test_pause_resume_with_subworkflow.yaml', - 'test_pause_resume_with_context_access.yaml', - 'test_pause_resume_with_init_vars.yaml', - 'test_pause_resume_with_no_more_task.yaml', - 'test_pause_resume_last_task_failed_with_no_next_task.yaml' + "chains": [ + "test_pause_resume.yaml", + "test_pause_resume_context_result", + "test_pause_resume_with_published_vars.yaml", + "test_pause_resume_with_error.yaml", + "test_pause_resume_with_subworkflow.yaml", + "test_pause_resume_with_context_access.yaml", + "test_pause_resume_with_init_vars.yaml", + "test_pause_resume_with_no_more_task.yaml", + "test_pause_resume_last_task_failed_with_no_next_task.yaml", + ], + "actions": [ + "test_pause_resume.yaml", + "test_pause_resume_context_result", + "test_pause_resume_with_published_vars.yaml", + "test_pause_resume_with_error.yaml", + "test_pause_resume_with_subworkflow.yaml", + "test_pause_resume_with_context_access.yaml", + "test_pause_resume_with_init_vars.yaml", + "test_pause_resume_with_no_more_task.yaml", + "test_pause_resume_last_task_failed_with_no_next_task.yaml", ], - 'actions': [ - 'test_pause_resume.yaml', - 'test_pause_resume_context_result', - 'test_pause_resume_with_published_vars.yaml', - 'test_pause_resume_with_error.yaml', - 'test_pause_resume_with_subworkflow.yaml', - 'test_pause_resume_with_context_access.yaml', - 'test_pause_resume_with_init_vars.yaml', - 'test_pause_resume_with_no_more_task.yaml', - 'test_pause_resume_last_task_failed_with_no_next_task.yaml' - ] } -TEST_PACK = 'action_chain_tests' -TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "action_chain_tests" +TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK -PACKS = [ - TEST_PACK_PATH, - fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"] -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object( - CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) -@mock.patch.object( - CUDPublisher, - 'publish_create', - mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), +) class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase): temp_file_path = None @@ -102,8 +95,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -114,7 +106,7 @@ def setUp(self): # Create temporary directory used by the tests _, self.temp_file_path = tempfile.mkstemp() - os.chmod(self.temp_file_path, 0o755) # nosec + os.chmod(self.temp_file_path, 0o755) # nosec def tearDown(self): if self.temp_file_path and os.path.exists(self.temp_file_path): @@ -138,7 +130,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100): # Wait until the execution has children. for i in range(0, retries): execution = ActionExecution.get_by_id(str(execution.id)) - if len(getattr(execution, 'children', [])) <= 0: + if len(getattr(execution, "children", [])) <= 0: eventlet.sleep(interval) continue @@ -151,32 +143,42 @@ def test_chain_pause_resume(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -185,15 +187,19 @@ def test_chain_pause_resume(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) def test_chain_pause_resume_with_published_vars(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -202,32 +208,42 @@ def test_chain_pause_resume_with_published_vars(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_published_vars" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -236,17 +252,23 @@ def test_chain_pause_resume_with_published_vars(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertIn('published', liveaction.result) - self.assertDictEqual({'var1': 'foobar', 'var2': 'fubar'}, liveaction.result['published']) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertIn("published", liveaction.result) + self.assertDictEqual( + {"var1": "foobar", "var2": "fubar"}, liveaction.result["published"] + ) def test_chain_pause_resume_with_published_vars_display_false(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -255,32 +277,42 @@ def test_chain_pause_resume_with_published_vars_display_false(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars' - params = {'tempfile': path, 'message': 'foobar', 'display_published': False} + action = TEST_PACK + "." + "test_pause_resume_with_published_vars" + params = {"tempfile": path, "message": "foobar", "display_published": False} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -289,16 +321,20 @@ def test_chain_pause_resume_with_published_vars_display_false(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertNotIn('published', liveaction.result) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertNotIn("published", liveaction.result) def test_chain_pause_resume_with_error(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -307,32 +343,42 @@ def test_chain_pause_resume_with_error(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_error' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_error" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -341,19 +387,23 @@ def test_chain_pause_resume_with_error(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertTrue(liveaction.result['tasks'][0]['result']['failed']) - self.assertEqual(1, liveaction.result['tasks'][0]['result']['return_code']) - self.assertTrue(liveaction.result['tasks'][1]['result']['succeeded']) - self.assertEqual(0, liveaction.result['tasks'][1]['result']['return_code']) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertTrue(liveaction.result["tasks"][0]["result"]["failed"]) + self.assertEqual(1, liveaction.result["tasks"][0]["result"]["return_code"]) + self.assertTrue(liveaction.result["tasks"][1]["result"]["succeeded"]) + self.assertEqual(0, liveaction.result["tasks"][1]["result"]["return_code"]) def test_chain_pause_resume_cascade_to_subworkflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -362,14 +412,16 @@ def test_chain_pause_resume_cascade_to_subworkflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Wait for subworkflow to register. @@ -378,71 +430,97 @@ def test_chain_pause_resume_cascade_to_subworkflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is pausing. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request action chain to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 2) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 2) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED + ) def test_chain_pause_resume_cascade_to_parent_workflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -451,14 +529,16 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Wait for subworkflow to register. @@ -467,8 +547,10 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request subworkflow to pause. @@ -476,10 +558,14 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is pausing. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) @@ -487,39 +573,55 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait until the parent liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request subworkflow to resume. task1_live, task1_exec = action_service.request_resume(task1_live, USERNAME) # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # The parent workflow will stay paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED) # Wait for non-blocking threads to complete. @@ -527,30 +629,38 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Check liveaction result of the parent, which should stay the same # because only the subworkflow was resumed. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request parent workflow to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 2) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 2) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED + ) def test_chain_pause_resume_with_context_access(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -559,32 +669,42 @@ def test_chain_pause_resume_with_context_access(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_context_access' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_context_access" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -593,16 +713,20 @@ def test_chain_pause_resume_with_context_access(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 3) - self.assertEqual(liveaction.result['tasks'][2]['result']['stdout'], 'foobar') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 3) + self.assertEqual(liveaction.result["tasks"][2]["result"]["stdout"], "foobar") def test_chain_pause_resume_with_init_vars(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -611,32 +735,42 @@ def test_chain_pause_resume_with_init_vars(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_init_vars' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_init_vars" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -645,16 +779,20 @@ def test_chain_pause_resume_with_init_vars(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertEqual(liveaction.result['tasks'][1]['result']['stdout'], 'FOOBAR') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertEqual(liveaction.result["tasks"][1]["result"]["stdout"], "FOOBAR") def test_chain_pause_resume_with_no_more_task(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -663,32 +801,42 @@ def test_chain_pause_resume_with_no_more_task(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_no_more_task' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_no_more_task" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -697,15 +845,19 @@ def test_chain_pause_resume_with_no_more_task(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) def test_chain_pause_resume_last_task_failed_with_no_next_task(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -714,32 +866,44 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_last_task_failed_with_no_next_task' - params = {'tempfile': path, 'message': 'foobar'} + action = ( + TEST_PACK + "." + "test_pause_resume_last_task_failed_with_no_next_task" + ) + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -748,62 +912,70 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_FAILED) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) self.assertEqual( - liveaction.result['tasks'][0]['state'], - action_constants.LIVEACTION_STATUS_FAILED + liveaction.result["tasks"][0]["state"], + action_constants.LIVEACTION_STATUS_FAILED, ) def test_chain_pause_resume_status_change(self): # Tests context_result is updated when last task's status changes between pause and resume - action = TEST_PACK + '.' + 'test_pause_resume_context_result' + action = TEST_PACK + "." + "test_pause_resume_context_result" liveaction = LiveActionDB(action=action) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() - last_task_liveaction_id = liveaction.result['tasks'][-1]['liveaction_id'] + last_task_liveaction_id = liveaction.result["tasks"][-1]["liveaction_id"] action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_SUCCEEDED, end_timestamp=date_utils.get_datetime_utc_now(), - result={'foo': 'bar'}, - liveaction_id=last_task_liveaction_id + result={"foo": "bar"}, + liveaction_id=last_task_liveaction_id, ) # Request action chain to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertEqual( liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED, - str(liveaction) + str(liveaction), ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertEqual(liveaction.result['tasks'][0]['result']['foo'], 'bar') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertEqual(liveaction.result["tasks"][0]["result"]["foo"], "bar") diff --git a/contrib/runners/announcement_runner/announcement_runner/__init__.py b/contrib/runners/announcement_runner/announcement_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/announcement_runner/announcement_runner/__init__.py +++ b/contrib/runners/announcement_runner/announcement_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py index 6d219f2819..4782544c3c 100644 --- a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py +++ b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py @@ -24,12 +24,7 @@ from st2common.models.api.trace import TraceContext from st2common.transport.announcement import AnnouncementDispatcher -__all__ = [ - 'AnnouncementRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["AnnouncementRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -42,28 +37,28 @@ def __init__(self, runner_id): def pre_run(self): super(AnnouncementRunner, self).pre_run() - LOG.debug('Entering AnnouncementRunner.pre_run() for liveaction_id="%s"', - self.liveaction_id) + LOG.debug( + 'Entering AnnouncementRunner.pre_run() for liveaction_id="%s"', + self.liveaction_id, + ) - if not self.runner_parameters.get('experimental'): - message = ('Experimental flag is missing for action %s' % self.action.ref) - LOG.exception('Experimental runner is called without experimental flag.') + if not self.runner_parameters.get("experimental"): + message = "Experimental flag is missing for action %s" % self.action.ref + LOG.exception("Experimental runner is called without experimental flag.") raise runnerexceptions.ActionRunnerPreRunError(message) - self._route = self.runner_parameters.get('route') + self._route = self.runner_parameters.get("route") def run(self, action_parameters): - trace_context = self.liveaction.context.get('trace_context', None) + trace_context = self.liveaction.context.get("trace_context", None) if trace_context: trace_context = TraceContext(**trace_context) - self._dispatcher.dispatch(self._route, - payload=action_parameters, - trace_context=trace_context) + self._dispatcher.dispatch( + self._route, payload=action_parameters, trace_context=trace_context + ) - result = { - "output": action_parameters - } + result = {"output": action_parameters} result.update(action_parameters) return (LIVEACTION_STATUS_SUCCEEDED, result, None) @@ -74,4 +69,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('announcement_runner')[0] + return get_runner_metadata("announcement_runner")[0] diff --git a/contrib/runners/announcement_runner/dist_utils.py b/contrib/runners/announcement_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/announcement_runner/dist_utils.py +++ b/contrib/runners/announcement_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/announcement_runner/setup.py b/contrib/runners/announcement_runner/setup.py index efd60b14af..a72469ffea 100644 --- a/contrib/runners/announcement_runner/setup.py +++ b/contrib/runners/announcement_runner/setup.py @@ -26,30 +26,32 @@ from announcement_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-announcement', + name="stackstorm-runner-announcement", version=__version__, - description=('Announcement action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Announcement action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'announcement_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"announcement_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'announcement = announcement_runner.announcement_runner', + "st2common.runners.runner": [ + "announcement = announcement_runner.announcement_runner", ], - } + }, ) diff --git a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py index cc9c541015..9ad56a2115 100644 --- a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py +++ b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py @@ -26,69 +26,63 @@ mock_dispatcher = mock.Mock() -@mock.patch('st2common.transport.announcement.AnnouncementDispatcher.dispatch') +@mock.patch("st2common.transport.announcement.AnnouncementDispatcher.dispatch") class AnnouncementRunnerTestCase(RunnerTestCase): - @classmethod def setUpClass(cls): tests_config.parse_args() def test_runner_creation(self, dispatch): runner = announcement_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), announcement_runner.AnnouncementRunner, - 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), + announcement_runner.AnnouncementRunner, + "Creation failed. No instance.", + ) self.assertEqual(runner._dispatcher.dispatch, dispatch) def test_announcement(self, dispatch): runner = announcement_runner.get_runner() - runner.runner_parameters = { - 'experimental': True, - 'route': 'general' - } + runner.runner_parameters = {"experimental": True, "route": "general"} runner.liveaction = mock.Mock(context={}) runner.pre_run() - (status, result, _) = runner.run({'test': 'passed'}) + (status, result, _) = runner.run({"test": "passed"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(result) - self.assertEqual(result['test'], 'passed') - dispatch.assert_called_once_with('general', payload={'test': 'passed'}, - trace_context=None) + self.assertEqual(result["test"], "passed") + dispatch.assert_called_once_with( + "general", payload={"test": "passed"}, trace_context=None + ) def test_announcement_no_experimental(self, dispatch): runner = announcement_runner.get_runner() - runner.action = mock.Mock(ref='some.thing') - runner.runner_parameters = { - 'route': 'general' - } + runner.action = mock.Mock(ref="some.thing") + runner.runner_parameters = {"route": "general"} runner.liveaction = mock.Mock(context={}) - expected_msg = 'Experimental flag is missing for action some.thing' + expected_msg = "Experimental flag is missing for action some.thing" self.assertRaisesRegexp(Exception, expected_msg, runner.pre_run) - @mock.patch('st2common.models.api.trace.TraceContext.__new__') + @mock.patch("st2common.models.api.trace.TraceContext.__new__") def test_announcement_with_trace(self, context, dispatch): runner = announcement_runner.get_runner() - runner.runner_parameters = { - 'experimental': True, - 'route': 'general' - } - runner.liveaction = mock.Mock(context={ - 'trace_context': { - 'id_': 'a', - 'trace_tag': 'b' - } - }) + runner.runner_parameters = {"experimental": True, "route": "general"} + runner.liveaction = mock.Mock( + context={"trace_context": {"id_": "a", "trace_tag": "b"}} + ) runner.pre_run() - (status, result, _) = runner.run({'test': 'passed'}) + (status, result, _) = runner.run({"test": "passed"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(result) - self.assertEqual(result['test'], 'passed') - context.assert_called_once_with(TraceContext, - **runner.liveaction.context['trace_context']) - dispatch.assert_called_once_with('general', payload={'test': 'passed'}, - trace_context=context.return_value) + self.assertEqual(result["test"], "passed") + context.assert_called_once_with( + TraceContext, **runner.liveaction.context["trace_context"] + ) + dispatch.assert_called_once_with( + "general", payload={"test": "passed"}, trace_context=context.return_value + ) diff --git a/contrib/runners/http_runner/dist_utils.py b/contrib/runners/http_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/http_runner/dist_utils.py +++ b/contrib/runners/http_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/http_runner/http_runner/__init__.py b/contrib/runners/http_runner/http_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/http_runner/http_runner/__init__.py +++ b/contrib/runners/http_runner/http_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/http_runner/http_runner/http_runner.py b/contrib/runners/http_runner/http_runner/http_runner.py index b2ff115fc6..6a02c809b9 100644 --- a/contrib/runners/http_runner/http_runner/http_runner.py +++ b/contrib/runners/http_runner/http_runner/http_runner.py @@ -35,45 +35,36 @@ import six from six.moves import range -__all__ = [ - 'HttpRunner', - - 'HTTPClient', - - 'get_runner', - 'get_metadata' -] +__all__ = ["HttpRunner", "HTTPClient", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) SUCCESS_STATUS_CODES = [code for code in range(200, 207)] # Lookup constants for runner params -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_URL = 'url' -RUNNER_HEADERS = 'headers' # Debatable whether this should be action params. -RUNNER_COOKIES = 'cookies' -RUNNER_ALLOW_REDIRECTS = 'allow_redirects' -RUNNER_HTTP_PROXY = 'http_proxy' -RUNNER_HTTPS_PROXY = 'https_proxy' -RUNNER_VERIFY_SSL_CERT = 'verify_ssl_cert' -RUNNER_USERNAME = 'username' -RUNNER_PASSWORD = 'password' -RUNNER_URL_HOSTS_BLACKLIST = 'url_hosts_blacklist' -RUNNER_URL_HOSTS_WHITELIST = 'url_hosts_whitelist' +RUNNER_ON_BEHALF_USER = "user" +RUNNER_URL = "url" +RUNNER_HEADERS = "headers" # Debatable whether this should be action params. +RUNNER_COOKIES = "cookies" +RUNNER_ALLOW_REDIRECTS = "allow_redirects" +RUNNER_HTTP_PROXY = "http_proxy" +RUNNER_HTTPS_PROXY = "https_proxy" +RUNNER_VERIFY_SSL_CERT = "verify_ssl_cert" +RUNNER_USERNAME = "username" +RUNNER_PASSWORD = "password" +RUNNER_URL_HOSTS_BLACKLIST = "url_hosts_blacklist" +RUNNER_URL_HOSTS_WHITELIST = "url_hosts_whitelist" # Lookup constants for action params -ACTION_AUTH = 'auth' -ACTION_BODY = 'body' -ACTION_TIMEOUT = 'timeout' -ACTION_METHOD = 'method' -ACTION_QUERY_PARAMS = 'params' -FILE_NAME = 'file_name' -FILE_CONTENT = 'file_content' -FILE_CONTENT_TYPE = 'file_content_type' +ACTION_AUTH = "auth" +ACTION_BODY = "body" +ACTION_TIMEOUT = "timeout" +ACTION_METHOD = "method" +ACTION_QUERY_PARAMS = "params" +FILE_NAME = "file_name" +FILE_CONTENT = "file_content" +FILE_CONTENT_TYPE = "file_content_type" -RESPONSE_BODY_PARSE_FUNCTIONS = { - 'application/json': json.loads -} +RESPONSE_BODY_PARSE_FUNCTIONS = {"application/json": json.loads} class HttpRunner(ActionRunner): @@ -85,37 +76,48 @@ def __init__(self, runner_id): def pre_run(self): super(HttpRunner, self).pre_run() - LOG.debug('Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id) - self._on_behalf_user = self.runner_parameters.get(RUNNER_ON_BEHALF_USER, - self._on_behalf_user) + LOG.debug( + 'Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id + ) + self._on_behalf_user = self.runner_parameters.get( + RUNNER_ON_BEHALF_USER, self._on_behalf_user + ) self._url = self.runner_parameters.get(RUNNER_URL, None) self._headers = self.runner_parameters.get(RUNNER_HEADERS, {}) self._cookies = self.runner_parameters.get(RUNNER_COOKIES, None) - self._allow_redirects = self.runner_parameters.get(RUNNER_ALLOW_REDIRECTS, False) + self._allow_redirects = self.runner_parameters.get( + RUNNER_ALLOW_REDIRECTS, False + ) self._username = self.runner_parameters.get(RUNNER_USERNAME, None) self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._http_proxy = self.runner_parameters.get(RUNNER_HTTP_PROXY, None) self._https_proxy = self.runner_parameters.get(RUNNER_HTTPS_PROXY, None) self._verify_ssl_cert = self.runner_parameters.get(RUNNER_VERIFY_SSL_CERT, None) - self._url_hosts_blacklist = self.runner_parameters.get(RUNNER_URL_HOSTS_BLACKLIST, []) - self._url_hosts_whitelist = self.runner_parameters.get(RUNNER_URL_HOSTS_WHITELIST, []) + self._url_hosts_blacklist = self.runner_parameters.get( + RUNNER_URL_HOSTS_BLACKLIST, [] + ) + self._url_hosts_whitelist = self.runner_parameters.get( + RUNNER_URL_HOSTS_WHITELIST, [] + ) def run(self, action_parameters): client = self._get_http_client(action_parameters) if self._url_hosts_blacklist and self._url_hosts_whitelist: - msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive. Only one should be provided.') + msg = ( + '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive. Only one should be provided." + ) raise ValueError(msg) try: result = client.run() except requests.exceptions.Timeout as e: - result = {'error': six.text_type(e)} + result = {"error": six.text_type(e)} status = LIVEACTION_STATUS_TIMED_OUT else: - status = HttpRunner._get_result_status(result.get('status_code', None)) + status = HttpRunner._get_result_status(result.get("status_code", None)) return (status, result, None) @@ -132,8 +134,8 @@ def _get_http_client(self, action_parameters): # Include our user agent and action name so requests can be tracked back headers = copy.deepcopy(self._headers) if self._headers else {} - headers['User-Agent'] = 'st2/v%s' % (st2_version) - headers['X-Stanley-Action'] = self.action_name + headers["User-Agent"] = "st2/v%s" % (st2_version) + headers["X-Stanley-Action"] = self.action_name if file_name and file_content: files = {} @@ -141,7 +143,7 @@ def _get_http_client(self, action_parameters): if file_content_type: value = (file_content, file_content_type) else: - value = (file_content) + value = file_content files[file_name] = value else: @@ -150,43 +152,72 @@ def _get_http_client(self, action_parameters): proxies = {} if self._http_proxy: - proxies['http'] = self._http_proxy + proxies["http"] = self._http_proxy if self._https_proxy: - proxies['https'] = self._https_proxy - - return HTTPClient(url=self._url, method=method, body=body, params=params, - headers=headers, cookies=self._cookies, auth=auth, - timeout=timeout, allow_redirects=self._allow_redirects, - proxies=proxies, files=files, verify=self._verify_ssl_cert, - username=self._username, password=self._password, - url_hosts_blacklist=self._url_hosts_blacklist, - url_hosts_whitelist=self._url_hosts_whitelist) + proxies["https"] = self._https_proxy + + return HTTPClient( + url=self._url, + method=method, + body=body, + params=params, + headers=headers, + cookies=self._cookies, + auth=auth, + timeout=timeout, + allow_redirects=self._allow_redirects, + proxies=proxies, + files=files, + verify=self._verify_ssl_cert, + username=self._username, + password=self._password, + url_hosts_blacklist=self._url_hosts_blacklist, + url_hosts_whitelist=self._url_hosts_whitelist, + ) @staticmethod def _get_result_status(status_code): - return LIVEACTION_STATUS_SUCCEEDED if status_code in SUCCESS_STATUS_CODES \ + return ( + LIVEACTION_STATUS_SUCCEEDED + if status_code in SUCCESS_STATUS_CODES else LIVEACTION_STATUS_FAILED + ) class HTTPClient(object): - def __init__(self, url=None, method=None, body='', params=None, headers=None, cookies=None, - auth=None, timeout=60, allow_redirects=False, proxies=None, - files=None, verify=False, username=None, password=None, - url_hosts_blacklist=None, url_hosts_whitelist=None): + def __init__( + self, + url=None, + method=None, + body="", + params=None, + headers=None, + cookies=None, + auth=None, + timeout=60, + allow_redirects=False, + proxies=None, + files=None, + verify=False, + username=None, + password=None, + url_hosts_blacklist=None, + url_hosts_whitelist=None, + ): if url is None: - raise Exception('URL must be specified.') + raise Exception("URL must be specified.") if method is None: if files or body: - method = 'POST' + method = "POST" else: - method = 'GET' + method = "GET" headers = headers or {} normalized_headers = self._normalize_headers(headers=headers) - if body and 'content-length' not in normalized_headers: - headers['Content-Length'] = str(len(body)) + if body and "content-length" not in normalized_headers: + headers["Content-Length"] = str(len(body)) self.url = url self.method = method @@ -207,8 +238,10 @@ def __init__(self, url=None, method=None, body='', params=None, headers=None, co self.url_hosts_whitelist = url_hosts_whitelist or [] if self.url_hosts_blacklist and self.url_hosts_whitelist: - msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive. Only one should be provided.') + msg = ( + '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive. Only one should be provided." + ) raise ValueError(msg) def run(self): @@ -235,7 +268,7 @@ def run(self): try: data = json.dumps(data) except ValueError: - msg = 'Request body (%s) can\'t be parsed as JSON' % (data) + msg = "Request body (%s) can't be parsed as JSON" % (data) raise ValueError(msg) else: data = self.body @@ -245,7 +278,7 @@ def run(self): # Ensure data is bytes since that what request expects if isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") resp = requests.request( self.method, @@ -259,19 +292,19 @@ def run(self): allow_redirects=self.allow_redirects, proxies=self.proxies, files=self.files, - verify=self.verify + verify=self.verify, ) headers = dict(resp.headers) body, parsed = self._parse_response_body(headers=headers, body=resp.text) - results['status_code'] = resp.status_code - results['body'] = body - results['parsed'] = parsed # flag which indicates if body has been parsed - results['headers'] = headers + results["status_code"] = resp.status_code + results["body"] = body + results["parsed"] = parsed # flag which indicates if body has been parsed + results["headers"] = headers return results except Exception as e: - LOG.exception('Exception making request to remote URL: %s, %s', self.url, e) + LOG.exception("Exception making request to remote URL: %s, %s", self.url, e) raise finally: if resp: @@ -285,27 +318,27 @@ def _parse_response_body(self, headers, body): :return: (parsed body, flag which indicates if body has been parsed) :rtype: (``object``, ``bool``) """ - body = body or '' + body = body or "" headers = self._normalize_headers(headers=headers) - content_type = headers.get('content-type', None) + content_type = headers.get("content-type", None) parsed = False if not content_type: return (body, parsed) # The header can also contain charset which we simply discard - content_type = content_type.split(';')[0] + content_type = content_type.split(";")[0] parse_func = RESPONSE_BODY_PARSE_FUNCTIONS.get(content_type, None) if not parse_func: return (body, parsed) - LOG.debug('Parsing body with content type: %s', content_type) + LOG.debug("Parsing body with content type: %s", content_type) try: body = parse_func(body) except Exception: - LOG.exception('Failed to parse body') + LOG.exception("Failed to parse body") else: parsed = True @@ -323,7 +356,7 @@ def _normalize_headers(self, headers): def _is_json_content(self): normalized = self._normalize_headers(self.headers) - return normalized.get('content-type', None) == 'application/json' + return normalized.get("content-type", None) == "application/json" def _cast_object(self, value): if isinstance(value, str) or isinstance(value, six.text_type): @@ -370,10 +403,10 @@ def _get_host_from_url(self, url): parsed = urlparse.urlparse(url) # Remove port and [] - host = parsed.netloc.replace('[', '').replace(']', '') + host = parsed.netloc.replace("[", "").replace("]", "") if parsed.port is not None: - host = host.replace(':%s' % (parsed.port), '') + host = host.replace(":%s" % (parsed.port), "") return host @@ -383,4 +416,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('http_runner')[0] + return get_runner_metadata("http_runner")[0] diff --git a/contrib/runners/http_runner/setup.py b/contrib/runners/http_runner/setup.py index 2b962da599..2a5c9e217b 100644 --- a/contrib/runners/http_runner/setup.py +++ b/contrib/runners/http_runner/setup.py @@ -26,30 +26,32 @@ from http_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-http', + name="stackstorm-runner-http", version=__version__, - description=('HTTP(s) action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "HTTP(s) action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'http_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"http_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'http-request = http_runner.http_runner', + "st2common.runners.runner": [ + "http-request = http_runner.http_runner", ], - } + }, ) diff --git a/contrib/runners/http_runner/tests/unit/test_http_runner.py b/contrib/runners/http_runner/tests/unit/test_http_runner.py index be64f6d420..9d2d99a7c1 100644 --- a/contrib/runners/http_runner/tests/unit/test_http_runner.py +++ b/contrib/runners/http_runner/tests/unit/test_http_runner.py @@ -28,16 +28,13 @@ import st2tests.config as tests_config -__all__ = [ - 'HTTPClientTestCase', - 'HTTPRunnerTestCase' -] +__all__ = ["HTTPClientTestCase", "HTTPRunnerTestCase"] if six.PY2: - EXPECTED_DATA = '' + EXPECTED_DATA = "" else: - EXPECTED_DATA = b'' + EXPECTED_DATA = b"" class MockResult(object): @@ -49,70 +46,70 @@ class HTTPClientTestCase(unittest2.TestCase): def setUpClass(cls): tests_config.parse_args() - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_parse_response_body(self, mock_requests): - client = HTTPClient(url='http://127.0.0.1') + client = HTTPClient(url="http://127.0.0.1") mock_result = MockResult() # Unknown content type, body should be returned raw - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['body'], mock_result.text) - self.assertEqual(result['status_code'], mock_result.status_code) - self.assertEqual(result['headers'], mock_result.headers) + self.assertEqual(result["body"], mock_result.text) + self.assertEqual(result["status_code"], mock_result.status_code) + self.assertEqual(result["headers"], mock_result.headers) # Unknown content type, JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.headers = {"Content-Type": "text/html"} mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['body'], mock_result.text) + self.assertEqual(result["body"], mock_result.text) # JSON content-type and JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_requests.request.return_value = mock_result result = client.run() - self.assertIsInstance(result['body'], dict) - self.assertEqual(result['body'], {'test1': 'val1'}) + self.assertIsInstance(result["body"], dict) + self.assertEqual(result["body"], {"test1": "val1"}) # JSON content-type with charset and JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'application/json; charset=UTF-8'} + mock_result.headers = {"Content-Type": "application/json; charset=UTF-8"} mock_requests.request.return_value = mock_result result = client.run() - self.assertIsInstance(result['body'], dict) - self.assertEqual(result['body'], {'test1': 'val1'}) + self.assertIsInstance(result["body"], dict) + self.assertEqual(result["body"], {"test1": "val1"}) # JSON content-type and invalid json body - mock_result.text = 'not json' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.text = "not json" + mock_result.headers = {"Content-Type": "application/json"} mock_requests.request.return_value = mock_result result = client.run() - self.assertNotIsInstance(result['body'], dict) - self.assertEqual(result['body'], mock_result.text) + self.assertNotIsInstance(result["body"], dict) + self.assertEqual(result["body"], mock_result.text) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_https_verify(self, mock_requests): - url = 'https://127.0.0.1:8888' + url = "https://127.0.0.1:8888" client = HTTPClient(url=url, verify=True) mock_result = MockResult() - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result @@ -121,23 +118,33 @@ def test_https_verify(self, mock_requests): self.assertTrue(client.verify) if six.PY2: - data = '' + data = "" else: - data = b'' + data = b"" mock_requests.request.assert_called_with( - 'GET', url, allow_redirects=False, auth=None, cookies=None, - data=data, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=True) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=None, + cookies=None, + data=data, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=True, + ) + + @mock.patch("http_runner.http_runner.requests") def test_https_verify_false(self, mock_requests): - url = 'https://127.0.0.1:8888' + url = "https://127.0.0.1:8888" client = HTTPClient(url=url) mock_result = MockResult() - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result @@ -146,182 +153,202 @@ def test_https_verify_false(self, mock_requests): self.assertFalse(client.verify) mock_requests.request.assert_called_with( - 'GET', url, allow_redirects=False, auth=None, cookies=None, - data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=False) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=None, + cookies=None, + data=EXPECTED_DATA, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=False, + ) + + @mock.patch("http_runner.http_runner.requests") def test_https_auth_basic(self, mock_requests): - url = 'https://127.0.0.1:8888' - username = 'misspiggy' - password = 'kermit' + url = "https://127.0.0.1:8888" + username = "misspiggy" + password = "kermit" client = HTTPClient(url=url, username=username, password=password) mock_result = MockResult() - mock_result.text = 'muppet show' - mock_result.headers = {'Authorization': 'bWlzc3BpZ2d5Omtlcm1pdA=='} + mock_result.text = "muppet show" + mock_result.headers = {"Authorization": "bWlzc3BpZ2d5Omtlcm1pdA=="} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['headers'], mock_result.headers) + self.assertEqual(result["headers"], mock_result.headers) mock_requests.request.assert_called_once_with( - 'GET', url, allow_redirects=False, auth=client.auth, cookies=None, - data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=False) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=client.auth, + cookies=None, + data=EXPECTED_DATA, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=False, + ) + + @mock.patch("http_runner.http_runner.requests") def test_http_unicode_body_data(self, mock_requests): - url = 'http://127.0.0.1:8888' - method = 'POST' + url = "http://127.0.0.1:8888" + method = "POST" mock_result = MockResult() # 1. String data headers = {} - body = 'žžžžž' - client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1) + body = "žžžžž" + client = HTTPClient( + url=url, method=method, headers=headers, body=body, timeout=0.1 + ) mock_result.text = '{"foo": "bar"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['status_code'], 200) + self.assertEqual(result["status_code"], 200) call_kwargs = mock_requests.request.call_args_list[0][1] - expected_data = u'žžžžž'.encode('utf-8') - self.assertEqual(call_kwargs['data'], expected_data) + expected_data = "žžžžž".encode("utf-8") + self.assertEqual(call_kwargs["data"], expected_data) # 1. Object / JSON data - body = { - 'foo': u'ažž' - } - headers = { - 'Content-Type': 'application/json; charset=utf-8' - } - client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1) + body = {"foo": "ažž"} + headers = {"Content-Type": "application/json; charset=utf-8"} + client = HTTPClient( + url=url, method=method, headers=headers, body=body, timeout=0.1 + ) mock_result.text = '{"foo": "bar"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['status_code'], 200) + self.assertEqual(result["status_code"], 200) call_kwargs = mock_requests.request.call_args_list[1][1] if six.PY2: - expected_data = { - 'foo': u'a\u017e\u017e' - } + expected_data = {"foo": "a\u017e\u017e"} else: expected_data = body - self.assertEqual(call_kwargs['data'], expected_data) + self.assertEqual(call_kwargs["data"], expected_data) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_blacklisted_url_url_hosts_blacklist_runner_parameter(self, mock_requests): # Black list is empty self.assertEqual(mock_requests.request.call_count, 0) - url = 'http://www.example.com' - client = HTTPClient(url=url, method='GET') + url = "http://www.example.com" + client = HTTPClient(url=url, method="GET") client.run() self.assertEqual(mock_requests.request.call_count, 1) # Blacklist is set url_hosts_blacklist = [ - 'example.com', - '127.0.0.1', - '::1', - '2001:0db8:85a3:0000:0000:8a2e:0370:7334' + "example.com", + "127.0.0.1", + "::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334", ] # Blacklisted urls urls = [ - 'https://example.com', - 'http://example.com', - 'http://example.com:81', - 'http://example.com:80', - 'http://example.com:9000', - 'http://[::1]:80/', - 'http://[::1]', - 'http://[::1]:9000', - 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', - 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000' + "https://example.com", + "http://example.com", + "http://example.com:81", + "http://example.com:80", + "http://example.com:9000", + "http://[::1]:80/", + "http://[::1]", + "http://[::1]:9000", + "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000", ] for url in urls: expected_msg = r'URL "%s" is blacklisted' % (re.escape(url)) - client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist) + client = HTTPClient( + url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist + ) self.assertRaisesRegexp(ValueError, expected_msg, client.run) # Non blacklisted URLs - urls = [ - 'https://example2.com', - 'http://example3.com', - 'http://example4.com:81' - ] + urls = ["https://example2.com", "http://example3.com", "http://example4.com:81"] for url in urls: mock_requests.request.reset_mock() self.assertEqual(mock_requests.request.call_count, 0) - client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist) + client = HTTPClient( + url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist + ) client.run() self.assertEqual(mock_requests.request.call_count, 1) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_requests): # Whitelist is empty self.assertEqual(mock_requests.request.call_count, 0) - url = 'http://www.example.com' - client = HTTPClient(url=url, method='GET') + url = "http://www.example.com" + client = HTTPClient(url=url, method="GET") client.run() self.assertEqual(mock_requests.request.call_count, 1) # Whitelist is set url_hosts_whitelist = [ - 'example.com', - '127.0.0.1', - '::1', - '2001:0db8:85a3:0000:0000:8a2e:0370:7334' + "example.com", + "127.0.0.1", + "::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334", ] # Non whitelisted urls urls = [ - 'https://www.google.com', - 'https://www.example2.com', - 'http://127.0.0.2' + "https://www.google.com", + "https://www.example2.com", + "http://127.0.0.2", ] for url in urls: expected_msg = r'URL "%s" is not whitelisted' % (re.escape(url)) - client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist) + client = HTTPClient( + url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist + ) self.assertRaisesRegexp(ValueError, expected_msg, client.run) # Whitelisted URLS urls = [ - 'https://example.com', - 'http://example.com', - 'http://example.com:81', - 'http://example.com:80', - 'http://example.com:9000', - 'http://[::1]:80/', - 'http://[::1]', - 'http://[::1]:9000', - 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', - 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000' + "https://example.com", + "http://example.com", + "http://example.com:81", + "http://example.com:80", + "http://example.com:9000", + "http://[::1]:80/", + "http://[::1]", + "http://[::1]:9000", + "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000", ] for url in urls: @@ -329,57 +356,71 @@ def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_request self.assertEqual(mock_requests.request.call_count, 0) - client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist) + client = HTTPClient( + url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist + ) client.run() self.assertEqual(mock_requests.request.call_count, 1) - def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self): - url = 'http://www.example.com' - - expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive.') - self.assertRaisesRegexp(ValueError, expected_msg, HTTPClient, url=url, method='GET', - url_hosts_blacklist=[url], url_hosts_whitelist=[url]) + def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive( + self, + ): + url = "http://www.example.com" + + expected_msg = ( + r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive." + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + HTTPClient, + url=url, + method="GET", + url_hosts_blacklist=[url], + url_hosts_whitelist=[url], + ) class HTTPRunnerTestCase(unittest2.TestCase): - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_get_success(self, mock_requests): mock_result = MockResult() # Unknown content type, body should be returned raw - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result - runner_parameters = { - 'url': 'http://www.example.com', - 'method': 'GET' - } - runner = HttpRunner('id') + runner_parameters = {"url": "http://www.example.com", "method": "GET"} + runner = HttpRunner("id") runner.runner_parameters = runner_parameters runner.pre_run() status, result, _ = runner.run({}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['body'], 'foo bar ponies') - self.assertEqual(result['status_code'], 200) - self.assertEqual(result['parsed'], False) + self.assertEqual(result["body"], "foo bar ponies") + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["parsed"], False) - def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self): + def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive( + self, + ): runner_parameters = { - 'url': 'http://www.example.com', - 'method': 'GET', - 'url_hosts_blacklist': ['http://127.0.0.1'], - 'url_hosts_whitelist': ['http://127.0.0.1'], + "url": "http://www.example.com", + "method": "GET", + "url_hosts_blacklist": ["http://127.0.0.1"], + "url_hosts_whitelist": ["http://127.0.0.1"], } - runner = HttpRunner('id') + runner = HttpRunner("id") runner.runner_parameters = runner_parameters runner.pre_run() - expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive.') + expected_msg = ( + r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.run, {}) diff --git a/contrib/runners/inquirer_runner/dist_utils.py b/contrib/runners/inquirer_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/inquirer_runner/dist_utils.py +++ b/contrib/runners/inquirer_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py +++ b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py index 6a5757bc44..af0f0c6f34 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py +++ b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py @@ -29,20 +29,16 @@ from st2common.util import action_db as action_utils -__all__ = [ - 'Inquirer', - 'get_runner', - 'get_metadata' -] +__all__ = ["Inquirer", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_SCHEMA = 'schema' -RUNNER_ROLES = 'roles' -RUNNER_USERS = 'users' -RUNNER_ROUTE = 'route' -RUNNER_TTL = 'ttl' +RUNNER_SCHEMA = "schema" +RUNNER_ROLES = "roles" +RUNNER_USERS = "users" +RUNNER_ROUTE = "route" +RUNNER_TTL = "ttl" DEFAULT_SCHEMA = { "title": "response_data", @@ -51,15 +47,14 @@ "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } - } + }, } class Inquirer(runners.ActionRunner): - """This runner implements the ability to ask for more input during a workflow - """ + """This runner implements the ability to ask for more input during a workflow""" def __init__(self, runner_id): super(Inquirer, self).__init__(runner_id=runner_id) @@ -83,14 +78,11 @@ def run(self, action_parameters): # Assemble and dispatch trigger trigger_ref = sys_db_models.ResourceReference.to_string_reference( - pack=trigger_constants.INQUIRY_TRIGGER['pack'], - name=trigger_constants.INQUIRY_TRIGGER['name'] + pack=trigger_constants.INQUIRY_TRIGGER["pack"], + name=trigger_constants.INQUIRY_TRIGGER["name"], ) - trigger_payload = { - "id": str(exc.id), - "route": self.route - } + trigger_payload = {"id": str(exc.id), "route": self.route} self.trigger_dispatcher.dispatch(trigger_ref, trigger_payload) @@ -99,7 +91,7 @@ def run(self, action_parameters): "roles": self.roles_param, "users": self.users_param, "route": self.route, - "ttl": self.ttl + "ttl": self.ttl, } return (action_constants.LIVEACTION_STATUS_PENDING, result, None) @@ -110,9 +102,10 @@ def post_run(self, status, result): # is made in the run method, but because the liveaction hasn't update to pending status # yet, there is a race condition where the pause request is mishandled. if status == action_constants.LIVEACTION_STATUS_PENDING: - pause_parent = ( - self.liveaction.context.get("parent") and - not workflow_service.is_action_execution_under_workflow_context(self.liveaction) + pause_parent = self.liveaction.context.get( + "parent" + ) and not workflow_service.is_action_execution_under_workflow_context( + self.liveaction ) # For action execution under Action Chain workflows, request the entire @@ -122,7 +115,9 @@ def post_run(self, status, result): # to pause the workflow. if pause_parent: root_liveaction = action_service.get_root_liveaction(self.liveaction) - action_service.request_pause(root_liveaction, self.context.get('user', None)) + action_service.request_pause( + root_liveaction, self.context.get("user", None) + ) # Invoke post run of parent for common post run related work. super(Inquirer, self).post_run(status, result) @@ -133,4 +128,4 @@ def get_runner(): def get_metadata(): - return runners.get_metadata('inquirer_runner')[0] + return runners.get_metadata("inquirer_runner")[0] diff --git a/contrib/runners/inquirer_runner/setup.py b/contrib/runners/inquirer_runner/setup.py index 9be54704f9..44d4a4d7f7 100644 --- a/contrib/runners/inquirer_runner/setup.py +++ b/contrib/runners/inquirer_runner/setup.py @@ -26,30 +26,32 @@ from inquirer_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-inquirer', + name="stackstorm-runner-inquirer", version=__version__, - description=('Inquirer action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Inquirer action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'inquirer_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"inquirer_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'inquirer = inquirer_runner.inquirer_runner', + "st2common.runners.runner": [ + "inquirer = inquirer_runner.inquirer_runner", ], - } + }, ) diff --git a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py index da9c70b78a..caa47bc53a 100644 --- a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py +++ b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py @@ -28,7 +28,7 @@ mock_exc_get = mock.Mock() -mock_exc_get.id = 'abcdef' +mock_exc_get.id = "abcdef" mock_inquiry_liveaction_db = mock.Mock() mock_inquiry_liveaction_db.result = {"response": {}} @@ -37,7 +37,7 @@ mock_action_utils.return_value = mock_inquiry_liveaction_db test_parent = mock.Mock() -test_parent.id = '1234567890' +test_parent.id = "1234567890" mock_get_root = mock.Mock() mock_get_root.return_value = test_parent @@ -45,38 +45,19 @@ mock_trigger_dispatcher = mock.Mock() mock_request_pause = mock.Mock() -test_user = 'st2admin' +test_user = "st2admin" -runner_params = { - "users": [], - "roles": [], - "route": "developers", - "schema": {} -} +runner_params = {"users": [], "roles": [], "route": "developers", "schema": {}} +@mock.patch.object(reactor_transport, "TriggerDispatcher", mock_trigger_dispatcher) +@mock.patch.object(action_utils, "get_liveaction_by_id", mock_action_utils) +@mock.patch.object(action_service, "request_pause", mock_request_pause) +@mock.patch.object(action_service, "get_root_liveaction", mock_get_root) @mock.patch.object( - reactor_transport, - 'TriggerDispatcher', - mock_trigger_dispatcher) -@mock.patch.object( - action_utils, - 'get_liveaction_by_id', - mock_action_utils) -@mock.patch.object( - action_service, - 'request_pause', - mock_request_pause) -@mock.patch.object( - action_service, - 'get_root_liveaction', - mock_get_root) -@mock.patch.object( - ex_db_access.ActionExecution, - 'get', - mock.MagicMock(return_value=mock_exc_get)) + ex_db_access.ActionExecution, "get", mock.MagicMock(return_value=mock_exc_get) +) class InquiryTestCase(st2tests.RunnerTestCase): - def tearDown(self): mock_trigger_dispatcher.reset_mock() mock_action_utils.reset_mock() @@ -85,17 +66,19 @@ def tearDown(self): def test_runner_creation(self): runner = inquirer_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), inquirer_runner.Inquirer, 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), inquirer_runner.Inquirer, "Creation failed. No instance." + ) def test_simple_inquiry(self): runner = inquirer_runner.get_runner() - runner.context = {'user': test_user} + runner.context = {"user": test_user} runner.action = self._get_mock_action_obj() runner.runner_parameters = runner_params runner.pre_run() - mock_inquiry_liveaction_db.context = {'parent': test_parent.id} + mock_inquiry_liveaction_db.context = {"parent": test_parent.id} runner.liveaction = mock_inquiry_liveaction_db (status, output, _) = runner.run({}) @@ -104,20 +87,16 @@ def test_simple_inquiry(self): self.assertEqual( output, { - 'users': [], - 'roles': [], - 'route': "developers", - 'schema': {}, - 'ttl': 1440 - } + "users": [], + "roles": [], + "route": "developers", + "schema": {}, + "ttl": 1440, + }, ) mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with( - 'core.st2.generic.inquiry', - { - 'id': mock_exc_get.id, - 'route': "developers" - } + "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"} ) runner.post_run(action_constants.LIVEACTION_STATUS_PENDING, {}) @@ -125,37 +104,28 @@ def test_simple_inquiry(self): mock_request_pause.assert_called_once_with(test_parent, test_user) def test_inquiry_no_parent(self): - """Should behave like a regular execution, but without requesting a pause - """ + """Should behave like a regular execution, but without requesting a pause""" runner = inquirer_runner.get_runner() - runner.context = { - 'user': 'st2admin' - } + runner.context = {"user": "st2admin"} runner.action = self._get_mock_action_obj() runner.runner_parameters = runner_params runner.pre_run() - mock_inquiry_liveaction_db.context = { - "parent": None - } + mock_inquiry_liveaction_db.context = {"parent": None} (status, output, _) = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_PENDING) self.assertEqual( output, { - 'users': [], - 'roles': [], - 'route': "developers", - 'schema': {}, - 'ttl': 1440 - } + "users": [], + "roles": [], + "route": "developers", + "schema": {}, + "ttl": 1440, + }, ) mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with( - 'core.st2.generic.inquiry', - { - 'id': mock_exc_get.id, - 'route': "developers" - } + "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"} ) mock_request_pause.assert_not_called() diff --git a/contrib/runners/local_runner/dist_utils.py b/contrib/runners/local_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/local_runner/dist_utils.py +++ b/contrib/runners/local_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/local_runner/local_runner/__init__.py b/contrib/runners/local_runner/local_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/local_runner/local_runner/__init__.py +++ b/contrib/runners/local_runner/local_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/local_runner/local_runner/base.py b/contrib/runners/local_runner/local_runner/base.py index 5bcf20137f..4fda9c1866 100644 --- a/contrib/runners/local_runner/local_runner/base.py +++ b/contrib/runners/local_runner/local_runner/base.py @@ -39,32 +39,36 @@ from st2common.services.action import store_execution_output_data from st2common.runners.utils import make_read_and_store_stream_func -__all__ = [ - 'BaseLocalShellRunner', - - 'RUNNER_COMMAND' -] +__all__ = ["BaseLocalShellRunner", "RUNNER_COMMAND"] LOG = logging.getLogger(__name__) -DEFAULT_KWARG_OP = '--' +DEFAULT_KWARG_OP = "--" LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0] # constants to lookup in runner_parameters. -RUNNER_SUDO = 'sudo' -RUNNER_SUDO_PASSWORD = 'sudo_password' -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_COMMAND = 'cmd' -RUNNER_CWD = 'cwd' -RUNNER_ENV = 'env' -RUNNER_KWARG_OP = 'kwarg_op' -RUNNER_TIMEOUT = 'timeout' +RUNNER_SUDO = "sudo" +RUNNER_SUDO_PASSWORD = "sudo_password" +RUNNER_ON_BEHALF_USER = "user" +RUNNER_COMMAND = "cmd" +RUNNER_CWD = "cwd" +RUNNER_ENV = "env" +RUNNER_KWARG_OP = "kwarg_op" +RUNNER_TIMEOUT = "timeout" PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP = { - str(exit_code_constants.SUCCESS_EXIT_CODE): action_constants.LIVEACTION_STATUS_SUCCEEDED, - str(exit_code_constants.FAILURE_EXIT_CODE): action_constants.LIVEACTION_STATUS_FAILED, - str(-1 * exit_code_constants.SIGKILL_EXIT_CODE): action_constants.LIVEACTION_STATUS_TIMED_OUT, - str(-1 * exit_code_constants.SIGTERM_EXIT_CODE): action_constants.LIVEACTION_STATUS_ABANDONED + str( + exit_code_constants.SUCCESS_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_SUCCEEDED, + str( + exit_code_constants.FAILURE_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_FAILED, + str( + -1 * exit_code_constants.SIGKILL_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_TIMED_OUT, + str( + -1 * exit_code_constants.SIGTERM_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_ABANDONED, } @@ -77,7 +81,8 @@ class BaseLocalShellRunner(ActionRunner, ShellRunnerMixin): Note: The user under which the action runner service is running (stanley user by default) needs to have pasworless sudo access set up. """ - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] + + KEYS_TO_TRANSFORM = ["stdout", "stderr"] def __init__(self, runner_id): super(BaseLocalShellRunner, self).__init__(runner_id=runner_id) @@ -87,14 +92,17 @@ def pre_run(self): self._sudo = self.runner_parameters.get(RUNNER_SUDO, False) self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None) - self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME) + self._on_behalf_user = self.context.get( + RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME + ) self._user = cfg.CONF.system_user.user self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) self._env = self._env or {} self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, DEFAULT_KWARG_OP) self._timeout = self.runner_parameters.get( - RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT) + RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT + ) def _run(self, action): env_vars = self._env @@ -110,8 +118,11 @@ def _run(self, action): # For consistency with the old Fabric based runner, make sure the file is executable if script_action: script_local_path_abs = self.entry_point - args = 'chmod +x %s ; %s' % (script_local_path_abs, args) - sanitized_args = 'chmod +x %s ; %s' % (script_local_path_abs, sanitized_args) + args = "chmod +x %s ; %s" % (script_local_path_abs, args) + sanitized_args = "chmod +x %s ; %s" % ( + script_local_path_abs, + sanitized_args, + ) env = os.environ.copy() @@ -122,22 +133,38 @@ def _run(self, action): st2_env_vars = self._get_common_action_env_variables() env.update(st2_env_vars) - LOG.info('Executing action via LocalRunner: %s', self.runner_id) - LOG.info('[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s' % - (action.name, action.action_exec_id, sanitized_args, action.user, action.sudo)) + LOG.info("Executing action via LocalRunner: %s", self.runner_id) + LOG.info( + "[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s" + % ( + action.name, + action.action_exec_id, + sanitized_args, + action.user, + action.sudo, + ) + ) stdout = StringIO() stderr = StringIO() - store_execution_stdout_line = functools.partial(store_execution_output_data, - output_type='stdout') - store_execution_stderr_line = functools.partial(store_execution_output_data, - output_type='stderr') + store_execution_stdout_line = functools.partial( + store_execution_output_data, output_type="stdout" + ) + store_execution_stderr_line = functools.partial( + store_execution_output_data, output_type="stderr" + ) - read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stdout_line) - read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stderr_line) + read_and_store_stdout = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stdout_line, + ) + read_and_store_stderr = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stderr_line, + ) subprocess = concurrency.get_subprocess_module() @@ -145,9 +172,10 @@ def _run(self, action): # Note: We don't need to explicitly escape the argument because we pass command as a list # to subprocess.Popen and all the arguments are escaped by the function. if self._sudo_password: - LOG.debug('Supplying sudo password via stdin') - echo_process = concurrency.subprocess_popen(['echo', self._sudo_password + '\n'], - stdout=subprocess.PIPE) + LOG.debug("Supplying sudo password via stdin") + echo_process = concurrency.subprocess_popen( + ["echo", self._sudo_password + "\n"], stdout=subprocess.PIPE + ) stdin = echo_process.stdout else: stdin = None @@ -161,57 +189,64 @@ def _run(self, action): # Ideally os.killpg should have done the trick but for some reason that failed. # Note: pkill will set the returncode to 143 so we don't need to explicitly set # it to some non-zero value. - exit_code, stdout, stderr, timed_out = shell.run_command(cmd=args, - stdin=stdin, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - cwd=self._cwd, - env=env, - timeout=self._timeout, - preexec_func=os.setsid, - kill_func=kill_process, - read_stdout_func=read_and_store_stdout, - read_stderr_func=read_and_store_stderr, - read_stdout_buffer=stdout, - read_stderr_buffer=stderr) + exit_code, stdout, stderr, timed_out = shell.run_command( + cmd=args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + cwd=self._cwd, + env=env, + timeout=self._timeout, + preexec_func=os.setsid, + kill_func=kill_process, + read_stdout_func=read_and_store_stdout, + read_stderr_func=read_and_store_stderr, + read_stdout_buffer=stdout, + read_stderr_buffer=stderr, + ) error = None if timed_out: - error = 'Action failed to complete in %s seconds' % (self._timeout) + error = "Action failed to complete in %s seconds" % (self._timeout) exit_code = -1 * exit_code_constants.SIGKILL_EXIT_CODE # Detect if user provided an invalid sudo password or sudo is not configured for that user if self._sudo_password: - if re.search(r'sudo: \d+ incorrect password attempts', stderr): - match = re.search(r'\[sudo\] password for (.+?)\:', stderr) + if re.search(r"sudo: \d+ incorrect password attempts", stderr): + match = re.search(r"\[sudo\] password for (.+?)\:", stderr) if match: username = match.groups()[0] else: - username = 'unknown' + username = "unknown" - error = ('Invalid sudo password provided or sudo is not configured for this user ' - '(%s)' % (username)) + error = ( + "Invalid sudo password provided or sudo is not configured for this user " + "(%s)" % (username) + ) exit_code = -1 - succeeded = (exit_code == exit_code_constants.SUCCESS_EXIT_CODE) + succeeded = exit_code == exit_code_constants.SUCCESS_EXIT_CODE result = { - 'failed': not succeeded, - 'succeeded': succeeded, - 'return_code': exit_code, - 'stdout': strip_shell_chars(stdout), - 'stderr': strip_shell_chars(stderr) + "failed": not succeeded, + "succeeded": succeeded, + "return_code": exit_code, + "stdout": strip_shell_chars(stdout), + "stderr": strip_shell_chars(stderr), } if error: - result['error'] = error + result["error"] = error status = PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP.get( - str(exit_code), - action_constants.LIVEACTION_STATUS_FAILED + str(exit_code), action_constants.LIVEACTION_STATUS_FAILED ) - return (status, jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM), None) + return ( + status, + jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM), + None, + ) diff --git a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py index 4ae61f3225..cbf603de27 100644 --- a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py +++ b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py @@ -23,28 +23,25 @@ from local_runner.base import BaseLocalShellRunner from local_runner.base import RUNNER_COMMAND -__all__ = [ - 'LocalShellCommandRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["LocalShellCommandRunner", "get_runner", "get_metadata"] class LocalShellCommandRunner(BaseLocalShellRunner): def run(self, action_parameters): if self.entry_point: - raise ValueError('entry_point is only valid for local-shell-script runner') + raise ValueError("entry_point is only valid for local-shell-script runner") command = self.runner_parameters.get(RUNNER_COMMAND, None) - action = ShellCommandAction(name=self.action_name, - action_exec_id=str(self.liveaction_id), - command=command, - user=self._user, - env_vars=self._env, - sudo=self._sudo, - timeout=self._timeout, - sudo_password=self._sudo_password) + action = ShellCommandAction( + name=self.action_name, + action_exec_id=str(self.liveaction_id), + command=command, + user=self._user, + env_vars=self._env, + sudo=self._sudo, + timeout=self._timeout, + sudo_password=self._sudo_password, + ) return self._run(action=action) @@ -54,7 +51,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('local_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("local_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py index 24a0fe6ddb..257e457ca1 100644 --- a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py +++ b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py @@ -23,34 +23,31 @@ from local_runner.base import BaseLocalShellRunner -__all__ = [ - 'LocalShellScriptRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["LocalShellScriptRunner", "get_runner", "get_metadata"] class LocalShellScriptRunner(BaseLocalShellRunner, GitWorktreeActionRunner): def run(self, action_parameters): if not self.entry_point: - raise ValueError('Missing entry_point action metadata attribute') + raise ValueError("Missing entry_point action metadata attribute") script_local_path_abs = self.entry_point positional_args, named_args = self._get_script_args(action_parameters) named_args = self._transform_named_args(named_args) - action = ShellScriptAction(name=self.action_name, - action_exec_id=str(self.liveaction_id), - script_local_path_abs=script_local_path_abs, - named_args=named_args, - positional_args=positional_args, - user=self._user, - env_vars=self._env, - sudo=self._sudo, - timeout=self._timeout, - cwd=self._cwd, - sudo_password=self._sudo_password) + action = ShellScriptAction( + name=self.action_name, + action_exec_id=str(self.liveaction_id), + script_local_path_abs=script_local_path_abs, + named_args=named_args, + positional_args=positional_args, + user=self._user, + env_vars=self._env, + sudo=self._sudo, + timeout=self._timeout, + cwd=self._cwd, + sudo_password=self._sudo_password, + ) return self._run(action=action) @@ -60,7 +57,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('local_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("local_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/local_runner/setup.py b/contrib/runners/local_runner/setup.py index feb1cb6554..063314ab74 100644 --- a/contrib/runners/local_runner/setup.py +++ b/contrib/runners/local_runner/setup.py @@ -26,32 +26,34 @@ from local_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-local', + name="stackstorm-runner-local", version=__version__, - description=('Local Shell Command and Script action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Local Shell Command and Script action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'local_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"local_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'local-shell-cmd = local_runner.local_shell_command_runner', - 'local-shell-script = local_runner.local_shell_script_runner', + "st2common.runners.runner": [ + "local-shell-cmd = local_runner.local_shell_command_runner", + "local-shell-script = local_runner.local_shell_script_runner", ], - } + }, ) diff --git a/contrib/runners/local_runner/tests/integration/test_localrunner.py b/contrib/runners/local_runner/tests/integration/test_localrunner.py index 0e5a2f3efc..05c241f46b 100644 --- a/contrib/runners/local_runner/tests/integration/test_localrunner.py +++ b/contrib/runners/local_runner/tests/integration/test_localrunner.py @@ -22,6 +22,7 @@ import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() from st2common.constants import action as action_constants @@ -40,13 +41,10 @@ from local_runner.local_shell_command_runner import LocalShellCommandRunner from local_runner.local_shell_script_runner import LocalShellScriptRunner -__all__ = [ - 'LocalShellCommandRunnerTestCase', - 'LocalShellScriptRunnerTestCase' -] +__all__ = ["LocalShellCommandRunnerTestCase", "LocalShellScriptRunnerTestCase"] MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" class LocalShellCommandRunnerTestCase(RunnerTestCase, CleanDbTestCase): @@ -56,108 +54,115 @@ def setUp(self): super(LocalShellCommandRunnerTestCase, self).setUp() # False is a default behavior so end result should be the same - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False) + cfg.CONF.set_override( + name="stream_output", group="actionrunner", override=False + ) def test_shell_command_action_basic(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo 10') + runner = self._get_runner(action_db, cmd="echo 10") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 10) + self.assertEqual(result["stdout"], 10) # End result should be the same when streaming is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Verify initial state output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 0) - runner = self._get_runner(action_db, cmd='echo 10') + runner = self._get_runner(action_db, cmd="echo 10") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 10) + self.assertEqual(result["stdout"], 10) output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 1) - self.assertEqual(output_dbs[0].output_type, 'stdout') - self.assertEqual(output_dbs[0].data, '10\n') + self.assertEqual(output_dbs[0].output_type, "stdout") + self.assertEqual(output_dbs[0].data, "10\n") def test_timeout(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] # smaller timeout == faster tests. - runner = self._get_runner(action_db, cmd='sleep 10', timeout=0.01) + runner = self._get_runner(action_db, cmd="sleep 10", timeout=0.01) runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_TIMED_OUT) @mock.patch.object( - shell, 'run_command', - mock.MagicMock(return_value=(-15, '', '', False))) + shell, "run_command", mock.MagicMock(return_value=(-15, "", "", False)) + ) def test_shutdown(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] - runner = self._get_runner(action_db, cmd='sleep 0.1') + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] + runner = self._get_runner(action_db, cmd="sleep 0.1") runner.pre_run() status, result, _ = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_ABANDONED) def test_common_st2_env_vars_are_available_to_the_action(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), get_full_public_api_url()) + self.assertEqual(result["stdout"].strip(), get_full_public_api_url()) - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_AUTH_TOKEN') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_AUTH_TOKEN") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), 'mock-token') + self.assertEqual(result["stdout"].strip(), "mock-token") def test_sudo_and_env_variable_preservation(self): # Verify that the environment environment are correctly preserved when running as a # root / non-system user # Note: This test will fail if SETENV option is not present in the sudoers file models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - cmd = 'echo `whoami` ; echo ${VAR1}' - env = {'VAR1': 'poniesponies'} + cmd = "echo `whoami` ; echo ${VAR1}" + env = {"VAR1": "poniesponies"} runner = self._get_runner(action_db, cmd=cmd, sudo=True, env=env) runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), 'root\nponiesponies') + self.assertEqual(result["stdout"].strip(), "root\nponiesponies") - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen): # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration @@ -165,78 +170,75 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop # No output to stdout and no result (implicit None) mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n', - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n', - 'stderr line 3\n' + "stdout line 1\n", + "stdout line 2\n", ] + mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"] mock_process = mock.Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=2) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=3) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=2 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=3 + ) models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3') - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2") + self.assertEqual( + result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3" + ) + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 2) self.assertEqual(output_dbs[0].data, mock_stdout[0]) self.assertEqual(output_dbs[1].data, mock_stdout[1]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 3) self.assertEqual(output_dbs[0].data, mock_stderr[0]) self.assertEqual(output_dbs[1].data, mock_stderr[1]) self.assertEqual(output_dbs[2].data, mock_stderr[2]) - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') - def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, mock_spawn, - mock_popen): + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") + def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action( + self, mock_spawn, mock_popen + ): # Verify that we correctly retrieve all the output and wait for stdout and stderr reading # threads for short running actions. models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration mock_spawn.side_effect = blocking_eventlet_spawn # No output to stdout and no result (implicit None) - mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n' - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n' - ] + mock_stdout = ["stdout line 1\n", "stdout line 2\n"] + mock_stderr = ["stderr line 1\n", "stderr line 2\n"] # We add a sleep to simulate action process exiting before we finish reading data from mock_process = mock.Mock() @@ -244,11 +246,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=2, - sleep_delay=1) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=2) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=2, sleep_delay=1 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=2 + ) for index in range(1, 4): mock_process.stdout.closed = False @@ -263,12 +266,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2') - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2") + self.assertEqual(result["stderr"], "stderr line 1\nstderr line 2") + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") if index == 1: db_index_1 = 0 @@ -287,7 +290,7 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, self.assertEqual(output_dbs[db_index_1].data, mock_stdout[0]) self.assertEqual(output_dbs[db_index_2].data, mock_stdout[1]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), (index * 2)) self.assertEqual(output_dbs[db_index_1].data, mock_stderr[0]) self.assertEqual(output_dbs[db_index_2].data, mock_stderr[1]) @@ -295,16 +298,13 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): # Verify that sudo password is correctly passed to sudo binary via stdin models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - sudo_passwords = [ - 'pass 1', - 'sudopass', - '$sudo p@ss 2' - ] + sudo_passwords = ["pass 1", "sudopass", "$sudo p@ss 2"] - cmd = ('{ read sudopass; echo $sudopass; }') + cmd = "{ read sudopass; echo $sudopass; }" # without sudo for sudo_password in sudo_passwords: @@ -314,9 +314,8 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): status, result, _ = runner.run({}) runner.post_run(status, result) - self.assertEqual(status, - action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], sudo_password) + self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual(result["stdout"], sudo_password) # with sudo for sudo_password in sudo_passwords: @@ -327,12 +326,13 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): status, result, _ = runner.run({}) runner.post_run(status, result) - self.assertEqual(status, - action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], sudo_password) + self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual(result["stdout"], sudo_password) # Verify new process which provides password via stdin to the command is created - with mock.patch('st2common.util.concurrency.subprocess_popen') as mock_subproc_popen: + with mock.patch( + "st2common.util.concurrency.subprocess_popen" + ) as mock_subproc_popen: index = 0 for sudo_password in sudo_passwords: runner = self._get_runner(action_db, cmd=cmd) @@ -349,58 +349,67 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): index += 1 - self.assertEqual(call_args[0][0], ['echo', '%s\n' % (sudo_password)]) + self.assertEqual(call_args[0][0], ["echo", "%s\n" % (sudo_password)]) self.assertEqual(index, len(sudo_passwords)) def test_shell_command_invalid_stdout_password(self): # Simulate message printed to stderr by sudo when invalid sudo password is provided models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] - - cmd = ('echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:' - ' Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password ' - 'attempts" 1>&2; exit 1') + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] + + cmd = ( + 'echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:' + " Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password " + 'attempts" 1>&2; exit 1' + ) runner = self._get_runner(action_db, cmd=cmd) runner.pre_run() - runner._sudo_password = 'pass' + runner._sudo_password = "pass" status, result, _ = runner.run({}) runner.post_run(status, result) - expected_error = ('Invalid sudo password provided or sudo is not configured for this ' - 'user (bar)') + expected_error = ( + "Invalid sudo password provided or sudo is not configured for this " + "user (bar)" + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(result['error'], expected_error) - self.assertEqual(result['stdout'], '') + self.assertEqual(result["error"], expected_error) + self.assertEqual(result["stdout"], "") @staticmethod - def _get_runner(action_db, - entry_point=None, - cmd=None, - on_behalf_user=None, - user=None, - kwarg_op=local_runner.DEFAULT_KWARG_OP, - timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT, - sudo=False, - env=None): + def _get_runner( + action_db, + entry_point=None, + cmd=None, + on_behalf_user=None, + user=None, + kwarg_op=local_runner.DEFAULT_KWARG_OP, + timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT, + sudo=False, + env=None, + ): runner = LocalShellCommandRunner(uuid.uuid4().hex) runner.execution = MOCK_EXECUTION runner.action = action_db runner.action_name = action_db.name runner.liveaction_id = uuid.uuid4().hex runner.entry_point = entry_point - runner.runner_parameters = {local_runner.RUNNER_COMMAND: cmd, - local_runner.RUNNER_SUDO: sudo, - local_runner.RUNNER_ENV: env, - local_runner.RUNNER_ON_BEHALF_USER: user, - local_runner.RUNNER_KWARG_OP: kwarg_op, - local_runner.RUNNER_TIMEOUT: timeout} + runner.runner_parameters = { + local_runner.RUNNER_COMMAND: cmd, + local_runner.RUNNER_SUDO: sudo, + local_runner.RUNNER_ENV: env, + local_runner.RUNNER_ON_BEHALF_USER: user, + local_runner.RUNNER_KWARG_OP: kwarg_op, + local_runner.RUNNER_TIMEOUT: timeout, + } runner.context = dict() runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner @@ -411,22 +420,27 @@ def setUp(self): super(LocalShellScriptRunnerTestCase, self).setUp() # False is a default behavior so end result should be the same - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False) + cfg.CONF.set_override( + name="stream_output", group="actionrunner", override=False + ) def test_script_with_parameters_parameter_serialization(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']}) - action_db = models['actions']['local_script_with_params.yaml'] - entry_point = os.path.join(get_fixtures_base_path(), - 'generic/actions/local_script_with_params.sh') + fixtures_pack="generic", + fixtures_dict={"actions": ["local_script_with_params.yaml"]}, + ) + action_db = models["actions"]["local_script_with_params.yaml"] + entry_point = os.path.join( + get_fixtures_base_path(), "generic/actions/local_script_with_params.sh" + ) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -434,20 +448,20 @@ def test_script_with_parameters_parameter_serialization(self): status, result, _ = runner.run(action_parameters=action_parameters) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=test string', result['stdout']) - self.assertIn('PARAM_INTEGER=1', result['stdout']) - self.assertIn('PARAM_FLOAT=2.55', result['stdout']) - self.assertIn('PARAM_BOOLEAN=1', result['stdout']) - self.assertIn('PARAM_LIST=a,b,c', result['stdout']) - self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout']) + self.assertIn("PARAM_STRING=test string", result["stdout"]) + self.assertIn("PARAM_INTEGER=1", result["stdout"]) + self.assertIn("PARAM_FLOAT=2.55", result["stdout"]) + self.assertIn("PARAM_BOOLEAN=1", result["stdout"]) + self.assertIn("PARAM_LIST=a,b,c", result["stdout"]) + self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"]) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': False, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": False, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -456,12 +470,12 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_BOOLEAN=0', result['stdout']) + self.assertIn("PARAM_BOOLEAN=0", result["stdout"]) action_parameters = { - 'param_string': '', - 'param_integer': None, - 'param_float': None, + "param_string": "", + "param_integer": None, + "param_float": None, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -470,24 +484,24 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=\n', result['stdout']) - self.assertIn('PARAM_INTEGER=\n', result['stdout']) - self.assertIn('PARAM_FLOAT=\n', result['stdout']) + self.assertIn("PARAM_STRING=\n", result["stdout"]) + self.assertIn("PARAM_INTEGER=\n", result["stdout"]) + self.assertIn("PARAM_FLOAT=\n", result["stdout"]) # End result should be the same when streaming is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Verify initial state output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 0) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -496,26 +510,26 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=test string', result['stdout']) - self.assertIn('PARAM_INTEGER=1', result['stdout']) - self.assertIn('PARAM_FLOAT=2.55', result['stdout']) - self.assertIn('PARAM_BOOLEAN=1', result['stdout']) - self.assertIn('PARAM_LIST=a,b,c', result['stdout']) - self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout']) - - output_dbs = ActionExecutionOutput.query(output_type='stdout') + self.assertIn("PARAM_STRING=test string", result["stdout"]) + self.assertIn("PARAM_INTEGER=1", result["stdout"]) + self.assertIn("PARAM_FLOAT=2.55", result["stdout"]) + self.assertIn("PARAM_BOOLEAN=1", result["stdout"]) + self.assertIn("PARAM_LIST=a,b,c", result["stdout"]) + self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"]) + + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 6) - self.assertEqual(output_dbs[0].data, 'PARAM_STRING=test string\n') + self.assertEqual(output_dbs[0].data, "PARAM_STRING=test string\n") self.assertEqual(output_dbs[5].data, 'PARAM_OBJECT={"foo": "bar"}\n') - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 0) - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen): # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration @@ -523,40 +537,41 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop # No output to stdout and no result (implicit None) mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n', - 'stdout line 3\n', - 'stdout line 4\n' - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n', - 'stderr line 3\n' + "stdout line 1\n", + "stdout line 2\n", + "stdout line 3\n", + "stdout line 4\n", ] + mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"] mock_process = mock.Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=4) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=3) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=4 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=3 + ) models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']}) - action_db = models['actions']['local_script_with_params.yaml'] - entry_point = os.path.join(get_fixtures_base_path(), - 'generic/actions/local_script_with_params.sh') + fixtures_pack="generic", + fixtures_dict={"actions": ["local_script_with_params.yaml"]}, + ) + action_db = models["actions"]["local_script_with_params.yaml"] + entry_point = os.path.join( + get_fixtures_base_path(), "generic/actions/local_script_with_params.sh" + ) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -564,20 +579,24 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop status, result, _ = runner.run(action_parameters=action_parameters) runner.post_run(status, result) - self.assertEqual(result['stdout'], - 'stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3') - self.assertEqual(result['return_code'], 0) + self.assertEqual( + result["stdout"], + "stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4", + ) + self.assertEqual( + result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3" + ) + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 4) self.assertEqual(output_dbs[0].data, mock_stdout[0]) self.assertEqual(output_dbs[1].data, mock_stdout[1]) self.assertEqual(output_dbs[2].data, mock_stdout[2]) self.assertEqual(output_dbs[3].data, mock_stdout[3]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 3) self.assertEqual(output_dbs[0].data, mock_stderr[0]) self.assertEqual(output_dbs[1].data, mock_stderr[1]) @@ -585,30 +604,36 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop def test_shell_script_action(self): models = self.fixtures_loader.load_models( - fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']}) - action_db = models['actions']['text_gen.yml'] + fixtures_pack="localrunner_pack", + fixtures_dict={"actions": ["text_gen.yml"]}, + ) + action_db = models["actions"]["text_gen.yml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'localrunner_pack', 'actions', 'text_gen.py') + "localrunner_pack", "actions", "text_gen.py" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() - status, result, _ = runner.run({'chars': 1000}) + status, result, _ = runner.run({"chars": 1000}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(len(result['stdout']), 1000) + self.assertEqual(len(result["stdout"]), 1000) def test_large_stdout(self): models = self.fixtures_loader.load_models( - fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']}) - action_db = models['actions']['text_gen.yml'] + fixtures_pack="localrunner_pack", + fixtures_dict={"actions": ["text_gen.yml"]}, + ) + action_db = models["actions"]["text_gen.yml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'localrunner_pack', 'actions', 'text_gen.py') + "localrunner_pack", "actions", "text_gen.py" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() char_count = 10 ** 6 # Note 10^7 succeeds but ends up being slow. - status, result, _ = runner.run({'chars': char_count}) + status, result, _ = runner.run({"chars": char_count}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(len(result['stdout']), char_count) + self.assertEqual(len(result["stdout"]), char_count) def _get_runner(self, action_db, entry_point): runner = LocalShellScriptRunner(uuid.uuid4().hex) @@ -622,5 +647,5 @@ def _get_runner(self, action_db, entry_point): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/runners/noop_runner/dist_utils.py b/contrib/runners/noop_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/noop_runner/dist_utils.py +++ b/contrib/runners/noop_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/noop_runner/noop_runner/__init__.py b/contrib/runners/noop_runner/noop_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/noop_runner/noop_runner/__init__.py +++ b/contrib/runners/noop_runner/noop_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/noop_runner/noop_runner/noop_runner.py b/contrib/runners/noop_runner/noop_runner/noop_runner.py index 0eb745218a..b4dda10fd5 100644 --- a/contrib/runners/noop_runner/noop_runner/noop_runner.py +++ b/contrib/runners/noop_runner/noop_runner/noop_runner.py @@ -22,12 +22,7 @@ from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED import st2common.util.jsonify as jsonify -__all__ = [ - 'NoopRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["NoopRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -36,7 +31,8 @@ class NoopRunner(ActionRunner): """ Runner which does absolutely nothing. No-op action. """ - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] + + KEYS_TO_TRANSFORM = ["stdout", "stderr"] def __init__(self, runner_id): super(NoopRunner, self).__init__(runner_id=runner_id) @@ -46,14 +42,15 @@ def pre_run(self): def run(self, action_parameters): - LOG.info('Executing action via NoopRunner: %s', self.runner_id) - LOG.info('[Action info] name: %s, Id: %s', - self.action_name, str(self.execution_id)) + LOG.info("Executing action via NoopRunner: %s", self.runner_id) + LOG.info( + "[Action info] name: %s, Id: %s", self.action_name, str(self.execution_id) + ) result = { - 'failed': False, - 'succeeded': True, - 'return_code': 0, + "failed": False, + "succeeded": True, + "return_code": 0, } status = LIVEACTION_STATUS_SUCCEEDED @@ -65,4 +62,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('noop_runner')[0] + return get_runner_metadata("noop_runner")[0] diff --git a/contrib/runners/noop_runner/setup.py b/contrib/runners/noop_runner/setup.py index 30b00bd68b..94b518c55f 100644 --- a/contrib/runners/noop_runner/setup.py +++ b/contrib/runners/noop_runner/setup.py @@ -26,30 +26,30 @@ from noop_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-noop', + name="stackstorm-runner-noop", version=__version__, - description=('No-Op action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=("No-Op action runner for StackStorm event-driven automation platform"), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'noop_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"noop_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'noop = noop_runner.noop_runner', + "st2common.runners.runner": [ + "noop = noop_runner.noop_runner", ], - } + }, ) diff --git a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py index 6783404ffb..98c66c33cd 100644 --- a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py +++ b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py @@ -19,6 +19,7 @@ import mock import st2tests.config as tests_config + tests_config.parse_args() from unittest2 import TestCase @@ -33,16 +34,17 @@ class TestNoopRunner(TestCase): def test_noop_command_executes(self): models = TestNoopRunner.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['noop.yaml']}) + fixtures_pack="generic", fixtures_dict={"actions": ["noop.yaml"]} + ) - action_db = models['actions']['noop.yaml'] + action_db = models["actions"]["noop.yaml"] runner = TestNoopRunner._get_runner(action_db) status, result, _ = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['failed'], False) - self.assertEqual(result['succeeded'], True) - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["failed"], False) + self.assertEqual(result["succeeded"], True) + self.assertEqual(result["return_code"], 0) @staticmethod def _get_runner(action_db): @@ -55,5 +57,5 @@ def _get_runner(action_db): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/runners/orquesta_runner/dist_utils.py b/contrib/runners/orquesta_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/orquesta_runner/dist_utils.py +++ b/contrib/runners/orquesta_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py index f5986392d7..e71dcafd40 100644 --- a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py +++ b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py @@ -33,15 +33,15 @@ def format_task_result(instances): instance = instances[-1] return { - 'task_execution_id': str(instance.id), - 'workflow_execution_id': instance.workflow_execution, - 'task_name': instance.task_id, - 'task_id': instance.task_id, - 'route': instance.task_route, - 'result': instance.result, - 'status': instance.status, - 'start_timestamp': str(instance.start_timestamp), - 'end_timestamp': str(instance.end_timestamp) + "task_execution_id": str(instance.id), + "workflow_execution_id": instance.workflow_execution, + "task_name": instance.task_id, + "task_id": instance.task_id, + "route": instance.task_route, + "result": instance.result, + "status": instance.status, + "start_timestamp": str(instance.start_timestamp), + "end_timestamp": str(instance.end_timestamp), } @@ -54,17 +54,17 @@ def task(context, task_id=None, route=None): current_task = {} if task_id is None: - task_id = current_task['id'] + task_id = current_task["id"] if route is None: - route = current_task.get('route', 0) + route = current_task.get("route", 0) try: - workflow_state = context['__state'] or {} + workflow_state = context["__state"] or {} except KeyError: workflow_state = {} - task_state_pointers = workflow_state.get('tasks') or {} + task_state_pointers = workflow_state.get("tasks") or {} task_state_entry_uid = constants.TASK_STATE_ROUTE_FORMAT % (task_id, str(route)) task_state_entry_idx = task_state_pointers.get(task_state_entry_uid) @@ -72,9 +72,11 @@ def task(context, task_id=None, route=None): # use an earlier route before the split to find the specific task. if task_state_entry_idx is None: if route > 0: - current_route_details = workflow_state['routes'][route] + current_route_details = workflow_state["routes"][route] # Reverse the list because we want to start with the next longest route. - for idx, prev_route_details in enumerate(reversed(workflow_state['routes'][:route])): + for idx, prev_route_details in enumerate( + reversed(workflow_state["routes"][:route]) + ): if len(set(prev_route_details) - set(current_route_details)) == 0: # The index is from a reversed list so need to calculate # the index of the item in the list before the reverse. @@ -83,17 +85,15 @@ def task(context, task_id=None, route=None): else: # Otherwise, get the task flow entry and use the # task id and route to query the database. - task_state_seqs = workflow_state.get('sequence') or [] + task_state_seqs = workflow_state.get("sequence") or [] task_state_entry = task_state_seqs[task_state_entry_idx] - route = task_state_entry['route'] - st2_ctx = context['__vars']['st2'] - workflow_execution_id = st2_ctx['workflow_execution_id'] + route = task_state_entry["route"] + st2_ctx = context["__vars"]["st2"] + workflow_execution_id = st2_ctx["workflow_execution_id"] # Query the database by the workflow execution ID, task ID, and task route. instances = wf_db_access.TaskExecution.query( - workflow_execution=workflow_execution_id, - task_id=task_id, - task_route=route + workflow_execution=workflow_execution_id, task_id=task_id, task_route=route ) if not instances: diff --git a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py index 35cae92cd7..ed23507a1b 100644 --- a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py +++ b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py @@ -29,26 +29,28 @@ def st2kv_(context, key, **kwargs): if not isinstance(key, six.string_types): - raise TypeError('Given key is not typeof string.') + raise TypeError("Given key is not typeof string.") - decrypt = kwargs.get('decrypt', False) + decrypt = kwargs.get("decrypt", False) if not isinstance(decrypt, bool): - raise TypeError('Decrypt parameter is not typeof bool.') + raise TypeError("Decrypt parameter is not typeof bool.") try: - username = context['__vars']['st2']['user'] + username = context["__vars"]["st2"]["user"] except KeyError: - raise KeyError('Could not get user from context.') + raise KeyError("Could not get user from context.") try: user_db = auth_db_access.User.get(username) except Exception as e: - raise Exception('Failed to retrieve User object for user "%s", "%s"' % - (username, six.text_type(e))) + raise Exception( + 'Failed to retrieve User object for user "%s", "%s"' + % (username, six.text_type(e)) + ) - has_default = 'default' in kwargs - default_value = kwargs.get('default') + has_default = "default" in kwargs + default_value = kwargs.get("default") try: return kvp_util.get_key(key=key, user_db=user_db, decrypt=decrypt) diff --git a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py index b59642609e..62f2492ae4 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py @@ -37,71 +37,72 @@ from st2common.util import api as api_util from st2common.util import ujson -__all__ = [ - 'OrquestaRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["OrquestaRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) class OrquestaRunner(runners.AsyncActionRunner): - @staticmethod def get_workflow_definition(entry_point): - with open(entry_point, 'r') as def_file: + with open(entry_point, "r") as def_file: return def_file.read() def _get_notify_config(self): return ( - notify_api_models.NotificationsHelper.from_model(notify_model=self.liveaction.notify) + notify_api_models.NotificationsHelper.from_model( + notify_model=self.liveaction.notify + ) if self.liveaction.notify else None ) def _construct_context(self, wf_ex): ctx = ujson.fast_deepcopy(self.context) - ctx['workflow_execution'] = str(wf_ex.id) + ctx["workflow_execution"] = str(wf_ex.id) return ctx def _construct_st2_context(self): st2_ctx = { - 'st2': { - 'action_execution_id': str(self.execution.id), - 'api_url': api_util.get_full_public_api_url(), - 'user': self.execution.context.get('user', cfg.CONF.system_user.user), - 'pack': self.execution.context.get('pack', None), - 'action': self.execution.action.get('ref', None), - 'runner': self.execution.action.get('runner_type', None) + "st2": { + "action_execution_id": str(self.execution.id), + "api_url": api_util.get_full_public_api_url(), + "user": self.execution.context.get("user", cfg.CONF.system_user.user), + "pack": self.execution.context.get("pack", None), + "action": self.execution.action.get("ref", None), + "runner": self.execution.action.get("runner_type", None), } } - if self.execution.context.get('api_user'): - st2_ctx['st2']['api_user'] = self.execution.context.get('api_user') + if self.execution.context.get("api_user"): + st2_ctx["st2"]["api_user"] = self.execution.context.get("api_user") - if self.execution.context.get('source_channel'): - st2_ctx['st2']['source_channel'] = self.execution.context.get('source_channel') + if self.execution.context.get("source_channel"): + st2_ctx["st2"]["source_channel"] = self.execution.context.get( + "source_channel" + ) if self.execution.context: - st2_ctx['parent'] = self.execution.context + st2_ctx["parent"] = self.execution.context return st2_ctx def _handle_workflow_return_value(self, wf_ex_db): if wf_ex_db.status in wf_statuses.COMPLETED_STATUSES: status = wf_ex_db.status - result = {'output': wf_ex_db.output or None} + result = {"output": wf_ex_db.output or None} if wf_ex_db.status in wf_statuses.ABENDED_STATUSES: - result['errors'] = wf_ex_db.errors + result["errors"] = wf_ex_db.errors for wf_ex_error in wf_ex_db.errors: - msg = 'Workflow execution completed with errors.' - wf_svc.update_progress(wf_ex_db, '%s %s' % (msg, str(wf_ex_error)), log=False) - LOG.error('[%s] %s', str(self.execution.id), msg, extra=wf_ex_error) + msg = "Workflow execution completed with errors." + wf_svc.update_progress( + wf_ex_db, "%s %s" % (msg, str(wf_ex_error)), log=False + ) + LOG.error("[%s] %s", str(self.execution.id), msg, extra=wf_ex_error) return (status, result, self.context) @@ -115,8 +116,8 @@ def _handle_workflow_return_value(self, wf_ex_db): def run(self, action_parameters): # If there is an action execution reference for rerun and there is task specified, # then rerun the existing workflow execution. - rerun_options = self.context.get('re-run', {}) - rerun_task_options = rerun_options.get('tasks', []) + rerun_options = self.context.get("re-run", {}) + rerun_task_options = rerun_options.get("tasks", []) if self.rerun_ex_ref and rerun_task_options: return self.rerun_workflow(self.rerun_ex_ref, options=rerun_options) @@ -131,14 +132,16 @@ def start_workflow(self, action_parameters): # Request workflow execution. st2_ctx = self._construct_st2_context() notify_cfg = self._get_notify_config() - wf_ex_db = wf_svc.request(wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg) + wf_ex_db = wf_svc.request( + wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg + ) except wf_exc.WorkflowInspectionError as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': e.args[1], 'output': None} + result = {"errors": e.args[1], "output": None} return (status, result, self.context) except Exception as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': [{'message': six.text_type(e)}], 'output': None} + result = {"errors": [{"message": six.text_type(e)}], "output": None} return (status, result, self.context) return self._handle_workflow_return_value(wf_ex_db) @@ -146,13 +149,13 @@ def start_workflow(self, action_parameters): def rerun_workflow(self, ac_ex_ref, options=None): try: # Request rerun of workflow execution. - wf_ex_id = ac_ex_ref.context.get('workflow_execution') + wf_ex_id = ac_ex_ref.context.get("workflow_execution") st2_ctx = self._construct_st2_context() - st2_ctx['workflow_execution_id'] = wf_ex_id + st2_ctx["workflow_execution_id"] = wf_ex_id wf_ex_db = wf_svc.request_rerun(self.execution, st2_ctx, options=options) except Exception as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': [{'message': six.text_type(e)}], 'output': None} + result = {"errors": [{"message": six.text_type(e)}], "output": None} return (status, result, self.context) return self._handle_workflow_return_value(wf_ex_db) @@ -160,8 +163,8 @@ def rerun_workflow(self, ac_ex_ref, options=None): @staticmethod def task_pauseable(ac_ex): wf_ex_pauseable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING ) return wf_ex_pauseable @@ -175,26 +178,24 @@ def pause(self): child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_pauseable(child_ex): ac_svc.request_pause( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) - if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active(self.liveaction.id): + if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active( + self.liveaction.id + ): status = ac_const.LIVEACTION_STATUS_PAUSING else: status = ac_const.LIVEACTION_STATUS_PAUSED - return ( - status, - self.liveaction.result, - self.liveaction.context - ) + return (status, self.liveaction.result, self.liveaction.context) @staticmethod def task_resumeable(ac_ex): wf_ex_resumeable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED ) return wf_ex_resumeable @@ -208,26 +209,26 @@ def resume(self): child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_resumeable(child_ex): ac_svc.request_resume( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) return ( wf_ex_db.status if wf_ex_db else ac_const.LIVEACTION_STATUS_RUNNING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) @staticmethod def task_cancelable(ac_ex): wf_ex_cancelable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES ) ac_ex_cancelable = ( - ac_ex.runner['name'] not in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES + ac_ex.runner["name"] not in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES ) return wf_ex_cancelable or ac_ex_cancelable @@ -242,8 +243,10 @@ def cancel(self): # If workflow execution is not found because the action execution is cancelled # before the workflow execution is created or if the workflow execution is # already completed, then ignore the exception and proceed with cancellation. - except (wf_svc_exc.WorkflowExecutionNotFoundException, - wf_svc_exc.WorkflowExecutionIsCompletedException): + except ( + wf_svc_exc.WorkflowExecutionNotFoundException, + wf_svc_exc.WorkflowExecutionIsCompletedException, + ): pass # If there is an unknown exception, then log the error. Continue with the # cancelation sequence below to cancel children and determine final status. @@ -253,19 +256,22 @@ def cancel(self): # execution will be in an unknown state. except Exception: _, ex, tb = sys.exc_info() - msg = 'Error encountered when canceling workflow execution.' - LOG.exception('[%s] %s', str(self.execution.id), msg) - msg = 'Error encountered when canceling workflow execution. %s' + msg = "Error encountered when canceling workflow execution." + LOG.exception("[%s] %s", str(self.execution.id), msg) + msg = "Error encountered when canceling workflow execution. %s" wf_svc.update_progress(wf_ex_db, msg % str(ex), log=False) - result = {'error': msg % str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": msg % str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } # Request cancellation of tasks that are workflows and still running. for child_ex_id in self.execution.children: child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_cancelable(child_ex): ac_svc.request_cancellation( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) status = ( @@ -277,7 +283,7 @@ def cancel(self): return ( status, result if result else self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) @@ -286,4 +292,4 @@ def get_runner(): def get_metadata(): - return runners.get_metadata('orquesta_runner')[0] + return runners.get_metadata("orquesta_runner")[0] diff --git a/contrib/runners/orquesta_runner/setup.py b/contrib/runners/orquesta_runner/setup.py index 5dac5ed34e..859a8b6050 100644 --- a/contrib/runners/orquesta_runner/setup.py +++ b/contrib/runners/orquesta_runner/setup.py @@ -26,62 +26,64 @@ from orquesta_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-orquesta', + name="stackstorm-runner-orquesta", version=__version__, - description='Orquesta workflow runner for StackStorm event-driven automation platform', - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="Orquesta workflow runner for StackStorm event-driven automation platform", + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'orquesta_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"orquesta_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'orquesta = orquesta_runner.orquesta_runner', + "st2common.runners.runner": [ + "orquesta = orquesta_runner.orquesta_runner", ], - 'orquesta.expressions.functions': [ - 'st2kv = orquesta_functions.st2kv:st2kv_', - 'task = orquesta_functions.runtime:task', - 'basename = st2common.expressions.functions.path:basename', - 'dirname = st2common.expressions.functions.path:dirname', - 'from_json_string = st2common.expressions.functions.data:from_json_string', - 'from_yaml_string = st2common.expressions.functions.data:from_yaml_string', - 'json_dump = st2common.expressions.functions.data:to_json_string', - 'json_parse = st2common.expressions.functions.data:from_json_string', - 'json_escape = st2common.expressions.functions.data:json_escape', - 'jsonpath_query = st2common.expressions.functions.data:jsonpath_query', - 'regex_match = st2common.expressions.functions.regex:regex_match', - 'regex_replace = st2common.expressions.functions.regex:regex_replace', - 'regex_search = st2common.expressions.functions.regex:regex_search', - 'regex_substring = st2common.expressions.functions.regex:regex_substring', - ('to_human_time_from_seconds = ' - 'st2common.expressions.functions.time:to_human_time_from_seconds'), - 'to_json_string = st2common.expressions.functions.data:to_json_string', - 'to_yaml_string = st2common.expressions.functions.data:to_yaml_string', - 'use_none = st2common.expressions.functions.data:use_none', - 'version_compare = st2common.expressions.functions.version:version_compare', - 'version_more_than = st2common.expressions.functions.version:version_more_than', - 'version_less_than = st2common.expressions.functions.version:version_less_than', - 'version_equal = st2common.expressions.functions.version:version_equal', - 'version_match = st2common.expressions.functions.version:version_match', - 'version_bump_major = st2common.expressions.functions.version:version_bump_major', - 'version_bump_minor = st2common.expressions.functions.version:version_bump_minor', - 'version_bump_patch = st2common.expressions.functions.version:version_bump_patch', - 'version_strip_patch = st2common.expressions.functions.version:version_strip_patch', - 'yaml_dump = st2common.expressions.functions.data:to_yaml_string', - 'yaml_parse = st2common.expressions.functions.data:from_yaml_string' + "orquesta.expressions.functions": [ + "st2kv = orquesta_functions.st2kv:st2kv_", + "task = orquesta_functions.runtime:task", + "basename = st2common.expressions.functions.path:basename", + "dirname = st2common.expressions.functions.path:dirname", + "from_json_string = st2common.expressions.functions.data:from_json_string", + "from_yaml_string = st2common.expressions.functions.data:from_yaml_string", + "json_dump = st2common.expressions.functions.data:to_json_string", + "json_parse = st2common.expressions.functions.data:from_json_string", + "json_escape = st2common.expressions.functions.data:json_escape", + "jsonpath_query = st2common.expressions.functions.data:jsonpath_query", + "regex_match = st2common.expressions.functions.regex:regex_match", + "regex_replace = st2common.expressions.functions.regex:regex_replace", + "regex_search = st2common.expressions.functions.regex:regex_search", + "regex_substring = st2common.expressions.functions.regex:regex_substring", + ( + "to_human_time_from_seconds = " + "st2common.expressions.functions.time:to_human_time_from_seconds" + ), + "to_json_string = st2common.expressions.functions.data:to_json_string", + "to_yaml_string = st2common.expressions.functions.data:to_yaml_string", + "use_none = st2common.expressions.functions.data:use_none", + "version_compare = st2common.expressions.functions.version:version_compare", + "version_more_than = st2common.expressions.functions.version:version_more_than", + "version_less_than = st2common.expressions.functions.version:version_less_than", + "version_equal = st2common.expressions.functions.version:version_equal", + "version_match = st2common.expressions.functions.version:version_match", + "version_bump_major = st2common.expressions.functions.version:version_bump_major", + "version_bump_minor = st2common.expressions.functions.version:version_bump_minor", + "version_bump_patch = st2common.expressions.functions.version:version_bump_patch", + "version_strip_patch = st2common.expressions.functions.version:version_strip_patch", + "yaml_dump = st2common.expressions.functions.data:to_yaml_string", + "yaml_parse = st2common.expressions.functions.data:from_yaml_string", ], - } + }, ) diff --git a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py index 8734bce072..0e273f6e83 100644 --- a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py +++ b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py @@ -21,78 +21,67 @@ class DatastoreFunctionTest(base.TestWorkflowExecution): @classmethod - def set_kvp(cls, name, value, scope='system', secret=False): + def set_kvp(cls, name, value, scope="system", secret=False): kvp = models.KeyValuePair( - id=name, - name=name, - value=value, - scope=scope, - secret=secret + id=name, name=name, value=value, scope=scope, secret=secret ) cls.st2client.keys.update(kvp) @classmethod - def del_kvp(cls, name, scope='system'): - kvp = models.KeyValuePair( - id=name, - name=name, - scope=scope - ) + def del_kvp(cls, name, scope="system"): + kvp = models.KeyValuePair(id=name, name=name, scope=scope) cls.st2client.keys.delete(kvp) def test_st2kv_system_scope(self): - key = 'lakshmi' - value = 'kanahansnasnasdlsajks' + key = "lakshmi" + value = "kanahansnasnasdlsajks" self.set_kvp(key, value) - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': 'system.%s' % key} + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_user_scope(self): - key = 'winson' - value = 'SoDiamondEng' + key = "winson" + value = "SoDiamondEng" - self.set_kvp(key, value, 'user') - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': key} + self.set_kvp(key, value, "user") + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) # self.del_kvp(key) def test_st2kv_decrypt(self): - key = 'kami' - value = 'eggplant' + key = "kami" + value = "eggplant" self.set_kvp(key, value, secret=True) - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) diff --git a/contrib/runners/orquesta_runner/tests/unit/base.py b/contrib/runners/orquesta_runner/tests/unit/base.py index dbd2895721..d3e518fab7 100644 --- a/contrib/runners/orquesta_runner/tests/unit/base.py +++ b/contrib/runners/orquesta_runner/tests/unit/base.py @@ -19,13 +19,13 @@ def get_wf_fixture_meta_data(fixture_pack_path, wf_meta_file_name): - wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name + wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name wf_meta_content = loader.load_meta_file(wf_meta_file_path) - wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name'] + wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"] return { - 'file_name': wf_meta_file_name, - 'file_path': wf_meta_file_path, - 'content': wf_meta_content, - 'name': wf_name + "file_name": wf_meta_file_name, + "file_path": wf_meta_file_path, + "content": wf_meta_content, + "name": wf_name, } diff --git a/contrib/runners/orquesta_runner/tests/unit/test_basic.py b/contrib/runners/orquesta_runner/tests/unit/test_basic.py index 7fc2255ed2..5f5c60a012 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_basic.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_basic.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -51,37 +52,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -91,8 +100,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -103,14 +111,15 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ @mock.patch.object( - runners_utils, - 'invoke_post_run', - mock.MagicMock(return_value=None)) + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_run_workflow(self): - username = 'stanley' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + username = "stanley" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # The main action execution for this workflow is not under the context of another workflow. @@ -120,9 +129,13 @@ def test_run_workflow(self): lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertTrue(lv_ac_db.action_is_workflow) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] # Check required attributes. @@ -134,26 +147,24 @@ def test_run_workflow(self): # Check context in the workflow execution. expected_wf_ex_ctx = { - 'st2': { - 'workflow_execution_id': str(wf_ex_db.id), - 'action_execution_id': str(ac_ex_db.id), - 'api_url': 'http://127.0.0.1/v1', - 'user': username, - 'pack': 'orquesta_tests', - 'action': 'orquesta_tests.sequential', - 'runner': 'orquesta' + "st2": { + "workflow_execution_id": str(wf_ex_db.id), + "action_execution_id": str(ac_ex_db.id), + "api_url": "http://127.0.0.1/v1", + "user": username, + "pack": "orquesta_tests", + "action": "orquesta_tests.sequential", + "runner": "orquesta", }, - 'parent': { - 'pack': 'orquesta_tests' - } + "parent": {"pack": "orquesta_tests"}, } self.assertDictEqual(wf_ex_db.context, expected_wf_ex_ctx) # Check context in the liveaction. expected_lv_ac_ctx = { - 'workflow_execution': str(wf_ex_db.id), - 'pack': 'orquesta_tests' + "workflow_execution": str(wf_ex_db.id), + "pack": "orquesta_tests", } self.assertDictEqual(lv_ac_db.context, expected_lv_ac_ctx) @@ -161,24 +172,26 @@ def test_run_workflow(self): # Check graph. self.assertIsNotNone(wf_ex_db.graph) self.assertIsInstance(wf_ex_db.graph, dict) - self.assertIn('nodes', wf_ex_db.graph) - self.assertIn('adjacency', wf_ex_db.graph) + self.assertIn("nodes", wf_ex_db.graph) + self.assertIn("adjacency", wf_ex_db.graph) # Check task states. self.assertIsNotNone(wf_ex_db.state) self.assertIsInstance(wf_ex_db.state, dict) - self.assertIn('tasks', wf_ex_db.state) - self.assertIn('sequence', wf_ex_db.state) + self.assertIn("tasks", wf_ex_db.state) + self.assertIn("sequence", wf_ex_db.state) # Check input. self.assertDictEqual(wf_ex_db.input, wf_input) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.context.get('user'), username) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual(tk1_lv_ac_db.context.get("user"), username) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -192,11 +205,13 @@ def test_run_workflow(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.context.get('user'), username) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual(tk2_lv_ac_db.context.get("user"), username) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk2_ac_ex_db)) @@ -210,11 +225,13 @@ def test_run_workflow(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.context.get('user'), username) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual(tk3_lv_ac_db.context.get("user"), username) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk3_ac_ex_db)) @@ -234,48 +251,60 @@ def test_run_workflow(self): self.assertEqual(runners_utils.invoke_post_run.call_count, 1) # Check workflow output. - expected_output = {'msg': '%s, All your base are belong to us!' % wf_input['who']} + expected_output = { + "msg": "%s, All your base are belong to us!" % wf_input["who"] + } self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_run_workflow_with_unicode_input(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': '薩諾斯'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "薩諾斯"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) @@ -290,33 +319,41 @@ def test_run_workflow_with_unicode_input(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output. - wf_input_val = wf_input['who'].decode('utf-8') if six.PY2 else wf_input['who'] - expected_output = {'msg': '%s, All your base are belong to us!' % wf_input_val} + wf_input_val = wf_input["who"].decode("utf-8") if six.PY2 else wf_input["who"] + expected_output = {"msg": "%s, All your base are belong to us!" % wf_input_val} self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_run_workflow_action_config_context(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'config-context.yaml') + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "config-context.yaml") wf_input = {} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -332,59 +369,77 @@ def test_run_workflow_action_config_context(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Verify config_context works - self.assertEqual(wf_ex_db.output, {'msg': 'value of config key a'}) + self.assertEqual(wf_ex_db.output, {"msg": "value of config key a"}) def test_run_workflow_with_action_less_tasks(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'action-less-tasks.yaml') - wf_input = {'name': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "action-less-tasks.yaml" + ) + wf_input = {"name": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id)) + tk1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + ) self.assertEqual(len(tk1_ac_ex_dbs), 0) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. wf_svc.handle_action_execution_completion(tk2_ac_ex_db) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. wf_svc.handle_action_execution_completion(tk3_ac_ex_db) # Assert task4 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task4'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task4"} tk4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk4_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk4_ex_db.id)) + tk4_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk4_ex_db.id) + ) self.assertEqual(len(tk4_ac_ex_dbs), 0) self.assertEqual(tk4_ex_db.status, wf_statuses.SUCCEEDED) # Assert task5 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task5'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task5"} tk5_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk5_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk5_ex_db.id))[0] - tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction['id']) + tk5_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk5_ex_db.id) + )[0] + tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction["id"]) self.assertEqual(tk5_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -399,65 +454,95 @@ def test_run_workflow_with_action_less_tasks(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output. - expected_output = {'greeting': '%s, All your base are belong to us!' % wf_input['name']} - expected_output['greeting'] = expected_output['greeting'].upper() + expected_output = { + "greeting": "%s, All your base are belong to us!" % wf_input["name"] + } + expected_output["greeting"] = expected_output["greeting"].upper() self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) @mock.patch.object( - pc_svc, 'apply_post_run_policies', - mock.MagicMock(return_value=None)) + pc_svc, "apply_post_run_policies", mock.MagicMock(return_value=None) + ) def test_handle_action_execution_completion(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Manually notify action execution completion for the tasks. # Assert policies are not applied in the notifier. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] notifier.get_notifier().process(t1_t1_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 1) workflows.get_engine().process(t1_t1_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] notifier.get_notifier().process(t1_t2_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 2) workflows.get_engine().process(t1_t2_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] notifier.get_notifier().process(t1_t3_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 3) workflows.get_engine().process(t1_t3_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) @@ -466,19 +551,25 @@ def test_handle_action_execution_completion(self): t1_ac_ex_db = ex_db_access.ActionExecution.get_by_id(t1_ac_ex_db.id) notifier.get_notifier().process(t1_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) workflows.get_engine().process(t1_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t2_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + t2_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} t2_ex_db = wf_db_access.TaskExecution.query(**t2_ex_db_qry)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] self.assertEqual(t2_ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) notifier.get_notifier().process(t2_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) workflows.get_engine().process(t2_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py index 145bd1f3b4..b49fd0f77b 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,37 +46,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerCancelTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerCancelTest, cls).setUpClass() @@ -85,8 +94,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -96,15 +104,15 @@ def setUpClass(cls): def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=True)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True)) def test_cancel(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) requester = cfg.CONF.system_user.user lv_ac_db, ac_ex_db = ac_svc.request_cancellation(lv_ac_db, requester) @@ -112,23 +120,33 @@ def test_cancel(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING) def test_cancel_workflow_cascade_down_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the main workflow. @@ -145,23 +163,33 @@ def test_cancel_workflow_cascade_down_to_subworkflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_subworkflow_cascade_up_to_workflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the subworkflow. @@ -183,34 +211,50 @@ def test_cancel_subworkflow_cascade_up_to_workflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 2) - tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk1_ac_ex_dbs), 1) - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_dbs[0].liveaction['id']) + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk1_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) - tk2_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id)) + tk2_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + ) self.assertEqual(len(tk2_ac_ex_dbs), 1) - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_dbs[0].liveaction['id']) + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk2_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the subworkflow which should cascade up to the root. requester = cfg.CONF.system_user.user - tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation(tk1_lv_ac_db, requester) + tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation( + tk1_lv_ac_db, requester + ) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING) # Assert the main workflow is canceling. @@ -239,15 +283,21 @@ def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_before_wf_ex_db_created(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Delete the workfow execution to mock issue where the record has not been created yet. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False, dispatch_trigger=False) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + wf_db_access.WorkflowExecution.delete( + wf_ex_db, publish=False, dispatch_trigger=False + ) # Cancel the action execution. requester = cfg.CONF.system_user.user @@ -256,15 +306,19 @@ def test_cancel_before_wf_ex_db_created(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_after_wf_ex_db_completed(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Delete the workfow execution to mock issue where the workflow is already completed # but the liveaction and action execution have not had time to be updated. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] wf_ex_db.status = wf_ex_statuses.SUCCEEDED wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) @@ -275,14 +329,16 @@ def test_cancel_after_wf_ex_db_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - wf_svc, 'request_cancellation', - mock.MagicMock(side_effect=Exception('foobar'))) + wf_svc, "request_cancellation", mock.MagicMock(side_effect=Exception("foobar")) + ) def test_cancel_unexpected_exception(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Cancel the action execution. requester = cfg.CONF.system_user.user @@ -297,4 +353,6 @@ def test_cancel_unexpected_exception(self): # to raise an exception and the records will be stuck in a canceling # status and user is unable to easily clean up. self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) - self.assertIn('Error encountered when canceling', lv_ac_db.result.get('error', '')) + self.assertIn( + "Error encountered when canceling", lv_ac_db.result.get("error", "") + ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_context.py b/contrib/runners/orquesta_runner/tests/unit/test_context.py index 373f512e87..bce5a50873 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_context.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_context.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaContextTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaContextTest, cls).setUpClass() @@ -83,24 +92,31 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_runtime_context(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'runtime-context.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "runtime-context.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] # Complete the worklfow. wf_svc.handle_action_execution_completion(t1_ac_ex_db) @@ -113,59 +129,75 @@ def test_runtime_context(self): # Check result. expected_st2_ctx = { - 'action_execution_id': str(ac_ex_db.id), - 'api_url': 'http://127.0.0.1/v1', - 'user': 'stanley', - 'pack': 'orquesta_tests', - 'action': 'orquesta_tests.runtime-context', - 'runner': 'orquesta' + "action_execution_id": str(ac_ex_db.id), + "api_url": "http://127.0.0.1/v1", + "user": "stanley", + "pack": "orquesta_tests", + "action": "orquesta_tests.runtime-context", + "runner": "orquesta", } expected_st2_ctx_with_wf_ex_id = copy.deepcopy(expected_st2_ctx) - expected_st2_ctx_with_wf_ex_id['workflow_execution_id'] = str(wf_ex_db.id) + expected_st2_ctx_with_wf_ex_id["workflow_execution_id"] = str(wf_ex_db.id) expected_output = { - 'st2_ctx_at_input': expected_st2_ctx, - 'st2_ctx_at_vars': expected_st2_ctx, - 'st2_ctx_at_publish': expected_st2_ctx_with_wf_ex_id, - 'st2_ctx_at_output': expected_st2_ctx_with_wf_ex_id + "st2_ctx_at_input": expected_st2_ctx, + "st2_ctx_at_vars": expected_st2_ctx, + "st2_ctx_at_publish": expected_st2_ctx_with_wf_ex_id, + "st2_ctx_at_output": expected_st2_ctx_with_wf_ex_id, } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_sys_user(self): - wf_name = 'subworkflow-default-value-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "subworkflow-default-value-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -184,44 +216,60 @@ def test_action_context_sys_user(self): # Check result. expected_result = { - 'output': { - 'msg': 'stanley, All your base are belong to us!' - } + "output": {"msg": "stanley, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_api_user(self): - wf_name = 'subworkflow-default-value-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], context={'api_user': 'Thanos'}) + wf_name = "subworkflow-default-value-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], context={"api_user": "Thanos"} + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -239,45 +287,57 @@ def test_action_context_api_user(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check result. - expected_result = { - 'output': { - 'msg': 'Thanos, All your base are belong to us!' - } - } + expected_result = {"output": {"msg": "Thanos, All your base are belong to us!"}} self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_no_channel(self): - wf_name = 'subworkflow-source-channel-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "subworkflow-source-channel-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -296,45 +356,60 @@ def test_action_context_no_channel(self): # Check result. expected_result = { - 'output': { - 'msg': 'no_channel, All your base are belong to us!' - } + "output": {"msg": "no_channel, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_source_channel(self): - wf_name = 'subworkflow-source-channel-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], - context={'source_channel': 'general'}) + wf_name = "subworkflow-source-channel-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], context={"source_channel": "general"} + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -353,9 +428,7 @@ def test_action_context_source_channel(self): # Check result. expected_result = { - 'output': { - 'msg': 'general, All your base are belong to us!' - } + "output": {"msg": "general, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py index 00d26f0155..d1c0c249ab 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -47,37 +48,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -87,8 +96,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -99,22 +107,30 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ def assert_data_flow(self, data): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'data-flow.yaml') - wf_input = {'a1': data} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "data-flow.yaml") + wf_input = {"a1": data} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -127,10 +143,12 @@ def assert_data_flow(self, data): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -143,10 +161,12 @@ def assert_data_flow(self, data): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -164,20 +184,20 @@ def assert_data_flow(self, data): # Check workflow output. expected_output = { - 'a5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8'), - 'b5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8') + "a5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"), + "b5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"), } self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_string(self): - self.assert_data_flow('xyz') + self.assert_data_flow("xyz") def test_unicode_string(self): - self.assert_data_flow('床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉') + self.assert_data_flow("床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉") diff --git a/contrib/runners/orquesta_runner/tests/unit/test_delay.py b/contrib/runners/orquesta_runner/tests/unit/test_delay.py index 66834f9952..d2535c8f03 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_delay.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_delay.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerDelayTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerDelayTest, cls).setUpClass() @@ -83,8 +92,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -94,17 +102,25 @@ def test_delay(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'delay.yaml') - wf_input = {'delay': expected_delay_sec} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "delay.yaml") + wf_input = {"delay": expected_delay_sec} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) # Identify records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] # Assert delay value is rendered and assigned. @@ -116,20 +132,28 @@ def test_delay_for_with_items(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-delay.yaml') - wf_input = {'delay': expected_delay_sec} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items-delay.yaml") + wf_input = {"delay": expected_delay_sec} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert delay value is rendered and assigned. @@ -166,20 +190,30 @@ def test_delay_for_with_items_concurrency(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_input = {'concurrency': concurrency, 'delay': expected_delay_sec} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency-delay.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency, "delay": expected_delay_sec} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency-delay.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the first set of action executions from with items concurrency. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert the number of concurrent items is correct. @@ -211,7 +245,9 @@ def test_delay_for_with_items_concurrency(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Process the second set of action executions from with items concurrency. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert delay value is rendered and assigned only to the first set of action executions. diff --git a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py index 6f140040ca..d06d335993 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -48,41 +49,50 @@ from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaErrorHandlingTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ WorkflowExecutionDB, TaskExecutionDB, - ActionExecutionSchedulingQueueItemDB + ActionExecutionSchedulingQueueItemDB, ] @classmethod @@ -94,8 +104,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -104,78 +113,86 @@ def setUpClass(cls): def test_fail_inspection(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-inspection.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertIn('errors', lv_ac_db.result) - self.assertListEqual(lv_ac_db.result['errors'], expected_errors) + self.assertIn("errors", lv_ac_db.result) + self.assertListEqual(lv_ac_db.result["errors"], expected_errors) def test_fail_input_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-input-rendering.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-input-rendering.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -188,28 +205,36 @@ def test_fail_input_rendering(self): def test_fail_vars_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-vars-rendering.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-vars-rendering.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -222,30 +247,38 @@ def test_fail_vars_rendering(self): def test_fail_start_task_action(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().func.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-start-task-action.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-start-task-action.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -258,31 +291,37 @@ def test_fail_start_task_action(self): def test_fail_start_task_input_expr_eval(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().msg1.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().msg1.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-start-task-input-expr-eval.yaml' + wf_file = "fail-start-task-input-expr-eval.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -294,37 +333,40 @@ def test_fail_start_task_input_expr_eval(self): def test_fail_start_task_input_value_type(self): if six.PY3: - msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"." else: - msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"." - msg = 'ValueError: ' + msg + msg = "ValueError: " + msg expected_errors = [ - { - 'type': 'error', - 'message': msg, - 'task_id': 'task1', - 'route': 0 - } + {"type": "error", "message": msg, "task_id": "task1", "route": 0} ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-start-task-input-value-type.yaml' + wf_file = "fail-start-task-input-value-type.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - wf_input = {'var1': {'x': 'foobar'}} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"var1": {"x": "foobar"}} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert workflow and task executions failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] self.assertEqual(tk_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk_ex_db.result, {'errors': expected_errors}) + self.assertDictEqual(tk_ex_db.result, {"errors": expected_errors}) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -337,29 +379,35 @@ def test_fail_start_task_input_value_type(self): def test_fail_next_task_action(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().func.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task2', - 'route': 0 + "task_id": "task2", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-action.yaml') + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-task-action.yaml") - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -370,7 +418,9 @@ def test_fail_next_task_action(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -383,29 +433,37 @@ def test_fail_next_task_action(self): def test_fail_next_task_input_expr_eval(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().msg2.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().msg2.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task2', - 'route': 0 + "task_id": "task2", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-input-expr-eval.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-input-expr-eval.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -416,7 +474,9 @@ def test_fail_next_task_input_expr_eval(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -428,34 +488,37 @@ def test_fail_next_task_input_expr_eval(self): def test_fail_next_task_input_value_type(self): if six.PY3: - msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"." else: - msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"." - msg = 'ValueError: ' + msg + msg = "ValueError: " + msg expected_errors = [ - { - 'type': 'error', - 'message': msg, - 'task_id': 'task2', - 'route': 0 - } + {"type": "error", "message": msg, "task_id": "task2", "route": 0} ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-task-input-value-type.yaml' + wf_file = "fail-task-input-value-type.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - wf_input = {'var1': {'x': 'foobar'}} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"var1": {"x": "foobar"}} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed and workflow execution is still running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) @@ -465,11 +528,13 @@ def test_fail_next_task_input_value_type(self): # Assert workflow execution and task2 execution failed. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) - tk2_ex_db = wf_db_access.TaskExecution.query(task_id='task2')[0] + tk2_ex_db = wf_db_access.TaskExecution.query(task_id="task2")[0] self.assertEqual(tk2_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk2_ex_db.result, {'errors': expected_errors}) + self.assertDictEqual(tk2_ex_db.result, {"errors": expected_errors}) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -482,37 +547,47 @@ def test_fail_next_task_input_value_type(self): def test_fail_task_execution(self): expected_errors = [ { - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'task_id': 'task1', - 'result': { - 'stdout': '', - 'stderr': 'boom!', - 'return_code': 1, - 'failed': True, - 'succeeded': False - } + "type": "error", + "message": "Execution failed. See result for details.", + "task_id": "task1", + "result": { + "stdout": "", + "stderr": "boom!", + "return_code": 1, + "failed": True, + "succeeded": False, + }, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-execution.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-execution.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Process task1. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) # Assert workflow state and result. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -525,28 +600,36 @@ def test_fail_task_execution(self): def test_fail_task_transition(self): expected_errors = [ { - 'type': 'error', - 'message': ( + "type": "error", + "message": ( "YaqlEvaluationException: Unable to resolve key 'foobar' in expression " "'<% succeeded() and result().foobar %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-transition.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-transition.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -557,7 +640,9 @@ def test_fail_task_transition(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -570,29 +655,37 @@ def test_fail_task_transition(self): def test_fail_task_publish(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% foobar() %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% foobar() %>'. NoFunctionRegisteredException: " 'Unknown function "foobar"' ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-publish.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-publish.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -603,7 +696,9 @@ def test_fail_task_publish(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -616,26 +711,34 @@ def test_fail_task_publish(self): def test_fail_output_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-output-rendering.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-output-rendering.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -646,7 +749,9 @@ def test_fail_output_rendering(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -657,50 +762,51 @@ def test_fail_output_rendering(self): self.assertDictEqual(ac_ex_db.result, expected_result) def test_output_on_error(self): - expected_output = { - 'progress': 25 - } + expected_output = {"progress": 25} expected_errors = [ { - 'type': 'error', - 'task_id': 'task2', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "type": "error", + "task_id": "task2", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_result = { - 'errors': expected_errors, - 'output': expected_output - } + expected_result = {"errors": expected_errors, "output": expected_output} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'output-on-error.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "output-on-error.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 is already completed and workflow execution is still running. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed and workflow execution has failed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) @@ -718,26 +824,32 @@ def test_output_on_error(self): self.assertDictEqual(ac_ex_db.result, expected_result) def test_fail_manually(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-manually.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-manually.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 and workflow execution failed due to fail in the task transition. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) # Assert log task is scheduled even though the workflow execution failed manually. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'log'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "log"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -746,38 +858,44 @@ def test_fail_manually(self): # Check errors and output. expected_errors = [ { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", }, { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, + }, ] - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) def test_fail_manually_with_recovery_failure(self): - wf_file = 'fail-manually-with-recovery-failure.yaml' + wf_file = "fail-manually-with-recovery-failure.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 and workflow execution failed due to fail in the task transition. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -785,10 +903,12 @@ def test_fail_manually_with_recovery_failure(self): # Assert recover task is scheduled even though the workflow execution failed manually. # The recover task in the workflow is setup to fail. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'recover'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "recover"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -797,61 +917,70 @@ def test_fail_manually_with_recovery_failure(self): # Check errors and output. expected_errors = [ { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", }, { - 'task_id': 'recover', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "recover", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, + }, ] - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) @mock.patch.object( - runners_utils, - 'invoke_post_run', - mock.MagicMock(return_value=None)) + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_include_result_to_error_log(self): - username = 'stanley' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + username = "stanley" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.context.get('user'), username) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual(tk1_lv_ac_db.context.get("user"), username) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually override and fail the action execution and write some result. @@ -862,11 +991,13 @@ def test_include_result_to_error_log(self): tk1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED, result=result, - publish=False + publish=False, ) - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) self.assertDictEqual(tk1_lv_ac_db.result, result) @@ -882,14 +1013,10 @@ def test_include_result_to_error_log(self): # Assert result is included in the error log. expected_errors = [ { - 'message': 'Execution failed. See result for details.', - 'type': 'error', - 'task_id': 'task1', - 'result': { - '127.0.0.1': { - 'hostname': 'foobar' - } - } + "message": "Execution failed. See result for details.", + "type": "error", + "task_id": "task1", + "result": {"127.0.0.1": {"hostname": "foobar"}}, } ] diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py index d8c416f13a..faa92bd03a 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -44,37 +45,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaFunctionTest, cls).setUpClass() @@ -84,30 +93,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def _execute_workflow(self, wf_name, expected_output): - wf_file = wf_name + '.yaml' + wf_file = wf_name + ".yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -123,149 +137,139 @@ def _execute_workflow(self, wf_name, expected_output): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output, liveaction result, and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(wf_ex_db.output, expected_output) self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_data_functions_in_yaql(self): - wf_name = 'yaql-data-functions' + wf_name = "yaql-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_none_str': data_funcs.NONE_MAGIC_VALUE, - 'data_str': 'foobar' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_none_str": data_funcs.NONE_MAGIC_VALUE, + "data_str": "foobar", } self._execute_workflow(wf_name, expected_output) def test_data_functions_in_jinja(self): - wf_name = 'jinja-data-functions' + wf_name = "jinja-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}', - 'data_none_str': data_funcs.NONE_MAGIC_VALUE, - 'data_str': 'foobar', - 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_pipe_str_1": '{"foo": {"bar": "foobar"}}', + "data_none_str": data_funcs.NONE_MAGIC_VALUE, + "data_str": "foobar", + "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n", } self._execute_workflow(wf_name, expected_output) def test_path_functions_in_yaql(self): - wf_name = 'yaql-path-functions' + wf_name = "yaql-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} self._execute_workflow(wf_name, expected_output) def test_path_functions_in_jinja(self): - wf_name = 'jinja-path-functions' + wf_name = "jinja-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} self._execute_workflow(wf_name, expected_output) def test_regex_functions_in_yaql(self): - wf_name = 'yaql-regex-functions' + wf_name = "yaql-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } self._execute_workflow(wf_name, expected_output) def test_regex_functions_in_jinja(self): - wf_name = 'jinja-regex-functions' + wf_name = "jinja-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } self._execute_workflow(wf_name, expected_output) def test_time_functions_in_yaql(self): - wf_name = 'yaql-time-functions' + wf_name = "yaql-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} self._execute_workflow(wf_name, expected_output) def test_time_functions_in_jinja(self): - wf_name = 'jinja-time-functions' + wf_name = "jinja-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} self._execute_workflow(wf_name, expected_output) def test_version_functions_in_yaql(self): - wf_name = 'yaql-version-functions' + wf_name = "yaql-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } self._execute_workflow(wf_name, expected_output) def test_version_functions_in_jinja(self): - wf_name = 'jinja-version-functions' + wf_name = "jinja-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } self._execute_workflow(wf_name, expected_output) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py index 846afa19f0..3004857bee 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from orquesta_functions import st2kv @@ -37,14 +38,13 @@ from st2common.util import keyvalue as kvp_util -MOCK_CTX = {'__vars': {'st2': {'user': 'stanley'}}} -MOCK_CTX_NO_USER = {'__vars': {'st2': {}}} +MOCK_CTX = {"__vars": {"st2": {"user": "stanley"}}} +MOCK_CTX_NO_USER = {"__vars": {"st2": {}}} class DatastoreFunctionTest(unittest2.TestCase): - def test_missing_user_context(self): - self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, 'foo') + self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, "foo") def test_invalid_input(self): self.assertRaises(TypeError, st2kv.st2kv_, None, 123) @@ -55,35 +55,29 @@ def test_invalid_input(self): class UserScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(UserScopeDatastoreFunctionTest, cls).setUpClass() - user = auth_db.UserDB(name='stanley') + user = auth_db.UserDB(name="stanley") user.save() scope = kvp_const.FULL_USER_SCOPE cls.kvps = {} # Plain keys - keys = { - 'stanley:foo': 'bar', - 'stanley:foo_empty': '', - 'stanley:foo_null': None - } + keys = {"stanley:foo": "bar", "stanley:foo_empty": "", "stanley:foo_null": None} for k, v in six.iteritems(keys): instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) # Secret key - keys = { - 'stanley:fu': 'bar', - 'stanley:fu_empty': '' - } + keys = {"stanley:fu": "bar", "stanley:fu_empty": ""} for k, v in six.iteritems(keys): value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v) - instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True) + instance = kvp_db.KeyValuePairDB( + name=k, value=value, scope=scope, secret=True + ) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) @classmethod @@ -94,9 +88,9 @@ def tearDownClass(cls): super(UserScopeDatastoreFunctionTest, cls).tearDownClass() def test_key_exists(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo'), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo_empty'), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foo_null')) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo"), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo_empty"), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foo_null")) def test_key_does_not_exist(self): self.assertRaisesRegexp( @@ -104,65 +98,61 @@ def test_key_does_not_exist(self): 'The key ".*" does not exist in the StackStorm datastore.', st2kv.st2kv_, MOCK_CTX, - 'foobar' + "foobar", ) def test_key_does_not_exist_but_return_default(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default='foosball'), 'foosball') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default=''), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foobar', default=None)) + self.assertEqual( + st2kv.st2kv_(MOCK_CTX, "foobar", default="foosball"), "foosball" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foobar", default=""), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foobar", default=None)) def test_key_decrypt(self): - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu'), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=False), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=True), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty'), '') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=False), '') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=True), '') + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu"), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=False), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=True), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty"), "") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=False), "") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=True), "") @mock.patch.object( - kvp_util, 'get_key', - mock.MagicMock(side_effect=Exception('Mock failure.'))) + kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure.")) + ) def test_get_key_exception(self): self.assertRaisesRegexp( exc.ExpressionEvaluationException, - 'Mock failure.', + "Mock failure.", st2kv.st2kv_, MOCK_CTX, - 'foo' + "foo", ) class SystemScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(SystemScopeDatastoreFunctionTest, cls).setUpClass() - user = auth_db.UserDB(name='stanley') + user = auth_db.UserDB(name="stanley") user.save() scope = kvp_const.FULL_SYSTEM_SCOPE cls.kvps = {} # Plain key - keys = { - 'foo': 'bar', - 'foo_empty': '', - 'foo_null': None - } + keys = {"foo": "bar", "foo_empty": "", "foo_null": None} for k, v in six.iteritems(keys): instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) # Secret key - keys = { - 'fu': 'bar', - 'fu_empty': '' - } + keys = {"fu": "bar", "fu_empty": ""} for k, v in six.iteritems(keys): value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v) - instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True) + instance = kvp_db.KeyValuePairDB( + name=k, value=value, scope=scope, secret=True + ) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) @classmethod @@ -173,9 +163,9 @@ def tearDownClass(cls): super(SystemScopeDatastoreFunctionTest, cls).tearDownClass() def test_key_exists(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo'), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo_empty'), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foo_null')) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo"), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo_empty"), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foo_null")) def test_key_does_not_exist(self): self.assertRaisesRegexp( @@ -183,30 +173,34 @@ def test_key_does_not_exist(self): 'The key ".*" does not exist in the StackStorm datastore.', st2kv.st2kv_, MOCK_CTX, - 'foo' + "foo", ) def test_key_does_not_exist_but_return_default(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default='foosball'), 'foosball') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=''), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=None)) + self.assertEqual( + st2kv.st2kv_(MOCK_CTX, "system.foobar", default="foosball"), "foosball" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=""), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=None)) def test_key_decrypt(self): - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu'), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=False), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=True), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty'), '') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=False), '') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=True), '') + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu"), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=False), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=True), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty"), "") + self.assertNotEqual( + st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=False), "" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=True), "") @mock.patch.object( - kvp_util, 'get_key', - mock.MagicMock(side_effect=Exception('Mock failure.'))) + kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure.")) + ) def test_get_key_exception(self): self.assertRaisesRegexp( exc.ExpressionEvaluationException, - 'Mock failure.', + "Mock failure.", st2kv.st2kv_, MOCK_CTX, - 'system.foo' + "system.foo", ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py index 146e7ee39e..46ffb861e3 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaFunctionTest, cls).setUpClass() @@ -83,42 +92,57 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) - def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, - expected_status=wf_statuses.SUCCEEDED, expected_errors=None): - wf_file = wf_name + '.yaml' + def _execute_workflow( + self, + wf_name, + expected_task_sequence, + expected_output, + expected_status=wf_statuses.SUCCEEDED, + expected_errors=None, + ): + wf_file = wf_name + ".yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) for task_id, route in expected_task_sequence: tk_ex_dbs = wf_db_access.TaskExecution.query( - workflow_execution=str(wf_ex_db.id), - task_id=task_id, - task_route=route + workflow_execution=str(wf_ex_db.id), task_id=task_id, task_route=route ) if len(tk_ex_dbs) <= 0: break - tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[len(tk_ex_dbs) - 1] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[ + len(tk_ex_dbs) - 1 + ] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_db.liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db)) + self.assertTrue( + wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db) + ) wf_svc.handle_action_execution_completion(tk_ac_ex_db) @@ -131,10 +155,10 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, self.assertEqual(ac_ex_db.status, expected_status) # Check workflow output, liveaction result, and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} if expected_errors is not None: - expected_result['errors'] = expected_errors + expected_result["errors"] = expected_errors if expected_output is not None: self.assertDictEqual(wf_ex_db.output, expected_output) @@ -143,83 +167,81 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, self.assertDictEqual(ac_ex_db.result, expected_result) def test_task_functions_in_yaql(self): - wf_name = 'yaql-task-functions' + wf_name = "yaql-task-functions" expected_task_sequence = [ - ('task1', 0), - ('task3', 0), - ('task6', 0), - ('task7', 0), - ('task2', 0), - ('task4', 0), - ('task8', 1), - ('task8', 2), - ('task4', 0), - ('task9', 1), - ('task9', 2), - ('task5', 0) + ("task1", 0), + ("task3", 0), + ("task6", 0), + ("task7", 0), + ("task2", 0), + ("task4", 0), + ("task8", 1), + ("task8", 2), + ("task4", 0), + ("task9", 1), + ("task9", 2), + ("task5", 0), ] expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } self._execute_workflow(wf_name, expected_task_sequence, expected_output) def test_task_functions_in_jinja(self): - wf_name = 'jinja-task-functions' + wf_name = "jinja-task-functions" expected_task_sequence = [ - ('task1', 0), - ('task3', 0), - ('task6', 0), - ('task7', 0), - ('task2', 0), - ('task4', 0), - ('task8', 1), - ('task8', 2), - ('task4', 0), - ('task9', 1), - ('task9', 2), - ('task5', 0) + ("task1", 0), + ("task3", 0), + ("task6", 0), + ("task7", 0), + ("task2", 0), + ("task4", 0), + ("task8", 1), + ("task8", 2), + ("task4", 0), + ("task9", 1), + ("task9", 2), + ("task5", 0), ] expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } self._execute_workflow(wf_name, expected_task_sequence, expected_output) def test_task_nonexistent_in_yaql(self): - wf_name = 'yaql-task-nonexistent' + wf_name = "yaql-task-nonexistent" - expected_task_sequence = [ - ('task1', 0) - ] + expected_task_sequence = [("task1", 0)] expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% task("task0") %>\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% task(\"task0\") %>'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] @@ -228,29 +250,27 @@ def test_task_nonexistent_in_yaql(self): expected_task_sequence, expected_output, expected_status=ac_const.LIVEACTION_STATUS_FAILED, - expected_errors=expected_errors + expected_errors=expected_errors, ) def test_task_nonexistent_in_jinja(self): - wf_name = 'jinja-task-nonexistent' + wf_name = "jinja-task-nonexistent" - expected_task_sequence = [ - ('task1', 0) - ] + expected_task_sequence = [("task1", 0)] expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'JinjaEvaluationException: Unable to evaluate expression ' - '\'{{ task("task0") }}\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "JinjaEvaluationException: Unable to evaluate expression " + "'{{ task(\"task0\") }}'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] @@ -259,5 +279,5 @@ def test_task_nonexistent_in_jinja(self): expected_task_sequence, expected_output, expected_status=ac_const.LIVEACTION_STATUS_FAILED, - expected_errors=expected_errors + expected_errors=expected_errors, ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py index 3e84d7bce8..8dfdf24a84 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,37 +46,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -85,30 +94,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_inquiry(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-approval.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "ask-approval.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -118,10 +132,15 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -133,12 +152,16 @@ def test_inquiry(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -148,11 +171,15 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) @@ -162,22 +189,30 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_consecutive_inquiries(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-consecutive-approvals.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-consecutive-approvals.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -187,10 +222,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -202,12 +242,16 @@ def test_consecutive_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -217,10 +261,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_confirmation'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_confirmation", + } t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) @@ -232,12 +281,16 @@ def test_consecutive_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id)) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id)) - self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id)) self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -247,11 +300,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0] - t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id']) - self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t4_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t4_ex_db.id) + )[0] + t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"]) + self.assertEqual( + t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t4_ac_ex_db) t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id) self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED) @@ -261,22 +318,30 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_parallel_inquiries(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-parallel-approvals.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-parallel-approvals.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -286,10 +351,12 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jack'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jack"} t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -300,10 +367,12 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.PAUSING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jill'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jill"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) @@ -315,12 +384,16 @@ def test_parallel_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -332,12 +405,16 @@ def test_parallel_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id)) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id)) - self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id)) self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -347,11 +424,15 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0] - t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id']) - self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t4_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t4_ex_db.id) + )[0] + t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"]) + self.assertEqual( + t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t4_ac_ex_db) t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id) self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED) @@ -361,22 +442,30 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_nested_inquiry(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-nested-approval.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-nested-approval.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -386,23 +475,36 @@ def test_nested_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the subworkflow is already started. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.RUNNING) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Process task1 of subworkflow. - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "start"} t2_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] - t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t1_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] + t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t1_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t1_ac_ex_db) t2_t1_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t1_ex_db.id) self.assertEqual(t2_t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -410,11 +512,20 @@ def test_nested_inquiry(self): self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Process inquiry task of subworkflow and assert the subworkflow is paused. - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(t2_wf_ex_db.id), + "task_id": "get_approval", + } t2_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] - t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t2_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] + t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t2_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING + ) workflows.get_engine().process(t2_t2_ac_ex_db) t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t2_ex_db.id) self.assertEqual(t2_t2_ex_db.status, wf_statuses.PENDING) @@ -422,8 +533,10 @@ def test_nested_inquiry(self): self.assertEqual(t2_wf_ex_db.status, wf_statuses.PAUSED) # Process the corresponding task in parent workflow and assert the task is paused. - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSED) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -435,34 +548,50 @@ def test_nested_inquiry(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_t2_lv_ac_db.id)) - self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_t2_ac_ex_db.id)) - self.assertEqual(t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t2_ac_ex_db) t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_t2_ex_db.id)) - self.assertEqual(t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Assert the main workflow is running again. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete the rest of the subworkflow - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "finish"} t2_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] - t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t3_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] + t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t3_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t3_ac_ex_db) t2_t3_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t3_ex_db.id) self.assertEqual(t2_t3_ex_db.status, wf_statuses.SUCCEEDED) t2_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t2_wf_ex_db.id)) self.assertEqual(t2_wf_ex_db.status, wf_statuses.SUCCEEDED) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED) @@ -470,11 +599,15 @@ def test_nested_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete the rest of the main workflow - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_notify.py b/contrib/runners/orquesta_runner/tests/unit/test_notify.py index dc8131f100..6ca125d855 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_notify.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_notify.py @@ -25,6 +25,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -47,57 +48,60 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] MOCK_NOTIFY = { - 'on-complete': { - 'data': { - 'source_channel': 'baloney', - 'user': 'lakstorm' - }, - 'routes': [ - 'hubot' - ] + "on-complete": { + "data": {"source_channel": "baloney", "user": "lakstorm"}, + "routes": ["hubot"], } } @mock.patch.object( - notifier.Notifier, - '_post_notify_triggers', - mock.MagicMock(return_value=None)) + notifier.Notifier, "_post_notify_triggers", mock.MagicMock(return_value=None) +) @mock.patch.object( - notifier.Notifier, - '_post_generic_trigger', - mock.MagicMock(return_value=None)) + notifier.Notifier, "_post_generic_trigger", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update)) + "publish_update", + mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaNotifyTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaNotifyTest, cls).setUpClass() @@ -107,177 +111,181 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_no_notify(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. self.assertDictEqual(wf_ex_db.notify, {}) def test_no_notify_task_list(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': [] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": []} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_custom_notify_task_list(self): - wf_input = {'notify': ['task1']} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"notify": ["task1"]} + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': wf_input['notify'] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": wf_input["notify"]} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_default_notify_task_list(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'notify.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "notify.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': ['task1', 'task2', 'task3'] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": ["task1", "task2", "task3"]} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_notify_task_list_bad_item_value(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) expected_schema_failure_test_cases = [ - 'task1', # Notify must be type of list. - [123], # Item has to be type of string. - [''], # String value cannot be empty. - [' '], # String value cannot be just spaces. - [' '], # String value cannot be just tabs. - ['init task'], # String value cannot have space. - ['init-task'], # String value cannot have dash. - ['task1', 'task1'] # String values have to be unique. + "task1", # Notify must be type of list. + [123], # Item has to be type of string. + [""], # String value cannot be empty. + [" "], # String value cannot be just spaces. + [" "], # String value cannot be just tabs. + ["init task"], # String value cannot have space. + ["init-task"], # String value cannot have dash. + ["task1", "task1"], # String values have to be unique. ] for notify_tasks in expected_schema_failure_test_cases: - lv_ac_db.parameters = {'notify': notify_tasks} + lv_ac_db.parameters = {"notify": notify_tasks} try: self.assertRaises( - jsonschema.ValidationError, - action_service.request, - lv_ac_db + jsonschema.ValidationError, action_service.request, lv_ac_db ) except Exception as e: - raise AssertionError('%s: %s' % (six.text_type(e), notify_tasks)) + raise AssertionError("%s: %s" % (six.text_type(e), notify_tasks)) def test_notify_task_list_nonexistent_task(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) - lv_ac_db.parameters = {'notify': ['init_task']} + lv_ac_db.parameters = {"notify": ["init_task"]} lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) expected_result = { - 'output': None, - 'errors': [ + "output": None, + "errors": [ { - 'message': ( - 'The following tasks in the notify parameter do not ' - 'exist in the workflow definition: init_task.' + "message": ( + "The following tasks in the notify parameter do not " + "exist in the workflow definition: init_task." ) } - ] + ], } self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) self.assertDictEqual(lv_ac_db.result, expected_result) def test_notify_task_list_item_value(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) - expected_schema_success_test_cases = [ - [], - ['task1'], - ['task1', 'task2'] - ] + expected_schema_success_test_cases = [[], ["task1"], ["task1", "task2"]] for notify_tasks in expected_schema_success_test_cases: - lv_ac_db.parameters = {'notify': notify_tasks} + lv_ac_db.parameters = {"notify": notify_tasks} lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + self.assertEqual( + lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING + ) def test_cascade_notify_to_tasks(self): - wf_input = {'notify': ['task2']} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"notify": ["task2"]} + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert task1 notify is not set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertIsNone(tk1_lv_ac_db.notify) - self.assertEqual(tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertFalse(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() @@ -289,13 +297,19 @@ def test_cascade_notify_to_tasks(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 notify is set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - notify = notify_api_models.NotificationsHelper.from_model(notify_model=tk2_lv_ac_db.notify) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + notify = notify_api_models.NotificationsHelper.from_model( + notify_model=tk2_lv_ac_db.notify + ) self.assertEqual(notify, MOCK_NOTIFY) - self.assertEqual(tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertTrue(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() @@ -307,12 +321,16 @@ def test_cascade_notify_to_tasks(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 notify is not set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertIsNone(tk3_lv_ac_db.notify) - self.assertEqual(tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertFalse(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() diff --git a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py index 5bae5bab27..f23084b527 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py @@ -22,6 +22,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,12 +46,14 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] FAIL_SCHEMA = { @@ -61,25 +64,32 @@ @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(RunnerTestCase, st2tests.ExecutionDbTestCase): @classmethod def setUpClass(cls): @@ -90,8 +100,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -102,28 +111,40 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ def test_adherence_to_output_schema(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential_with_schema.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "sequential_with_schema.yaml" + ) + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk3_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) @@ -134,30 +155,39 @@ def test_adherence_to_output_schema(self): def test_fail_incorrect_output_schema(self): wf_meta = base.get_wf_fixture_meta_data( - TEST_PACK_PATH, - 'sequential_with_broken_schema.yaml' + TEST_PACK_PATH, "sequential_with_broken_schema.yaml" + ) + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input ) - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk3_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) @@ -167,9 +197,9 @@ def test_fail_incorrect_output_schema(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) expected_result = { - 'error': "Additional properties are not allowed", - 'message': 'Error validating output. See error output for more details.' + "error": "Additional properties are not allowed", + "message": "Error validating output. See error output for more details.", } - self.assertIn(expected_result['error'], ac_ex_db.result['error']) - self.assertEqual(expected_result['message'], ac_ex_db.result['message']) + self.assertIn(expected_result["error"], ac_ex_db.result["error"]) + self.assertEqual(expected_result["message"], ac_ex_db.result["message"]) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py index c2021b379e..6ade390029 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerPauseResumeTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerPauseResumeTest, cls).setUpClass() @@ -86,8 +95,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -97,56 +105,68 @@ def setUpClass(cls): def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=False)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False)) def test_pause(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=True)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True)) def test_pause_with_active_children(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) def test_pause_subworkflow_not_cascade_up_to_workflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow. - tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause(tk_lv_ac_db, cfg.CONF.system_user.user) + tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause( + tk_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -154,38 +174,52 @@ def test_pause_subworkflow_not_cascade_up_to_workflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) def test_pause_workflow_cascade_down_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) wf_ex_db = wf_ex_dbs[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) tk_ex_db = tk_ex_dbs[0] - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) tk_ac_ex_db = tk_ac_ex_dbs[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Identify the records for the subworkflow. - sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id)) + sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(tk_ac_ex_db.id) + ) self.assertEqual(len(sub_wf_ex_dbs), 1) sub_wf_ex_db = sub_wf_ex_dbs[0] - sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id)) + sub_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(sub_wf_ex_db.id) + ) self.assertEqual(len(sub_tk_ex_dbs), 1) sub_tk_ex_db = sub_tk_ex_dbs[0] - sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id)) + sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(sub_tk_ex_db.id) + ) self.assertEqual(len(sub_tk_ac_ex_dbs), 1) # Pause the main workflow and assert it is pausing because subworkflow is still running. @@ -213,32 +247,48 @@ def test_pause_workflow_cascade_down_to_subworkflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_pause_subworkflow_while_another_subworkflow_running(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -246,12 +296,16 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -267,18 +321,30 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -293,32 +359,48 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_pause_subworkflow_while_another_subworkflow_completed(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -326,18 +408,30 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -352,12 +446,16 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the target subworkflow is still pausing. - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -372,15 +470,15 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=False)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False)) def test_resume(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Pause the workflow. lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) @@ -388,63 +486,93 @@ def test_resume(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Identify the records for the running task(s) and manually complete it. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_ac_ex_dbs[0].status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk_ac_ex_dbs[0]) # Ensure the workflow is paused. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result + ) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(wf_ex_dbs[0].status, wf_statuses.PAUSED) # Resume the workflow. lv_ac_db, ac_ex_db = ac_svc.request_resume(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(wf_ex_dbs[0].status, wf_statuses.RUNNING) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 2) def test_resume_cascade_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) wf_ex_db = wf_ex_dbs[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) tk_ex_db = tk_ex_dbs[0] - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) tk_ac_ex_db = tk_ac_ex_dbs[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Identify the records for the subworkflow. - sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id)) + sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(tk_ac_ex_db.id) + ) self.assertEqual(len(sub_wf_ex_dbs), 1) sub_wf_ex_db = sub_wf_ex_dbs[0] - sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id)) + sub_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(sub_wf_ex_db.id) + ) self.assertEqual(len(sub_tk_ex_dbs), 1) sub_tk_ex_db = sub_tk_ex_dbs[0] - sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id)) + sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(sub_tk_ex_db.id) + ) self.assertEqual(len(sub_tk_ac_ex_dbs), 1) # Pause the main workflow and assert it is pausing because subworkflow is still running. @@ -481,32 +609,48 @@ def test_resume_cascade_to_subworkflow(self): self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) def test_resume_from_each_subworkflow_when_parent_is_paused(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause one of the subworkflows. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -514,12 +658,16 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -535,11 +683,13 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Pause the other subworkflow. - t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause(t2_lv_ac_db, cfg.CONF.system_user.user) + t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause( + t2_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -547,8 +697,12 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -564,7 +718,9 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -573,11 +729,19 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -592,32 +756,48 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_resume_from_subworkflow_when_parent_is_paused(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -625,12 +805,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -646,18 +830,30 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -672,7 +868,9 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -681,11 +879,19 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -696,12 +902,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): workflows.get_engine().process(t1_ac_ex_db) # Assert task3 has started and completed. - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 3) - t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(t3_ac_ex_db) @@ -710,32 +920,48 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_resume_from_subworkflow_when_parent_is_running(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -743,12 +969,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -764,11 +994,13 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -777,15 +1009,23 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -796,14 +1036,26 @@ def test_resume_from_subworkflow_when_parent_is_running(self): workflows.get_engine().process(t1_ac_ex_db) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -814,12 +1066,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self): workflows.get_engine().process(t2_ac_ex_db) # Assert task3 has started and completed. - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 3) - t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(t3_ac_ex_db) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_policies.py b/contrib/runners/orquesta_runner/tests/unit/test_policies.py index 2595609f63..81ab639262 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_policies.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_policies.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -86,7 +95,7 @@ def setUpClass(cls): policiesregistrar.register_policy_types(st2common) # Register test pack(s). - registrar_options = {'use_pack_cache': False, 'fail_on_failure': True} + registrar_options = {"use_pack_cache": False, "fail_on_failure": True} actions_registrar = actionsregistrar.ActionsRegistrar(**registrar_options) policies_registrar = policiesregistrar.PolicyRegistrar(**registrar_options) @@ -106,27 +115,37 @@ def tearDown(self): ac_ex_db.delete() def test_retry_policy_applied_on_workflow_failure(self): - wf_name = 'sequential' - wf_ac_ref = TEST_PACK + '.' + wf_name - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "sequential" + wf_ac_ref = TEST_PACK + "." + wf_name + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Ensure there is only one execution recorded. self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 1) # Identify the records for the workflow and task. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] # Manually set the status to fail. ac_svc.update_status(t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED) t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) notifier.get_notifier().process(t1_ac_ex_db) workflows.get_engine().process(t1_ac_ex_db) @@ -140,32 +159,48 @@ def test_retry_policy_applied_on_workflow_failure(self): self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 2) def test_no_retry_policy_applied_on_task_failure(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Ensure there is only one execution for the task. - tk_ac_ref = TEST_PACK + '.' + 'sequential' + tk_ac_ref = TEST_PACK + "." + "sequential" self.assertEqual(len(lv_db_access.LiveAction.query(action=tk_ac_ref)), 1) # Fail the subtask of the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_lv_ac_db = lv_db_access.LiveAction.query( + task_execution=str(t1_t1_ex_db.id) + )[0] ac_svc.update_status(t1_t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED) - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] self.assertEqual(t1_t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) notifier.get_notifier().process(t1_t1_ac_ex_db) workflows.get_engine().process(t1_t1_ac_ex_db) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py index 59f2f94d08..191f3a0681 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py @@ -20,6 +20,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from local_runner import local_shell_command_runner @@ -41,41 +42,57 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {}) -RUNNER_RESULT_RUNNING = (action_constants.LIVEACTION_STATUS_RUNNING, {'stdout': '...'}, {}) -RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {}) +RUNNER_RESULT_RUNNING = ( + action_constants.LIVEACTION_STATUS_RUNNING, + {"stdout": "..."}, + {}, +) +RUNNER_RESULT_SUCCEEDED = ( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + {"stdout": "foobar"}, + {}, +) @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestRunnerTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(OrquestRunnerTest, cls).setUpClass() @@ -85,28 +102,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_rerun_workflow(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -121,18 +145,15 @@ def test_rerun_workflow(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_FAILED) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) # Assert the workflow reran ok and is running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db2.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -140,33 +161,45 @@ def test_rerun_workflow(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1 and make sure it succeeds. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk1_ex_dbs), 2) tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp) tk1_ex_db = tk1_ex_dbs[-1] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_missing_workflow_execution_id(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -184,49 +217,52 @@ def test_rerun_with_missing_workflow_execution_id(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Manually delete the workflow_execution_id from context of the action execution. - lv_ac_db1.context.pop('workflow_execution') + lv_ac_db1.context.pop("workflow_execution") lv_ac_db1 = lv_db_access.LiveAction.add_or_update(lv_ac_db1, publish=False) ac_ex_db1 = execution_service.update_execution(lv_ac_db1, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( - 'Unable to rerun workflow execution because ' - 'workflow_execution_id is not provided.' + "Unable to rerun workflow execution because " + "workflow_execution_id is not provided." ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_invalid_workflow_execution(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -244,45 +280,50 @@ def test_rerun_with_invalid_workflow_execution(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( 'Unable to rerun workflow execution "%s" because ' - 'it does not exist.' % str(wf_ex_db.id) + "it does not exist." % str(wf_ex_db.id) ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING]), + ) def test_rerun_workflow_still_running(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING + ) # Assert workflow is still running. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -293,47 +334,52 @@ def test_rerun_workflow_still_running(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_RUNNING) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( 'Unable to rerun workflow execution "%s" because ' - 'it is not in a completed state.' % str(wf_ex_db.id) + "it is not in a completed state." % str(wf_ex_db.id) ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - workflow_service, 'request_rerun', - mock.MagicMock(side_effect=Exception('Unexpected.'))) + workflow_service, + "request_rerun", + mock.MagicMock(side_effect=Exception("Unexpected.")), + ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_unexpected_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -351,62 +397,75 @@ def test_rerun_with_unexpected_error(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) - expected_error = 'Unexpected.' + expected_error = "Unexpected." # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED), + ) def test_rerun_workflow_already_succeeded(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED) @@ -420,18 +479,15 @@ def test_rerun_workflow_already_succeeded(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) # Assert the workflow reran ok and is running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db2.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -439,40 +495,52 @@ def test_rerun_workflow_already_succeeded(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert there are two task1 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk1_ex_dbs), 2) tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp) tk1_ex_db = tk1_ex_dbs[-1] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Assert there are two task2 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk2_ex_dbs), 2) tk2_ex_dbs = sorted(tk2_ex_dbs, key=lambda x: x.start_timestamp) tk2_ex_db = tk2_ex_dbs[-1] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Assert there are two task3 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk3_ex_dbs), 2) tk3_ex_dbs = sorted(tk3_ex_dbs, key=lambda x: x.start_timestamp) tk3_ex_db = tk3_ex_dbs[-1] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py index cc0846d733..6676874586 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py @@ -25,6 +25,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -48,37 +49,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaWithItemsTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaWithItemsTest, cls).setUpClass() @@ -88,8 +97,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -101,35 +109,34 @@ def get_runner_class(cls, runner_name): def set_execution_status(self, lv_ac_db_id, status): lv_ac_db = action_utils.update_liveaction_status( - status=status, - liveaction_id=lv_ac_db_id, - publish=False + status=status, liveaction_id=lv_ac_db_id, publish=False ) - ac_ex_db = execution_service.update_execution( - lv_ac_db, - publish=False - ) + ac_ex_db = execution_service.update_execution(lv_ac_db, publish=False) return lv_ac_db, ac_ex_db def test_with_items(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -155,20 +162,26 @@ def test_with_items(self): def test_with_items_failure(self): num_items = 10 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-failure.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-failure.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -195,52 +208,68 @@ def test_with_items_failure(self): def test_with_items_empty_list(self): items = [] num_items = len(items) - wf_input = {'members': items} + wf_input = {"members": items} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Wait for the liveaction to complete. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED) + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Retrieve records from database. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) # Ensure there is no action executions for the task and the task is already completed. self.assertEqual(len(t1_ac_ex_dbs), num_items) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) - self.assertDictEqual(t1_ex_db.result, {'items': []}) + self.assertDictEqual(t1_ex_db.result, {"items": []}) # Assert the main workflow is completed. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertDictEqual(lv_ac_db.result, {'output': {'items': []}}) + self.assertDictEqual(lv_ac_db.result, {"output": {"items": []}}) def test_with_items_concurrency(self): num_items = 3 concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the first set of action executions from with items concurrency. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), concurrency) @@ -261,7 +290,9 @@ def test_with_items_concurrency(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Process the second set of action executions from with items concurrency. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -287,30 +318,37 @@ def test_with_items_concurrency(self): def test_with_items_cancellation(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -328,11 +366,12 @@ def test_with_items_cancellation(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -353,31 +392,40 @@ def test_with_items_cancellation(self): def test_with_items_concurrency_cancellation(self): concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), concurrency) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -395,11 +443,12 @@ def test_with_items_concurrency_cancellation(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -420,30 +469,37 @@ def test_with_items_concurrency_cancellation(self): def test_with_items_pause_and_resume(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -461,11 +517,12 @@ def test_with_items_pause_and_resume(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -498,31 +555,40 @@ def test_with_items_concurrency_pause_and_resume(self): num_items = 3 concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), concurrency) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -540,11 +606,12 @@ def test_with_items_concurrency_pause_and_resume(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -572,7 +639,9 @@ def test_with_items_concurrency_pause_and_resume(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check new set of action execution is scheduled. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Manually process the last action execution. @@ -585,20 +654,34 @@ def test_with_items_concurrency_pause_and_resume(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) def test_subworkflow_with_items_empty_list(self): - wf_input = {'members': []} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-empty-parent.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"members": []} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-empty-parent.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] - self.assertEqual(t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] + self.assertEqual( + t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertEqual(t1_wf_ex_db.status, wf_statuses.SUCCEEDED) # Manually processing completion of the subworkflow in task1. diff --git a/contrib/runners/python_runner/dist_utils.py b/contrib/runners/python_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/python_runner/dist_utils.py +++ b/contrib/runners/python_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/python_runner/python_runner/__init__.py b/contrib/runners/python_runner/python_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/python_runner/python_runner/__init__.py +++ b/contrib/runners/python_runner/python_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/python_runner/python_runner/python_action_wrapper.py b/contrib/runners/python_runner/python_runner/python_action_wrapper.py index 119f6bdf84..b9ae0757b3 100644 --- a/contrib/runners/python_runner/python_runner/python_action_wrapper.py +++ b/contrib/runners/python_runner/python_runner/python_action_wrapper.py @@ -18,7 +18,8 @@ # Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7 import warnings from cryptography.utils import CryptographyDeprecationWarning -warnings.filterwarnings('ignore', category=CryptographyDeprecationWarning) + +warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import os import sys @@ -33,8 +34,8 @@ # lives gets added to sys.path and we don't want that. # Note: We need to use just the suffix, because full path is different depending if the process # is ran in virtualenv or not -RUNNERS_PATH_SUFFIX = 'st2common/runners' -if __name__ == '__main__': +RUNNERS_PATH_SUFFIX = "st2common/runners" +if __name__ == "__main__": script_path = sys.path[0] if RUNNERS_PATH_SUFFIX in script_path: sys.path.pop(0) @@ -61,10 +62,7 @@ from st2common.constants.runners import PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL -__all__ = [ - 'PythonActionWrapper', - 'ActionService' -] +__all__ = ["PythonActionWrapper", "ActionService"] LOG = logging.getLogger(__name__) @@ -104,15 +102,18 @@ def datastore_service(self): # duration of the action lifetime action_name = self._action_wrapper._class_name log_level = self._action_wrapper._log_level - logger = get_logger_for_python_runner_action(action_name=action_name, - log_level=log_level) + logger = get_logger_for_python_runner_action( + action_name=action_name, log_level=log_level + ) pack_name = self._action_wrapper._pack class_name = self._action_wrapper._class_name - auth_token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) - self._datastore_service = ActionDatastoreService(logger=logger, - pack_name=pack_name, - class_name=class_name, - auth_token=auth_token) + auth_token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) + self._datastore_service = ActionDatastoreService( + logger=logger, + pack_name=pack_name, + class_name=class_name, + auth_token=auth_token, + ) return self._datastore_service ################################## @@ -130,20 +131,32 @@ def list_values(self, local=True, prefix=None): return self.datastore_service.list_values(local=local, prefix=prefix) def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): - return self.datastore_service.get_value(name=name, local=local, scope=scope, - decrypt=decrypt) + return self.datastore_service.get_value( + name=name, local=local, scope=scope, decrypt=decrypt + ) - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): - return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local, - scope=scope, encrypt=encrypt) + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): + return self.datastore_service.set_value( + name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt + ) def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): return self.datastore_service.delete_value(name=name, local=local, scope=scope) class PythonActionWrapper(object): - def __init__(self, pack, file_path, config=None, parameters=None, user=None, parent_args=None, - log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL): + def __init__( + self, + pack, + file_path, + config=None, + parameters=None, + user=None, + parent_args=None, + log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + ): """ :param pack: Name of the pack this action belongs to. :type pack: ``str`` @@ -173,19 +186,22 @@ def __init__(self, pack, file_path, config=None, parameters=None, user=None, par self._log_level = log_level self._class_name = None - self._logger = logging.getLogger('PythonActionWrapper') + self._logger = logging.getLogger("PythonActionWrapper") try: st2common_config.parse_args(args=self._parent_args) except Exception as e: - LOG.debug('Failed to parse config using parent args (parent_args=%s): %s' % - (str(self._parent_args), six.text_type(e))) + LOG.debug( + "Failed to parse config using parent args (parent_args=%s): %s" + % (str(self._parent_args), six.text_type(e)) + ) # Note: We can only set a default user value if one is not provided after parsing the # config if not self._user: # Note: We use late import to avoid performance overhead from oslo_config import cfg + self._user = cfg.CONF.system_user.user def run(self): @@ -201,26 +217,25 @@ def run(self): action_status = None action_result = output - action_output = { - 'result': action_result, - 'status': None - } + action_output = {"result": action_result, "status": None} if action_status is not None and not isinstance(action_status, bool): - sys.stderr.write('Status returned from the action run() method must either be ' - 'True or False, got: %s\n' % (action_status)) + sys.stderr.write( + "Status returned from the action run() method must either be " + "True or False, got: %s\n" % (action_status) + ) sys.stderr.write(INVALID_STATUS_ERROR_MESSAGE) sys.exit(PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE) if action_status is not None and isinstance(action_status, bool): - action_output['status'] = action_status + action_output["status"] = action_status # Special case if result object is not JSON serializable - aka user wanted to return a # non-simple type (e.g. class instance or other non-JSON serializable type) try: - json.dumps(action_output['result']) + json.dumps(action_output["result"]) except TypeError: - action_output['result'] = str(action_output['result']) + action_output["result"] = str(action_output["result"]) try: print_output = json.dumps(action_output) @@ -229,7 +244,7 @@ def run(self): # Print output to stdout so the parent can capture it sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER) - sys.stdout.write(print_output + '\n') + sys.stdout.write(print_output + "\n") sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER) sys.stdout.flush() @@ -238,17 +253,22 @@ def _get_action_instance(self): actions_cls = action_loader.register_plugin(Action, self._file_path) except Exception as e: tb_msg = traceback.format_exc() - msg = ('Failed to load action class from file "%s" (action file most likely doesn\'t ' - 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to load action class from file "%s" (action file most likely doesn\'t ' + "exist or contains invalid syntax): %s" + % (self._file_path, six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) action_cls = actions_cls[0] if actions_cls and len(actions_cls) > 0 else None if not action_cls: - raise Exception('File "%s" has no action class or the file doesn\'t exist.' % - (self._file_path)) + raise Exception( + 'File "%s" has no action class or the file doesn\'t exist.' + % (self._file_path) + ) # Retrieve name of the action class # Note - we need to either use cls.__name_ or inspect.getmro(cls)[0].__name__ to @@ -256,31 +276,45 @@ def _get_action_instance(self): self._class_name = action_cls.__name__ action_service = ActionService(action_wrapper=self) - action_instance = get_action_class_instance(action_cls=action_cls, - config=self._config, - action_service=action_service) + action_instance = get_action_class_instance( + action_cls=action_cls, config=self._config, action_service=action_service + ) return action_instance -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Python action runner process wrapper') - parser.add_argument('--pack', required=True, - help='Name of the pack this action belongs to') - parser.add_argument('--file-path', required=True, - help='Path to the action module') - parser.add_argument('--config', required=False, - help='Pack config serialized as JSON') - parser.add_argument('--parameters', required=False, - help='Serialized action parameters') - parser.add_argument('--stdin-parameters', required=False, action='store_true', - help='Serialized action parameters via stdin') - parser.add_argument('--user', required=False, - help='User who triggered the action execution') - parser.add_argument('--parent-args', required=False, - help='Command line arguments passed to the parent process serialized as ' - ' JSON') - parser.add_argument('--log-level', required=False, default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, - help='Log level for actions') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Python action runner process wrapper") + parser.add_argument( + "--pack", required=True, help="Name of the pack this action belongs to" + ) + parser.add_argument("--file-path", required=True, help="Path to the action module") + parser.add_argument( + "--config", required=False, help="Pack config serialized as JSON" + ) + parser.add_argument( + "--parameters", required=False, help="Serialized action parameters" + ) + parser.add_argument( + "--stdin-parameters", + required=False, + action="store_true", + help="Serialized action parameters via stdin", + ) + parser.add_argument( + "--user", required=False, help="User who triggered the action execution" + ) + parser.add_argument( + "--parent-args", + required=False, + help="Command line arguments passed to the parent process serialized as " + " JSON", + ) + parser.add_argument( + "--log-level", + required=False, + default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + help="Log level for actions", + ) args = parser.parse_args() config = json.loads(args.config) if args.config else {} @@ -289,46 +323,54 @@ def _get_action_instance(self): log_level = args.log_level if not isinstance(config, dict): - raise ValueError('Pack config needs to be a dictionary') + raise ValueError("Pack config needs to be a dictionary") parameters = {} if args.parameters: - LOG.debug('Getting parameters from argument') + LOG.debug("Getting parameters from argument") args_parameters = args.parameters args_parameters = json.loads(args_parameters) if args_parameters else {} parameters.update(args_parameters) if args.stdin_parameters: - LOG.debug('Getting parameters from stdin') + LOG.debug("Getting parameters from stdin") i, _, _ = select.select([sys.stdin], [], [], READ_STDIN_INPUT_TIMEOUT) if not i: - raise ValueError(('No input received and timed out while waiting for ' - 'parameters from stdin')) + raise ValueError( + ( + "No input received and timed out while waiting for " + "parameters from stdin" + ) + ) stdin_data = sys.stdin.readline().strip() try: stdin_parameters = json.loads(stdin_data) - stdin_parameters = stdin_parameters.get('parameters', {}) + stdin_parameters = stdin_parameters.get("parameters", {}) except Exception as e: - msg = ('Failed to parse parameters from stdin. Expected a JSON object with ' - '"parameters" attribute: %s' % (six.text_type(e))) + msg = ( + "Failed to parse parameters from stdin. Expected a JSON object with " + '"parameters" attribute: %s' % (six.text_type(e)) + ) raise ValueError(msg) parameters.update(stdin_parameters) - LOG.debug('Received parameters: %s', parameters) + LOG.debug("Received parameters: %s", parameters) assert isinstance(parent_args, list) - obj = PythonActionWrapper(pack=args.pack, - file_path=args.file_path, - config=config, - parameters=parameters, - user=user, - parent_args=parent_args, - log_level=log_level) + obj = PythonActionWrapper( + pack=args.pack, + file_path=args.file_path, + config=config, + parameters=parameters, + user=user, + parent_args=parent_args, + log_level=log_level, + ) obj.run() diff --git a/contrib/runners/python_runner/python_runner/python_runner.py b/contrib/runners/python_runner/python_runner/python_runner.py index fd412c890e..b11668e000 100644 --- a/contrib/runners/python_runner/python_runner/python_runner.py +++ b/contrib/runners/python_runner/python_runner/python_runner.py @@ -58,34 +58,39 @@ from python_runner import python_action_wrapper __all__ = [ - 'PythonRunner', - - 'get_runner', - 'get_metadata', + "PythonRunner", + "get_runner", + "get_metadata", ] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_ENV = 'env' -RUNNER_TIMEOUT = 'timeout' -RUNNER_LOG_LEVEL = 'log_level' +RUNNER_ENV = "env" +RUNNER_TIMEOUT = "timeout" +RUNNER_LOG_LEVEL = "log_level" # Environment variables which can't be specified by the user BLACKLISTED_ENV_VARS = [ # We don't allow user to override PYTHONPATH since this would break things - 'pythonpath' + "pythonpath" ] BASE_DIR = os.path.dirname(os.path.abspath(python_action_wrapper.__file__)) -WRAPPER_SCRIPT_NAME = 'python_action_wrapper.py' +WRAPPER_SCRIPT_NAME = "python_action_wrapper.py" WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME) class PythonRunner(GitWorktreeActionRunner): - - def __init__(self, runner_id, config=None, timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT, - log_level=None, sandbox=True, use_parent_args=True): + def __init__( + self, + runner_id, + config=None, + timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT, + log_level=None, + sandbox=True, + use_parent_args=True, + ): """ :param timeout: Action execution timeout in seconds. @@ -123,36 +128,42 @@ def pre_run(self): self._log_level = cfg.CONF.actionrunner.python_runner_log_level def run(self, action_parameters): - LOG.debug('Running pythonrunner.') - LOG.debug('Getting pack name.') + LOG.debug("Running pythonrunner.") + LOG.debug("Getting pack name.") pack = self.get_pack_ref() - LOG.debug('Getting user.') + LOG.debug("Getting user.") user = self.get_user() - LOG.debug('Serializing parameters.') - serialized_parameters = json.dumps(action_parameters if action_parameters else {}) - LOG.debug('Getting virtualenv_path.') + LOG.debug("Serializing parameters.") + serialized_parameters = json.dumps( + action_parameters if action_parameters else {} + ) + LOG.debug("Getting virtualenv_path.") virtualenv_path = get_sandbox_virtualenv_path(pack=pack) - LOG.debug('Getting python path.') + LOG.debug("Getting python path.") if self._sandbox: python_path = get_sandbox_python_binary_path(pack=pack) else: python_path = sys.executable - LOG.debug('Checking virtualenv path.') + LOG.debug("Checking virtualenv path.") if virtualenv_path and not os.path.isdir(virtualenv_path): - format_values = {'pack': pack, 'virtualenv_path': virtualenv_path} + format_values = {"pack": pack, "virtualenv_path": virtualenv_path} msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values - LOG.error('virtualenv_path set but not a directory: %s', msg) + LOG.error("virtualenv_path set but not a directory: %s", msg) raise Exception(msg) - LOG.debug('Checking entry_point.') + LOG.debug("Checking entry_point.") if not self.entry_point: - LOG.error('Action "%s" is missing entry_point attribute' % (self.action.name)) - raise Exception('Action "%s" is missing entry_point attribute' % (self.action.name)) + LOG.error( + 'Action "%s" is missing entry_point attribute' % (self.action.name) + ) + raise Exception( + 'Action "%s" is missing entry_point attribute' % (self.action.name) + ) # Note: We pass config as command line args so the actual wrapper process is standalone # and doesn't need access to db - LOG.debug('Setting args.') + LOG.debug("Setting args.") if self._use_parent_args: parent_args = json.dumps(sys.argv[1:]) @@ -161,12 +172,12 @@ def run(self, action_parameters): args = [ python_path, - '-u', # unbuffered mode so streaming mode works as expected + "-u", # unbuffered mode so streaming mode works as expected WRAPPER_SCRIPT_PATH, - '--pack=%s' % (pack), - '--file-path=%s' % (self.entry_point), - '--user=%s' % (user), - '--parent-args=%s' % (parent_args), + "--pack=%s" % (pack), + "--file-path=%s" % (self.entry_point), + "--user=%s" % (user), + "--parent-args=%s" % (parent_args), ] subprocess = concurrency.get_subprocess_module() @@ -178,35 +189,36 @@ def run(self, action_parameters): stdin_params = None if len(serialized_parameters) >= MAX_PARAM_LENGTH: stdin = subprocess.PIPE - LOG.debug('Parameters are too big...changing to stdin') + LOG.debug("Parameters are too big...changing to stdin") stdin_params = '{"parameters": %s}\n' % (serialized_parameters) - args.append('--stdin-parameters') + args.append("--stdin-parameters") else: - LOG.debug('Parameters are just right...adding them to arguments') - args.append('--parameters=%s' % (serialized_parameters)) + LOG.debug("Parameters are just right...adding them to arguments") + args.append("--parameters=%s" % (serialized_parameters)) if self._config: - args.append('--config=%s' % (json.dumps(self._config))) + args.append("--config=%s" % (json.dumps(self._config))) if self._log_level != PYTHON_RUNNER_DEFAULT_LOG_LEVEL: # We only pass --log-level parameter if non default log level value is specified - args.append('--log-level=%s' % (self._log_level)) + args.append("--log-level=%s" % (self._log_level)) # We need to ensure all the st2 dependencies are also available to the subprocess - LOG.debug('Setting env.') + LOG.debug("Setting env.") env = os.environ.copy() - env['PATH'] = get_sandbox_path(virtualenv_path=virtualenv_path) + env["PATH"] = get_sandbox_path(virtualenv_path=virtualenv_path) sandbox_python_path = get_sandbox_python_path_for_python_action( - pack=pack, - inherit_from_parent=True, - inherit_parent_virtualenv=True) + pack=pack, inherit_from_parent=True, inherit_parent_virtualenv=True + ) if self._enable_common_pack_libs: try: pack_common_libs_path = self._get_pack_common_libs_path(pack_ref=pack) except Exception as e: - LOG.debug('Failed to retrieve pack common lib path: %s' % (six.text_type(e))) + LOG.debug( + "Failed to retrieve pack common lib path: %s" % (six.text_type(e)) + ) # There is no MongoDB connection available in Lambda and pack common lib # functionality is not also mandatory for Lambda so we simply ignore those errors. # Note: We should eventually refactor this code to make runner standalone and not @@ -217,13 +229,13 @@ def run(self, action_parameters): pack_common_libs_path = None # Remove leading : (if any) - if sandbox_python_path.startswith(':'): + if sandbox_python_path.startswith(":"): sandbox_python_path = sandbox_python_path[1:] if self._enable_common_pack_libs and pack_common_libs_path: - sandbox_python_path = pack_common_libs_path + ':' + sandbox_python_path + sandbox_python_path = pack_common_libs_path + ":" + sandbox_python_path - env['PYTHONPATH'] = sandbox_python_path + env["PYTHONPATH"] = sandbox_python_path # Include user provided environment variables (if any) user_env_vars = self._get_env_vars() @@ -238,40 +250,53 @@ def run(self, action_parameters): stdout = StringIO() stderr = StringIO() - store_execution_stdout_line = functools.partial(store_execution_output_data, - output_type='stdout') - store_execution_stderr_line = functools.partial(store_execution_output_data, - output_type='stderr') - - read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stdout_line) - read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stderr_line) + store_execution_stdout_line = functools.partial( + store_execution_output_data, output_type="stdout" + ) + store_execution_stderr_line = functools.partial( + store_execution_output_data, output_type="stderr" + ) + + read_and_store_stdout = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stdout_line, + ) + read_and_store_stderr = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stderr_line, + ) command_string = list2cmdline(args) if stdin_params: - command_string = 'echo %s | %s' % (quote_unix(stdin_params), command_string) + command_string = "echo %s | %s" % (quote_unix(stdin_params), command_string) bufsize = cfg.CONF.actionrunner.stream_output_buffer_size - LOG.debug('Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s' % (bufsize, env['PATH'], - env['PYTHONPATH'], - command_string)) - exit_code, stdout, stderr, timed_out = run_command(cmd=args, - stdin=stdin, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False, - env=env, - timeout=self._timeout, - read_stdout_func=read_and_store_stdout, - read_stderr_func=read_and_store_stderr, - read_stdout_buffer=stdout, - read_stderr_buffer=stderr, - stdin_value=stdin_params, - bufsize=bufsize) - LOG.debug('Returning values: %s, %s, %s, %s', exit_code, stdout, stderr, timed_out) - LOG.debug('Returning.') + LOG.debug( + "Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s" + % (bufsize, env["PATH"], env["PYTHONPATH"], command_string) + ) + exit_code, stdout, stderr, timed_out = run_command( + cmd=args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + env=env, + timeout=self._timeout, + read_stdout_func=read_and_store_stdout, + read_stderr_func=read_and_store_stderr, + read_stdout_buffer=stdout, + read_stderr_buffer=stderr, + stdin_value=stdin_params, + bufsize=bufsize, + ) + LOG.debug( + "Returning values: %s, %s, %s, %s", exit_code, stdout, stderr, timed_out + ) + LOG.debug("Returning.") return self._get_output_values(exit_code, stdout, stderr, timed_out) def _get_pack_common_libs_path(self, pack_ref): @@ -280,7 +305,9 @@ def _get_pack_common_libs_path(self, pack_ref): (if used). """ worktree_path = self.git_worktree_path - pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref) + pack_common_libs_path = get_pack_common_libs_path_for_pack_ref( + pack_ref=pack_ref + ) if not worktree_path: return pack_common_libs_path @@ -288,18 +315,20 @@ def _get_pack_common_libs_path(self, pack_ref): # Modify the path so it uses git worktree directory pack_base_path = get_pack_base_path(pack_name=pack_ref) - new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, '') + new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, "") # Remove leading slash (if any) - if new_pack_common_libs_path.startswith('/'): + if new_pack_common_libs_path.startswith("/"): new_pack_common_libs_path = new_pack_common_libs_path[1:] - new_pack_common_libs_path = os.path.join(worktree_path, new_pack_common_libs_path) + new_pack_common_libs_path = os.path.join( + worktree_path, new_pack_common_libs_path + ) # Check to prevent directory traversal common_prefix = os.path.commonprefix([worktree_path, new_pack_common_libs_path]) if common_prefix != worktree_path: - raise ValueError('pack libs path is not located inside the pack directory') + raise ValueError("pack libs path is not located inside the pack directory") return new_pack_common_libs_path @@ -312,7 +341,7 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): :rtype: ``tuple`` """ if timed_out: - error = 'Action failed to complete in %s seconds' % (self._timeout) + error = "Action failed to complete in %s seconds" % (self._timeout) else: error = None @@ -335,16 +364,18 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): action_result = json.loads(action_result) except Exception as e: # Failed to de-serialize the result, probably it contains non-simple type or similar - LOG.warning('Failed to de-serialize result "%s": %s' % (str(action_result), - six.text_type(e))) + LOG.warning( + 'Failed to de-serialize result "%s": %s' + % (str(action_result), six.text_type(e)) + ) if action_result: if isinstance(action_result, dict): - result = action_result.get('result', None) - status = action_result.get('status', None) + result = action_result.get("result", None) + status = action_result.get("status", None) else: # Failed to de-serialize action result aka result is a string - match = re.search("'result': (.*?)$", action_result or '') + match = re.search("'result': (.*?)$", action_result or "") if match: action_result = match.groups()[0] @@ -352,21 +383,22 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): result = action_result status = None else: - result = 'None' + result = "None" status = None output = { - 'stdout': stdout, - 'stderr': stderr, - 'exit_code': exit_code, - 'result': result + "stdout": stdout, + "stderr": stderr, + "exit_code": exit_code, + "result": result, } if error: - output['error'] = error + output["error"] = error - status = self._get_final_status(action_status=status, timed_out=timed_out, - exit_code=exit_code) + status = self._get_final_status( + action_status=status, timed_out=timed_out, exit_code=exit_code + ) return (status, output, None) def _get_final_status(self, action_status, timed_out, exit_code): @@ -415,8 +447,10 @@ def _get_env_vars(self): to_delete.append(key) for key in to_delete: - LOG.debug('User specified environment variable "%s" which is being ignored...' % - (key)) + LOG.debug( + 'User specified environment variable "%s" which is being ignored...' + % (key) + ) del env_vars[key] return env_vars @@ -441,4 +475,4 @@ def get_runner(config=None): def get_metadata(): - return get_runner_metadata('python_runner')[0] + return get_runner_metadata("python_runner")[0] diff --git a/contrib/runners/python_runner/setup.py b/contrib/runners/python_runner/setup.py index c1a5d6c20a..04e55a31c0 100644 --- a/contrib/runners/python_runner/setup.py +++ b/contrib/runners/python_runner/setup.py @@ -26,30 +26,30 @@ from python_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-python', + name="stackstorm-runner-python", version=__version__, - description='Python action runner for StackStorm event-driven automation platform', - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="Python action runner for StackStorm event-driven automation platform", + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'python_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"python_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'python-script = python_runner.python_runner', + "st2common.runners.runner": [ + "python-script = python_runner.python_runner", ], - } + }, ) diff --git a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py index 27e42ecc5a..e1d39361a2 100644 --- a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py +++ b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py @@ -42,49 +42,53 @@ from st2common.util.shell import run_command from six.moves import range -__all__ = [ - 'PythonRunnerActionWrapperProcessTestCase' -] +__all__ = ["PythonRunnerActionWrapperProcessTestCase"] # Maximum limit for the process wrapper script execution time (in seconds) WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT = 0.31 -ASSERTION_ERROR_MESSAGE = (""" +ASSERTION_ERROR_MESSAGE = """ Python wrapper process script took more than %s seconds to execute (%s). This most likely means that a direct or in-direct import of a module which takes a long time to load has been added (e.g. jsonschema, pecan, kombu, etc). Please review recently changed and added code for potential slow import issues and refactor / re-organize code if possible. -""".strip()) +""".strip() BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, - '../../../python_runner/python_runner/python_action_wrapper.py') +WRAPPER_SCRIPT_PATH = os.path.join( + BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py" +) WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH) -TIME_BINARY_PATH = find_executable('time') +TIME_BINARY_PATH = find_executable("time") TIME_BINARY_AVAILABLE = TIME_BINARY_PATH is not None -@unittest2.skipIf(not TIME_BINARY_PATH, 'time binary not available') +@unittest2.skipIf(not TIME_BINARY_PATH, "time binary not available") class PythonRunnerActionWrapperProcessTestCase(unittest2.TestCase): def test_process_wrapper_exits_in_reasonable_timeframe(self): # 1. Verify wrapper script path is correct and file exists self.assertTrue(os.path.isfile(WRAPPER_SCRIPT_PATH)) # 2. First run it without time to verify path is valid - command_string = 'python %s --file-path=foo.py' % (WRAPPER_SCRIPT_PATH) + command_string = "python %s --file-path=foo.py" % (WRAPPER_SCRIPT_PATH) _, _, stderr = run_command(command_string, shell=True) - self.assertIn('usage: python_action_wrapper.py', stderr) + self.assertIn("usage: python_action_wrapper.py", stderr) - expected_msg_1 = 'python_action_wrapper.py: error: argument --pack is required' - expected_msg_2 = ('python_action_wrapper.py: error: the following arguments are ' - 'required: --pack') + expected_msg_1 = "python_action_wrapper.py: error: argument --pack is required" + expected_msg_2 = ( + "python_action_wrapper.py: error: the following arguments are " + "required: --pack" + ) self.assertTrue(expected_msg_1 in stderr or expected_msg_2 in stderr) # 3. Now time it - command_string = '%s -f "%%e" python %s' % (TIME_BINARY_PATH, WRAPPER_SCRIPT_PATH) + command_string = '%s -f "%%e" python %s' % ( + TIME_BINARY_PATH, + WRAPPER_SCRIPT_PATH, + ) # Do multiple runs and average it run_times = [] @@ -92,14 +96,18 @@ def test_process_wrapper_exits_in_reasonable_timeframe(self): count = 8 for i in range(0, count): _, _, stderr = run_command(command_string, shell=True) - stderr = stderr.strip().split('\n')[-1] + stderr = stderr.strip().split("\n")[-1] run_time_seconds = float(stderr) run_times.append(run_time_seconds) - avg_run_time_seconds = (sum(run_times) / count) - assertion_msg = ASSERTION_ERROR_MESSAGE % (WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, - avg_run_time_seconds) - self.assertTrue(avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg) + avg_run_time_seconds = sum(run_times) / count + assertion_msg = ASSERTION_ERROR_MESSAGE % ( + WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, + avg_run_time_seconds, + ) + self.assertTrue( + avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg + ) def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self): # Test case which verifies that actions with large configs and a lot of parameters work @@ -107,48 +115,55 @@ def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self): # upper limit on the size. config = {} for index in range(0, 50): - config['key_%s' % (index)] = 'value value foo %s' % (index) + config["key_%s" % (index)] = "value value foo %s" % (index) config = json.dumps(config) parameters = {} for index in range(0, 30): - parameters['param_foo_%s' % (index)] = 'some param value %s' % (index) + parameters["param_foo_%s" % (index)] = "some param value %s" % (index) parameters = json.dumps(parameters) - file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py") - command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--parameters=\'%s\'' % - (WRAPPER_SCRIPT_PATH, file_path, config, parameters)) + command_string = ( + "python %s --pack=dummy --file-path=%s --config='%s' " + "--parameters='%s'" % (WRAPPER_SCRIPT_PATH, file_path, config, parameters) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) self.assertEqual(exit_code, 0) self.assertIn('"status"', stdout) def test_stdin_params_timeout_no_stdin_data_provided(self): config = {} - file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py") # try running in a sub-shell to ensure that the stdin is empty - command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--stdin-parameters' % - (WRAPPER_SCRIPT_PATH, file_path, config)) + command_string = ( + "python %s --pack=dummy --file-path=%s --config='%s' " + "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) - expected_msg = ('ValueError: No input received and timed out while waiting for parameters ' - 'from stdin') + expected_msg = ( + "ValueError: No input received and timed out while waiting for parameters " + "from stdin" + ) self.assertEqual(exit_code, 1) self.assertIn(expected_msg, stderr) def test_stdin_params_invalid_format_friendly_error(self): config = {} - file_path = os.path.join(BASE_DIR, '../../../contrib/examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../contrib/examples/actions/noop.py") # Not a valid JSON string - command_string = ('echo "invalid" | python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--stdin-parameters' % - (WRAPPER_SCRIPT_PATH, file_path, config)) + command_string = ( + "echo \"invalid\" | python %s --pack=dummy --file-path=%s --config='%s' " + "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) - expected_msg = ('ValueError: Failed to parse parameters from stdin. Expected a JSON ' - 'object with "parameters" attribute') + expected_msg = ( + "ValueError: Failed to parse parameters from stdin. Expected a JSON " + 'object with "parameters" attribute' + ) self.assertEqual(exit_code, 1) self.assertIn(expected_msg, stderr) diff --git a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py index 328a4a0fc0..a6d300be23 100644 --- a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py +++ b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py @@ -30,13 +30,12 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'PythonRunnerBehaviorTestCase' -] +__all__ = ["PythonRunnerBehaviorTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, - '../../../python_runner/python_runner/python_action_wrapper.py') +WRAPPER_SCRIPT_PATH = os.path.join( + BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py" +) WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH) @@ -46,24 +45,24 @@ def setUp(self): config.parse_args() dir_path = tempfile.mkdtemp() - cfg.CONF.set_override(name='base_path', override=dir_path, group='system') + cfg.CONF.set_override(name="base_path", override=dir_path, group="system") self.base_path = dir_path - self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/') + self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/") # Make sure dir is deleted on tearDown self.to_delete_directories.append(self.base_path) def test_priority_of_loading_library_after_setup_pack_virtualenv(self): - ''' + """ This test checks priority of loading library, whether the library which is specified in the 'requirements.txt' of pack is loaded when a same name module is also specified in the 'requirements.txt' of st2, at a subprocess in ActionRunner. To test above, this uses 'get_library_path.py' action in 'test_library_dependencies' pack. This action returns file-path of imported module which is specified by 'module' parameter. - ''' - pack_name = 'test_library_dependencies' + """ + pack_name = "test_library_dependencies" # Before calling action, this sets up virtualenv for test pack. This pack has # requirements.txt wihch only writes 'six' module. @@ -72,20 +71,25 @@ def test_priority_of_loading_library_after_setup_pack_virtualenv(self): # This test suite expects that loaded six module is located under the virtualenv library, # because 'six' is written in the requirements.txt of 'test_library_dependencies' pack. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'}) - self.assertEqual(output['result'].find(self.virtualenvs_path), 0) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "six"} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), 0) # Conversely, this expects that 'mock' module file-path is not under sandbox library, # but the parent process's library path, because that is not under the pack's virtualenv. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'mock'}) - self.assertEqual(output['result'].find(self.virtualenvs_path), -1) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "mock"} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), -1) # While a module which is in the pack's virtualenv library is specified at 'module' # parameter of the action, this test suite expects that file-path under the parent's # library is returned when 'sandbox' parameter of PythonRunner is False. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'}, - {'_sandbox': False}) - self.assertEqual(output['result'].find(self.virtualenvs_path), -1) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "six"}, {"_sandbox": False} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), -1) def _run_action(self, pack, action, params, runner_params={}): action_db = mock.Mock() @@ -99,7 +103,8 @@ def _run_action(self, pack, action, params, runner_params={}): for key, value in runner_params.items(): setattr(runner, key, value) - runner.entry_point = os.path.join(get_fixtures_base_path(), - 'packs/%s/actions/%s' % (pack, action)) + runner.entry_point = os.path.join( + get_fixtures_base_path(), "packs/%s/actions/%s" % (pack, action) + ) runner.pre_run() return runner.run(params) diff --git a/contrib/runners/python_runner/tests/unit/test_output_schema.py b/contrib/runners/python_runner/tests/unit/test_output_schema.py index 218ba669a6..218a8f0732 100644 --- a/contrib/runners/python_runner/tests/unit/test_output_schema.py +++ b/contrib/runners/python_runner/tests/unit/test_output_schema.py @@ -33,15 +33,16 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/pascal_row.py') +PASCAL_ROW_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py" +) MOCK_SYS = mock.Mock() MOCK_SYS.argv = [] MOCK_SYS.executable = sys.executable MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" FAIL_SCHEMA = { "notvalid": { @@ -50,7 +51,7 @@ } -@mock.patch('python_runner.python_runner.sys', MOCK_SYS) +@mock.patch("python_runner.python_runner.sys", MOCK_SYS) class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase): register_packs = True register_pack_configs = True @@ -61,29 +62,23 @@ def setUpClass(cls): assert_submodules_are_checked_out() def test_adherence_to_output_schema(self): - config = self.loader(os.path.join(BASE_DIR, '../../runner.yaml')) + config = self.loader(os.path.join(BASE_DIR, "../../runner.yaml")) runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) - output_schema._validate_runner( - config[0]['output_schema'], - output - ) + (status, output, _) = runner.run({"row_index": 5}) + output_schema._validate_runner(config[0]["output_schema"], output) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1]) + self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1]) def test_fail_incorrect_output_schema(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) + (status, output, _) = runner.run({"row_index": 5}) with self.assertRaises(jsonschema.ValidationError): - output_schema._validate_runner( - FAIL_SCHEMA, - output - ) + output_schema._validate_runner(FAIL_SCHEMA, output) def _get_mock_runner_obj(self, pack=None, sandbox=None): runner = python_runner.get_runner() @@ -106,10 +101,8 @@ def _get_mock_action_obj(self): Pack gets set to the system pack so the action doesn't require a separate virtualenv. """ action = mock.Mock() - action.ref = 'dummy.action' + action.ref = "dummy.action" action.pack = SYSTEM_PACK_NAME - action.entry_point = 'foo.py' - action.runner_type = { - 'name': 'python-script' - } + action.entry_point = "foo.py" + action.runner_type = {"name": "python-script"} return action diff --git a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py index 8d55f8262d..940d087af6 100644 --- a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py +++ b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py @@ -29,7 +29,10 @@ from st2common.runners.utils import get_action_class_instance from st2common.services import config as config_service from st2common.constants.action import ACTION_OUTPUT_RESULT_DELIMITER -from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED +from st2common.constants.action import ( + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, +) from st2common.constants.action import LIVEACTION_STATUS_TIMED_OUT from st2common.constants.action import MAX_PARAM_LENGTH from st2common.constants.pack import SYSTEM_PACK_NAME @@ -43,29 +46,49 @@ import st2tests.base as tests_base -PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/pascal_row.py') -ECHOER_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/echoer.py') -TEST_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/test.py') -PATHS_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/python_paths.py') -ACTION_1_PATH = os.path.join(tests_base.get_fixtures_path(), - 'packs/dummy_pack_9/actions/list_repos_doesnt_exist.py') -ACTION_2_PATH = os.path.join(tests_base.get_fixtures_path(), - 'packs/dummy_pack_9/actions/invalid_syntax.py') -NON_SIMPLE_TYPE_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/non_simple_type.py') -PRINT_VERSION_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs', - 'test_content_version/actions/print_version.py') -PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs', - 'test_content_version/actions/print_version_local_import.py') - -PRINT_CONFIG_ITEM_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/print_config_item_doesnt_exist.py') -PRINT_TO_STDOUT_STDERR_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/print_to_stdout_and_stderr.py') +PASCAL_ROW_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py" +) +ECHOER_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/echoer.py" +) +TEST_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/test.py" +) +PATHS_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/python_paths.py" +) +ACTION_1_PATH = os.path.join( + tests_base.get_fixtures_path(), + "packs/dummy_pack_9/actions/list_repos_doesnt_exist.py", +) +ACTION_2_PATH = os.path.join( + tests_base.get_fixtures_path(), "packs/dummy_pack_9/actions/invalid_syntax.py" +) +NON_SIMPLE_TYPE_ACTION = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/non_simple_type.py" +) +PRINT_VERSION_ACTION = os.path.join( + tests_base.get_fixtures_path(), + "packs", + "test_content_version/actions/print_version.py", +) +PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join( + tests_base.get_fixtures_path(), + "packs", + "test_content_version/actions/print_version_local_import.py", +) + +PRINT_CONFIG_ITEM_ACTION = os.path.join( + tests_base.get_resources_path(), + "packs", + "pythonactions/actions/print_config_item_doesnt_exist.py", +) +PRINT_TO_STDOUT_STDERR_ACTION = os.path.join( + tests_base.get_resources_path(), + "packs", + "pythonactions/actions/print_to_stdout_and_stderr.py", +) # Note: runner inherits parent args which doesn't work with tests since test pass additional @@ -75,10 +98,10 @@ mock_sys.executable = sys.executable MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" -@mock.patch('python_runner.python_runner.sys', mock_sys) +@mock.patch("python_runner.python_runner.sys", mock_sys) class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase): register_packs = True register_pack_configs = True @@ -90,8 +113,10 @@ def setUpClass(cls): def test_runner_creation(self): runner = python_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), python_runner.PythonRunner, 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), python_runner.PythonRunner, "Creation failed. No instance." + ) def test_action_returns_non_serializable_result(self): # Actions returns non-simple type which can't be serialized, verify result is simple str() @@ -105,33 +130,37 @@ def test_action_returns_non_serializable_result(self): self.assertIsNotNone(output) if six.PY2: - expected_result_re = (r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': " - r"}\]") + expected_result_re = ( + r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': " + r"}\]" + ) else: - expected_result_re = (r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': " - r"}\]") + expected_result_re = ( + r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': " + r"}\]" + ) - match = re.match(expected_result_re, output['result']) + match = re.match(expected_result_re, output["result"]) self.assertTrue(match) def test_simple_action_with_result_no_status(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) + (status, output, _) = runner.run({"row_index": 5}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1]) + self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1]) def test_simple_action_with_result_as_None_no_status(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'b'}) + (status, output, _) = runner.run({"row_index": "b"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['exit_code'], 0) - self.assertEqual(output['result'], None) + self.assertEqual(output["exit_code"], 0) + self.assertEqual(output["result"], None) def test_simple_action_timeout(self): timeout = 0 @@ -139,30 +168,30 @@ def test_simple_action_timeout(self): runner.runner_parameters = {python_runner.RUNNER_TIMEOUT: timeout} runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 4}) + (status, output, _) = runner.run({"row_index": 4}) self.assertEqual(status, LIVEACTION_STATUS_TIMED_OUT) self.assertIsNotNone(output) - self.assertEqual(output['result'], 'None') - self.assertEqual(output['error'], 'Action failed to complete in 0 seconds') - self.assertEqual(output['exit_code'], -9) + self.assertEqual(output["result"], "None") + self.assertEqual(output["error"], "Action failed to complete in 0 seconds") + self.assertEqual(output["exit_code"], -9) def test_simple_action_with_status_succeeded(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 4}) + (status, output, _) = runner.run({"row_index": 4}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 4, 6, 4, 1]) + self.assertEqual(output["result"], [1, 4, 6, 4, 1]) def test_simple_action_with_status_failed(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'a'}) + (status, output, _) = runner.run({"row_index": "a"}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertEqual(output['result'], "This is suppose to fail don't worry!!") + self.assertEqual(output["result"], "This is suppose to fail don't worry!!") def test_simple_action_with_status_complex_type_returned_for_result(self): # Result containing a complex type shouldn't break the returning a tuple with status @@ -170,78 +199,79 @@ def test_simple_action_with_status_complex_type_returned_for_result(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'complex_type'}) + (status, output, _) = runner.run({"row_index": "complex_type"}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertIn('.*" % - runner.git_worktree_path) - self.assertRegexpMatches(output['stdout'].strip(), expected_stdout) + expected_stdout = ( + ".*" + % runner.git_worktree_path + ) + self.assertRegexpMatches(output["stdout"].strip(), expected_stdout) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_old_git_version(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ git: 'worktree' is not a git command. See 'git --help'. -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'v0.10.0'} + runner.runner_parameters = {"content_version": "v0.10.0"} - expected_msg = (r'Failed to create git worktree for pack "core": Installed git version ' - 'doesn\'t support git worktree command. To be able to utilize this ' - 'functionality you need to use git >= 2.5.0.') + expected_msg = ( + r'Failed to create git worktree for pack "core": Installed git version ' + "doesn't support git worktree command. To be able to utilize this " + "functionality you need to use git >= 2.5.0." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_pack_repo_not_git_repository(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ fatal: Not a git repository (or any parent up to mount point /home) Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set). -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'v0.10.0'} - - expected_msg = (r'Failed to create git worktree for pack "core": Pack directory ' - '".*" is not a ' - 'git repository. To utilize this functionality, pack directory needs to ' - 'be a git repository.') + runner.runner_parameters = {"content_version": "v0.10.0"} + + expected_msg = ( + r'Failed to create git worktree for pack "core": Pack directory ' + '".*" is not a ' + "git repository. To utilize this functionality, pack directory needs to " + "be a git repository." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_invalid_git_revision(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ fatal: invalid reference: vinvalid -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'vinvalid'} + runner.runner_parameters = {"content_version": "vinvalid"} - expected_msg = (r'Failed to create git worktree for pack "core": Invalid content_version ' - '"vinvalid" provided. Make sure that git repository is up ' - 'to date and contains that revision.') + expected_msg = ( + r'Failed to create git worktree for pack "core": Invalid content_version ' + '"vinvalid" provided. Make sure that git repository is up ' + "to date and contains that revision." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) def test_missing_config_item_user_friendly_error(self): @@ -953,10 +1051,12 @@ def test_missing_config_item_user_friendly_error(self): self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertIn('{}', output['stdout']) - self.assertIn('default_value', output['stdout']) - self.assertIn('Config for pack "core" is missing key "key"', output['stderr']) - self.assertIn('make sure you run "st2ctl reload --register-configs"', output['stderr']) + self.assertIn("{}", output["stdout"]) + self.assertIn("default_value", output["stdout"]) + self.assertIn('Config for pack "core" is missing key "key"', output["stderr"]) + self.assertIn( + 'make sure you run "st2ctl reload --register-configs"', output["stderr"] + ) def _get_mock_runner_obj(self, pack=None, sandbox=None): runner = python_runner.get_runner() @@ -972,22 +1072,25 @@ def _get_mock_runner_obj(self, pack=None, sandbox=None): return runner - @mock.patch('st2actions.container.base.ActionExecution.get', mock.Mock()) + @mock.patch("st2actions.container.base.ActionExecution.get", mock.Mock()) def _get_mock_runner_obj_from_container(self, pack, user, sandbox=None): container = RunnerContainer() runnertype_db = mock.Mock() - runnertype_db.name = 'python-script' - runnertype_db.runner_package = 'python_runner' - runnertype_db.runner_module = 'python_runner' + runnertype_db.name = "python-script" + runnertype_db.runner_package = "python_runner" + runnertype_db.runner_module = "python_runner" action_db = mock.Mock() action_db.pack = pack - action_db.entry_point = 'foo.py' + action_db.entry_point = "foo.py" liveaction_db = mock.Mock() - liveaction_db.id = '123' - liveaction_db.context = {'user': user} - runner = container._get_runner(runner_type_db=runnertype_db, action_db=action_db, - liveaction_db=liveaction_db) + liveaction_db.id = "123" + liveaction_db.context = {"user": user} + runner = container._get_runner( + runner_type_db=runnertype_db, + action_db=action_db, + liveaction_db=liveaction_db, + ) runner.execution = MOCK_EXECUTION runner.action = action_db runner.runner_parameters = {} @@ -1004,10 +1107,8 @@ def _get_mock_action_obj(self): Pack gets set to the system pack so the action doesn't require a separate virtualenv. """ action = mock.Mock() - action.ref = 'dummy.action' + action.ref = "dummy.action" action.pack = SYSTEM_PACK_NAME - action.entry_point = 'foo.py' - action.runner_type = { - 'name': 'python-script' - } + action.entry_point = "foo.py" + action.runner_type = {"name": "python-script"} return action diff --git a/contrib/runners/remote_runner/dist_utils.py b/contrib/runners/remote_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/remote_runner/dist_utils.py +++ b/contrib/runners/remote_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/remote_runner/remote_runner/__init__.py b/contrib/runners/remote_runner/remote_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/remote_runner/remote_runner/__init__.py +++ b/contrib/runners/remote_runner/remote_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py index 09382d9125..60880a0431 100644 --- a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py +++ b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py @@ -24,12 +24,7 @@ from st2common.runners.base import get_metadata as get_runner_metadata from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction -__all__ = [ - 'ParamikoRemoteCommandRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ParamikoRemoteCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -38,42 +33,52 @@ class ParamikoRemoteCommandRunner(BaseParallelSSHRunner): def run(self, action_parameters): remote_action = self._get_remote_action(action_parameters) - LOG.debug('Executing remote command action.', extra={'_action_params': remote_action}) + LOG.debug( + "Executing remote command action.", extra={"_action_params": remote_action} + ) result = self._run(remote_action) - LOG.debug('Executed remote_action.', extra={'_result': result}) - status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure) + LOG.debug("Executed remote_action.", extra={"_result": result}) + status = self._get_result_status( + result, cfg.CONF.ssh_runner.allow_partial_failure + ) return (status, result, None) def _run(self, remote_action): command = remote_action.get_full_command_string() - return self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout()) + return self._parallel_ssh_client.run( + command, timeout=remote_action.get_timeout() + ) def _get_remote_action(self, action_paramaters): # remote script actions with entry_point don't make sense, user probably wanted to use # "remote-shell-script" action if self.entry_point: - msg = ('Action "%s" specified "entry_point" attribute. Perhaps wanted to use ' - '"remote-shell-script" runner?' % (self.action_name)) + msg = ( + 'Action "%s" specified "entry_point" attribute. Perhaps wanted to use ' + '"remote-shell-script" runner?' % (self.action_name) + ) raise Exception(msg) command = self.runner_parameters.get(RUNNER_COMMAND, None) env_vars = self._get_env_vars() - return ParamikoRemoteCommandAction(self.action_name, - str(self.liveaction_id), - command, - env_vars=env_vars, - on_behalf_user=self._on_behalf_user, - user=self._username, - password=self._password, - private_key=self._private_key, - passphrase=self._passphrase, - hosts=self._hosts, - parallel=self._parallel, - sudo=self._sudo, - sudo_password=self._sudo_password, - timeout=self._timeout, - cwd=self._cwd) + return ParamikoRemoteCommandAction( + self.action_name, + str(self.liveaction_id), + command, + env_vars=env_vars, + on_behalf_user=self._on_behalf_user, + user=self._username, + password=self._password, + private_key=self._private_key, + passphrase=self._passphrase, + hosts=self._hosts, + parallel=self._parallel, + sudo=self._sudo, + sudo_password=self._sudo_password, + timeout=self._timeout, + cwd=self._cwd, + ) def get_runner(): @@ -81,7 +86,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('remote_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("remote_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py index 292f391850..e71e8f6314 100644 --- a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py +++ b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py @@ -27,12 +27,7 @@ from st2common.runners.base import get_metadata as get_runner_metadata from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction -__all__ = [ - 'ParamikoRemoteScriptRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ParamikoRemoteScriptRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -41,10 +36,12 @@ class ParamikoRemoteScriptRunner(BaseParallelSSHRunner): def run(self, action_parameters): remote_action = self._get_remote_action(action_parameters) - LOG.debug('Executing remote action.', extra={'_action_params': remote_action}) + LOG.debug("Executing remote action.", extra={"_action_params": remote_action}) result = self._run(remote_action) - LOG.debug('Executed remote action.', extra={'_result': result}) - status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure) + LOG.debug("Executed remote action.", extra={"_result": result}) + status = self._get_result_status( + result, cfg.CONF.ssh_runner.allow_partial_failure + ) return (status, result, None) @@ -54,109 +51,133 @@ def _run(self, remote_action): except: # If for whatever reason there is a top level exception, # we just bail here. - error = 'Failed copying content to remote boxes.' + error = "Failed copying content to remote boxes." LOG.exception(error) _, ex, tb = sys.exc_info() - copy_results = self._generate_error_results(' '.join([error, str(ex)]), tb) + copy_results = self._generate_error_results(" ".join([error, str(ex)]), tb) return copy_results try: exec_results = self._run_script_on_remote_host(remote_action) try: remote_dir = remote_action.get_remote_base_dir() - LOG.debug('Deleting remote execution dir.', extra={'_remote_dir': remote_dir}) - delete_results = self._parallel_ssh_client.delete_dir(path=remote_dir, - force=True) - LOG.debug('Deleted remote execution dir.', extra={'_result': delete_results}) + LOG.debug( + "Deleting remote execution dir.", extra={"_remote_dir": remote_dir} + ) + delete_results = self._parallel_ssh_client.delete_dir( + path=remote_dir, force=True + ) + LOG.debug( + "Deleted remote execution dir.", extra={"_result": delete_results} + ) except: - LOG.exception('Failed deleting remote dir.', extra={'_remote_dir': remote_dir}) + LOG.exception( + "Failed deleting remote dir.", extra={"_remote_dir": remote_dir} + ) finally: return exec_results except: - error = 'Failed executing script on remote boxes.' - LOG.exception(error, extra={'_action_params': remote_action}) + error = "Failed executing script on remote boxes." + LOG.exception(error, extra={"_action_params": remote_action}) _, ex, tb = sys.exc_info() - exec_results = self._generate_error_results(' '.join([error, str(ex)]), tb) + exec_results = self._generate_error_results(" ".join([error, str(ex)]), tb) return exec_results def _copy_artifacts(self, remote_action): # First create remote execution directory. remote_dir = remote_action.get_remote_base_dir() - LOG.debug('Creating remote execution dir.', extra={'_path': remote_dir}) - mkdir_result = self._parallel_ssh_client.mkdir(path=remote_action.get_remote_base_dir()) + LOG.debug("Creating remote execution dir.", extra={"_path": remote_dir}) + mkdir_result = self._parallel_ssh_client.mkdir( + path=remote_action.get_remote_base_dir() + ) # Copy the script to remote dir in remote host. local_script_abs_path = remote_action.get_local_script_abs_path() remote_script_abs_path = remote_action.get_remote_script_abs_path() file_mode = 0o744 - extra = {'_local_script': local_script_abs_path, '_remote_script': remote_script_abs_path, - 'mode': file_mode} - LOG.debug('Copying local script to remote box.', extra=extra) - put_result_1 = self._parallel_ssh_client.put(local_path=local_script_abs_path, - remote_path=remote_script_abs_path, - mirror_local_mode=False, mode=file_mode) + extra = { + "_local_script": local_script_abs_path, + "_remote_script": remote_script_abs_path, + "mode": file_mode, + } + LOG.debug("Copying local script to remote box.", extra=extra) + put_result_1 = self._parallel_ssh_client.put( + local_path=local_script_abs_path, + remote_path=remote_script_abs_path, + mirror_local_mode=False, + mode=file_mode, + ) # If `lib` exist for the script, copy that to remote host. local_libs_path = remote_action.get_local_libs_path_abs() if os.path.exists(local_libs_path): - extra = {'_local_libs': local_libs_path, '_remote_path': remote_dir} - LOG.debug('Copying libs to remote host.', extra=extra) - put_result_2 = self._parallel_ssh_client.put(local_path=local_libs_path, - remote_path=remote_dir, - mirror_local_mode=True) + extra = {"_local_libs": local_libs_path, "_remote_path": remote_dir} + LOG.debug("Copying libs to remote host.", extra=extra) + put_result_2 = self._parallel_ssh_client.put( + local_path=local_libs_path, + remote_path=remote_dir, + mirror_local_mode=True, + ) result = mkdir_result or put_result_1 or put_result_2 return result def _run_script_on_remote_host(self, remote_action): command = remote_action.get_full_command_string() - LOG.info('Command to run: %s', command) - results = self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout()) - LOG.debug('Results from script: %s', results) + LOG.info("Command to run: %s", command) + results = self._parallel_ssh_client.run( + command, timeout=remote_action.get_timeout() + ) + LOG.debug("Results from script: %s", results) return results def _get_remote_action(self, action_parameters): # remote script actions without entry_point don't make sense, user probably wanted to use # "remote-shell-cmd" action if not self.entry_point: - msg = ('Action "%s" is missing "entry_point" attribute. Perhaps wanted to use ' - '"remote-shell-script" runner?' % (self.action_name)) + msg = ( + 'Action "%s" is missing "entry_point" attribute. Perhaps wanted to use ' + '"remote-shell-script" runner?' % (self.action_name) + ) raise Exception(msg) script_local_path_abs = self.entry_point pos_args, named_args = self._get_script_args(action_parameters) named_args = self._transform_named_args(named_args) env_vars = self._get_env_vars() - remote_dir = self.runner_parameters.get(RUNNER_REMOTE_DIR, - cfg.CONF.ssh_runner.remote_dir) + remote_dir = self.runner_parameters.get( + RUNNER_REMOTE_DIR, cfg.CONF.ssh_runner.remote_dir + ) remote_dir = os.path.join(remote_dir, self.liveaction_id) - return ParamikoRemoteScriptAction(self.action_name, - str(self.liveaction_id), - script_local_path_abs, - self.libs_dir_path, - named_args=named_args, - positional_args=pos_args, - env_vars=env_vars, - on_behalf_user=self._on_behalf_user, - user=self._username, - password=self._password, - private_key=self._private_key, - remote_dir=remote_dir, - hosts=self._hosts, - parallel=self._parallel, - sudo=self._sudo, - sudo_password=self._sudo_password, - timeout=self._timeout, - cwd=self._cwd) + return ParamikoRemoteScriptAction( + self.action_name, + str(self.liveaction_id), + script_local_path_abs, + self.libs_dir_path, + named_args=named_args, + positional_args=pos_args, + env_vars=env_vars, + on_behalf_user=self._on_behalf_user, + user=self._username, + password=self._password, + private_key=self._private_key, + remote_dir=remote_dir, + hosts=self._hosts, + parallel=self._parallel, + sudo=self._sudo, + sudo_password=self._sudo_password, + timeout=self._timeout, + cwd=self._cwd, + ) @staticmethod def _generate_error_results(error, tb): error_dict = { - 'error': error, - 'traceback': ''.join(traceback.format_tb(tb, 20)) if tb else '', - 'failed': True, - 'succeeded': False, - 'return_code': 255 + "error": error, + "traceback": "".join(traceback.format_tb(tb, 20)) if tb else "", + "failed": True, + "succeeded": False, + "return_code": 255, } return error_dict @@ -166,7 +187,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('remote_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("remote_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/remote_runner/setup.py b/contrib/runners/remote_runner/setup.py index cdd61b68b1..3e83437aff 100644 --- a/contrib/runners/remote_runner/setup.py +++ b/contrib/runners/remote_runner/setup.py @@ -26,32 +26,34 @@ from remote_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-remote', + name="stackstorm-runner-remote", version=__version__, - description=('Remote SSH shell command and script action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Remote SSH shell command and script action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'remote_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"remote_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'remote-shell-cmd = remote_runner.remote_command_runner', - 'remote-shell-script = remote_runner.remote_script_runner', + "st2common.runners.runner": [ + "remote-shell-cmd = remote_runner.remote_command_runner", + "remote-shell-script = remote_runner.remote_script_runner", ], - } + }, ) diff --git a/contrib/runners/winrm_runner/dist_utils.py b/contrib/runners/winrm_runner/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/contrib/runners/winrm_runner/dist_utils.py +++ b/contrib/runners/winrm_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/winrm_runner/setup.py b/contrib/runners/winrm_runner/setup.py index f3f014277b..53d7b952e1 100644 --- a/contrib/runners/winrm_runner/setup.py +++ b/contrib/runners/winrm_runner/setup.py @@ -26,33 +26,35 @@ from winrm_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-winrm', + name="stackstorm-runner-winrm", version=__version__, - description=('WinRM shell command and PowerShell script action runner for' - ' the StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "WinRM shell command and PowerShell script action runner for" + " the StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'winrm_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"winrm_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'winrm-cmd = winrm_runner.winrm_command_runner', - 'winrm-ps-cmd = winrm_runner.winrm_ps_command_runner', - 'winrm-ps-script = winrm_runner.winrm_ps_script_runner', + "st2common.runners.runner": [ + "winrm-cmd = winrm_runner.winrm_command_runner", + "winrm-ps-cmd = winrm_runner.winrm_ps_command_runner", + "winrm-ps-script = winrm_runner.winrm_ps_script_runner", ], - } + }, ) diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py index 0803b3e25a..1ff9f2ce1d 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py @@ -32,157 +32,170 @@ class WinRmBaseTestCase(RunnerTestCase): - def setUp(self): super(WinRmBaseTestCase, self).setUpClass() self._runner = winrm_ps_command_runner.get_runner() def _init_runner(self): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'xyz987'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "xyz987", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() def test_win_rm_runner_timout_error(self): - error = WinRmRunnerTimoutError('test_response') + error = WinRmRunnerTimoutError("test_response") self.assertIsInstance(error, Exception) - self.assertEqual(error.response, 'test_response') + self.assertEqual(error.response, "test_response") with self.assertRaises(WinRmRunnerTimoutError): - raise WinRmRunnerTimoutError('test raising') + raise WinRmRunnerTimoutError("test raising") def test_init(self): - runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef') + runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'timeout': 99, - 'port': 1234, - 'scheme': 'http', - 'transport': 'ntlm', - 'verify_ssl_cert': False, - 'cwd': 'C:\\Test', - 'env': {'TEST_VAR': 'TEST_VALUE'}, - 'kwarg_op': '/'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "timeout": 99, + "port": 1234, + "scheme": "http", + "transport": "ntlm", + "verify_ssl_cert": False, + "cwd": "C:\\Test", + "env": {"TEST_VAR": "TEST_VALUE"}, + "kwarg_op": "/", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._session, None) - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 99) self.assertEqual(self._runner._read_timeout, 100) self.assertEqual(self._runner._port, 1234) - self.assertEqual(self._runner._scheme, 'http') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:1234/wsman') + self.assertEqual(self._runner._scheme, "http") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:1234/wsman") self.assertEqual(self._runner._verify_ssl, False) - self.assertEqual(self._runner._server_cert_validation, 'ignore') - self.assertEqual(self._runner._cwd, 'C:\\Test') - self.assertEqual(self._runner._env, {'TEST_VAR': 'TEST_VALUE'}) - self.assertEqual(self._runner._kwarg_op, '/') + self.assertEqual(self._runner._server_cert_validation, "ignore") + self.assertEqual(self._runner._cwd, "C:\\Test") + self.assertEqual(self._runner._env, {"TEST_VAR": "TEST_VALUE"}) + self.assertEqual(self._runner._kwarg_op, "/") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_defaults(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 60) self.assertEqual(self._runner._read_timeout, 61) self.assertEqual(self._runner._port, 5986) - self.assertEqual(self._runner._scheme, 'https') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'https://host@domain.tld:5986/wsman') + self.assertEqual(self._runner._scheme, "https") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "https://host@domain.tld:5986/wsman") self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") self.assertEqual(self._runner._cwd, None) self.assertEqual(self._runner._env, {}) - self.assertEqual(self._runner._kwarg_op, '-') + self.assertEqual(self._runner._kwarg_op, "-") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_5985_force_http(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'port': 5985, - 'scheme': 'https'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "port": 5985, + "scheme": "https", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 60) self.assertEqual(self._runner._read_timeout, 61) # ensure port is still 5985 self.assertEqual(self._runner._port, 5985) # ensure scheme is set back to http - self.assertEqual(self._runner._scheme, 'http') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:5985/wsman') + self.assertEqual(self._runner._scheme, "http") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:5985/wsman") self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") self.assertEqual(self._runner._cwd, None) self.assertEqual(self._runner._env, {}) - self.assertEqual(self._runner._kwarg_op, '-') + self.assertEqual(self._runner._kwarg_op, "-") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_none_env(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'env': None} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "env": None, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() # ensure that env is set to {} even though we passed in None self.assertEqual(self._runner._env, {}) - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_ssl_verify_true(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'verify_ssl_cert': True} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "verify_ssl_cert": True, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_ssl_verify_false(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'verify_ssl_cert': False} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "verify_ssl_cert": False, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._verify_ssl, False) - self.assertEqual(self._runner._server_cert_validation, 'ignore') + self.assertEqual(self._runner._server_cert_validation, "ignore") - @mock.patch('winrm_runner.winrm_base.Session') + @mock.patch("winrm_runner.winrm_base.Session") def test_get_session(self, mock_session): self._runner._session = None - self._runner._winrm_url = 'https://host@domain.tld:5986/wsman' - self._runner._username = 'user@domain.tld' - self._runner._password = 'abc123' - self._runner._transport = 'ntlm' - self._runner._server_cert_validation = 'validate' + self._runner._winrm_url = "https://host@domain.tld:5986/wsman" + self._runner._username = "user@domain.tld" + self._runner._password = "abc123" + self._runner._transport = "ntlm" + self._runner._server_cert_validation = "validate" self._runner._timeout = 60 self._runner._read_timeout = 61 mock_session.return_value = "session" @@ -190,12 +203,14 @@ def test_get_session(self, mock_session): result = self._runner._get_session() self.assertEqual(result, "session") self.assertEqual(result, self._runner._session) - mock_session.assert_called_with('https://host@domain.tld:5986/wsman', - auth=('user@domain.tld', 'abc123'), - transport='ntlm', - server_cert_validation='validate', - operation_timeout_sec=60, - read_timeout_sec=61) + mock_session.assert_called_with( + "https://host@domain.tld:5986/wsman", + auth=("user@domain.tld", "abc123"), + transport="ntlm", + server_cert_validation="validate", + operation_timeout_sec=60, + read_timeout_sec=61, + ) # ensure calling _get_session again doesn't create a new one, it reuses the existing old_session = self._runner._session @@ -206,18 +221,18 @@ def test_winrm_get_command_output(self): self._runner._timeout = 0 mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 123, False), - (b'output2', b'error2', 456, False), - (b'output3', b'error3', 789, True) + (b"output1", b"error1", 123, False), + (b"output2", b"error2", 456, False), + (b"output3", b"error3", 789, True), ] result = self._runner._winrm_get_command_output(mock_protocol, 567, 890) - self.assertEqual(result, (b'output1output2output3', b'error1error2error3', 789)) + self.assertEqual(result, (b"output1output2output3", b"error1error2error3", 789)) mock_protocol._raw_get_command_output.assert_has_calls = [ mock.call(567, 890), mock.call(567, 890), - mock.call(567, 890) + mock.call(567, 890), ] def test_winrm_get_command_output_timeout(self): @@ -227,7 +242,7 @@ def test_winrm_get_command_output_timeout(self): def sleep_for_timeout(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout @@ -235,9 +250,11 @@ def sleep_for_timeout(*args, **kwargs): self._runner._winrm_get_command_output(mock_protocol, 567, 890) timeout_exception = cm.exception - self.assertEqual(timeout_exception.response.std_out, b'output1') - self.assertEqual(timeout_exception.response.std_err, b'error1') - self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE) + self.assertEqual(timeout_exception.response.std_out, b"output1") + self.assertEqual(timeout_exception.response.std_err, b"error1") + self.assertEqual( + timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE + ) mock_protocol._raw_get_command_output.assert_called_with(567, 890) def test_winrm_get_command_output_operation_timeout(self): @@ -255,292 +272,354 @@ def sleep_for_timeout_then_raise(*args, **kwargs): self._runner._winrm_get_command_output(mock_protocol, 567, 890) timeout_exception = cm.exception - self.assertEqual(timeout_exception.response.std_out, b'') - self.assertEqual(timeout_exception.response.std_err, b'') - self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE) + self.assertEqual(timeout_exception.response.std_out, b"") + self.assertEqual(timeout_exception.response.std_err, b"") + self.assertEqual( + timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE + ) mock_protocol._raw_get_command_output.assert_called_with(567, 890) def test_winrm_run_cmd(self): mock_protocol = mock.MagicMock() mock_protocol.open_shell.return_value = 123 mock_protocol.run_command.return_value = 456 - mock_protocol._raw_get_command_output.return_value = (b'output', b'error', 9, True) + mock_protocol._raw_get_command_output.return_value = ( + b"output", + b"error", + 9, + True, + ) mock_session = mock.MagicMock(protocol=mock_protocol) self._init_runner() - result = self._runner._winrm_run_cmd(mock_session, "fake-command", - args=['arg1', 'arg2'], - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - expected_response = Response((b'output', b'error', 9)) + result = self._runner._winrm_run_cmd( + mock_session, + "fake-command", + args=["arg1", "arg2"], + env={"PATH": "C:\\st2\\bin"}, + cwd="C:\\st2", + ) + expected_response = Response((b"output", b"error", 9)) expected_response.timeout = False self.assertEqual(result.__dict__, expected_response.__dict__) - mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'}, - working_directory='C:\\st2') - mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2']) + mock_protocol.open_shell.assert_called_with( + env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2" + ) + mock_protocol.run_command.assert_called_with( + 123, "fake-command", ["arg1", "arg2"] + ) mock_protocol._raw_get_command_output.assert_called_with(123, 456) mock_protocol.cleanup_command.assert_called_with(123, 456) mock_protocol.close_shell.assert_called_with(123) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output") def test_winrm_run_cmd_timeout(self, mock_get_command_output): mock_protocol = mock.MagicMock() mock_protocol.open_shell.return_value = 123 mock_protocol.run_command.return_value = 456 mock_session = mock.MagicMock(protocol=mock_protocol) - mock_get_command_output.side_effect = WinRmRunnerTimoutError(Response(('', '', 5))) + mock_get_command_output.side_effect = WinRmRunnerTimoutError( + Response(("", "", 5)) + ) self._init_runner() - result = self._runner._winrm_run_cmd(mock_session, "fake-command", - args=['arg1', 'arg2'], - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - expected_response = Response(('', '', 5)) + result = self._runner._winrm_run_cmd( + mock_session, + "fake-command", + args=["arg1", "arg2"], + env={"PATH": "C:\\st2\\bin"}, + cwd="C:\\st2", + ) + expected_response = Response(("", "", 5)) expected_response.timeout = True self.assertEqual(result.__dict__, expected_response.__dict__) - mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'}, - working_directory='C:\\st2') - mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2']) + mock_protocol.open_shell.assert_called_with( + env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2" + ) + mock_protocol.run_command.assert_called_with( + 123, "fake-command", ["arg1", "arg2"] + ) mock_protocol.cleanup_command.assert_called_with(123, 456) mock_protocol.close_shell.assert_called_with(123) def test_winrm_encode(self): - result = self._runner._winrm_encode('hello world') + result = self._runner._winrm_encode("hello world") # result translated into UTF-16 little-endian - self.assertEqual(result, 'aABlAGwAbABvACAAdwBvAHIAbABkAA==') + self.assertEqual(result, "aABlAGwAbABvACAAdwBvAHIAbABkAA==") def test_winrm_ps_cmd(self): - result = self._runner._winrm_ps_cmd('abc123==') - self.assertEqual(result, 'powershell -encodedcommand abc123==') + result = self._runner._winrm_ps_cmd("abc123==") + self.assertEqual(result, "powershell -encodedcommand abc123==") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd") def test_winrm_run_ps(self, mock_run_cmd): - mock_run_cmd.return_value = Response(('output', '', 3)) + mock_run_cmd.return_value = Response(("output", "", 3)) script = "Get-ADUser stanley" - result = self._runner._winrm_run_ps("session", script, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + result = self._runner._winrm_run_ps( + "session", script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - self.assertEqual(result.__dict__, - Response(('output', '', 3)).__dict__) - expected_ps = ('powershell -encodedcommand ' + - b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii')) - mock_run_cmd.assert_called_with("session", - expected_ps, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + self.assertEqual(result.__dict__, Response(("output", "", 3)).__dict__) + expected_ps = "powershell -encodedcommand " + b64encode( + "Get-ADUser stanley".encode("utf_16_le") + ).decode("ascii") + mock_run_cmd.assert_called_with( + "session", expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd") def test_winrm_run_ps_clean_stderr(self, mock_run_cmd): - mock_run_cmd.return_value = Response(('output', 'error', 3)) + mock_run_cmd.return_value = Response(("output", "error", 3)) mock_session = mock.MagicMock() - mock_session._clean_error_msg.return_value = 'e' + mock_session._clean_error_msg.return_value = "e" script = "Get-ADUser stanley" - result = self._runner._winrm_run_ps(mock_session, script, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + result = self._runner._winrm_run_ps( + mock_session, script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - self.assertEqual(result.__dict__, - Response(('output', 'e', 3)).__dict__) - expected_ps = ('powershell -encodedcommand ' + - b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii')) - mock_run_cmd.assert_called_with(mock_session, - expected_ps, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - mock_session._clean_error_msg.assert_called_with('error') + self.assertEqual(result.__dict__, Response(("output", "e", 3)).__dict__) + expected_ps = "powershell -encodedcommand " + b64encode( + "Get-ADUser stanley".encode("utf_16_le") + ).decode("ascii") + mock_run_cmd.assert_called_with( + mock_session, expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) + mock_session._clean_error_msg.assert_called_with("error") def test_translate_response_success(self): - response = Response(('output1', 'error1', 0)) + response = Response(("output1", "error1", 0)) response.timeout = False result = self._runner._translate_response(response) - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) def test_translate_response_failure(self): - response = Response(('output1', 'error1', 123)) + response = Response(("output1", "error1", 123)) response.timeout = False result = self._runner._translate_response(response) - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 123, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 123, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) def test_translate_response_timeout(self): - response = Response(('output1', 'error1', 123)) + response = Response(("output1", "error1", 123)) response.timeout = True result = self._runner._translate_response(response) - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_make_tmp_dir(self, mock_run_ps_or_raise): - mock_run_ps_or_raise.return_value = {'stdout': ' expected \n'} + mock_run_ps_or_raise.return_value = {"stdout": " expected \n"} - result = self._runner._make_tmp_dir('C:\\Windows\\Temp') - self.assertEqual(result, 'expected') - mock_run_ps_or_raise.assert_called_with('''$parent = C:\\Windows\\Temp + result = self._runner._make_tmp_dir("C:\\Windows\\Temp") + self.assertEqual(result, "expected") + mock_run_ps_or_raise.assert_called_with( + """$parent = C:\\Windows\\Temp $name = [System.IO.Path]::GetRandomFileName() $path = Join-Path $parent $name New-Item -ItemType Directory -Path $path | Out-Null -$path''', - ("Unable to make temporary directory for" - " powershell script")) +$path""", + ("Unable to make temporary directory for" " powershell script"), + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_rm_dir(self, mock_run_ps_or_raise): - self._runner._rm_dir('C:\\Windows\\Temp\\testtmpdir') + self._runner._rm_dir("C:\\Windows\\Temp\\testtmpdir") mock_run_ps_or_raise.assert_called_with( 'Remove-Item -Force -Recurse -Path "C:\\Windows\\Temp\\testtmpdir"', - "Unable to remove temporary directory for powershell script") + "Unable to remove temporary directory for powershell script", + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('winrm_runner.winrm_base.open') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("winrm_runner.winrm_base.open") + @mock.patch("os.path.exists") def test_upload_chunk_file(self, mock_os_path_exists, mock_open, mock_upload_chunk): mock_os_path_exists.return_value = True mock_src_file = mock.MagicMock() mock_src_file.read.return_value = "test data" mock_open.return_value.__enter__.return_value = mock_src_file - self._runner._upload('/opt/data/test.ps1', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('/opt/data/test.ps1') - mock_open.assert_called_with('/opt/data/test.ps1', 'r') + self._runner._upload("/opt/data/test.ps1", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("/opt/data/test.ps1") + mock_open.assert_called_with("/opt/data/test.ps1", "r") mock_src_file.read.assert_called_with() - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'test data') - ]) + mock_upload_chunk.assert_has_calls( + [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")] + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("os.path.exists") def test_upload_chunk_data(self, mock_os_path_exists, mock_upload_chunk): mock_os_path_exists.return_value = False - self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('test data') - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'test data') - ]) + self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("test data") + mock_upload_chunk.assert_has_calls( + [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")] + ) - @mock.patch('winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES', 2) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES", 2) + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("os.path.exists") def test_upload_chunk_multiple_chunks(self, mock_os_path_exists, mock_upload_chunk): mock_os_path_exists.return_value = False - self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('test data') - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'te'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'st'), - mock.call('C:\\Windows\\Temp\\test.ps1', ' d'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'at'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'a'), - ]) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("test data") + mock_upload_chunk.assert_has_calls( + [ + mock.call("C:\\Windows\\Temp\\test.ps1", "te"), + mock.call("C:\\Windows\\Temp\\test.ps1", "st"), + mock.call("C:\\Windows\\Temp\\test.ps1", " d"), + mock.call("C:\\Windows\\Temp\\test.ps1", "at"), + mock.call("C:\\Windows\\Temp\\test.ps1", "a"), + ] + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_upload_chunk(self, mock_run_ps_or_raise): - self._runner._upload_chunk('C:\\Windows\\Temp\\testtmp.ps1', 'hello world') + self._runner._upload_chunk("C:\\Windows\\Temp\\testtmp.ps1", "hello world") mock_run_ps_or_raise.assert_called_with( - '''$filePath = "C:\\Windows\\Temp\\testtmp.ps1" + """$filePath = "C:\\Windows\\Temp\\testtmp.ps1" $s = @" aGVsbG8gd29ybGQ= "@ $data = [System.Convert]::FromBase64String($s) Add-Content -value $data -encoding byte -path $filePath -''', - "Failed to upload chunk of powershell script") +""", + "Failed to upload chunk of powershell script", + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir") def test_tmp_script(self, mock_make_tmp_dir, mock_upload, mock_rm_dir): - mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123' - - with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp: - self.assertEqual(tmp, 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp') - mock_upload.assert_called_with('Get-ChildItem', - 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123') - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir') - def test_tmp_script_cleans_up_when_raises(self, mock_make_tmp_dir, mock_upload, - mock_rm_dir): - mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123' + mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123" + + with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp: + self.assertEqual(tmp, "C:\\Windows\\Temp\\abc123\\script.ps1") + mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp") + mock_upload.assert_called_with( + "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1" + ) + mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123") + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir") + def test_tmp_script_cleans_up_when_raises( + self, mock_make_tmp_dir, mock_upload, mock_rm_dir + ): + mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123" mock_upload.side_effect = RuntimeError with self.assertRaises(RuntimeError): - with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp: + with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp: self.assertEqual(tmp, "can never get here") - mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp') - mock_upload.assert_called_with('Get-ChildItem', - 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123') + mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp") + mock_upload.assert_called_with( + "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1" + ) + mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123") - @mock.patch('winrm.Protocol') + @mock.patch("winrm.Protocol") def test_run_cmd(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_cmd_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_cmd_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -548,61 +627,82 @@ def test_run_cmd_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -610,91 +710,113 @@ def test_run_ps_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode") def test_run_ps_params(self, mock_winrm_encode, mock_run_ps): - mock_winrm_encode.return_value = 'xyz123==' + mock_winrm_encode.return_value = "xyz123==" mock_run_ps.return_value = "expected" self._init_runner() - result = self._runner.run_ps("Get-Location", '-param1 value1 arg1') + result = self._runner.run_ps("Get-Location", "-param1 value1 arg1") self.assertEqual(result, "expected") - mock_winrm_encode.assert_called_with('& {Get-Location} -param1 value1 arg1') - mock_run_ps.assert_called_with('xyz123==', is_b64=True) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script') - def test_run_ps_large_command_convert_to_script(self, mock_run_ps_script, - mock_winrm_ps_cmd): + mock_winrm_encode.assert_called_with("& {Get-Location} -param1 value1 arg1") + mock_run_ps.assert_called_with("xyz123==", is_b64=True) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script") + def test_run_ps_large_command_convert_to_script( + self, mock_run_ps_script, mock_winrm_ps_cmd + ): mock_run_ps_script.return_value = "expected" # max length of a command in powershelll - script = 'powershell -encodedcommand ' - script += '#' * (WINRM_MAX_CMD_LENGTH + 1 - len(script)) + script = "powershell -encodedcommand " + script += "#" * (WINRM_MAX_CMD_LENGTH + 1 - len(script)) mock_winrm_ps_cmd.return_value = script self._init_runner() - result = self._runner.run_ps('$PSVersionTable') + result = self._runner.run_ps("$PSVersionTable") self.assertEqual(result, "expected") - mock_run_ps_script.assert_called_with('$PSVersionTable', None) + mock_run_ps_script.assert_called_with("$PSVersionTable", None) - @mock.patch('winrm.Protocol') + @mock.patch("winrm.Protocol") def test__run_ps(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test__run_ps_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test__run_ps_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -702,238 +824,236 @@ def test__run_ps_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps") def test__run_ps_b64_default(self, mock_winrm_run_ps): - mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0, - timeout=False, - std_out='output1', - std_err='error1') + mock_winrm_run_ps.return_value = mock.MagicMock( + status_code=0, timeout=False, std_out="output1", std_err="error1" + ) self._init_runner() result = self._runner._run_ps("$PSVersionTable") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - mock_winrm_run_ps.assert_called_with(self._runner._session, - '$PSVersionTable', - env={}, - cwd=None, - is_b64=False) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + mock_winrm_run_ps.assert_called_with( + self._runner._session, "$PSVersionTable", env={}, cwd=None, is_b64=False + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps") def test__run_ps_b64_true(self, mock_winrm_run_ps): - mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0, - timeout=False, - std_out='output1', - std_err='error1') + mock_winrm_run_ps.return_value = mock.MagicMock( + status_code=0, timeout=False, std_out="output1", std_err="error1" + ) self._init_runner() result = self._runner._run_ps("xyz123", is_b64=True) - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - mock_winrm_run_ps.assert_called_with(self._runner._session, - 'xyz123', - env={}, - cwd=None, - is_b64=True) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + mock_winrm_run_ps.assert_called_with( + self._runner._session, "xyz123", env={}, cwd=None, is_b64=True + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script") def test__run_ps_script(self, mock_tmp_script, mock_run_ps): - mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1' - mock_run_ps.return_value = 'expected' + mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1" + mock_run_ps.return_value = "expected" self._init_runner() result = self._runner._run_ps_script("$PSVersionTable") - self.assertEqual(result, 'expected') - mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()', - '$PSVersionTable') - mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1}') + self.assertEqual(result, "expected") + mock_tmp_script.assert_called_with( + "[System.IO.Path]::GetTempPath()", "$PSVersionTable" + ) + mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1}") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script") def test__run_ps_script_with_params(self, mock_tmp_script, mock_run_ps): - mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1' - mock_run_ps.return_value = 'expected' + mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1" + mock_run_ps.return_value = "expected" self._init_runner() - result = self._runner._run_ps_script("Get-ChildItem", '-param1 value1 arg1') - self.assertEqual(result, 'expected') - mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()', - 'Get-ChildItem') - mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1} -param1 value1 arg1') + result = self._runner._run_ps_script("Get-ChildItem", "-param1 value1 arg1") + self.assertEqual(result, "expected") + mock_tmp_script.assert_called_with( + "[System.IO.Path]::GetTempPath()", "Get-ChildItem" + ) + mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1} -param1 value1 arg1") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") def test__run_ps_or_raise(self, mock_run_ps): - mock_run_ps.return_value = ('success', - { - 'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output', - 'stderr': 'error', - }, - None) + mock_run_ps.return_value = ( + "success", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output", + "stderr": "error", + }, + None, + ) self._init_runner() - result = self._runner._run_ps_or_raise('Get-ChildItem', 'my error message') - self.assertEqual(result, { - 'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output', - 'stderr': 'error', - }) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') + result = self._runner._run_ps_or_raise("Get-ChildItem", "my error message") + self.assertEqual( + result, + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output", + "stderr": "error", + }, + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") def test__run_ps_or_raise_raises_on_failure(self, mock_run_ps): - mock_run_ps.return_value = ('success', - { - 'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output', - 'stderr': 'error', - }, - None) + mock_run_ps.return_value = ( + "success", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output", + "stderr": "error", + }, + None, + ) self._init_runner() with self.assertRaises(RuntimeError): - self._runner._run_ps_or_raise('Get-ChildItem', 'my error message') + self._runner._run_ps_or_raise("Get-ChildItem", "my error message") def test_multireplace(self): - multireplace_map = {'a': 'x', - 'c': 'y', - 'aaa': 'z'} - result = self._runner._multireplace('aaaccaa', multireplace_map) - self.assertEqual(result, 'zyyxx') + multireplace_map = {"a": "x", "c": "y", "aaa": "z"} + result = self._runner._multireplace("aaaccaa", multireplace_map) + self.assertEqual(result, "zyyxx") def test_multireplace_powershell(self): - param_str = ( - '\n' - '\r' - '\t' - '\a' - '\b' - '\f' - '\v' - '"' - '\'' - '`' - '\0' - '$' - ) + param_str = "\n" "\r" "\t" "\a" "\b" "\f" "\v" '"' "'" "`" "\0" "$" result = self._runner._multireplace(param_str, PS_ESCAPE_SEQUENCES) - self.assertEqual(result, ( - '`n' - '`r' - '`t' - '`a' - '`b' - '`f' - '`v' - '`"' - '`\'' - '``' - '`0' - '`$' - )) + self.assertEqual( + result, ("`n" "`r" "`t" "`a" "`b" "`f" "`v" '`"' "`'" "``" "`0" "`$") + ) def test_param_to_ps_none(self): # test None/null param = None result = self._runner._param_to_ps(param) - self.assertEqual(result, '$null') + self.assertEqual(result, "$null") def test_param_to_ps_string(self): # test ascii - param_str = 'StackStorm 1234' + param_str = "StackStorm 1234" result = self._runner._param_to_ps(param_str) self.assertEqual(result, '"StackStorm 1234"') # test escaped - param_str = '\n\r\t' + param_str = "\n\r\t" result = self._runner._param_to_ps(param_str) self.assertEqual(result, '"`n`r`t"') def test_param_to_ps_bool(self): # test True result = self._runner._param_to_ps(True) - self.assertEqual(result, '$true') + self.assertEqual(result, "$true") # test False result = self._runner._param_to_ps(False) - self.assertEqual(result, '$false') + self.assertEqual(result, "$false") def test_param_to_ps_integer(self): result = self._runner._param_to_ps(9876) - self.assertEqual(result, '9876') + self.assertEqual(result, "9876") result = self._runner._param_to_ps(-765) - self.assertEqual(result, '-765') + self.assertEqual(result, "-765") def test_param_to_ps_float(self): result = self._runner._param_to_ps(98.76) - self.assertEqual(result, '98.76') + self.assertEqual(result, "98.76") result = self._runner._param_to_ps(-76.5) - self.assertEqual(result, '-76.5') + self.assertEqual(result, "-76.5") def test_param_to_ps_list(self): - input_list = ['StackStorm Test String', - '`\0$', - True, - 99] + input_list = ["StackStorm Test String", "`\0$", True, 99] result = self._runner._param_to_ps(input_list) self.assertEqual(result, '@("StackStorm Test String", "```0`$", $true, 99)') def test_param_to_ps_list_nested(self): - input_list = [['a'], ['b'], [['c']]] + input_list = [["a"], ["b"], [["c"]]] result = self._runner._param_to_ps(input_list) self.assertEqual(result, '@(@("a"), @("b"), @(@("c")))') def test_param_to_ps_dict(self): input_list = collections.OrderedDict( - [('str key', 'Value String'), - ('esc str\n', '\b\f\v"'), - (False, True), - (11, 99), - (18.3, 12.34)]) + [ + ("str key", "Value String"), + ("esc str\n", '\b\f\v"'), + (False, True), + (11, 99), + (18.3, 12.34), + ] + ) result = self._runner._param_to_ps(input_list) expected_str = ( '@{"str key" = "Value String"; ' - '"esc str`n" = "`b`f`v`\""; ' - '$false = $true; ' - '11 = 99; ' - '18.3 = 12.34}' + '"esc str`n" = "`b`f`v`""; ' + "$false = $true; " + "11 = 99; " + "18.3 = 12.34}" ) self.assertEqual(result, expected_str) def test_param_to_ps_dict_nexted(self): input_list = collections.OrderedDict( - [('a', {'deep_a': 'value'}), - ('b', {'deep_b': {'deep_deep_b': 'value'}})]) + [("a", {"deep_a": "value"}), ("b", {"deep_b": {"deep_deep_b": "value"}})] + ) result = self._runner._param_to_ps(input_list) expected_str = ( '@{"a" = @{"deep_a" = "value"}; ' @@ -945,21 +1065,22 @@ def test_param_to_ps_deep_nested_dict_outer(self): #### # dict as outer container input_dict = collections.OrderedDict( - [('a', [{'deep_a': 'value'}, - {'deep_b': ['a', 'b', 'c']}])]) + [("a", [{"deep_a": "value"}, {"deep_b": ["a", "b", "c"]}])] + ) result = self._runner._param_to_ps(input_dict) expected_str = ( - '@{"a" = @(@{"deep_a" = "value"}, ' - '@{"deep_b" = @("a", "b", "c")})}' + '@{"a" = @(@{"deep_a" = "value"}, ' '@{"deep_b" = @("a", "b", "c")})}' ) self.assertEqual(result, expected_str) def test_param_to_ps_deep_nested_list_outer(self): #### # list as outer container - input_list = [{'deep_a': 'value'}, - {'deep_b': ['a', 'b', 'c']}, - {'deep_c': [{'x': 'y'}]}] + input_list = [ + {"deep_a": "value"}, + {"deep_b": ["a", "b", "c"]}, + {"deep_c": [{"x": "y"}]}, + ] result = self._runner._param_to_ps(input_list) expected_str = ( '@(@{"deep_a" = "value"}, ' @@ -969,45 +1090,48 @@ def test_param_to_ps_deep_nested_list_outer(self): self.assertEqual(result, expected_str) def test_transform_params_to_ps(self): - positional_args = [1, 'a', '\n'] + positional_args = [1, "a", "\n"] named_args = collections.OrderedDict( - [('a', 'value1'), - ('b', True), - ('c', ['x', 'y']), - ('d', {'z': 'w'})] + [("a", "value1"), ("b", True), ("c", ["x", "y"]), ("d", {"z": "w"})] ) - result_pos, result_named = self._runner._transform_params_to_ps(positional_args, - named_args) - self.assertEqual(result_pos, ['1', '"a"', '"`n"']) - self.assertEqual(result_named, collections.OrderedDict([ - ('a', '"value1"'), - ('b', '$true'), - ('c', '@("x", "y")'), - ('d', '@{"z" = "w"}')])) + result_pos, result_named = self._runner._transform_params_to_ps( + positional_args, named_args + ) + self.assertEqual(result_pos, ["1", '"a"', '"`n"']) + self.assertEqual( + result_named, + collections.OrderedDict( + [ + ("a", '"value1"'), + ("b", "$true"), + ("c", '@("x", "y")'), + ("d", '@{"z" = "w"}'), + ] + ), + ) def test_transform_params_to_ps_none(self): positional_args = None named_args = None - result_pos, result_named = self._runner._transform_params_to_ps(positional_args, - named_args) + result_pos, result_named = self._runner._transform_params_to_ps( + positional_args, named_args + ) self.assertEqual(result_pos, None) self.assertEqual(result_named, None) def test_create_ps_params_string(self): - positional_args = [1, 'a', '\n'] + positional_args = [1, "a", "\n"] named_args = collections.OrderedDict( - [('-a', 'value1'), - ('-b', True), - ('-c', ['x', 'y']), - ('-d', {'z': 'w'})] + [("-a", "value1"), ("-b", True), ("-c", ["x", "y"]), ("-d", {"z": "w"})] ) result = self._runner.create_ps_params_string(positional_args, named_args) - self.assertEqual(result, - '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"') + self.assertEqual( + result, '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"' + ) def test_create_ps_params_string_none(self): positional_args = None diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py index 9ff36a1b47..78365a333b 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py @@ -23,23 +23,22 @@ class WinRmCommandRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmCommandRunnerTestCase, self).setUpClass() self._runner = winrm_command_runner.get_runner() def test_init(self): - runner = winrm_command_runner.WinRmCommandRunner('abcdef') + runner = winrm_command_runner.WinRmCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd') + @mock.patch("winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd") def test_run(self, mock_run_cmd): - mock_run_cmd.return_value = 'expected' + mock_run_cmd.return_value = "expected" - self._runner.runner_parameters = {'cmd': 'ipconfig /all'} + self._runner.runner_parameters = {"cmd": "ipconfig /all"} result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_cmd.assert_called_with('ipconfig /all') + self.assertEqual(result, "expected") + mock_run_cmd.assert_called_with("ipconfig /all") diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py index d6bae23e2c..90d9e95abd 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py @@ -23,23 +23,22 @@ class WinRmPsCommandRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmPsCommandRunnerTestCase, self).setUpClass() self._runner = winrm_ps_command_runner.get_runner() def test_init(self): - runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef') + runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps') + @mock.patch("winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps") def test_run(self, mock_run_ps): - mock_run_ps.return_value = 'expected' + mock_run_ps.return_value = "expected" - self._runner.runner_parameters = {'cmd': 'Get-ADUser stanley'} + self._runner.runner_parameters = {"cmd": "Get-ADUser stanley"} result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_ps.assert_called_with('Get-ADUser stanley') + self.assertEqual(result, "expected") + mock_run_ps.assert_called_with("Get-ADUser stanley") diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py index b3c1e14034..c1414c25e7 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py @@ -22,39 +22,41 @@ from winrm_runner import winrm_ps_script_runner from winrm_runner.winrm_base import WinRmBaseRunner -FIXTURES_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') +FIXTURES_PATH = os.path.join(os.path.dirname(__file__), "fixtures") POWERSHELL_SCRIPT_PATH = os.path.join(FIXTURES_PATH, "TestScript.ps1") class WinRmPsScriptRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmPsScriptRunnerTestCase, self).setUpClass() self._runner = winrm_ps_script_runner.get_runner() def test_init(self): - runner = winrm_ps_script_runner.WinRmPsScriptRunner('abcdef') + runner = winrm_ps_script_runner.WinRmPsScriptRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args') - @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps') + @mock.patch( + "winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args" + ) + @mock.patch("winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps") def test_run(self, mock_run_ps, mock_get_script_args): - mock_run_ps.return_value = 'expected' - pos_args = [1, 'abc'] + mock_run_ps.return_value = "expected" + pos_args = [1, "abc"] named_args = {"d": {"test": ["\r", True, 3]}} mock_get_script_args.return_value = (pos_args, named_args) self._runner.entry_point = POWERSHELL_SCRIPT_PATH self._runner.runner_parameters = {} - self._runner._kwarg_op = '-' + self._runner._kwarg_op = "-" result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_ps.assert_called_with('''[CmdletBinding()] + self.assertEqual(result, "expected") + mock_run_ps.assert_called_with( + """[CmdletBinding()] Param( [bool]$p_bool, [int]$p_integer, @@ -77,5 +79,6 @@ def test_run(self, mock_run_ps, mock_get_script_args): Write-Output "p_obj = $($p_obj | ConvertTo-Json -Compress)" Write-Output "p_pos0 = $p_pos0" Write-Output "p_pos1 = $p_pos1" -''', - '-d @{"test" = @("`r", $true, 3)} 1 "abc"') +""", + '-d @{"test" = @("`r", $true, 3)} 1 "abc"', + ) diff --git a/contrib/runners/winrm_runner/winrm_runner/__init__.py b/contrib/runners/winrm_runner/winrm_runner/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/contrib/runners/winrm_runner/winrm_runner/__init__.py +++ b/contrib/runners/winrm_runner/winrm_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py index fb26e49db6..9bebbedc7b 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py @@ -32,7 +32,7 @@ from winrm.exceptions import WinRMOperationTimeoutError __all__ = [ - 'WinRmBaseRunner', + "WinRmBaseRunner", ] LOG = logging.getLogger(__name__) @@ -49,7 +49,7 @@ RUNNER_USERNAME = "username" RUNNER_VERIFY_SSL = "verify_ssl_cert" -WINRM_DEFAULT_TMP_DIR_PS = '[System.IO.Path]::GetTempPath()' +WINRM_DEFAULT_TMP_DIR_PS = "[System.IO.Path]::GetTempPath()" # maximum cmdline length for systems >= Windows XP # https://support.microsoft.com/en-us/help/830473/command-prompt-cmd-exe-command-line-string-limitation WINRM_MAX_CMD_LENGTH = 8191 @@ -76,28 +76,28 @@ # Compiled list from the following sources: # https://ss64.com/ps/syntax-esc.html # https://www.techotopia.com/index.php/Windows_PowerShell_1.0_String_Quoting_and_Escape_Sequences#PowerShell_Special_Escape_Sequences -PS_ESCAPE_SEQUENCES = {'\n': '`n', - '\r': '`r', - '\t': '`t', - '\a': '`a', - '\b': '`b', - '\f': '`f', - '\v': '`v', - '"': '`"', - '\'': '`\'', - '`': '``', - '\0': '`0', - '$': '`$'} +PS_ESCAPE_SEQUENCES = { + "\n": "`n", + "\r": "`r", + "\t": "`t", + "\a": "`a", + "\b": "`b", + "\f": "`f", + "\v": "`v", + '"': '`"', + "'": "`'", + "`": "``", + "\0": "`0", + "$": "`$", +} class WinRmRunnerTimoutError(Exception): - def __init__(self, response): self.response = response class WinRmBaseRunner(ActionRunner): - def pre_run(self): super(WinRmBaseRunner, self).pre_run() @@ -107,12 +107,16 @@ def pre_run(self): self._username = self.runner_parameters[RUNNER_USERNAME] self._password = self.runner_parameters[RUNNER_PASSWORD] self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, DEFAULT_TIMEOUT) - self._read_timeout = self._timeout + 1 # read_timeout must be > operation_timeout + self._read_timeout = ( + self._timeout + 1 + ) # read_timeout must be > operation_timeout # default to https port 5986 over ntlm self._port = self.runner_parameters.get(RUNNER_PORT, DEFAULT_PORT) self._scheme = self.runner_parameters.get(RUNNER_SCHEME, DEFAULT_SCHEME) - self._transport = self.runner_parameters.get(RUNNER_TRANSPORT, DEFAULT_TRANSPORT) + self._transport = self.runner_parameters.get( + RUNNER_TRANSPORT, DEFAULT_TRANSPORT + ) # if connecting to the HTTP port then we must use "http" as the scheme # in the URL @@ -120,10 +124,14 @@ def pre_run(self): self._scheme = "http" # construct the URL for connecting to WinRM on the host - self._winrm_url = "{}://{}:{}/wsman".format(self._scheme, self._host, self._port) + self._winrm_url = "{}://{}:{}/wsman".format( + self._scheme, self._host, self._port + ) # default to verifying SSL certs - self._verify_ssl = self.runner_parameters.get(RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL) + self._verify_ssl = self.runner_parameters.get( + RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL + ) self._server_cert_validation = "validate" if self._verify_ssl else "ignore" # additional parameters @@ -136,12 +144,14 @@ def _get_session(self): # cache session (only create if it doesn't exist yet) if not self._session: LOG.debug("Connecting via WinRM to url: {}".format(self._winrm_url)) - self._session = Session(self._winrm_url, - auth=(self._username, self._password), - transport=self._transport, - server_cert_validation=self._server_cert_validation, - operation_timeout_sec=self._timeout, - read_timeout_sec=self._read_timeout) + self._session = Session( + self._winrm_url, + auth=(self._username, self._password), + transport=self._transport, + server_cert_validation=self._server_cert_validation, + operation_timeout_sec=self._timeout, + read_timeout_sec=self._read_timeout, + ) return self._session def _winrm_get_command_output(self, protocol, shell_id, command_id): @@ -154,37 +164,46 @@ def _winrm_get_command_output(self, protocol, shell_id, command_id): while not command_done: # check if we need to timeout (StackStorm custom) current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if self._timeout and (elapsed_time > self._timeout): - raise WinRmRunnerTimoutError(Response((b''.join(stdout_buffer), - b''.join(stderr_buffer), - WINRM_TIMEOUT_EXIT_CODE))) + raise WinRmRunnerTimoutError( + Response( + ( + b"".join(stdout_buffer), + b"".join(stderr_buffer), + WINRM_TIMEOUT_EXIT_CODE, + ) + ) + ) # end stackstorm custom try: - stdout, stderr, return_code, command_done = \ - protocol._raw_get_command_output(shell_id, command_id) + ( + stdout, + stderr, + return_code, + command_done, + ) = protocol._raw_get_command_output(shell_id, command_id) stdout_buffer.append(stdout) stderr_buffer.append(stderr) except WinRMOperationTimeoutError: # this is an expected error when waiting for a long-running process, # just silently retry pass - return b''.join(stdout_buffer), b''.join(stderr_buffer), return_code + return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None): # NOTE: this is copied from pywinrm because it doesn't support # passing env and working_directory from the Session.run_cmd. # It also doesn't support timeouts. All of these things have been # added - shell_id = session.protocol.open_shell(env_vars=env, - working_directory=cwd) + shell_id = session.protocol.open_shell(env_vars=env, working_directory=cwd) command_id = session.protocol.run_command(shell_id, command, args) # try/catch is for custom timeout handing (StackStorm custom) try: - rs = Response(self._winrm_get_command_output(session.protocol, - shell_id, - command_id)) + rs = Response( + self._winrm_get_command_output(session.protocol, shell_id, command_id) + ) rs.timeout = False except WinRmRunnerTimoutError as e: rs = e.response @@ -195,37 +214,34 @@ def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None): return rs def _winrm_encode(self, script): - return b64encode(script.encode('utf_16_le')).decode('ascii') + return b64encode(script.encode("utf_16_le")).decode("ascii") def _winrm_ps_cmd(self, encoded_ps): - return 'powershell -encodedcommand {0}'.format(encoded_ps) + return "powershell -encodedcommand {0}".format(encoded_ps) def _winrm_run_ps(self, session, script, env=None, cwd=None, is_b64=False): # NOTE: this is copied from pywinrm because it doesn't support # passing env and working_directory from the Session.run_ps # encode the script in UTF only if it isn't passed in encoded - LOG.debug('_winrm_run_ps() - script size = {}'.format(len(script))) + LOG.debug("_winrm_run_ps() - script size = {}".format(len(script))) encoded_ps = script if is_b64 else self._winrm_encode(script) ps_cmd = self._winrm_ps_cmd(encoded_ps) - LOG.debug('_winrm_run_ps() - ps cmd size = {}'.format(len(ps_cmd))) - rs = self._winrm_run_cmd(session, - ps_cmd, - env=env, - cwd=cwd) + LOG.debug("_winrm_run_ps() - ps cmd size = {}".format(len(ps_cmd))) + rs = self._winrm_run_cmd(session, ps_cmd, env=env, cwd=cwd) if len(rs.std_err): # if there was an error message, clean it it up and make it human # readable if isinstance(rs.std_err, bytes): # decode bytes into utf-8 because of a bug in pywinrm # real fix is here: https://github.com/diyan/pywinrm/pull/222/files - rs.std_err = rs.std_err.decode('utf-8') + rs.std_err = rs.std_err.decode("utf-8") rs.std_err = session._clean_error_msg(rs.std_err) return rs def _translate_response(self, response): # check exit status for errors - succeeded = (response.status_code == exit_code_constants.SUCCESS_EXIT_CODE) + succeeded = response.status_code == exit_code_constants.SUCCESS_EXIT_CODE status = action_constants.LIVEACTION_STATUS_SUCCEEDED status_code = response.status_code if response.timeout: @@ -236,39 +252,46 @@ def _translate_response(self, response): # create result result = { - 'failed': not succeeded, - 'succeeded': succeeded, - 'return_code': status_code, - 'stdout': response.std_out, - 'stderr': response.std_err + "failed": not succeeded, + "succeeded": succeeded, + "return_code": status_code, + "stdout": response.std_out, + "stderr": response.std_err, } # Ensure stdout and stderr is always a string - if isinstance(result['stdout'], six.binary_type): - result['stdout'] = result['stdout'].decode('utf-8') + if isinstance(result["stdout"], six.binary_type): + result["stdout"] = result["stdout"].decode("utf-8") - if isinstance(result['stderr'], six.binary_type): - result['stderr'] = result['stderr'].decode('utf-8') + if isinstance(result["stderr"], six.binary_type): + result["stderr"] = result["stderr"].decode("utf-8") # automatically convert result stdout/stderr from JSON strings to # objects so they can be used natively return (status, jsonify.json_loads(result, RESULT_KEYS_TO_TRANSFORM), None) def _make_tmp_dir(self, parent): - LOG.debug("Creating temporary directory for WinRM script in parent: {}".format(parent)) + LOG.debug( + "Creating temporary directory for WinRM script in parent: {}".format(parent) + ) ps = """$parent = {parent} $name = [System.IO.Path]::GetRandomFileName() $path = Join-Path $parent $name New-Item -ItemType Directory -Path $path | Out-Null -$path""".format(parent=parent) - result = self._run_ps_or_raise(ps, ("Unable to make temporary directory for" - " powershell script")) +$path""".format( + parent=parent + ) + result = self._run_ps_or_raise( + ps, ("Unable to make temporary directory for" " powershell script") + ) # strip to remove trailing newline and whitespace (if any) - return result['stdout'].strip() + return result["stdout"].strip() def _rm_dir(self, directory): ps = 'Remove-Item -Force -Recurse -Path "{}"'.format(directory) - self._run_ps_or_raise(ps, "Unable to remove temporary directory for powershell script") + self._run_ps_or_raise( + ps, "Unable to remove temporary directory for powershell script" + ) def _upload(self, src_path_or_data, dst_path): src_data = None @@ -276,7 +299,7 @@ def _upload(self, src_path_or_data, dst_path): # if this is a path, then read the data from the path if os.path.exists(src_path_or_data): LOG.debug("WinRM uploading local file: {}".format(src_path_or_data)) - with open(src_path_or_data, 'r') as src_file: + with open(src_path_or_data, "r") as src_file: src_data = src_file.read() else: LOG.debug("WinRM uploading data from a string") @@ -285,14 +308,19 @@ def _upload(self, src_path_or_data, dst_path): # upload the data in chunks such that each chunk doesn't exceed the # max command size of the windows command line for i in range(0, len(src_data), WINRM_UPLOAD_CHUNK_SIZE_BYTES): - LOG.debug("WinRM uploading data bytes: {}-{}". - format(i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES))) - self._upload_chunk(dst_path, src_data[i:(i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)]) + LOG.debug( + "WinRM uploading data bytes: {}-{}".format( + i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES) + ) + ) + self._upload_chunk( + dst_path, src_data[i : (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)] + ) def _upload_chunk(self, dst_path, src_data): # adapted from https://github.com/diyan/pywinrm/issues/18 if not isinstance(src_data, six.binary_type): - src_data = src_data.encode('utf-8') + src_data = src_data.encode("utf-8") ps = """$filePath = "{dst_path}" $s = @" @@ -300,10 +328,11 @@ def _upload_chunk(self, dst_path, src_data): "@ $data = [System.Convert]::FromBase64String($s) Add-Content -value $data -encoding byte -path $filePath -""".format(dst_path=dst_path, - b64_data=base64.b64encode(src_data).decode('utf-8')) +""".format( + dst_path=dst_path, b64_data=base64.b64encode(src_data).decode("utf-8") + ) - LOG.debug('WinRM uploading chunk, size = {}'.format(len(ps))) + LOG.debug("WinRM uploading chunk, size = {}".format(len(ps))) self._run_ps_or_raise(ps, "Failed to upload chunk of powershell script") @contextmanager @@ -335,7 +364,7 @@ def run_cmd(self, cmd): def run_ps(self, script, params=None): # temporary directory for the powershell script if params: - powershell = '& {%s} %s' % (script, params) + powershell = "& {%s} %s" % (script, params) else: powershell = script encoded_ps = self._winrm_encode(powershell) @@ -346,9 +375,12 @@ def run_ps(self, script, params=None): # else we need to upload the script to a temporary file and execute it, # then remove the temporary file if len(ps_cmd) <= WINRM_MAX_CMD_LENGTH: - LOG.info(("WinRM powershell command size {} is > {}, the max size of a" - " powershell command. Converting to a script execution.") - .format(WINRM_MAX_CMD_LENGTH, len(ps_cmd))) + LOG.info( + ( + "WinRM powershell command size {} is > {}, the max size of a" + " powershell command. Converting to a script execution." + ).format(WINRM_MAX_CMD_LENGTH, len(ps_cmd)) + ) return self._run_ps(encoded_ps, is_b64=True) else: return self._run_ps_script(script, params) @@ -360,8 +392,9 @@ def _run_ps(self, powershell, is_b64=False): # connect session = self._get_session() # execute - response = self._winrm_run_ps(session, powershell, env=self._env, cwd=self._cwd, - is_b64=is_b64) + response = self._winrm_run_ps( + session, powershell, env=self._env, cwd=self._cwd, is_b64=is_b64 + ) # create triplet from WinRM response return self._translate_response(response) @@ -383,12 +416,12 @@ def _run_ps_or_raise(self, ps, error_msg): response = self._run_ps(ps) # response is a tuple: (status, result, None) result = response[1] - if result['failed']: - raise RuntimeError(("{}:\n" - "stdout = {}\n\n" - "stderr = {}").format(error_msg, - result['stdout'], - result['stderr'])) + if result["failed"]: + raise RuntimeError( + ("{}:\n" "stdout = {}\n\n" "stderr = {}").format( + error_msg, result["stdout"], result["stderr"] + ) + ) return result def _multireplace(self, string, replacements): @@ -407,7 +440,7 @@ def _multireplace(self, string, replacements): substrs = sorted(replacements, key=len, reverse=True) # Create a big OR regex that matches any of the substrings to replace - regexp = re.compile('|'.join([re.escape(s) for s in substrs])) + regexp = re.compile("|".join([re.escape(s) for s in substrs])) # For each match, look up the new string in the replacements return regexp.sub(lambda match: replacements[match.group(0)], string) @@ -426,8 +459,12 @@ def _param_to_ps(self, param): ps_str += ")" elif isinstance(param, dict): ps_str = "@{" - ps_str += "; ".join([(self._param_to_ps(k) + ' = ' + self._param_to_ps(v)) - for k, v in six.iteritems(param)]) + ps_str += "; ".join( + [ + (self._param_to_ps(k) + " = " + self._param_to_ps(v)) + for k, v in six.iteritems(param) + ] + ) ps_str += "}" else: ps_str = str(param) @@ -446,12 +483,15 @@ def _transform_params_to_ps(self, positional_args, named_args): def create_ps_params_string(self, positional_args, named_args): # convert the script parameters into powershell strings - positional_args, named_args = self._transform_params_to_ps(positional_args, - named_args) + positional_args, named_args = self._transform_params_to_ps( + positional_args, named_args + ) # concatenate them into a long string ps_params_str = "" if named_args: - ps_params_str += " " .join([(k + " " + v) for k, v in six.iteritems(named_args)]) + ps_params_str += " ".join( + [(k + " " + v) for k, v in six.iteritems(named_args)] + ) ps_params_str += " " if positional_args: ps_params_str += " ".join(positional_args) diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py index d09e5ce7d6..1239f3efd5 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py @@ -20,19 +20,14 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmCommandRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RUNNER_COMMAND = 'cmd' +RUNNER_COMMAND = "cmd" class WinRmCommandRunner(WinRmBaseRunner): - def run(self, action_parameters): cmd_command = self.runner_parameters[RUNNER_COMMAND] @@ -45,7 +40,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py index f49db2b09e..e6d0a37e2f 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py @@ -20,19 +20,14 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmPsCommandRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmPsCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RUNNER_COMMAND = 'cmd' +RUNNER_COMMAND = "cmd" class WinRmPsCommandRunner(WinRmBaseRunner): - def run(self, action_parameters): powershell_command = self.runner_parameters[RUNNER_COMMAND] @@ -45,7 +40,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py index 9f156bd8c9..ff162b7aee 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py @@ -21,23 +21,18 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmPsScriptRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmPsScriptRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) class WinRmPsScriptRunner(WinRmBaseRunner, ShellRunnerMixin): - def run(self, action_parameters): if not self.entry_point: - raise ValueError('Missing entry_point action metadata attribute') + raise ValueError("Missing entry_point action metadata attribute") # read in the script contents from the local file - with open(self.entry_point, 'r') as script_file: + with open(self.entry_point, "r") as script_file: ps_script = script_file.read() # extract script parameters specified in the action metadata file @@ -57,7 +52,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/lint-configs/python/.flake8 b/lint-configs/python/.flake8 index f3cc01b319..4edeebe162 100644 --- a/lint-configs/python/.flake8 +++ b/lint-configs/python/.flake8 @@ -2,7 +2,10 @@ max-line-length = 100 # L102 - apache license header enable-extensions = L101,L102 -ignore = E128,E402,E722,W504 +# We ignore some rules which conflict with black +# E203 - whitespace before ':' - in direct conflict with black rule +# W503 line break before binary operator - in direct conflict with black rule +ignore = E128,E402,E722,W504,E203,W503 exclude=*.egg/*,build,dist # Configuration for flake8-copyright extension diff --git a/pylint_plugins/api_models.py b/pylint_plugins/api_models.py index 398a664d40..4e14095f71 100644 --- a/pylint_plugins/api_models.py +++ b/pylint_plugins/api_models.py @@ -29,9 +29,7 @@ from astroid import scoped_nodes # A list of class names for which we want to skip the checks -CLASS_NAME_BLACKLIST = [ - 'ExecutionSpecificationAPI' -] +CLASS_NAME_BLACKLIST = ["ExecutionSpecificationAPI"] def register(linter): @@ -42,11 +40,11 @@ def transform(cls): if cls.name in CLASS_NAME_BLACKLIST: return - if cls.name.endswith('API') or 'schema' in cls.locals: + if cls.name.endswith("API") or "schema" in cls.locals: # This is a class which defines attributes in "schema" variable using json schema. # Those attributes are then assigned during run time inside the constructor fqdn = cls.qname() - module_name, class_name = fqdn.rsplit('.', 1) + module_name, class_name = fqdn.rsplit(".", 1) module = __import__(module_name, fromlist=[class_name]) actual_cls = getattr(module, class_name) @@ -57,29 +55,31 @@ def transform(cls): # Not a class we are interested in return - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - property_name = property_name.replace('-', '_') # Note: We do the same in Python code - property_type = property_data.get('type', None) + property_name = property_name.replace( + "-", "_" + ) # Note: We do the same in Python code + property_type = property_data.get("type", None) if isinstance(property_type, (list, tuple)): # Hack for attributes with multiple types (e.g. string, null) property_type = property_type[0] - if property_type == 'object': + if property_type == "object": node = nodes.Dict() - elif property_type == 'array': + elif property_type == "array": node = nodes.List() - elif property_type == 'integer': - node = scoped_nodes.builtin_lookup('int')[1][0] - elif property_type == 'number': - node = scoped_nodes.builtin_lookup('float')[1][0] - elif property_type == 'string': - node = scoped_nodes.builtin_lookup('str')[1][0] - elif property_type == 'boolean': - node = scoped_nodes.builtin_lookup('bool')[1][0] - elif property_type == 'null': - node = scoped_nodes.builtin_lookup('None')[1][0] + elif property_type == "integer": + node = scoped_nodes.builtin_lookup("int")[1][0] + elif property_type == "number": + node = scoped_nodes.builtin_lookup("float")[1][0] + elif property_type == "string": + node = scoped_nodes.builtin_lookup("str")[1][0] + elif property_type == "boolean": + node = scoped_nodes.builtin_lookup("bool")[1][0] + elif property_type == "null": + node = scoped_nodes.builtin_lookup("None")[1][0] else: # Unknown type node = astroid.ClassDef(property_name, None) diff --git a/pylint_plugins/db_models.py b/pylint_plugins/db_models.py index 241e9ea582..da9251462e 100644 --- a/pylint_plugins/db_models.py +++ b/pylint_plugins/db_models.py @@ -23,8 +23,7 @@ from astroid import nodes # A list of class names for which we want to skip the checks -CLASS_NAME_BLACKLIST = [ -] +CLASS_NAME_BLACKLIST = [] def register(linter): @@ -35,14 +34,14 @@ def transform(cls): if cls.name in CLASS_NAME_BLACKLIST: return - if cls.name == 'StormFoundationDB': + if cls.name == "StormFoundationDB": # _fields get added automagically by mongoengine - if '_fields' not in cls.locals: - cls.locals['_fields'] = [nodes.Dict()] + if "_fields" not in cls.locals: + cls.locals["_fields"] = [nodes.Dict()] - if cls.name.endswith('DB'): + if cls.name.endswith("DB"): # mongoengine explicitly declared "id" field on each class so we teach pylint about that - property_name = 'id' + property_name = "id" node = astroid.ClassDef(property_name, None) cls.locals[property_name] = [node] diff --git a/scripts/dist_utils.py b/scripts/dist_utils.py index ba73f554c6..c0af527b6b 100644 --- a/scripts/dist_utils.py +++ b/scripts/dist_utils.py @@ -47,17 +47,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -68,15 +68,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -85,10 +85,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -102,30 +104,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -135,8 +139,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -150,7 +154,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -159,14 +163,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/scripts/dist_utils_old.py b/scripts/dist_utils_old.py index 5dfadb1bef..da38f6edbf 100644 --- a/scripts/dist_utils_old.py +++ b/scripts/dist_utils_old.py @@ -35,17 +35,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" try: import pip from pip import __version__ as pip_version except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) try: @@ -57,28 +57,30 @@ try: from pip._internal.req.req_file import parse_requirements except ImportError as e: - print('Failed to import parse_requirements from pip: %s' % (text_type(e))) - print('Using pip: %s' % (str(pip_version))) + print("Failed to import parse_requirements from pip: %s" % (text_type(e))) + print("Using pip: %s" % (str(pip_version))) sys.exit(1) __all__ = [ - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) @@ -90,7 +92,7 @@ def fetch_requirements(requirements_file_path): reqs = [] for req in parse_requirements(requirements_file_path, session=False): # Note: req.url was used before 9.0.0 and req.link is used in all the recent versions - link = getattr(req, 'link', getattr(req, 'url', None)) + link = getattr(req, "link", getattr(req, "url", None)) if link: links.append(str(link)) reqs.append(str(req.req)) @@ -104,7 +106,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -113,14 +115,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/scripts/fixate-requirements.py b/scripts/fixate-requirements.py index dd5c8d2505..4277c986f8 100755 --- a/scripts/fixate-requirements.py +++ b/scripts/fixate-requirements.py @@ -43,18 +43,18 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 OSCWD = os.path.abspath(os.curdir) -GET_PIP = ' curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = " curl https://bootstrap.pypa.io/get-pip.py | python" try: import pip from pip import __version__ as pip_version except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) try: @@ -66,24 +66,43 @@ try: from pip._internal.req.req_file import parse_requirements except ImportError as e: - print('Failed to import parse_requirements from pip: %s' % (text_type(e))) - print('Using pip: %s' % (str(pip_version))) + print("Failed to import parse_requirements from pip: %s" % (text_type(e))) + print("Using pip: %s" % (str(pip_version))) sys.exit(1) def parse_args(): - parser = argparse.ArgumentParser(description='Tool for requirements.txt generation.') - parser.add_argument('-s', '--source-requirements', nargs='+', - required=True, - help='Specify paths to requirements file(s). ' - 'In case several requirements files are given their content is merged.') - parser.add_argument('-f', '--fixed-requirements', required=True, - help='Specify path to fixed-requirements.txt file.') - parser.add_argument('-o', '--output-file', default='requirements.txt', - help='Specify path to the resulting requirements file.') - parser.add_argument('--skip', default=None, - help=('Comma delimited list of requirements to not ' - 'include in the generated file.')) + parser = argparse.ArgumentParser( + description="Tool for requirements.txt generation." + ) + parser.add_argument( + "-s", + "--source-requirements", + nargs="+", + required=True, + help="Specify paths to requirements file(s). " + "In case several requirements files are given their content is merged.", + ) + parser.add_argument( + "-f", + "--fixed-requirements", + required=True, + help="Specify path to fixed-requirements.txt file.", + ) + parser.add_argument( + "-o", + "--output-file", + default="requirements.txt", + help="Specify path to the resulting requirements file.", + ) + parser.add_argument( + "--skip", + default=None, + help=( + "Comma delimited list of requirements to not " + "include in the generated file." + ), + ) if len(sys.argv) < 2: parser.print_help() sys.exit(1) @@ -91,9 +110,11 @@ def parse_args(): def check_pip_version(): - if StrictVersion(pip.__version__) < StrictVersion('6.1.0'): - print("Upgrade pip, your version `{0}' " - "is outdated:\n".format(pip.__version__), GET_PIP) + if StrictVersion(pip.__version__) < StrictVersion("6.1.0"): + print( + "Upgrade pip, your version `{0}' " "is outdated:\n".format(pip.__version__), + GET_PIP, + ) sys.exit(1) @@ -129,13 +150,14 @@ def merge_source_requirements(sources): elif req.link: merged_requirements.append(req) else: - raise RuntimeError('Unexpected requirement {0}'.format(req)) + raise RuntimeError("Unexpected requirement {0}".format(req)) return merged_requirements -def write_requirements(sources=None, fixed_requirements=None, output_file=None, - skip=None): +def write_requirements( + sources=None, fixed_requirements=None, output_file=None, skip=None +): """ Write resulting requirements taking versions from the fixed_requirements. """ @@ -153,7 +175,9 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, continue if project_name in fixedreq_hash: - raise ValueError('Duplicate definition for dependency "%s"' % (project_name)) + raise ValueError( + 'Duplicate definition for dependency "%s"' % (project_name) + ) fixedreq_hash[project_name] = req @@ -169,7 +193,7 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, rline = str(req.link) if req.editable: - rline = '-e %s' % (rline) + rline = "-e %s" % (rline) elif req.req: project = req.name req_obj = fixedreq_hash.get(project, req) @@ -184,30 +208,40 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, # Sort the lines to guarantee a stable order lines_to_write = sorted(lines_to_write) - data = '\n'.join(lines_to_write) + '\n' - with open(output_file, 'w') as fp: - fp.write('# Don\'t edit this file. It\'s generated automatically!\n') - fp.write('# If you want to update global dependencies, modify fixed-requirements.txt\n') - fp.write('# and then run \'make requirements\' to update requirements.txt for all\n') - fp.write('# components.\n') - fp.write('# If you want to update depdencies for a single component, modify the\n') - fp.write('# in-requirements.txt for that component and then run \'make requirements\' to\n') - fp.write('# update the component requirements.txt\n') + data = "\n".join(lines_to_write) + "\n" + with open(output_file, "w") as fp: + fp.write("# Don't edit this file. It's generated automatically!\n") + fp.write( + "# If you want to update global dependencies, modify fixed-requirements.txt\n" + ) + fp.write( + "# and then run 'make requirements' to update requirements.txt for all\n" + ) + fp.write("# components.\n") + fp.write( + "# If you want to update depdencies for a single component, modify the\n" + ) + fp.write( + "# in-requirements.txt for that component and then run 'make requirements' to\n" + ) + fp.write("# update the component requirements.txt\n") fp.write(data) - print('Requirements written to: {0}'.format(output_file)) + print("Requirements written to: {0}".format(output_file)) -if __name__ == '__main__': +if __name__ == "__main__": check_pip_version() args = parse_args() - if args['skip']: - skip = args['skip'].split(',') + if args["skip"]: + skip = args["skip"].split(",") else: skip = None - write_requirements(sources=args['source_requirements'], - fixed_requirements=args['fixed_requirements'], - output_file=args['output_file'], - skip=skip) + write_requirements( + sources=args["source_requirements"], + fixed_requirements=args["fixed_requirements"], + output_file=args["output_file"], + skip=skip, + ) diff --git a/st2actions/dist_utils.py b/st2actions/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2actions/dist_utils.py +++ b/st2actions/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2actions/setup.py b/st2actions/setup.py index 6fcb2cde92..a4e8c12790 100644 --- a/st2actions/setup.py +++ b/st2actions/setup.py @@ -23,9 +23,9 @@ from dist_utils import apply_vagrant_workaround from st2actions import __version__ -ST2_COMPONENT = 'st2actions' +ST2_COMPONENT = "st2actions" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,21 +33,23 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2actionrunner', - 'bin/st2notifier', - 'bin/st2workflowengine', - 'bin/st2scheduler', - ] + "bin/st2actionrunner", + "bin/st2notifier", + "bin/st2workflowengine", + "bin/st2scheduler", + ], ) diff --git a/st2actions/st2actions/__init__.py b/st2actions/st2actions/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2actions/st2actions/__init__.py +++ b/st2actions/st2actions/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2actions/st2actions/cmd/actionrunner.py b/st2actions/st2actions/cmd/actionrunner.py index 457bf45e03..6aa339115a 100644 --- a/st2actions/st2actions/cmd/actionrunner.py +++ b/st2actions/st2actions/cmd/actionrunner.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -30,15 +31,12 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup_sigterm_handler(): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and allow for component cleanup. sys.exit(0) @@ -49,18 +47,22 @@ def sigterm_handler(signum=None, frame=None): def _setup(): - capabilities = { - 'name': 'actionrunner', - 'type': 'passive' - } - common_setup(service='actionrunner', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "actionrunner", "type": "passive"} + common_setup( + service="actionrunner", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) _setup_sigterm_handler() def _run_worker(): - LOG.info('(PID=%s) Worker started.', os.getpid()) + LOG.info("(PID=%s) Worker started.", os.getpid()) action_worker = worker.get_worker() @@ -68,20 +70,20 @@ def _run_worker(): action_worker.start() action_worker.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Worker stopped.', os.getpid()) + LOG.info("(PID=%s) Worker stopped.", os.getpid()) errors = False try: action_worker.shutdown() except: - LOG.exception('Unable to shutdown worker.') + LOG.exception("Unable to shutdown worker.") errors = True if errors: return 1 except: - LOG.exception('(PID=%s) Worker unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Worker unexpectedly stopped.", os.getpid()) return 1 return 0 @@ -98,7 +100,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Worker quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Worker quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/scheduler.py b/st2actions/st2actions/cmd/scheduler.py index b3c972b654..df6dd768db 100644 --- a/st2actions/st2actions/cmd/scheduler.py +++ b/st2actions/st2actions/cmd/scheduler.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -28,9 +29,7 @@ from st2common.service_setup import teardown as common_teardown from st2common.service_setup import setup as common_setup -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -46,23 +45,27 @@ def sigterm_handler(signum=None, frame=None): def _setup(): - capabilities = { - 'name': 'scheduler', - 'type': 'passive' - } - common_setup(service='scheduler', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "scheduler", "type": "passive"} + common_setup( + service="scheduler", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) _setup_sigterm_handler() def _run_scheduler(): - LOG.info('(PID=%s) Scheduler started.', os.getpid()) + LOG.info("(PID=%s) Scheduler started.", os.getpid()) # Lazy load these so that decorator metrics are in place from st2actions.scheduler import ( handler as scheduler_handler, - entrypoint as scheduler_entrypoint + entrypoint as scheduler_entrypoint, ) handler = scheduler_handler.get_handler() @@ -73,14 +76,18 @@ def _run_scheduler(): try: handler._cleanup_policy_delayed() except Exception: - LOG.exception('(PID=%s) Scheduler unable to perform migration cleanup.', os.getpid()) + LOG.exception( + "(PID=%s) Scheduler unable to perform migration cleanup.", os.getpid() + ) # TODO: Remove this try block for _fix_missing_action_execution_id in v3.2. # This is a temporary fix to auto-populate action_execution_id. try: handler._fix_missing_action_execution_id() except Exception: - LOG.exception('(PID=%s) Scheduler unable to populate action_execution_id.', os.getpid()) + LOG.exception( + "(PID=%s) Scheduler unable to populate action_execution_id.", os.getpid() + ) try: handler.start() @@ -89,7 +96,7 @@ def _run_scheduler(): # Wait on handler first since entrypoint is more durable. handler.wait() or entrypoint.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Scheduler stopped.', os.getpid()) + LOG.info("(PID=%s) Scheduler stopped.", os.getpid()) errors = False @@ -97,13 +104,13 @@ def _run_scheduler(): handler.shutdown() entrypoint.shutdown() except: - LOG.exception('Unable to shutdown scheduler.') + LOG.exception("Unable to shutdown scheduler.") errors = True if errors: return 1 except: - LOG.exception('(PID=%s) Scheduler unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Scheduler unexpectedly stopped.", os.getpid()) try: handler.shutdown() @@ -127,7 +134,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Scheduler quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Scheduler quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/st2notifier.py b/st2actions/st2actions/cmd/st2notifier.py index fdf74f5bf1..7f1ccc7222 100644 --- a/st2actions/st2actions/cmd/st2notifier.py +++ b/st2actions/st2actions/cmd/st2notifier.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -27,29 +28,31 @@ from st2actions.notifier import config from st2actions.notifier import notifier -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup(): - capabilities = { - 'name': 'notifier', - 'type': 'passive' - } - common_setup(service='notifier', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "notifier", "type": "passive"} + common_setup( + service="notifier", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) def _run_worker(): - LOG.info('(PID=%s) Actions notifier started.', os.getpid()) + LOG.info("(PID=%s) Actions notifier started.", os.getpid()) actions_notifier = notifier.get_notifier() try: actions_notifier.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Actions notifier stopped.', os.getpid()) + LOG.info("(PID=%s) Actions notifier stopped.", os.getpid()) actions_notifier.shutdown() return 0 @@ -65,7 +68,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Results tracker quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Results tracker quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/workflow_engine.py b/st2actions/st2actions/cmd/workflow_engine.py index 361d6ce9e1..f51296b4b0 100644 --- a/st2actions/st2actions/cmd/workflow_engine.py +++ b/st2actions/st2actions/cmd/workflow_engine.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -32,15 +33,12 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def setup_sigterm_handler(): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and allow for component cleanup. sys.exit(0) @@ -51,35 +49,32 @@ def sigterm_handler(signum=None, frame=None): def setup(): - capabilities = { - 'name': 'workflowengine', - 'type': 'passive' - } + capabilities = {"name": "workflowengine", "type": "passive"} common_setup( - service='workflow_engine', + service="workflow_engine", config=config, setup_db=True, register_mq_exchanges=True, register_signal_handlers=True, service_registry=True, - capabilities=capabilities + capabilities=capabilities, ) setup_sigterm_handler() def run_server(): - LOG.info('(PID=%s) Workflow engine started.', os.getpid()) + LOG.info("(PID=%s) Workflow engine started.", os.getpid()) engine = workflows.get_engine() try: engine.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Workflow engine stopped.', os.getpid()) + LOG.info("(PID=%s) Workflow engine stopped.", os.getpid()) engine.shutdown() except: - LOG.exception('(PID=%s) Workflow engine unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Workflow engine unexpectedly stopped.", os.getpid()) return 1 return 0 @@ -97,7 +92,7 @@ def main(): sys.exit(exit_code) except Exception: traceback.print_exc() - LOG.exception('(PID=%s) Workflow engine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Workflow engine quit due to exception.", os.getpid()) return 1 finally: teardown() diff --git a/st2actions/st2actions/config.py b/st2actions/st2actions/config.py index b4e83a5306..14dc2c4f58 100644 --- a/st2actions/st2actions/config.py +++ b/st2actions/st2actions/config.py @@ -28,8 +28,11 @@ def parse_args(args=None): - CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): diff --git a/st2actions/st2actions/container/base.py b/st2actions/st2actions/container/base.py index a350a3dd69..7f2b50f0c7 100644 --- a/st2actions/st2actions/container/base.py +++ b/st2actions/st2actions/container/base.py @@ -30,8 +30,8 @@ from st2common.models.system.action import ResolvedActionParameters from st2common.persistence.execution import ActionExecution from st2common.services import access, executions, queries -from st2common.util.action_db import (get_action_by_ref, get_runnertype_by_name) -from st2common.util.action_db import (update_liveaction_status, get_liveaction_by_id) +from st2common.util.action_db import get_action_by_ref, get_runnertype_by_name +from st2common.util.action_db import update_liveaction_status, get_liveaction_by_id from st2common.util import param as param_utils from st2common.util.config_loader import ContentPackConfigLoader from st2common.metrics.base import CounterWithTimer @@ -42,30 +42,28 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'RunnerContainer', - 'get_runner_container' -] +__all__ = ["RunnerContainer", "get_runner_container"] class RunnerContainer(object): - def dispatch(self, liveaction_db): action_db = get_action_by_ref(liveaction_db.action) if not action_db: - raise Exception('Action %s not found in DB.' % (liveaction_db.action)) + raise Exception("Action %s not found in DB." % (liveaction_db.action)) - liveaction_db.context['pack'] = action_db.pack + liveaction_db.context["pack"] = action_db.pack - runner_type_db = get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = get_runnertype_by_name(action_db.runner_type["name"]) - extra = {'liveaction_db': liveaction_db, 'runner_type_db': runner_type_db} - LOG.info('Dispatching Action to a runner', extra=extra) + extra = {"liveaction_db": liveaction_db, "runner_type_db": runner_type_db} + LOG.info("Dispatching Action to a runner", extra=extra) # Get runner instance. runner = self._get_runner(runner_type_db, action_db, liveaction_db) - LOG.debug('Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner) + LOG.debug( + 'Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner + ) # Process the request. funcs = { @@ -74,12 +72,12 @@ def dispatch(self, liveaction_db): action_constants.LIVEACTION_STATUS_RUNNING: self._do_run, action_constants.LIVEACTION_STATUS_CANCELING: self._do_cancel, action_constants.LIVEACTION_STATUS_PAUSING: self._do_pause, - action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume + action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume, } if liveaction_db.status not in funcs: raise actionrunner.ActionRunnerDispatchError( - 'Action runner is unable to dispatch the liveaction because it is ' + "Action runner is unable to dispatch the liveaction because it is " 'in an unsupported status of "%s".' % liveaction_db.status ) @@ -94,7 +92,8 @@ def _do_run(self, runner): runner.auth_token = self._create_auth_token( context=runner.context, action_db=runner.action, - liveaction_db=runner.liveaction) + liveaction_db=runner.liveaction, + ) try: # Finalized parameters are resolved and then rendered. This process could @@ -104,13 +103,14 @@ def _do_run(self, runner): runner.runner_type.runner_parameters, runner.action.parameters, runner.liveaction.parameters, - runner.liveaction.context) + runner.liveaction.context, + ) runner.runner_parameters = runner_params except ParamException as e: raise actionrunner.ActionRunnerException(six.text_type(e)) - LOG.debug('Performing pre-run for runner: %s', runner.runner_id) + LOG.debug("Performing pre-run for runner: %s", runner.runner_id) runner.pre_run() # Mask secret parameters in the log context @@ -118,90 +118,117 @@ def _do_run(self, runner): action_db=runner.action, runner_type_db=runner.runner_type, runner_parameters=runner_params, - action_parameters=action_params) + action_parameters=action_params, + ) - extra = {'runner': runner, 'parameters': resolved_action_params} - LOG.debug('Performing run for runner: %s' % (runner.runner_id), extra=extra) + extra = {"runner": runner, "parameters": resolved_action_params} + LOG.debug("Performing run for runner: %s" % (runner.runner_id), extra=extra) - with CounterWithTimer(key='action.executions'): - with CounterWithTimer(key='action.%s.executions' % (runner.action.ref)): + with CounterWithTimer(key="action.executions"): + with CounterWithTimer(key="action.%s.executions" % (runner.action.ref)): (status, result, context) = runner.run(action_params) result = jsonify.try_loads(result) action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES - if (isinstance(runner, PollingAsyncActionRunner) and - runner.is_polling_enabled() and not action_completed): + if ( + isinstance(runner, PollingAsyncActionRunner) + and runner.is_polling_enabled() + and not action_completed + ): queries.setup_query(runner.liveaction.id, runner.runner_type, context) except: - LOG.exception('Failed to run action.') + LOG.exception("Failed to run action.") _, ex, tb = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = None finally: # Log action completion - extra = {'result': result, 'status': status} + extra = {"result": result, "status": status} LOG.debug('Action "%s" completed.' % (runner.action.name), extra=extra) # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) + LOG.debug("Performing post_run for runner: %s", runner.runner_id) runner.post_run(status=status, result=result) - LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result}) - LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction}) + LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result}) + LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction}) return runner.liveaction def _do_cancel(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing cancel for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing cancel for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.cancel() # Update the final status of liveaction and corresponding action execution. # The status is updated here because we want to keep the workflow running # as is if the cancel operation failed. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} - LOG.exception('Failed to cancel action %s.' % (runner.liveaction.id), extra=result) + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } + LOG.exception( + "Failed to cancel action %s." % (runner.liveaction.id), extra=result + ) finally: # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) - result = {'error': 'Execution canceled by user.'} + LOG.debug("Performing post_run for runner: %s", runner.runner_id) + result = {"error": "Execution canceled by user."} runner.post_run(status=runner.liveaction.status, result=result) return runner.liveaction def _do_pause(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing pause for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing pause for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.pause() except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. status = action_constants.LIVEACTION_STATUS_FAILED - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = runner.liveaction.context - LOG.exception('Failed to pause action %s.' % (runner.liveaction.id), extra=result) + LOG.exception( + "Failed to pause action %s." % (runner.liveaction.id), extra=result + ) finally: # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) @@ -210,35 +237,47 @@ def _do_pause(self, runner): def _do_resume(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing resume for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing resume for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.resume() result = jsonify.try_loads(result) action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES - if (isinstance(runner, PollingAsyncActionRunner) and - runner.is_polling_enabled() and not action_completed): + if ( + isinstance(runner, PollingAsyncActionRunner) + and runner.is_polling_enabled() + and not action_completed + ): queries.setup_query(runner.liveaction.id, runner.runner_type, context) except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. status = action_constants.LIVEACTION_STATUS_FAILED - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = runner.liveaction.context - LOG.exception('Failed to resume action %s.' % (runner.liveaction.id), extra=result) + LOG.exception( + "Failed to resume action %s." % (runner.liveaction.id), extra=result + ) finally: # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) + LOG.debug("Performing post_run for runner: %s", runner.runner_id) runner.post_run(status=status, result=result) - LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result}) - LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction}) + LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result}) + LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction}) return runner.liveaction @@ -260,7 +299,7 @@ def _clean_up_auth_token(self, runner, status): try: self._delete_auth_token(runner.auth_token) except: - LOG.exception('Unable to clean-up auth_token.') + LOG.exception("Unable to clean-up auth_token.") return True @@ -273,8 +312,8 @@ def _update_live_action_db(self, liveaction_id, status, result, context): liveaction_db = get_liveaction_by_id(liveaction_id) state_changed = ( - liveaction_db.status != status and - liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES + liveaction_db.status != status + and liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES ) if status in action_constants.LIVEACTION_COMPLETED_STATES: @@ -287,64 +326,69 @@ def _update_live_action_db(self, liveaction_id, status, result, context): result=result, context=context, end_timestamp=end_timestamp, - liveaction_db=liveaction_db + liveaction_db=liveaction_db, ) return (liveaction_db, state_changed) def _update_status(self, liveaction_id, status, result, context): try: - LOG.debug('Setting status: %s for liveaction: %s', status, liveaction_id) + LOG.debug("Setting status: %s for liveaction: %s", status, liveaction_id) liveaction_db, state_changed = self._update_live_action_db( - liveaction_id, status, result, context) + liveaction_id, status, result, context + ) except Exception as e: LOG.exception( - 'Cannot update liveaction ' - '(id: %s, status: %s, result: %s).' % ( - liveaction_id, status, result) + "Cannot update liveaction " + "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result) ) raise e try: executions.update_execution(liveaction_db, publish=state_changed) - extra = {'liveaction_db': liveaction_db} - LOG.debug('Updated liveaction after run', extra=extra) + extra = {"liveaction_db": liveaction_db} + LOG.debug("Updated liveaction after run", extra=extra) except Exception as e: LOG.exception( - 'Cannot update action execution for liveaction ' - '(id: %s, status: %s, result: %s).' % ( - liveaction_id, status, result) + "Cannot update action execution for liveaction " + "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result) ) raise e return liveaction_db def _get_entry_point_abs_path(self, pack, entry_point): - return content_utils.get_entry_point_abs_path(pack=pack, entry_point=entry_point) + return content_utils.get_entry_point_abs_path( + pack=pack, entry_point=entry_point + ) def _get_action_libs_abs_path(self, pack, entry_point): - return content_utils.get_action_libs_abs_path(pack=pack, entry_point=entry_point) + return content_utils.get_action_libs_abs_path( + pack=pack, entry_point=entry_point + ) def _get_rerun_reference(self, context): - execution_id = context.get('re-run', {}).get('ref') + execution_id = context.get("re-run", {}).get("ref") return ActionExecution.get_by_id(execution_id) if execution_id else None def _get_runner(self, runner_type_db, action_db, liveaction_db): - resolved_entry_point = self._get_entry_point_abs_path(action_db.pack, action_db.entry_point) - context = getattr(liveaction_db, 'context', dict()) - user = context.get('user', cfg.CONF.system_user.user) + resolved_entry_point = self._get_entry_point_abs_path( + action_db.pack, action_db.entry_point + ) + context = getattr(liveaction_db, "context", dict()) + user = context.get("user", cfg.CONF.system_user.user) config = None # Note: Right now configs are only supported by the Python runner actions - if (runner_type_db.name == 'python-script' or - runner_type_db.runner_module == 'python_runner'): - LOG.debug('Loading config from pack for python runner.') + if ( + runner_type_db.name == "python-script" + or runner_type_db.runner_module == "python_runner" + ): + LOG.debug("Loading config from pack for python runner.") config_loader = ContentPackConfigLoader(pack_name=action_db.pack, user=user) config = config_loader.get_config() - runner = get_runner( - name=runner_type_db.name, - config=config) + runner = get_runner(name=runner_type_db.name, config=config) # TODO: Pass those arguments to the constructor instead of late # assignment, late assignment is awful @@ -357,13 +401,16 @@ def _get_runner(self, runner_type_db, action_db, liveaction_db): runner.execution_id = str(runner.execution.id) runner.entry_point = resolved_entry_point runner.context = context - runner.callback = getattr(liveaction_db, 'callback', dict()) - runner.libs_dir_path = self._get_action_libs_abs_path(action_db.pack, - action_db.entry_point) + runner.callback = getattr(liveaction_db, "callback", dict()) + runner.libs_dir_path = self._get_action_libs_abs_path( + action_db.pack, action_db.entry_point + ) # For re-run, get the ActionExecutionDB in which the re-run is based on. - rerun_ref_id = runner.context.get('re-run', {}).get('ref') - runner.rerun_ex_ref = ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None + rerun_ref_id = runner.context.get("re-run", {}).get("ref") + runner.rerun_ex_ref = ( + ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None + ) return runner @@ -371,19 +418,20 @@ def _create_auth_token(self, context, action_db, liveaction_db): if not context: return None - user = context.get('user', None) + user = context.get("user", None) if not user: return None metadata = { - 'service': 'actions_container', - 'action_name': action_db.name, - 'live_action_id': str(liveaction_db.id) - + "service": "actions_container", + "action_name": action_db.name, + "live_action_id": str(liveaction_db.id), } ttl = cfg.CONF.auth.service_token_ttl - token_db = access.create_token(username=user, ttl=ttl, metadata=metadata, service=True) + token_db = access.create_token( + username=user, ttl=ttl, metadata=metadata, service=True + ) return token_db def _delete_auth_token(self, auth_token): diff --git a/st2actions/st2actions/notifier/config.py b/st2actions/st2actions/notifier/config.py index 6c0162f310..0322179bbc 100644 --- a/st2actions/st2actions/notifier/config.py +++ b/st2actions/st2actions/notifier/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,11 +50,13 @@ def _register_common_opts(): def _register_notifier_opts(): notifier_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.notifier.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.notifier.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(notifier_opts, group='notifier') + CONF.register_opts(notifier_opts, group="notifier") register_opts() diff --git a/st2actions/st2actions/notifier/notifier.py b/st2actions/st2actions/notifier/notifier.py index 37db830e52..ea1a537733 100644 --- a/st2actions/st2actions/notifier/notifier.py +++ b/st2actions/st2actions/notifier/notifier.py @@ -42,22 +42,23 @@ from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX from st2common.constants.action import ACTION_PARAMETERS_KV_PREFIX from st2common.constants.action import ACTION_RESULTS_KV_PREFIX -from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, SYSTEM_SCOPE, DATASTORE_PARENT_SCOPE +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + SYSTEM_SCOPE, + DATASTORE_PARENT_SCOPE, +) from st2common.services.keyvalues import KeyValueLookup from st2common.transport.queues import NOTIFIER_ACTIONUPDATE_WORK_QUEUE from st2common.metrics.base import CounterWithTimer from st2common.metrics.base import Timer -__all__ = [ - 'Notifier', - 'get_notifier' -] +__all__ = ["Notifier", "get_notifier"] LOG = logging.getLogger(__name__) # XXX: Fix this nasty positional dependency. -ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0] -NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1] +ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0] +NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1] class Notifier(consumers.MessageHandler): @@ -69,35 +70,40 @@ def __init__(self, connection, queues, trigger_dispatcher=None): trigger_dispatcher = TriggerDispatcher(LOG) self._trigger_dispatcher = trigger_dispatcher self._notify_trigger = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER_TYPE['pack'], - name=NOTIFY_TRIGGER_TYPE['name']) + pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"] + ) self._action_trigger = ResourceReference.to_string_reference( - pack=ACTION_TRIGGER_TYPE['pack'], - name=ACTION_TRIGGER_TYPE['name']) + pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"] + ) - @CounterWithTimer(key='notifier.action.executions') + @CounterWithTimer(key="notifier.action.executions") def process(self, execution_db): execution_id = str(execution_db.id) - extra = {'execution': execution_db} + extra = {"execution": execution_db} LOG.debug('Processing action execution "%s".', execution_id, extra=extra) # Get the corresponding liveaction record. - liveaction_db = LiveAction.get_by_id(execution_db.liveaction['id']) + liveaction_db = LiveAction.get_by_id(execution_db.liveaction["id"]) if execution_db.status in LIVEACTION_COMPLETED_STATES: # If the action execution is executed under an orquesta workflow, policies for the # action execution will be applied by the workflow engine. A policy may affect the # final state of the action execution thereby impacting the state of the workflow. - if not workflow_service.is_action_execution_under_workflow_context(execution_db): - with CounterWithTimer(key='notifier.apply_post_run_policies'): + if not workflow_service.is_action_execution_under_workflow_context( + execution_db + ): + with CounterWithTimer(key="notifier.apply_post_run_policies"): policy_service.apply_post_run_policies(liveaction_db) if liveaction_db.notify: - with CounterWithTimer(key='notifier.notify_trigger.post'): - self._post_notify_triggers(liveaction_db=liveaction_db, - execution_db=execution_db) + with CounterWithTimer(key="notifier.notify_trigger.post"): + self._post_notify_triggers( + liveaction_db=liveaction_db, execution_db=execution_db + ) - self._post_generic_trigger(liveaction_db=liveaction_db, execution_db=execution_db) + self._post_generic_trigger( + liveaction_db=liveaction_db, execution_db=execution_db + ) def _get_execution_for_liveaction(self, liveaction): execution = ActionExecution.get(liveaction__id=str(liveaction.id)) @@ -108,39 +114,52 @@ def _get_execution_for_liveaction(self, liveaction): return execution def _post_notify_triggers(self, liveaction_db=None, execution_db=None): - notify = getattr(liveaction_db, 'notify', None) + notify = getattr(liveaction_db, "notify", None) if not notify: return if notify.on_complete: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_complete, - default_message_suffix='completed.') + default_message_suffix="completed.", + ) if liveaction_db.status == LIVEACTION_STATUS_SUCCEEDED and notify.on_success: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_success, - default_message_suffix='succeeded.') + default_message_suffix="succeeded.", + ) if liveaction_db.status in LIVEACTION_FAILED_STATES and notify.on_failure: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_failure, - default_message_suffix='failed.') + default_message_suffix="failed.", + ) - def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None, - notify_subsection=None, - default_message_suffix=None): - routes = (getattr(notify_subsection, 'routes') or - getattr(notify_subsection, 'channels', [])) or [] + def _post_notify_subsection_triggers( + self, + liveaction_db=None, + execution_db=None, + notify_subsection=None, + default_message_suffix=None, + ): + routes = ( + getattr(notify_subsection, "routes") + or getattr(notify_subsection, "channels", []) + ) or [] execution_id = str(execution_db.id) if routes and len(routes) >= 1: payload = {} message = notify_subsection.message or ( - 'Action ' + liveaction_db.action + ' ' + default_message_suffix) + "Action " + liveaction_db.action + " " + default_message_suffix + ) data = notify_subsection.data or {} jinja_context = self._build_jinja_context( @@ -148,17 +167,18 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None ) try: - with Timer(key='notifier.transform_message'): - message = self._transform_message(message=message, - context=jinja_context) + with Timer(key="notifier.transform_message"): + message = self._transform_message( + message=message, context=jinja_context + ) except: - LOG.exception('Failed (Jinja) transforming `message`.') + LOG.exception("Failed (Jinja) transforming `message`.") try: - with Timer(key='notifier.transform_data'): + with Timer(key="notifier.transform_data"): data = self._transform_data(data=data, context=jinja_context) except: - LOG.exception('Failed (Jinja) transforming `data`.') + LOG.exception("Failed (Jinja) transforming `data`.") # At this point convert result to a string. This restricts the rulesengines # ability to introspect the result. On the other handle atleast a json usable @@ -166,69 +186,82 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None # to a string representation it uses str(...) which make it impossible to # parse the result as json any longer. # TODO: Use to_serializable_dict - data['result'] = json.dumps(liveaction_db.result) + data["result"] = json.dumps(liveaction_db.result) - payload['message'] = message - payload['data'] = data - payload['execution_id'] = execution_id - payload['status'] = liveaction_db.status - payload['start_timestamp'] = isotime.format(liveaction_db.start_timestamp) + payload["message"] = message + payload["data"] = data + payload["execution_id"] = execution_id + payload["status"] = liveaction_db.status + payload["start_timestamp"] = isotime.format(liveaction_db.start_timestamp) try: - payload['end_timestamp'] = isotime.format(liveaction_db.end_timestamp) + payload["end_timestamp"] = isotime.format(liveaction_db.end_timestamp) except AttributeError: # This can be raised if liveaction.end_timestamp is None, which is caused # when policy cancels a request due to concurrency # In this case, use datetime.now() instead - payload['end_timestamp'] = isotime.format(datetime.utcnow()) + payload["end_timestamp"] = isotime.format(datetime.utcnow()) - payload['action_ref'] = liveaction_db.action - payload['runner_ref'] = self._get_runner_ref(liveaction_db.action) + payload["action_ref"] = liveaction_db.action + payload["runner_ref"] = self._get_runner_ref(liveaction_db.action) trace_context = self._get_trace_context(execution_id=execution_id) failed_routes = [] for route in routes: try: - payload['route'] = route + payload["route"] = route # Deprecated. Only for backward compatibility reasons. - payload['channel'] = route - LOG.debug('POSTing %s for %s. Payload - %s.', NOTIFY_TRIGGER_TYPE['name'], - liveaction_db.id, payload) - - with CounterWithTimer(key='notifier.notify_trigger.dispatch'): - self._trigger_dispatcher.dispatch(self._notify_trigger, payload=payload, - trace_context=trace_context) + payload["channel"] = route + LOG.debug( + "POSTing %s for %s. Payload - %s.", + NOTIFY_TRIGGER_TYPE["name"], + liveaction_db.id, + payload, + ) + + with CounterWithTimer(key="notifier.notify_trigger.dispatch"): + self._trigger_dispatcher.dispatch( + self._notify_trigger, + payload=payload, + trace_context=trace_context, + ) except: failed_routes.append(route) if len(failed_routes) > 0: - raise Exception('Failed notifications to routes: %s' % ', '.join(failed_routes)) + raise Exception( + "Failed notifications to routes: %s" % ", ".join(failed_routes) + ) def _build_jinja_context(self, liveaction_db, execution_db): context = {} - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) context.update({ACTION_PARAMETERS_KV_PREFIX: liveaction_db.parameters}) context.update({ACTION_CONTEXT_KV_PREFIX: liveaction_db.context}) context.update({ACTION_RESULTS_KV_PREFIX: execution_db.result}) return context def _transform_message(self, message, context=None): - mapping = {'message': message} + mapping = {"message": message} context = context or {} - return (jinja_utils.render_values(mapping=mapping, context=context)).get('message', - message) + return (jinja_utils.render_values(mapping=mapping, context=context)).get( + "message", message + ) def _transform_data(self, data, context=None): return jinja_utils.render_values(mapping=data, context=context) def _get_trace_context(self, execution_id): trace_db = trace_service.get_trace_db_by_action_execution( - action_execution_id=execution_id) + action_execution_id=execution_id + ) if trace_db: return TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag) # If no trace_context is found then do not create a new one here. If necessary @@ -237,38 +270,48 @@ def _get_trace_context(self, execution_id): def _post_generic_trigger(self, liveaction_db=None, execution_db=None): if not cfg.CONF.action_sensor.enable: - LOG.debug('Action trigger is disabled, skipping trigger dispatch...') + LOG.debug("Action trigger is disabled, skipping trigger dispatch...") return execution_id = str(execution_db.id) - extra = {'execution': execution_db} + extra = {"execution": execution_db} target_statuses = cfg.CONF.action_sensor.emit_when if execution_db.status not in target_statuses: msg = 'Skip action execution "%s" because state "%s" is not in %s' - LOG.debug(msg % (execution_id, execution_db.status, target_statuses), extra=extra) + LOG.debug( + msg % (execution_id, execution_db.status, target_statuses), extra=extra + ) return - with CounterWithTimer(key='notifier.generic_trigger.post'): - payload = {'execution_id': execution_id, - 'status': liveaction_db.status, - 'start_timestamp': str(liveaction_db.start_timestamp), - # deprecate 'action_name' at some point and switch to 'action_ref' - 'action_name': liveaction_db.action, - 'action_ref': liveaction_db.action, - 'runner_ref': self._get_runner_ref(liveaction_db.action), - 'parameters': liveaction_db.get_masked_parameters(), - 'result': liveaction_db.result} + with CounterWithTimer(key="notifier.generic_trigger.post"): + payload = { + "execution_id": execution_id, + "status": liveaction_db.status, + "start_timestamp": str(liveaction_db.start_timestamp), + # deprecate 'action_name' at some point and switch to 'action_ref' + "action_name": liveaction_db.action, + "action_ref": liveaction_db.action, + "runner_ref": self._get_runner_ref(liveaction_db.action), + "parameters": liveaction_db.get_masked_parameters(), + "result": liveaction_db.result, + } # Use execution_id to extract trace rather than liveaction. execution_id # will look-up an exact TraceDB while liveaction depending on context # may not end up going to the DB. trace_context = self._get_trace_context(execution_id=execution_id) - LOG.debug('POSTing %s for %s. Payload - %s. TraceContext - %s', - ACTION_TRIGGER_TYPE['name'], liveaction_db.id, payload, trace_context) + LOG.debug( + "POSTing %s for %s. Payload - %s. TraceContext - %s", + ACTION_TRIGGER_TYPE["name"], + liveaction_db.id, + payload, + trace_context, + ) - with CounterWithTimer(key='notifier.generic_trigger.dispatch'): - self._trigger_dispatcher.dispatch(self._action_trigger, payload=payload, - trace_context=trace_context) + with CounterWithTimer(key="notifier.generic_trigger.dispatch"): + self._trigger_dispatcher.dispatch( + self._action_trigger, payload=payload, trace_context=trace_context + ) def _get_runner_ref(self, action_ref): """ @@ -277,10 +320,13 @@ def _get_runner_ref(self, action_ref): :rtype: ``str`` """ action = Action.get_by_ref(action_ref) - return action['runner_type']['name'] + return action["runner_type"]["name"] def get_notifier(): with transport_utils.get_connection() as conn: - return Notifier(conn, [NOTIFIER_ACTIONUPDATE_WORK_QUEUE], - trigger_dispatcher=TriggerDispatcher(LOG)) + return Notifier( + conn, + [NOTIFIER_ACTIONUPDATE_WORK_QUEUE], + trigger_dispatcher=TriggerDispatcher(LOG), + ) diff --git a/st2actions/st2actions/policies/concurrency.py b/st2actions/st2actions/policies/concurrency.py index 4f98b093c7..cf47ed0b69 100644 --- a/st2actions/st2actions/policies/concurrency.py +++ b/st2actions/st2actions/policies/concurrency.py @@ -22,53 +22,64 @@ from st2common.services import action as action_service -__all__ = [ - 'ConcurrencyApplicator' -] +__all__ = ["ConcurrencyApplicator"] LOG = logging.getLogger(__name__) class ConcurrencyApplicator(BaseConcurrencyApplicator): - - def __init__(self, policy_ref, policy_type, threshold=0, action='delay'): - super(ConcurrencyApplicator, self).__init__(policy_ref=policy_ref, policy_type=policy_type, - threshold=threshold, - action=action) + def __init__(self, policy_ref, policy_type, threshold=0, action="delay"): + super(ConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=threshold, + action=action, + ) def _get_lock_uid(self, target): - values = {'policy_type': self._policy_type, 'action': target.action} + values = {"policy_type": self._policy_type, "action": target.action} return self._get_lock_name(values=values) def _apply_before(self, target): # Get the count of scheduled instances of the action. scheduled = action_access.LiveAction.count( - action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED) + action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED + ) # Get the count of running instances of the action. running = action_access.LiveAction.count( - action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING) + action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING + ) count = scheduled + running # Mark the execution as scheduled if threshold is not reached or delayed otherwise. if count < self.threshold: - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is not reached. Action execution will be scheduled.', - count, target.action, self._policy_ref) + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is not reached. Action execution will be scheduled.", + count, + target.action, + self._policy_ref, + ) status = action_constants.LIVEACTION_STATUS_REQUESTED else: - action = 'delayed' if self.policy_action == 'delay' else 'canceled' - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is reached. Action execution will be %s.', - count, target.action, self._policy_ref, action) + action = "delayed" if self.policy_action == "delay" else "canceled" + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is reached. Action execution will be %s.", + count, + target.action, + self._policy_ref, + action, + ) status = self._get_status_for_policy_action(action=self.policy_action) # Update the status in the database. Publish status for cancellation so the # appropriate runner can cancel the execution. Other statuses are not published # because they will be picked up by the worker(s) to be processed again, # leading to duplicate action executions. - publish = (status == action_constants.LIVEACTION_STATUS_CANCELING) + publish = status == action_constants.LIVEACTION_STATUS_CANCELING target = action_service.update_status(target, status, publish=publish) return target @@ -78,13 +89,17 @@ def apply_before(self, target): valid_states = [ action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] # Exit if target not in valid state. if target.status not in valid_states: - LOG.debug('The live action is not in a valid state therefore the policy ' - '"%s" cannot be applied. %s', self._policy_ref, target) + LOG.debug( + "The live action is not in a valid state therefore the policy " + '"%s" cannot be applied. %s', + self._policy_ref, + target, + ) return target target = self._apply_before(target) diff --git a/st2actions/st2actions/policies/concurrency_by_attr.py b/st2actions/st2actions/policies/concurrency_by_attr.py index 7c9ee1dabc..ea3f9cd421 100644 --- a/st2actions/st2actions/policies/concurrency_by_attr.py +++ b/st2actions/st2actions/policies/concurrency_by_attr.py @@ -25,38 +25,41 @@ from st2common.policies.concurrency import BaseConcurrencyApplicator from st2common.services import coordination -__all__ = [ - 'ConcurrencyByAttributeApplicator' -] +__all__ = ["ConcurrencyByAttributeApplicator"] LOG = logging.getLogger(__name__) class ConcurrencyByAttributeApplicator(BaseConcurrencyApplicator): - - def __init__(self, policy_ref, policy_type, threshold=0, action='delay', attributes=None): - super(ConcurrencyByAttributeApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type, - threshold=threshold, - action=action) + def __init__( + self, policy_ref, policy_type, threshold=0, action="delay", attributes=None + ): + super(ConcurrencyByAttributeApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=threshold, + action=action, + ) self.attributes = attributes or [] def _get_lock_uid(self, target): meta = { - 'policy_type': self._policy_type, - 'action': target.action, - 'attributes': self.attributes + "policy_type": self._policy_type, + "action": target.action, + "attributes": self.attributes, } return json.dumps(meta) def _get_filters(self, target): - filters = {('parameters__%s' % k): v - for k, v in six.iteritems(target.parameters) - if k in self.attributes} + filters = { + ("parameters__%s" % k): v + for k, v in six.iteritems(target.parameters) + if k in self.attributes + } - filters['action'] = target.action - filters['status'] = None + filters["action"] = target.action + filters["status"] = None return filters @@ -65,54 +68,71 @@ def _apply_before(self, target): filters = self._get_filters(target) # Get the count of scheduled instances of the action. - filters['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED + filters["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED scheduled = action_access.LiveAction.count(**filters) # Get the count of running instances of the action. - filters['status'] = action_constants.LIVEACTION_STATUS_RUNNING + filters["status"] = action_constants.LIVEACTION_STATUS_RUNNING running = action_access.LiveAction.count(**filters) count = scheduled + running # Mark the execution as scheduled if threshold is not reached or delayed otherwise. if count < self.threshold: - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is not reached. Action execution will be scheduled.', - count, target.action, self._policy_ref) + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is not reached. Action execution will be scheduled.", + count, + target.action, + self._policy_ref, + ) status = action_constants.LIVEACTION_STATUS_REQUESTED else: - action = 'delayed' if self.policy_action == 'delay' else 'canceled' - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is reached. Action execution will be %s.', - count, target.action, self._policy_ref, action) + action = "delayed" if self.policy_action == "delay" else "canceled" + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is reached. Action execution will be %s.", + count, + target.action, + self._policy_ref, + action, + ) status = self._get_status_for_policy_action(action=self.policy_action) # Update the status in the database. Publish status for cancellation so the # appropriate runner can cancel the execution. Other statuses are not published # because they will be picked up by the worker(s) to be processed again, # leading to duplicate action executions. - publish = (status == action_constants.LIVEACTION_STATUS_CANCELING) + publish = status == action_constants.LIVEACTION_STATUS_CANCELING target = action_service.update_status(target, status, publish=publish) return target def apply_before(self, target): - target = super(ConcurrencyByAttributeApplicator, self).apply_before(target=target) + target = super(ConcurrencyByAttributeApplicator, self).apply_before( + target=target + ) valid_states = [ action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] # Exit if target not in valid state. if target.status not in valid_states: - LOG.debug('The live action is not schedulable therefore the policy ' - '"%s" cannot be applied. %s', self._policy_ref, target) + LOG.debug( + "The live action is not schedulable therefore the policy " + '"%s" cannot be applied. %s', + self._policy_ref, + target, + ) return target # Warn users that the coordination service is not configured. if not coordination.configured(): - LOG.warn('Coordination service is not configured. Policy enforcement is best effort.') + LOG.warn( + "Coordination service is not configured. Policy enforcement is best effort." + ) target = self._apply_before(target) diff --git a/st2actions/st2actions/policies/retry.py b/st2actions/st2actions/policies/retry.py index 85775d4f13..abbbd70453 100644 --- a/st2actions/st2actions/policies/retry.py +++ b/st2actions/st2actions/policies/retry.py @@ -27,22 +27,16 @@ from st2common.util.enum import Enum from st2common.policies.base import ResourcePolicyApplicator -__all__ = [ - 'RetryOnPolicy', - 'ExecutionRetryPolicyApplicator' -] +__all__ = ["RetryOnPolicy", "ExecutionRetryPolicyApplicator"] LOG = logging.getLogger(__name__) -VALID_RETRY_STATUSES = [ - LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_TIMED_OUT -] +VALID_RETRY_STATUSES = [LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT] class RetryOnPolicy(Enum): - FAILURE = 'failure' # Retry on execution failure - TIMEOUT = 'timeout' # Retry on execution timeout + FAILURE = "failure" # Retry on execution failure + TIMEOUT = "timeout" # Retry on execution timeout class ExecutionRetryPolicyApplicator(ResourcePolicyApplicator): @@ -57,8 +51,9 @@ def __init__(self, policy_ref, policy_type, retry_on, max_retry_count=2, delay=0 :param delay: How long to wait before retrying an execution. :type delay: ``float`` """ - super(ExecutionRetryPolicyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type) + super(ExecutionRetryPolicyApplicator, self).__init__( + policy_ref=policy_ref, policy_type=policy_type + ) self.retry_on = retry_on self.max_retry_count = max_retry_count @@ -71,27 +66,33 @@ def apply_after(self, target): if self._is_live_action_part_of_workflow_action(live_action_db): LOG.warning( - 'Retry cannot be applied to this liveaction because it is executed under a ' - 'workflow. Use workflow specific retry functionality where applicable. %s', - live_action_db + "Retry cannot be applied to this liveaction because it is executed under a " + "workflow. Use workflow specific retry functionality where applicable. %s", + live_action_db, ) return target retry_count = self._get_live_action_retry_count(live_action_db=live_action_db) - extra = {'live_action_db': live_action_db, 'policy_ref': self._policy_ref, - 'retry_on': self.retry_on, 'max_retry_count': self.max_retry_count, - 'current_retry_count': retry_count} + extra = { + "live_action_db": live_action_db, + "policy_ref": self._policy_ref, + "retry_on": self.retry_on, + "max_retry_count": self.max_retry_count, + "current_retry_count": retry_count, + } if live_action_db.status not in VALID_RETRY_STATUSES: # Currently we only support retrying on failed action - LOG.debug('Liveaction not in a valid retry state, not checking retry policy', - extra=extra) + LOG.debug( + "Liveaction not in a valid retry state, not checking retry policy", + extra=extra, + ) return target if (retry_count + 1) > self.max_retry_count: - LOG.info('Maximum retry count has been reached, not retrying', extra=extra) + LOG.info("Maximum retry count has been reached, not retrying", extra=extra) return target has_failed = live_action_db.status == LIVEACTION_STATUS_FAILED @@ -100,34 +101,50 @@ def apply_after(self, target): # TODO: This is not crash and restart safe, switch to using "DELAYED" # status if self.delay > 0: - re_run_live_action = functools.partial(eventlet.spawn_after, self.delay, - self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + eventlet.spawn_after, + self.delay, + self._re_run_live_action, + live_action_db=live_action_db, + ) else: # Even if delay is 0, use a small delay (0.1 seconds) to prevent busy wait - re_run_live_action = functools.partial(eventlet.spawn_after, 0.1, - self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + eventlet.spawn_after, + 0.1, + self._re_run_live_action, + live_action_db=live_action_db, + ) - re_run_live_action = functools.partial(self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + self._re_run_live_action, live_action_db=live_action_db + ) if has_failed and self.retry_on == RetryOnPolicy.FAILURE: - extra['failure'] = True - LOG.info('Policy matched (failure), retrying action execution in %s seconds...' % - (self.delay), extra=extra) + extra["failure"] = True + LOG.info( + "Policy matched (failure), retrying action execution in %s seconds..." + % (self.delay), + extra=extra, + ) re_run_live_action() return target if has_timed_out and self.retry_on == RetryOnPolicy.TIMEOUT: - extra['timeout'] = True - LOG.info('Policy matched (timeout), retrying action execution in %s seconds...' % - (self.delay), extra=extra) + extra["timeout"] = True + LOG.info( + "Policy matched (timeout), retrying action execution in %s seconds..." + % (self.delay), + extra=extra, + ) re_run_live_action() return target - LOG.info('Invalid status "%s" for live action "%s", wont retry' % - (live_action_db.status, str(live_action_db.id)), extra=extra) + LOG.info( + 'Invalid status "%s" for live action "%s", wont retry' + % (live_action_db.status, str(live_action_db.id)), + extra=extra, + ) return target @@ -137,9 +154,9 @@ def _is_live_action_part_of_workflow_action(self, live_action_db): :rtype: ``dict`` """ - context = getattr(live_action_db, 'context', {}) - parent = context.get('parent', {}) - is_wf_action = (parent is not None and parent != {}) + context = getattr(live_action_db, "context", {}) + parent = context.get("parent", {}) + is_wf_action = parent is not None and parent != {} return is_wf_action @@ -151,8 +168,8 @@ def _get_live_action_retry_count(self, live_action_db): """ # TODO: Ideally we would store retry_count in zookeeper or similar and use locking so we # can run multiple instances of st2notififer - context = getattr(live_action_db, 'context', {}) - retry_count = context.get('policies', {}).get('retry', {}).get('retry_count', 0) + context = getattr(live_action_db, "context", {}) + retry_count = context.get("policies", {}).get("retry", {}).get("retry_count", 0) return retry_count @@ -160,17 +177,18 @@ def _re_run_live_action(self, live_action_db): retry_count = self._get_live_action_retry_count(live_action_db=live_action_db) # Add additional policy specific info to the context - context = getattr(live_action_db, 'context', {}) + context = getattr(live_action_db, "context", {}) new_context = copy.deepcopy(context) - new_context['policies'] = {} - new_context['policies']['retry'] = { - 'applied_policy': self._policy_ref, - 'retry_count': (retry_count + 1), - 'retried_liveaction_id': str(live_action_db.id) + new_context["policies"] = {} + new_context["policies"]["retry"] = { + "applied_policy": self._policy_ref, + "retry_count": (retry_count + 1), + "retried_liveaction_id": str(live_action_db.id), } action_ref = live_action_db.action parameters = live_action_db.parameters - new_live_action_db = LiveActionDB(action=action_ref, parameters=parameters, - context=new_context) + new_live_action_db = LiveActionDB( + action=action_ref, parameters=parameters, context=new_context + ) _, action_execution_db = action_services.request(new_live_action_db) return action_execution_db diff --git a/st2actions/st2actions/runners/pythonrunner.py b/st2actions/st2actions/runners/pythonrunner.py index 215edd83c8..33a3f3ec39 100644 --- a/st2actions/st2actions/runners/pythonrunner.py +++ b/st2actions/st2actions/runners/pythonrunner.py @@ -16,6 +16,4 @@ from __future__ import absolute_import from st2common.runners.base_action import Action -__all__ = [ - 'Action' -] +__all__ = ["Action"] diff --git a/st2actions/st2actions/scheduler/config.py b/st2actions/st2actions/scheduler/config.py index a991403a9b..8df6c3ff3e 100644 --- a/st2actions/st2actions/scheduler/config.py +++ b/st2actions/st2actions/scheduler/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=sys_constants.VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=sys_constants.VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,36 +50,48 @@ def _register_common_opts(): def _register_service_opts(): scheduler_opts = [ cfg.StrOpt( - 'logging', - default='/etc/st2/logging.scheduler.conf', - help='Location of the logging configuration file.' + "logging", + default="/etc/st2/logging.scheduler.conf", + help="Location of the logging configuration file.", ), cfg.FloatOpt( - 'execution_scheduling_timeout_threshold_min', default=1, - help='How long GC to search back in minutes for orphaned scheduled actions'), + "execution_scheduling_timeout_threshold_min", + default=1, + help="How long GC to search back in minutes for orphaned scheduled actions", + ), cfg.IntOpt( - 'pool_size', default=10, - help='The size of the pool used by the scheduler for scheduling executions.'), + "pool_size", + default=10, + help="The size of the pool used by the scheduler for scheduling executions.", + ), cfg.FloatOpt( - 'sleep_interval', default=0.10, - help='How long (in seconds) to sleep between each action scheduler main loop run ' - 'interval.'), + "sleep_interval", + default=0.10, + help="How long (in seconds) to sleep between each action scheduler main loop run " + "interval.", + ), cfg.FloatOpt( - 'gc_interval', default=10, - help='How often (in seconds) to look for zombie execution requests before rescheduling ' - 'them.'), + "gc_interval", + default=10, + help="How often (in seconds) to look for zombie execution requests before rescheduling " + "them.", + ), cfg.IntOpt( - 'retry_max_attempt', default=10, - help='The maximum number of attempts that the scheduler retries on error.'), + "retry_max_attempt", + default=10, + help="The maximum number of attempts that the scheduler retries on error.", + ), cfg.IntOpt( - 'retry_wait_msec', default=3000, - help='The number of milliseconds to wait in between retries.') + "retry_wait_msec", + default=3000, + help="The number of milliseconds to wait in between retries.", + ), ] - cfg.CONF.register_opts(scheduler_opts, group='scheduler') + cfg.CONF.register_opts(scheduler_opts, group="scheduler") try: register_opts() except cfg.DuplicateOptError: - LOG.exception('The scheduler configuration options are already parsed and loaded.') + LOG.exception("The scheduler configuration options are already parsed and loaded.") diff --git a/st2actions/st2actions/scheduler/entrypoint.py b/st2actions/st2actions/scheduler/entrypoint.py index ee8a76f2d1..14d816ded3 100644 --- a/st2actions/st2actions/scheduler/entrypoint.py +++ b/st2actions/st2actions/scheduler/entrypoint.py @@ -29,10 +29,7 @@ from st2common.persistence.execution_queue import ActionExecutionSchedulingQueue from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB -__all__ = [ - 'SchedulerEntrypoint', - 'get_scheduler_entrypoint' -] +__all__ = ["SchedulerEntrypoint", "get_scheduler_entrypoint"] LOG = logging.getLogger(__name__) @@ -43,6 +40,7 @@ class SchedulerEntrypoint(consumers.MessageHandler): SchedulerEntrypoint subscribes to the Action scheduler request queue and places new Live Actions into the scheduling queue collection for scheduling on action runners. """ + message_type = LiveActionDB def process(self, request): @@ -53,18 +51,25 @@ def process(self, request): :type request: ``st2common.models.db.liveaction.LiveActionDB`` """ if request.status != action_constants.LIVEACTION_STATUS_REQUESTED: - LOG.info('%s is ignoring %s (id=%s) with "%s" status.', - self.__class__.__name__, type(request), request.id, request.status) + LOG.info( + '%s is ignoring %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(request), + request.id, + request.status, + ) return try: liveaction_db = action_utils.get_liveaction_by_id(str(request.id)) except StackStormDBObjectNotFoundError: - LOG.exception('Failed to find liveaction %s in the database.', str(request.id)) + LOG.exception( + "Failed to find liveaction %s in the database.", str(request.id) + ) raise query = { - 'liveaction_id': str(liveaction_db.id), + "liveaction_id": str(liveaction_db.id), } queued_requests = ActionExecutionSchedulingQueue.query(**query) @@ -75,17 +80,16 @@ def process(self, request): if liveaction_db.delay and liveaction_db.delay > 0: liveaction_db = action_service.update_status( - liveaction_db, - action_constants.LIVEACTION_STATUS_DELAYED, - publish=False + liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False ) execution_queue_item_db = self._create_execution_queue_item_db_from_liveaction( - liveaction_db, - delay=liveaction_db.delay + liveaction_db, delay=liveaction_db.delay ) - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) return execution_queue_item_db @@ -99,9 +103,8 @@ def _create_execution_queue_item_db_from_liveaction(self, liveaction, delay=None execution_queue_item_db.action_execution_id = str(execution.id) execution_queue_item_db.liveaction_id = str(liveaction.id) execution_queue_item_db.original_start_timestamp = liveaction.start_timestamp - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - liveaction.start_timestamp, - delay or 0 + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time(liveaction.start_timestamp, delay or 0) ) execution_queue_item_db.delay = delay diff --git a/st2actions/st2actions/scheduler/handler.py b/st2actions/st2actions/scheduler/handler.py index 76d54066a9..e39871db3e 100644 --- a/st2actions/st2actions/scheduler/handler.py +++ b/st2actions/st2actions/scheduler/handler.py @@ -37,10 +37,7 @@ from st2common.metrics import base as metrics from st2common.exceptions import db as db_exc -__all__ = [ - 'ActionExecutionSchedulingQueueHandler', - 'get_handler' -] +__all__ = ["ActionExecutionSchedulingQueueHandler", "get_handler"] LOG = logging.getLogger(__name__) @@ -61,14 +58,15 @@ def __init__(self): # fast (< 5 seconds). If an item is still being marked as processing it likely indicates # that the scheduler process which was processing that item crashed or similar so we need # to mark it as "handling=False" so some other scheduler process can pick it up. - self._execution_scheduling_timeout_threshold_ms = \ + self._execution_scheduling_timeout_threshold_ms = ( cfg.CONF.scheduler.execution_scheduling_timeout_threshold_min * 60 * 1000 + ) self._coordinator = coordination_service.get_coordinator(start_heart=True) self._main_thread = None self._cleanup_thread = None def run(self): - LOG.debug('Starting scheduler handler...') + LOG.debug("Starting scheduler handler...") while not self._shutdown: eventlet.greenthread.sleep(cfg.CONF.scheduler.sleep_interval) @@ -77,7 +75,8 @@ def run(self): @retrying.retry( retry_on_exception=service_utils.retry_on_exceptions, stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt, - wait_fixed=cfg.CONF.scheduler.retry_wait_msec) + wait_fixed=cfg.CONF.scheduler.retry_wait_msec, + ) def process(self): execution_queue_item_db = self._get_next_execution() @@ -85,7 +84,7 @@ def process(self): self._pool.spawn(self._handle_execution, execution_queue_item_db) def cleanup(self): - LOG.debug('Starting scheduler garbage collection...') + LOG.debug("Starting scheduler garbage collection...") while not self._shutdown: eventlet.greenthread.sleep(cfg.CONF.scheduler.gc_interval) @@ -99,11 +98,11 @@ def _reset_handling_flag(self): False so other scheduler can pick it up. """ query = { - 'scheduled_start_timestamp__lte': date.append_milliseconds_to_time( + "scheduled_start_timestamp__lte": date.append_milliseconds_to_time( date.get_datetime_utc_now(), - -self._execution_scheduling_timeout_threshold_ms + -self._execution_scheduling_timeout_threshold_ms, ), - 'handling': True + "handling": True, } execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**query) or [] @@ -112,17 +111,19 @@ def _reset_handling_flag(self): execution_queue_item_db.handling = False try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) LOG.info( '[%s] Removing lock for orphaned execution queue item "%s".', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) except db_exc.StackStormDBObjectWriteConflictError: LOG.info( '[%s] Execution queue item "%s" updated during garbage collection.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) # TODO: Remove this function for fixing missing action_execution_id in v3.2. @@ -132,7 +133,9 @@ def _fix_missing_action_execution_id(self): """ Auto-populate the action_execution_id in ActionExecutionSchedulingQueue if empty. """ - for entry in ActionExecutionSchedulingQueue.query(action_execution_id__in=['', None]): + for entry in ActionExecutionSchedulingQueue.query( + action_execution_id__in=["", None] + ): execution_db = ActionExecution.get(liveaction__id=entry.liveaction_id) if not execution_db: @@ -152,23 +155,27 @@ def _cleanup_policy_delayed(self): moved back into requested status. """ - policy_delayed_liveaction_dbs = LiveAction.query(status='policy-delayed') or [] + policy_delayed_liveaction_dbs = LiveAction.query(status="policy-delayed") or [] for liveaction_db in policy_delayed_liveaction_dbs: - ex_que_qry = {'liveaction_id': str(liveaction_db.id), 'handling': False} - execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**ex_que_qry) or [] + ex_que_qry = {"liveaction_id": str(liveaction_db.id), "handling": False} + execution_queue_item_dbs = ( + ActionExecutionSchedulingQueue.query(**ex_que_qry) or [] + ) for execution_queue_item_db in execution_queue_item_dbs: # Mark the entry in the scheduling queue for handling. try: execution_queue_item_db.handling = True - execution_queue_item_db = ActionExecutionSchedulingQueue.add_or_update( - execution_queue_item_db, publish=False) + execution_queue_item_db = ( + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) + ) except db_exc.StackStormDBObjectWriteConflictError: - msg = ( - '[%s] Item "%s" is currently being processed by another scheduler.' % - (execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id)) + msg = '[%s] Item "%s" is currently being processed by another scheduler.' % ( + execution_queue_item_db.action_execution_id, + str(execution_queue_item_db.id), ) LOG.error(msg) raise Exception(msg) @@ -177,7 +184,7 @@ def _cleanup_policy_delayed(self): LOG.info( '[%s] Removing policy-delayed entry "%s" from the scheduling queue.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) @@ -186,18 +193,20 @@ def _cleanup_policy_delayed(self): LOG.info( '[%s] Removing policy-delayed entry "%s" from the scheduling queue.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) liveaction_db = action_service.update_status( - liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED) + liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED + ) execution_service.update_execution(liveaction_db) @retrying.retry( retry_on_exception=service_utils.retry_on_exceptions, stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt, - wait_fixed=cfg.CONF.scheduler.retry_wait_msec) + wait_fixed=cfg.CONF.scheduler.retry_wait_msec, + ) def _handle_garbage_collection(self): self._reset_handling_flag() @@ -212,13 +221,10 @@ def _get_next_execution(self): due to a policy. """ query = { - 'scheduled_start_timestamp__lte': date.get_datetime_utc_now(), - 'handling': False, - 'limit': 1, - 'order_by': [ - '+scheduled_start_timestamp', - '+original_start_timestamp' - ] + "scheduled_start_timestamp__lte": date.get_datetime_utc_now(), + "handling": False, + "limit": 1, + "order_by": ["+scheduled_start_timestamp", "+original_start_timestamp"], } execution_queue_item_db = ActionExecutionSchedulingQueue.query(**query).first() @@ -229,45 +235,52 @@ def _get_next_execution(self): # Mark that this scheduler process is currently handling (processing) that request # NOTE: This operation is atomic (CAS) msg = '[%s] Retrieved item "%s" from scheduling queue.' - LOG.info(msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id) + LOG.info( + msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id + ) execution_queue_item_db.handling = True try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) return execution_queue_item_db except db_exc.StackStormDBObjectWriteConflictError: LOG.info( '[%s] Item "%s" is already handled by another scheduler.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) return None - @metrics.CounterWithTimer(key='scheduler.handle_execution') + @metrics.CounterWithTimer(key="scheduler.handle_execution") def _handle_execution(self, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Scheduling Liveaction "%s".', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) try: liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) except StackStormDBObjectNotFoundError: msg = '[%s] Failed to find liveaction "%s" in the database (queue_item_id=%s).' - LOG.exception(msg, action_execution_id, liveaction_id, queue_item_id, extra=extra) + LOG.exception( + msg, action_execution_id, liveaction_id, queue_item_id, extra=extra + ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) raise # Identify if the action has policies that require locking. action_has_policies_require_lock = policy_service.has_policies( - liveaction_db, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + liveaction_db, policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK ) # Acquire a distributed lock if the referenced action has specific policies attached. @@ -275,9 +288,9 @@ def _handle_execution(self, execution_queue_item_db): # Warn users that the coordination service is not configured. if not coordination_service.configured(): LOG.warn( - '[%s] Coordination backend is not configured. ' - 'Policy enforcement is best effort.', - action_execution_id + "[%s] Coordination backend is not configured. " + "Policy enforcement is best effort.", + action_execution_id, ) # Acquire a distributed lock before querying the database to make sure that only one @@ -304,11 +317,14 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Liveaction "%s" has status "%s" before applying policies.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) # Apply policies defined for the action. @@ -316,13 +332,18 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): LOG.info( '[%s] Liveaction "%s" has status "%s" after applying policies.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) if liveaction_db.status == action_constants.LIVEACTION_STATUS_DELAYED: LOG.info( '[%s] Liveaction "%s" is delayed and scheduling queue is updated.', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) liveaction_db = action_service.update_status( @@ -330,23 +351,30 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): ) execution_queue_item_db.handling = False - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - date.get_datetime_utc_now(), - POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time( + date.get_datetime_utc_now(), + POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS, + ) ) try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) except db_exc.StackStormDBObjectWriteConflictError: LOG.warning( - '[%s] Database write conflict on updating scheduling queue.', - action_execution_id, extra=extra + "[%s] Database write conflict on updating scheduling queue.", + action_execution_id, + extra=extra, ) return - if (liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES or - liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES): + if ( + liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES + or liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES + ): ActionExecutionSchedulingQueue.delete(execution_queue_item_db) return @@ -356,33 +384,41 @@ def _delay(self, liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Liveaction "%s" is delayed and scheduling queue is updated.', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) liveaction_db = action_service.update_status( liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False ) - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - date.get_datetime_utc_now(), - POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time( + date.get_datetime_utc_now(), POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + ) ) try: execution_queue_item_db.handling = False - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) except db_exc.StackStormDBObjectWriteConflictError: LOG.warning( - '[%s] Database write conflict on updating scheduling queue.', - action_execution_id, extra=extra + "[%s] Database write conflict on updating scheduling queue.", + action_execution_id, + extra=extra, ) def _schedule(self, liveaction_db, execution_queue_item_db): - if self._is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): + if self._is_execution_queue_item_runnable( + liveaction_db, execution_queue_item_db + ): self._update_to_scheduled(liveaction_db, execution_queue_item_db) @staticmethod @@ -396,7 +432,7 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): valid_status = [ action_constants.LIVEACTION_STATUS_REQUESTED, action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] if liveaction_db.status in valid_status: @@ -405,11 +441,14 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Ignoring Liveaction "%s" with status "%s" after policies are applied.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) @@ -421,18 +460,26 @@ def _update_to_scheduled(liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} # Update liveaction status to "scheduled". LOG.info( '[%s] Liveaction "%s" with status "%s" is updated to status "scheduled."', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) - if liveaction_db.status in [action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED]: + if liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_REQUESTED, + action_constants.LIVEACTION_STATUS_DELAYED, + ]: liveaction_db = action_service.update_status( - liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED, publish=False) + liveaction_db, + action_constants.LIVEACTION_STATUS_SCHEDULED, + publish=False, + ) # Publish the "scheduled" status here manually. Otherwise, there could be a # race condition with the update of the action_execution_db if the execution diff --git a/st2actions/st2actions/worker.py b/st2actions/st2actions/worker.py index 3147ce1aae..1741d60724 100644 --- a/st2actions/st2actions/worker.py +++ b/st2actions/st2actions/worker.py @@ -34,10 +34,7 @@ from st2common.transport import queues -__all__ = [ - 'ActionExecutionDispatcher', - 'get_worker' -] +__all__ = ["ActionExecutionDispatcher", "get_worker"] LOG = logging.getLogger(__name__) @@ -46,14 +43,14 @@ queues.ACTIONRUNNER_WORK_QUEUE, queues.ACTIONRUNNER_CANCEL_QUEUE, queues.ACTIONRUNNER_PAUSE_QUEUE, - queues.ACTIONRUNNER_RESUME_QUEUE + queues.ACTIONRUNNER_RESUME_QUEUE, ] ACTIONRUNNER_DISPATCHABLE_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_CANCELING, action_constants.LIVEACTION_STATUS_PAUSING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] @@ -83,41 +80,54 @@ def process(self, liveaction): """ if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: - LOG.info('%s is not executing %s (id=%s) with "%s" status.', - self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status) + LOG.info( + '%s is not executing %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(liveaction), + liveaction.id, + liveaction.status, + ) if not liveaction.result: updated_liveaction = action_utils.update_liveaction_status( status=liveaction.status, - result={'message': 'Action execution canceled by user.'}, - liveaction_id=liveaction.id) + result={"message": "Action execution canceled by user."}, + liveaction_id=liveaction.id, + ) executions.update_execution(updated_liveaction) return if liveaction.status not in ACTIONRUNNER_DISPATCHABLE_STATES: - LOG.info('%s is not dispatching %s (id=%s) with "%s" status.', - self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status) + LOG.info( + '%s is not dispatching %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(liveaction), + liveaction.id, + liveaction.status, + ) return try: liveaction_db = action_utils.get_liveaction_by_id(liveaction.id) except StackStormDBObjectNotFoundError: - LOG.exception('Failed to find liveaction %s in the database.', liveaction.id) + LOG.exception( + "Failed to find liveaction %s in the database.", liveaction.id + ) raise if liveaction.status != liveaction_db.status: LOG.warning( - 'The status of liveaction %s has changed from %s to %s ' - 'while in the queue waiting for processing.', + "The status of liveaction %s has changed from %s to %s " + "while in the queue waiting for processing.", liveaction.id, liveaction.status, - liveaction_db.status + liveaction_db.status, ) dispatchers = { action_constants.LIVEACTION_STATUS_SCHEDULED: self._run_action, action_constants.LIVEACTION_STATUS_CANCELING: self._cancel_action, action_constants.LIVEACTION_STATUS_PAUSING: self._pause_action, - action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action + action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action, } return dispatchers[liveaction.status](liveaction) @@ -130,7 +140,7 @@ def shutdown(self): try: executions.abandon_execution_if_incomplete(liveaction_id=liveaction_id) except: - LOG.exception('Failed to abandon liveaction %s.', liveaction_id) + LOG.exception("Failed to abandon liveaction %s.", liveaction_id) def _run_action(self, liveaction_db): # stamp liveaction with process_info @@ -140,35 +150,49 @@ def _run_action(self, liveaction_db): liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_RUNNING, runner_info=runner_info, - liveaction_id=liveaction_db.id) + liveaction_id=liveaction_db.id, + ) self._running_liveactions.add(liveaction_db.id) action_execution_db = executions.update_execution(liveaction_db) # Launch action - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Launching action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Launching action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) - - extra = {'liveaction_db': liveaction_db} + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) + + extra = {"liveaction_db": liveaction_db} try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) if not result and not liveaction_db.action_is_workflow: - raise ActionRunnerException('Failed to execute action.') + raise ActionRunnerException("Failed to execute action.") except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra) + extra["error"] = str(ex) + LOG.info( + 'Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra + ) liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_FAILED, liveaction_id=liveaction_db.id, - result={'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) executions.update_execution(liveaction_db) raise finally: @@ -182,66 +206,98 @@ def _run_action(self, liveaction_db): def _cancel_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Canceling action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Canceling action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to cancel action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to cancel action execution %s." % (liveaction_db.id), + extra=extra, + ) raise return result def _pause_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Pausing action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Pausing action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to pause action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to pause action execution %s." % (liveaction_db.id), extra=extra + ) raise return result def _resume_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Resuming action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Resuming action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to resume action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to resume action execution %s." % (liveaction_db.id), + extra=extra, + ) raise # Cascade the resume upstream if action execution is child of an orquesta workflow. # The action service request_resume function is not used here because we do not want # other peer subworkflows to be resumed. - if 'orquesta' in action_execution_db.context and 'parent' in action_execution_db.context: + if ( + "orquesta" in action_execution_db.context + and "parent" in action_execution_db.context + ): wf_svc.handle_action_execution_resume(action_execution_db) return result diff --git a/st2actions/st2actions/workflows/config.py b/st2actions/st2actions/workflows/config.py index 0d2556f67a..6854323ddd 100644 --- a/st2actions/st2actions/workflows/config.py +++ b/st2actions/st2actions/workflows/config.py @@ -23,8 +23,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=sys_constants.VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=sys_constants.VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -43,13 +46,13 @@ def _register_common_opts(): def _register_service_opts(): wf_engine_opts = [ cfg.StrOpt( - 'logging', - default='/etc/st2/logging.workflowengine.conf', - help='Location of the logging configuration file.' + "logging", + default="/etc/st2/logging.workflowengine.conf", + help="Location of the logging configuration file.", ) ] - cfg.CONF.register_opts(wf_engine_opts, group='workflow_engine') + cfg.CONF.register_opts(wf_engine_opts, group="workflow_engine") register_opts() diff --git a/st2actions/st2actions/workflows/workflows.py b/st2actions/st2actions/workflows/workflows.py index 0351998025..2151c7d440 100644 --- a/st2actions/st2actions/workflows/workflows.py +++ b/st2actions/st2actions/workflows/workflows.py @@ -37,17 +37,16 @@ WORKFLOW_EXECUTION_QUEUES = [ queues.WORKFLOW_EXECUTION_WORK_QUEUE, queues.WORKFLOW_EXECUTION_RESUME_QUEUE, - queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE + queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE, ] class WorkflowExecutionHandler(consumers.VariableMessageHandler): - def __init__(self, connection, queues): super(WorkflowExecutionHandler, self).__init__(connection, queues) def handle_workflow_execution_with_instrumentation(wf_ex_db): - with metrics.CounterWithTimer(key='orquesta.workflow.executions'): + with metrics.CounterWithTimer(key="orquesta.workflow.executions"): return self.handle_workflow_execution(wf_ex_db=wf_ex_db) def handle_action_execution_with_instrumentation(ac_ex_db): @@ -55,27 +54,27 @@ def handle_action_execution_with_instrumentation(ac_ex_db): if not wf_svc.is_action_execution_under_workflow_context(ac_ex_db): return - with metrics.CounterWithTimer(key='orquesta.action.executions'): + with metrics.CounterWithTimer(key="orquesta.action.executions"): return self.handle_action_execution(ac_ex_db=ac_ex_db) self.message_types = { wf_db_models.WorkflowExecutionDB: handle_workflow_execution_with_instrumentation, - ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation + ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation, } def get_queue_consumer(self, connection, queues): # We want to use a special ActionsQueueConsumer which uses 2 dispatcher pools return consumers.VariableMessageQueueConsumer( - connection=connection, - queues=queues, - handler=self + connection=connection, queues=queues, handler=self ) def process(self, message): handler_function = self.message_types.get(type(message), None) if not handler_function: - msg = 'Handler function for message type "%s" is not defined.' % type(message) + msg = 'Handler function for message type "%s" is not defined.' % type( + message + ) raise ValueError(msg) try: @@ -90,43 +89,45 @@ def process(self, message): def fail_workflow_execution(self, message, exception): # Prepare attributes based on message type. if isinstance(message, wf_db_models.WorkflowExecutionDB): - msg_type = 'workflow' + msg_type = "workflow" wf_ex_db = message wf_ex_id = str(wf_ex_db.id) task = None else: - msg_type = 'task' + msg_type = "task" ac_ex_db = message - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) - task = {'id': task_ex_db.task_id, 'route': task_ex_db.task_route} + task = {"id": task_ex_db.task_id, "route": task_ex_db.task_route} # Log the error. - msg = 'Unknown error while processing %s execution. %s: %s' + msg = "Unknown error while processing %s execution. %s: %s" wf_svc.update_progress( wf_ex_db, msg % (msg_type, exception.__class__.__name__, str(exception)), - severity='error' + severity="error", ) # Fail the task execution so it's marked correctly in the # conductor state to allow for task rerun if needed. if isinstance(message, ex_db_models.ActionExecutionDB): msg = 'Unknown error while processing %s execution. Failing task execution "%s".' - wf_svc.update_progress(wf_ex_db, msg % (msg_type, task_ex_id), severity='error') + wf_svc.update_progress( + wf_ex_db, msg % (msg_type, task_ex_id), severity="error" + ) wf_svc.update_task_execution(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.update_task_state(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED) # Fail the workflow execution. msg = 'Unknown error while processing %s execution. Failing workflow execution "%s".' - wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity='error') + wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity="error") wf_svc.fail_workflow_execution(wf_ex_id, exception, task=task) def handle_workflow_execution(self, wf_ex_db): # Request the next set of tasks to execute. - wf_svc.update_progress(wf_ex_db, 'Processing request for workflow execution.') + wf_svc.update_progress(wf_ex_db, "Processing request for workflow execution.") wf_svc.request_next_tasks(wf_ex_db) def handle_action_execution(self, ac_ex_db): @@ -135,16 +136,17 @@ def handle_action_execution(self, ac_ex_db): return # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) - msg = ( - 'Action execution "%s" for task "%s" is updated and in "%s" state.' % - (str(ac_ex_db.id), task_ex_db.task_id, ac_ex_db.status) + msg = 'Action execution "%s" for task "%s" is updated and in "%s" state.' % ( + str(ac_ex_db.id), + task_ex_db.task_id, + ac_ex_db.status, ) wf_svc.update_progress(wf_ex_db, msg) @@ -152,9 +154,13 @@ def handle_action_execution(self, ac_ex_db): if task_ex_db.status in statuses.COMPLETED_STATUSES: msg = ( 'Action execution "%s" for task "%s", route "%s", is not processed ' - 'because task execution "%s" is already in completed state "%s".' % ( - str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route), - str(task_ex_db.id), task_ex_db.status + 'because task execution "%s" is already in completed state "%s".' + % ( + str(ac_ex_db.id), + task_ex_db.task_id, + str(task_ex_db.task_route), + str(task_ex_db.id), + task_ex_db.status, ) ) wf_svc.update_progress(wf_ex_db, msg) @@ -175,7 +181,7 @@ def handle_action_execution(self, ac_ex_db): return # Apply post run policies. - lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction['id']) + lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction["id"]) pc_svc.apply_post_run_policies(lv_ac_db) # Process completion of the action execution. diff --git a/st2actions/tests/unit/policies/test_base.py b/st2actions/tests/unit/policies/test_base.py index fcf3aef40d..2e5003d89c 100644 --- a/st2actions/tests/unit/policies/test_base.py +++ b/st2actions/tests/unit/policies/test_base.py @@ -17,6 +17,7 @@ import mock from st2tests import config as test_config + test_config.parse_args() import st2common @@ -32,28 +33,21 @@ from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'SchedulerPoliciesTestCase', - 'NotifierPoliciesTestCase' -] +__all__ = ["SchedulerPoliciesTestCase", "NotifierPoliciesTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES_1 = { - 'actions': [ - 'action1.yaml' + "actions": ["action1.yaml"], + "policies": [ + "policy_4.yaml", ], - 'policies': [ - 'policy_4.yaml', - ] } TEST_FIXTURES_2 = { - 'actions': [ - 'action1.yaml' + "actions": ["action1.yaml"], + "policies": [ + "policy_1.yaml", ], - 'policies': [ - 'policy_1.yaml', - ] } @@ -73,15 +67,14 @@ def setUp(self): register_policy_types(st2common) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES_2) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_2 + ) # Policy with "post_run" application - self.policy_db = models['policies']['policy_1.yaml'] + self.policy_db = models["policies"]["policy_1.yaml"] - @mock.patch.object( - policies, 'get_driver', - mock.MagicMock(return_value=None)) + @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None)) def test_disabled_policy_not_applied_on_pre_run(self): ########## # First test a scenario where policy is enabled @@ -91,7 +84,9 @@ def test_disabled_policy_not_applied_on_pre_run(self): # Post run hasn't been called yet, call count should be 0 self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_pre_run_policies(live_action_db) @@ -108,7 +103,9 @@ def test_disabled_policy_not_applied_on_pre_run(self): self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_pre_run_policies(live_action_db) @@ -133,15 +130,14 @@ def setUp(self): register_policy_types(st2common) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES_1) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_1 + ) # Policy with "post_run" application - self.policy_db = models['policies']['policy_4.yaml'] + self.policy_db = models["policies"]["policy_4.yaml"] - @mock.patch.object( - policies, 'get_driver', - mock.MagicMock(return_value=None)) + @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None)) def test_disabled_policy_not_applied_on_post_run(self): ########## # First test a scenario where policy is enabled @@ -151,7 +147,9 @@ def test_disabled_policy_not_applied_on_post_run(self): # Post run hasn't been called yet, call count should be 0 self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_post_run_policies(live_action_db) @@ -168,7 +166,9 @@ def test_disabled_policy_not_applied_on_post_run(self): self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_post_run_policies(live_action_db) diff --git a/st2actions/tests/unit/policies/test_concurrency.py b/st2actions/tests/unit/policies/test_concurrency.py index 670c38d839..f22a0303cd 100644 --- a/st2actions/tests/unit/policies/test_concurrency.py +++ b/st2actions/tests/unit/policies/test_concurrency.py @@ -42,40 +42,40 @@ from st2tests.mocks.runners import runner -__all__ = [ - 'ConcurrencyPolicyTestCase' -] +__all__ = ["ConcurrencyPolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_5.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_1.yaml", "policy_5.yaml"], } -NON_EMPTY_RESULT = 'non-empty' -MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None) +NON_EMPTY_RESULT = "non-empty" +MOCK_RUN_RETURN_VALUE = ( + action_constants.LIVEACTION_STATUS_RUNNING, + NON_EMPTY_RESULT, + None, +) SCHEDULED_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_SUCCEEDED + action_constants.LIVEACTION_STATUS_SUCCEEDED, ] -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ConcurrencyPolicyTestCase(EventletTestCase, ExecutionDbTestCase): @classmethod def setUpClass(cls): @@ -93,8 +93,7 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @classmethod def tearDownClass(cls): @@ -106,10 +105,15 @@ def tearDownClass(cls): # NOTE: This monkey patch needs to happen again here because during tests for some reason this # method gets unpatched (test doing reload() or similar) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def tearDown(self): for liveaction in LiveAction.get_all(): - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @staticmethod def _process_scheduling_queue(): @@ -117,64 +121,82 @@ def _process_scheduling_queue(): scheduling_queue.get_handler()._handle_execution(queued_req) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_delay_executions(self): # Ensure the concurrency policy is accurate. - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo-last'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo-last"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed async, wait for the liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Mark one of the scheduled/running execution as completed. action_service.update_status( - scheduled[0], - action_constants.LIVEACTION_STATUS_SUCCEEDED, - publish=True + scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True ) expected_num_pubs += 1 # Tally succeeded state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -185,52 +207,74 @@ def test_over_threshold_delay_executions(self): # Since states are being processed async, wait for the liveaction to be scheduled. liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Check the status changes. execution = ActionExecution.get(liveaction__id=str(liveaction.id)) - expected_status_changes = ['requested', 'delayed', 'requested', 'scheduled', 'running'] - actual_status_changes = [entry['status'] for entry in execution.log] + expected_status_changes = [ + "requested", + "delayed", + "requested", + "scheduled", + "running", + ] + actual_status_changes = [entry["status"] for entry in execution.log] self.assertListEqual(actual_status_changes, expected_status_changes) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_cancel_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.cancel') - self.assertEqual(policy_db.parameters['action'], 'cancel') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.cancel") + self.assertEqual(policy_db.parameters["action"], "cancel") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-2', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-2", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be canceled since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -240,67 +284,91 @@ def test_over_threshold_cancel_executions(self): LiveActionPublisher.publish_state.assert_has_calls(calls) expected_num_pubs += 2 # Tally canceling and canceled state changes. expected_num_exec += 0 # This request will not be scheduled for execution. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Assert the action is canceled. liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_on_cancellation(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed async, wait for the liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Cancel execution. - action_service.request_cancellation(scheduled[0], 'stanley') + action_service.request_cancellation(scheduled[0], "stanley") expected_num_pubs += 2 # Tally the canceling and canceled states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -312,5 +380,7 @@ def test_on_cancellation(self): # Execution is expected to be rescheduled. liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertIn(liveaction.status, SCHEDULED_STATES) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) diff --git a/st2actions/tests/unit/policies/test_concurrency_by_attr.py b/st2actions/tests/unit/policies/test_concurrency_by_attr.py index b576e3a669..98cfc3a4dc 100644 --- a/st2actions/tests/unit/policies/test_concurrency_by_attr.py +++ b/st2actions/tests/unit/policies/test_concurrency_by_attr.py @@ -39,42 +39,41 @@ from st2tests.mocks.runners import runner from six.moves import range -__all__ = [ - 'ConcurrencyByAttributePolicyTestCase' -] +__all__ = ["ConcurrencyByAttributePolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_3.yaml', - 'policy_7.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_3.yaml", "policy_7.yaml"], } -NON_EMPTY_RESULT = 'non-empty' -MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None) +NON_EMPTY_RESULT = "non-empty" +MOCK_RUN_RETURN_VALUE = ( + action_constants.LIVEACTION_STATUS_RUNNING, + NON_EMPTY_RESULT, + None, +) SCHEDULED_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_SUCCEEDED + action_constants.LIVEACTION_STATUS_SUCCEEDED, ] -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) -@mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ConcurrencyByAttributePolicyTestCase(EventletTestCase, ExecutionDbTestCase): - @classmethod def setUpClass(cls): EventletTestCase.setUpClass() @@ -91,8 +90,7 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @classmethod def tearDownClass(cls): @@ -104,10 +102,15 @@ def tearDownClass(cls): # NOTE: This monkey patch needs to happen again here because during tests for some reason this # method gets unpatched (test doing reload() or similar) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def tearDown(self): for liveaction in LiveAction.get_all(): - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @staticmethod def _process_scheduling_queue(): @@ -115,58 +118,80 @@ def _process_scheduling_queue(): scheduling_queue.get_handler()._handle_execution(queued_req) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_delay_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed asynchronously, wait for the # liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be scheduled since concurrency threshold is not reached. # The execution with actionstr "fu" is over the threshold but actionstr "bar" is not. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "bar"} + ) liveaction, _ = action_service.request(liveaction) # Run the scheduler to schedule action executions. @@ -177,18 +202,20 @@ def test_over_threshold_delay_executions(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # This request is expected to be executed. expected_num_pubs += 3 # Tally requested, scheduled, and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Mark one of the execution as completed. action_service.update_status( - scheduled[0], - action_constants.LIVEACTION_STATUS_SUCCEEDED, - publish=True + scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True ) expected_num_pubs += 1 # Tally succeeded state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -197,47 +224,65 @@ def test_over_threshold_delay_executions(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # The delayed request is expected to be executed. expected_num_pubs += 2 # Tally scheduled and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_cancel_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.attr.cancel') - self.assertEqual(policy_db.parameters['action'], 'cancel') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.attr.cancel") + self.assertEqual(policy_db.parameters["action"], "cancel") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -247,7 +292,9 @@ def test_over_threshold_cancel_executions(self): LiveActionPublisher.publish_state.assert_has_calls(calls) expected_num_pubs += 2 # Tally canceling and canceled state changes. expected_num_exec += 0 # This request will not be scheduled for execution. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Assert the action is canceled. @@ -255,58 +302,80 @@ def test_over_threshold_cancel_executions(self): self.assertEqual(canceled.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_on_cancellation(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed asynchronously, wait for the # liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) delayed = liveaction expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be scheduled since concurrency threshold is not reached. # The execution with actionstr "fu" is over the threshold but actionstr "bar" is not. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "bar"} + ) liveaction, _ = action_service.request(liveaction) # Run the scheduler to schedule action executions. @@ -317,13 +386,17 @@ def test_on_cancellation(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # This request is expected to be executed. expected_num_pubs += 3 # Tally requested, scheduled, and running states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Cancel execution. - action_service.request_cancellation(scheduled[0], 'stanley') + action_service.request_cancellation(scheduled[0], "stanley") expected_num_pubs += 2 # Tally the canceling and canceled states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -331,7 +404,9 @@ def test_on_cancellation(self): # Once capacity freed up, the delayed execution is published as requested again. expected_num_exec += 1 # The delayed request is expected to be executed. expected_num_pubs += 2 # Tally scheduled and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Since states are being processed asynchronously, wait for the diff --git a/st2actions/tests/unit/policies/test_retry_policy.py b/st2actions/tests/unit/policies/test_retry_policy.py index 6b6f0f0cc4..21371c6a02 100644 --- a/st2actions/tests/unit/policies/test_retry_policy.py +++ b/st2actions/tests/unit/policies/test_retry_policy.py @@ -35,19 +35,10 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RetryPolicyTestCase' -] +__all__ = ["RetryPolicyTestCase"] -PACK = 'generic' -TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ], - 'policies': [ - 'policy_4.yaml' - ] -} +PACK = "generic" +TEST_FIXTURES = {"actions": ["action1.yaml"], "policies": ["policy_4.yaml"]} class RetryPolicyTestCase(CleanDbTestCase): @@ -66,18 +57,21 @@ def setUp(self): register_policy_types(st2actions) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES + ) # Instantiate policy applicator we will use in the tests - policy_db = models['policies']['policy_4.yaml'] - retry_on = policy_db.parameters['retry_on'] - max_retry_count = policy_db.parameters['max_retry_count'] - self.policy = ExecutionRetryPolicyApplicator(policy_ref='test_policy', - policy_type='action.retry', - retry_on=retry_on, - max_retry_count=max_retry_count, - delay=0) + policy_db = models["policies"]["policy_4.yaml"] + retry_on = policy_db.parameters["retry_on"] + max_retry_count = policy_db.parameters["max_retry_count"] + self.policy = ExecutionRetryPolicyApplicator( + policy_ref="test_policy", + policy_type="action.retry", + retry_on=retry_on, + max_retry_count=max_retry_count, + delay=0, + ) def test_retry_on_timeout_no_retry_since_no_timeout_reached(self): # Verify initial state @@ -85,7 +79,9 @@ def test_retry_on_timeout_no_retry_since_no_timeout_reached(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which succeeds - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_SUCCEEDED @@ -110,7 +106,9 @@ def test_retry_on_timeout_first_retry_is_successful(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT @@ -130,14 +128,16 @@ def test_retry_on_timeout_first_retry_is_successful(self): self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[0].liveaction['id'] + original_liveaction_id = action_execution_dbs[0].liveaction["id"] context = action_execution_dbs[1].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 1) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 1) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) # Simulate success of second action so no it shouldn't be retried anymore live_action_db = live_action_dbs[1] @@ -161,7 +161,9 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT @@ -181,14 +183,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[0].liveaction['id'] + original_liveaction_id = action_execution_dbs[0].liveaction["id"] context = action_execution_dbs[1].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 1) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 1) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) # Simulate timeout of second action which should cause another retry live_action_db = live_action_dbs[1] @@ -212,14 +216,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertEqual(action_execution_dbs[2].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[1].liveaction['id'] + original_liveaction_id = action_execution_dbs[1].liveaction["id"] context = action_execution_dbs[2].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 2) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 2) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) def test_retry_on_timeout_max_retries_reached(self): # Verify initial state @@ -227,12 +233,14 @@ def test_retry_on_timeout_max_retries_reached(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT - live_action_db.context['policies'] = {} - live_action_db.context['policies']['retry'] = {'retry_count': 2} + live_action_db.context["policies"] = {} + live_action_db.context["policies"]["retry"] = {"retry_count": 2} execution_db.status = LIVEACTION_STATUS_TIMED_OUT LiveAction.add_or_update(live_action_db) ActionExecution.add_or_update(execution_db) @@ -248,8 +256,10 @@ def test_retry_on_timeout_max_retries_reached(self): self.assertEqual(action_execution_dbs[0].status, LIVEACTION_STATUS_TIMED_OUT) @mock.patch.object( - trace_service, 'get_trace_db_by_live_action', - mock.MagicMock(return_value=(None, None))) + trace_service, + "get_trace_db_by_live_action", + mock.MagicMock(return_value=(None, None)), + ) def test_no_retry_on_workflow_task(self): # Verify initial state self.assertSequenceEqual(LiveAction.get_all(), []) @@ -257,9 +267,9 @@ def test_no_retry_on_workflow_task(self): # Start a mock action which times out live_action_db = LiveActionDB( - action='wolfpack.action-1', - parameters={'actionstr': 'foo'}, - context={'parent': {'execution_id': 'abcde'}} + action="wolfpack.action-1", + parameters={"actionstr": "foo"}, + context={"parent": {"execution_id": "abcde"}}, ) live_action_db, execution_db = action_service.request(live_action_db) @@ -268,7 +278,7 @@ def test_no_retry_on_workflow_task(self): # Expire the workflow instance. live_action_db.status = LIVEACTION_STATUS_TIMED_OUT - live_action_db.context['policies'] = {} + live_action_db.context["policies"] = {} execution_db.status = LIVEACTION_STATUS_TIMED_OUT LiveAction.add_or_update(live_action_db) ActionExecution.add_or_update(execution_db) @@ -297,10 +307,12 @@ def test_no_retry_on_non_applicable_statuses(self): LIVEACTION_STATUS_CANCELED, ] - action_ref = 'wolfpack.action-1' + action_ref = "wolfpack.action-1" for status in non_retry_statuses: - liveaction = LiveActionDB(action=action_ref, parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action=action_ref, parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = status diff --git a/st2actions/tests/unit/test_action_runner_worker.py b/st2actions/tests/unit/test_action_runner_worker.py index 4f2494a431..1d0c7bbbd0 100644 --- a/st2actions/tests/unit/test_action_runner_worker.py +++ b/st2actions/tests/unit/test_action_runner_worker.py @@ -21,11 +21,10 @@ from st2common.models.db.liveaction import LiveActionDB from st2tests import config as test_config + test_config.parse_args() -__all__ = [ - 'ActionsQueueConsumerTestCase' -] +__all__ = ["ActionsQueueConsumerTestCase"] class ActionsQueueConsumerTestCase(TestCase): @@ -38,7 +37,9 @@ def test_process_right_dispatcher_is_used(self): consumer._workflows_dispatcher = Mock() consumer._actions_dispatcher = Mock() - body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=False) + body = LiveActionDB( + status="scheduled", action="core.local", action_is_workflow=False + ) message = Mock() consumer.process(body=body, message=message) @@ -49,7 +50,9 @@ def test_process_right_dispatcher_is_used(self): consumer._workflows_dispatcher = Mock() consumer._actions_dispatcher = Mock() - body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=True) + body = LiveActionDB( + status="scheduled", action="core.local", action_is_workflow=True + ) message = Mock() consumer.process(body=body, message=message) diff --git a/st2actions/tests/unit/test_actions_registrar.py b/st2actions/tests/unit/test_actions_registrar.py index c4d2771268..cc9da33299 100644 --- a/st2actions/tests/unit/test_actions_registrar.py +++ b/st2actions/tests/unit/test_actions_registrar.py @@ -31,18 +31,24 @@ import st2tests.fixturesloader as fixtures_loader from st2tests.fixturesloader import get_fixtures_base_path -MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name='run-local', runner_module='st2.runners.local') +MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name="run-local", runner_module="st2.runners.local") # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(get_fixtures_base_path(), 'generic'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(get_fixtures_base_path(), "generic")), +) class ActionsRegistrarTest(tests_base.DbTestCase): - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_all_actions(self): try: packs_base_path = fixtures_loader.get_fixtures_base_path() @@ -50,111 +56,157 @@ def test_register_all_actions(self): actions_registrar.register_actions(packs_base_paths=[packs_base_path]) except Exception as e: print(six.text_type(e)) - self.fail('All actions must be registered without exceptions.') + self.fail("All actions must be registered without exceptions.") else: all_actions_in_db = Action.get_all() self.assertTrue(len(all_actions_in_db) > 0) # Assert metadata_file field is populated - expected_path = 'actions/action-with-no-parameters.yaml' + expected_path = "actions/action-with-no-parameters.yaml" self.assertEqual(all_actions_in_db[0].metadata_file, expected_path) def test_register_actions_from_bad_pack(self): packs_base_path = tests_base.get_fixtures_path() try: actions_registrar.register_actions(packs_base_paths=[packs_base_path]) - self.fail('Should have thrown.') + self.fail("Should have thrown.") except: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_pack_name_missing(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_3_pack_missing.yaml') - registrar._register_action('dummy', action_file) + "generic", "actions", "action_3_pack_missing.yaml" + ) + registrar._register_action("dummy", action_file) action_name = None - with open(action_file, 'r') as fd: + with open(action_file, "r") as fd: content = yaml.safe_load(fd) - action_name = str(content['name']) + action_name = str(content["name"]) action_db = Action.get_by_name(action_name) - expected_msg = 'Content pack must be set to dummy' - self.assertEqual(action_db.pack, 'dummy', expected_msg) + expected_msg = "Content pack must be set to dummy" + self.assertEqual(action_db.pack, "dummy", expected_msg) Action.delete(action_db) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_with_no_params(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action-with-no-parameters.yaml') - - self.assertEqual(registrar._register_action('dummy', action_file), None) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action-with-no-parameters.yaml" + ) + + self.assertEqual(registrar._register_action("dummy", action_file), None) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_invalid_parameter_type_attribute(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_invalid_param_type.yaml') - - expected_msg = '\'list\' is not valid under any of the given schema' - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_action, - 'dummy', action_file) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action_invalid_param_type.yaml" + ) + + expected_msg = "'list' is not valid under any of the given schema" + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_action, + "dummy", + action_file, + ) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_invalid_parameter_name(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_invalid_parameter_name.yaml') - - expected_msg = ('Parameter name "action-name" is invalid. Valid characters for ' - 'parameter name are') - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_action, - 'generic', action_file) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action_invalid_parameter_name.yaml" + ) + + expected_msg = ( + 'Parameter name "action-name" is invalid. Valid characters for ' + "parameter name are" + ) + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_action, + "generic", + action_file, + ) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_invalid_params_schema(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action-invalid-schema-params.yaml') + "generic", "actions", "action-invalid-schema-params.yaml" + ) try: - registrar._register_action('generic', action_file) - self.fail('Invalid action schema. Should have failed.') + registrar._register_action("generic", action_file) + self.fail("Invalid action schema. Should have failed.") except jsonschema.ValidationError: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_action_update(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action1.yaml') - registrar._register_action('wolfpack', action_file) + "generic", "actions", "action1.yaml" + ) + registrar._register_action("wolfpack", action_file) # try registering again. this should not throw errors. - registrar._register_action('wolfpack', action_file) + registrar._register_action("wolfpack", action_file) action_name = None - with open(action_file, 'r') as fd: + with open(action_file, "r") as fd: content = yaml.safe_load(fd) - action_name = str(content['name']) + action_name = str(content["name"]) action_db = Action.get_by_name(action_name) - expected_msg = 'Content pack must be set to wolfpack' - self.assertEqual(action_db.pack, 'wolfpack', expected_msg) + expected_msg = "Content pack must be set to wolfpack" + self.assertEqual(action_db.pack, "wolfpack", expected_msg) Action.delete(action_db) diff --git a/st2actions/tests/unit/test_async_runner.py b/st2actions/tests/unit/test_async_runner.py index 0409202903..31258fae4e 100644 --- a/st2actions/tests/unit/test_async_runner.py +++ b/st2actions/tests/unit/test_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2actions/tests/unit/test_execution_cancellation.py b/st2actions/tests/unit/test_execution_cancellation.py index 6a130e2fe7..e6c51159ef 100644 --- a/st2actions/tests/unit/test_execution_cancellation.py +++ b/st2actions/tests/unit/test_execution_cancellation.py @@ -22,6 +22,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.constants import action as action_constants @@ -42,35 +43,32 @@ from st2tests.mocks.liveaction import MockLiveActionPublisherNonBlocking from st2tests.mocks.runners import runner -__all__ = [ - 'ExecutionCancellationTestCase' -] +__all__ = ["ExecutionCancellationTestCase"] -TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ] -} +TEST_FIXTURES = {"actions": ["action1.yaml"]} -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ExecutionCancellationTestCase(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(ExecutionCancellationTestCase, cls).setUpClass() - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) Action.add_or_update(ActionAPI.to_model(instance)) @@ -80,62 +78,84 @@ def tearDown(self): # Ensure all liveactions are canceled at end of each test. for liveaction in LiveAction.get_all(): action_service.update_status( - liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @classmethod def get_runner_class(cls, runner_name): return runners.get_runner(runner_name).__class__ @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) - @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), + ) + @mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) + ) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def test_basic_cancel(self): - runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None) + runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None) mock_runner_run = mock.Mock(return_value=runner_run_result) - with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_RUNNING + liveaction, action_constants.LIVEACTION_STATUS_RUNNING ) # Cancel execution. action_service.request_cancellation(liveaction, cfg.CONF.system_user.user) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_CANCELED + liveaction, action_constants.LIVEACTION_STATUS_CANCELED ) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create)) + CUDPublisher, + "publish_create", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create), + ) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(side_effect=Exception('Mock cancellation failure.'))) - @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + runners.ActionRunner, + "cancel", + mock.MagicMock(side_effect=Exception("Mock cancellation failure.")), + ) + @mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) + ) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def test_failed_cancel(self): - runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None) + runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None) mock_runner_run = mock.Mock(return_value=runner_run_result) - with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_RUNNING + liveaction, action_constants.LIVEACTION_STATUS_RUNNING ) # Cancel execution. @@ -144,22 +164,28 @@ def test_failed_cancel(self): # Cancellation failed and execution state remains "canceling". runners.ActionRunner.cancel.assert_called_once_with() liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING + ) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(return_value=None)) + runners.ActionRunner, "cancel", mock.MagicMock(return_value=None) + ) def test_noop_cancel(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Cancel execution. action_service.request_cancellation(liveaction, cfg.CONF.system_user.user) @@ -171,22 +197,28 @@ def test_noop_cancel(self): self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(return_value=None)) + runners.ActionRunner, "cancel", mock.MagicMock(return_value=None) + ) def test_cancel_delayed_execution(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Manually update the liveaction from requested to delayed to mock concurrency policy. - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED) @@ -200,27 +232,33 @@ def test_cancel_delayed_execution(self): self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - trace_service, 'get_trace_db_by_live_action', - mock.MagicMock(return_value=(None, None))) + trace_service, + "get_trace_db_by_live_action", + mock.MagicMock(return_value=(None, None)), + ) def test_cancel_delayed_execution_with_parent(self): liveaction = LiveActionDB( - action='wolfpack.action-1', - parameters={'actionstr': 'foo'}, - context={'parent': {'execution_id': uuid.uuid4().hex}} + action="wolfpack.action-1", + parameters={"actionstr": "foo"}, + context={"parent": {"execution_id": uuid.uuid4().hex}}, ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Manually update the liveaction from requested to delayed to mock concurrency policy. - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED) @@ -230,4 +268,6 @@ def test_cancel_delayed_execution_with_parent(self): # Cancel is only called when liveaction is still in running state. # Otherwise, the cancellation is only a state change. liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING + ) diff --git a/st2actions/tests/unit/test_executions.py b/st2actions/tests/unit/test_executions.py index f143631e42..64bde6b654 100644 --- a/st2actions/tests/unit/test_executions.py +++ b/st2actions/tests/unit/test_executions.py @@ -20,6 +20,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() import st2common.bootstrap.runnersregistrar as runners_registrar @@ -53,47 +54,57 @@ @mock.patch.object( - LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=(action_constants.LIVEACTION_STATUS_FAILED, 'Non-empty', None))) + LocalShellCommandRunner, + "run", + mock.MagicMock( + return_value=(action_constants.LIVEACTION_STATUS_FAILED, "Non-empty", None) + ), +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create)) + CUDPublisher, + "publish_create", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create), +) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), +) class TestActionExecutionHistoryWorker(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(TestActionExecutionHistoryWorker, cls).setUpClass() runners_registrar.register_runners() - action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['local'])) + action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["local"])) Action.add_or_update(ActionAPI.to_model(action_local)) - action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['chain'])) - action_chain.entry_point = fixture.PATH + '/chain.yaml' + action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"])) + action_chain.entry_point = fixture.PATH + "/chain.yaml" Action.add_or_update(ActionAPI.to_model(action_chain)) def tearDown(self): - MOCK_FAIL_EXECUTION_CREATE = False # noqa + MOCK_FAIL_EXECUTION_CREATE = False # noqa super(TestActionExecutionHistoryWorker, self).tearDown() def test_basic_execution(self): - liveaction = LiveActionDB(action='executions.local', parameters={'cmd': 'uname -a'}) + liveaction = LiveActionDB( + action="executions.local", parameters={"cmd": "uname -a"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) self.assertDictEqual(execution.trigger, {}) self.assertDictEqual(execution.trigger_type, {}) self.assertDictEqual(execution.trigger_instance, {}) self.assertDictEqual(execution.rule, {}) - action = action_utils.get_action_by_ref('executions.local') + action = action_utils.get_action_by_ref("executions.local") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -101,26 +112,27 @@ def test_basic_execution(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) def test_basic_execution_history_create_failed(self): - MOCK_FAIL_EXECUTION_CREATE = True # noqa + MOCK_FAIL_EXECUTION_CREATE = True # noqa self.test_basic_execution() def test_chained_executions(self): - liveaction = LiveActionDB(action='executions.chain') + liveaction = LiveActionDB(action="executions.chain") liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) - action = action_utils.get_action_by_ref('executions.chain') + action = action_utils.get_action_by_ref("executions.chain") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -128,56 +140,69 @@ def test_chained_executions(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) self.assertGreater(len(execution.children), 0) for child in execution.children: record = ActionExecution.get(id=child, raise_exception=True) self.assertEqual(record.parent, str(execution.id)) - self.assertEqual(record.action['name'], 'local') - self.assertEqual(record.runner['name'], 'local-shell-cmd') + self.assertEqual(record.action["name"], "local") + self.assertEqual(record.runner["name"], "local-shell-cmd") def test_triggered_execution(self): docs = { - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance'])} + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]), + } # Trigger an action execution. trigger_type = TriggerType.add_or_update( - TriggerTypeAPI.to_model(TriggerTypeAPI(**docs['trigger_type']))) - trigger = Trigger.add_or_update(TriggerAPI.to_model(TriggerAPI(**docs['trigger']))) - rule = RuleAPI.to_model(RuleAPI(**docs['rule'])) + TriggerTypeAPI.to_model(TriggerTypeAPI(**docs["trigger_type"])) + ) + trigger = Trigger.add_or_update( + TriggerAPI.to_model(TriggerAPI(**docs["trigger"])) + ) + rule = RuleAPI.to_model(RuleAPI(**docs["rule"])) rule.trigger = reference.get_str_resource_ref_from_model(trigger) rule = Rule.add_or_update(rule) trigger_instance = TriggerInstance.add_or_update( - TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs['trigger_instance']))) + TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs["trigger_instance"])) + ) trace_service.add_or_update_given_trace_context( - trace_context={'trace_tag': 'test_triggered_execution_trace'}, - trigger_instances=[str(trigger_instance.id)]) + trace_context={"trace_tag": "test_triggered_execution_trace"}, + trigger_instances=[str(trigger_instance.id)], + ) enforcer = RuleEnforcer(trigger_instance, rule) enforcer.enforce() # Wait for the action execution to complete and then confirm outcome. - liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id)) + liveaction = LiveAction.get( + context__trigger_instance__id=str(trigger_instance.id) + ) self.assertIsNotNone(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger))) - self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))) - self.assertDictEqual(execution.trigger_instance, - vars(TriggerInstanceAPI.from_model(trigger_instance))) + self.assertDictEqual( + execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)) + ) + self.assertDictEqual( + execution.trigger_instance, + vars(TriggerInstanceAPI.from_model(trigger_instance)), + ) self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule))) action = action_utils.get_action_by_ref(liveaction.action) self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -185,8 +210,8 @@ def test_triggered_execution(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) def _get_action_execution(self, **kwargs): return ActionExecution.get(**kwargs) diff --git a/st2actions/tests/unit/test_notifier.py b/st2actions/tests/unit/test_notifier.py index fa1af31ca8..b648d7fad3 100644 --- a/st2actions/tests/unit/test_notifier.py +++ b/st2actions/tests/unit/test_notifier.py @@ -20,6 +20,7 @@ import mock import st2tests.config as tests_config + tests_config.parse_args() from st2actions.notifier.notifier import Notifier @@ -41,77 +42,96 @@ from st2common.util import isotime from st2tests.base import CleanDbTestCase -ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0] -NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1] -MOCK_EXECUTION = ActionExecutionDB(id=bson.ObjectId(), result={'stdout': 'stuff happens'}) +ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0] +NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1] +MOCK_EXECUTION = ActionExecutionDB( + id=bson.ObjectId(), result={"stdout": "stuff happens"} +) class NotifierTestCase(CleanDbTestCase): - class MockDispatcher(object): def __init__(self, tester): self.tester = tester self.notify_trigger = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER_TYPE['pack'], - name=NOTIFY_TRIGGER_TYPE['name']) + pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"] + ) self.action_trigger = ResourceReference.to_string_reference( - pack=ACTION_TRIGGER_TYPE['pack'], - name=ACTION_TRIGGER_TYPE['name']) + pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"] + ) def dispatch(self, *args, **kwargs): try: self.tester.assertEqual(len(args), 1) - self.tester.assertTrue('payload' in kwargs) - payload = kwargs['payload'] + self.tester.assertTrue("payload" in kwargs) + payload = kwargs["payload"] if args[0] == self.notify_trigger: - self.tester.assertEqual(payload['status'], 'succeeded') - self.tester.assertTrue('execution_id' in payload) - self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id)) - self.tester.assertTrue('start_timestamp' in payload) - self.tester.assertTrue('end_timestamp' in payload) - self.tester.assertEqual('core.local', payload['action_ref']) - self.tester.assertEqual('Action succeeded.', payload['message']) - self.tester.assertTrue('data' in payload) - self.tester.assertTrue('local-shell-cmd', payload['runner_ref']) + self.tester.assertEqual(payload["status"], "succeeded") + self.tester.assertTrue("execution_id" in payload) + self.tester.assertEqual( + payload["execution_id"], str(MOCK_EXECUTION.id) + ) + self.tester.assertTrue("start_timestamp" in payload) + self.tester.assertTrue("end_timestamp" in payload) + self.tester.assertEqual("core.local", payload["action_ref"]) + self.tester.assertEqual("Action succeeded.", payload["message"]) + self.tester.assertTrue("data" in payload) + self.tester.assertTrue("local-shell-cmd", payload["runner_ref"]) if args[0] == self.action_trigger: - self.tester.assertEqual(payload['status'], 'succeeded') - self.tester.assertTrue('execution_id' in payload) - self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id)) - self.tester.assertTrue('start_timestamp' in payload) - self.tester.assertEqual('core.local', payload['action_name']) - self.tester.assertEqual('core.local', payload['action_ref']) - self.tester.assertTrue('result' in payload) - self.tester.assertTrue('parameters' in payload) - self.tester.assertTrue('local-shell-cmd', payload['runner_ref']) + self.tester.assertEqual(payload["status"], "succeeded") + self.tester.assertTrue("execution_id" in payload) + self.tester.assertEqual( + payload["execution_id"], str(MOCK_EXECUTION.id) + ) + self.tester.assertTrue("start_timestamp" in payload) + self.tester.assertEqual("core.local", payload["action_name"]) + self.tester.assertEqual("core.local", payload["action_ref"]) + self.tester.assertTrue("result" in payload) + self.tester.assertTrue("parameters" in payload) + self.tester.assertTrue("local-shell-cmd", payload["runner_ref"]) except Exception: - self.tester.fail('Test failed') - - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}, - parameters={}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) + self.tester.fail("Test failed") + + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", + name="local", + runner_type={"name": "local-shell-cmd"}, + parameters={}, + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) def test_notify_triggers(self): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' + liveaction_db.description = "" + liveaction_db.status = "succeeded" liveaction_db.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - liveaction_db.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + liveaction_db.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() - liveaction_db.end_timestamp = \ - (liveaction_db.start_timestamp + datetime.timedelta(seconds=50)) + liveaction_db.end_timestamp = ( + liveaction_db.start_timestamp + datetime.timedelta(seconds=50) + ) LiveAction.add_or_update(liveaction_db) execution = MOCK_EXECUTION @@ -122,26 +142,39 @@ def test_notify_triggers(self): notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher) notifier.process(execution) - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}, - parameters={}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", + name="local", + runner_type={"name": "local-shell-cmd"}, + parameters={}, + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) def test_notify_triggers_end_timestamp_none(self): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' + liveaction_db.description = "" + liveaction_db.status = "succeeded" liveaction_db.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - liveaction_db.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + liveaction_db.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() # This tests for end_timestamp being set to None, which can happen when a policy cancels @@ -159,30 +192,48 @@ def test_notify_triggers_end_timestamp_none(self): notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher) notifier.process(execution) - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={'runner_foo': 'foo'}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_post_generic_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", name="local", runner_type={"name": "local-shell-cmd"} + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock( + return_value=RunnerTypeDB( + name="foo", runner_parameters={"runner_foo": "foo"} + ) + ), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object( + Notifier, "_post_generic_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_notify_triggers_jinja_patterns(self, dispatch): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' - liveaction_db.parameters = {'cmd': 'mamma mia', 'runner_foo': 'foo'} - on_success = NotificationSubSchema(message='Command {{action_parameters.cmd}} succeeded.', - data={'stdout': '{{action_results.stdout}}'}) + liveaction_db.description = "" + liveaction_db.status = "succeeded" + liveaction_db.parameters = {"cmd": "mamma mia", "runner_foo": "foo"} + on_success = NotificationSubSchema( + message="Command {{action_parameters.cmd}} succeeded.", + data={"stdout": "{{action_results.stdout}}"}, + ) liveaction_db.notify = NotificationSchema(on_success=on_success) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() - liveaction_db.end_timestamp = \ - (liveaction_db.start_timestamp + datetime.timedelta(seconds=50)) + liveaction_db.end_timestamp = ( + liveaction_db.start_timestamp + datetime.timedelta(seconds=50) + ) LiveAction.add_or_update(liveaction_db) @@ -192,26 +243,31 @@ def test_notify_triggers_jinja_patterns(self, dispatch): notifier = Notifier(connection=None, queues=[]) notifier.process(execution) - exp = {'status': 'succeeded', - 'start_timestamp': isotime.format(liveaction_db.start_timestamp), - 'route': 'notify.default', 'runner_ref': 'local-shell-cmd', - 'channel': 'notify.default', 'message': u'Command mamma mia succeeded.', - 'data': {'result': '{}', 'stdout': 'stuff happens'}, - 'action_ref': u'core.local', - 'execution_id': str(MOCK_EXECUTION.id), - 'end_timestamp': isotime.format(liveaction_db.end_timestamp)} - dispatch.assert_called_once_with('core.st2.generic.notifytrigger', payload=exp, - trace_context={}) + exp = { + "status": "succeeded", + "start_timestamp": isotime.format(liveaction_db.start_timestamp), + "route": "notify.default", + "runner_ref": "local-shell-cmd", + "channel": "notify.default", + "message": "Command mamma mia succeeded.", + "data": {"result": "{}", "stdout": "stuff happens"}, + "action_ref": "core.local", + "execution_id": str(MOCK_EXECUTION.id), + "end_timestamp": isotime.format(liveaction_db.end_timestamp), + } + dispatch.assert_called_once_with( + "core.st2.generic.notifytrigger", payload=exp, trace_context={} + ) notifier.process(execution) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch): for status in LIVEACTION_STATUSES: - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -221,28 +277,34 @@ def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch): notifier._post_generic_trigger(liveaction_db, execution) if status in LIVEACTION_COMPLETED_STATES: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(dispatch.call_count, len(LIVEACTION_COMPLETED_STATES)) - @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock( - emit_when=['scheduled', 'pending', 'abandoned'])) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch( + "oslo_config.cfg.CONF.action_sensor", + mock.MagicMock(emit_when=["scheduled", "pending", "abandoned"]), + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_generic_trigger_with_emit_condition(self, dispatch): for status in LIVEACTION_STATUSES: - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -251,36 +313,45 @@ def test_post_generic_trigger_with_emit_condition(self, dispatch): notifier = Notifier(connection=None, queues=[]) notifier._post_generic_trigger(liveaction_db, execution) - if status in ['scheduled', 'pending', 'abandoned']: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + if status in ["scheduled", "pending", "abandoned"]: + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(dispatch.call_count, 3) - @mock.patch('oslo_config.cfg.CONF.action_sensor.enable', mock.MagicMock( - return_value=True)) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') - @mock.patch('st2actions.notifier.notifier.LiveAction') - @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock()) - def test_process_post_generic_notify_trigger_on_completed_state_default(self, - mock_LiveAction, mock_dispatch): + @mock.patch( + "oslo_config.cfg.CONF.action_sensor.enable", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") + @mock.patch("st2actions.notifier.notifier.LiveAction") + @mock.patch( + "st2actions.notifier.notifier.policy_service.apply_post_run_policies", + mock.Mock(), + ) + def test_process_post_generic_notify_trigger_on_completed_state_default( + self, mock_LiveAction, mock_dispatch + ): # Verify that generic action trigger is posted on all completed states when action sensor # is enabled for status in LIVEACTION_STATUSES: notifier = Notifier(connection=None, queues=[]) - liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local') + liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -292,35 +363,45 @@ def test_process_post_generic_notify_trigger_on_completed_state_default(self, notifier.process(execution) if status in LIVEACTION_COMPLETED_STATES: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - mock_dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + mock_dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(mock_dispatch.call_count, len(LIVEACTION_COMPLETED_STATES)) - @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock( - enable=True, emit_when=['scheduled', 'pending', 'abandoned'])) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') - @mock.patch('st2actions.notifier.notifier.LiveAction') - @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock()) - def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self, - mock_LiveAction, mock_dispatch): + @mock.patch( + "oslo_config.cfg.CONF.action_sensor", + mock.MagicMock(enable=True, emit_when=["scheduled", "pending", "abandoned"]), + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") + @mock.patch("st2actions.notifier.notifier.LiveAction") + @mock.patch( + "st2actions.notifier.notifier.policy_service.apply_post_run_policies", + mock.Mock(), + ) + def test_process_post_generic_notify_trigger_on_custom_emit_when_states( + self, mock_LiveAction, mock_dispatch + ): # Verify that generic action trigger is posted on all completed states when action sensor # is enabled for status in LIVEACTION_STATUSES: notifier = Notifier(connection=None, queues=[]) - liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local') + liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -331,15 +412,19 @@ def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self, notifier = Notifier(connection=None, queues=[]) notifier.process(execution) - if status in ['scheduled', 'pending', 'abandoned']: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - mock_dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + if status in ["scheduled", "pending", "abandoned"]: + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + mock_dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(mock_dispatch.call_count, 3) diff --git a/st2actions/tests/unit/test_parallel_ssh.py b/st2actions/tests/unit/test_parallel_ssh.py index bf8c1df87b..67052a53e0 100644 --- a/st2actions/tests/unit/test_parallel_ssh.py +++ b/st2actions/tests/unit/test_parallel_ssh.py @@ -17,13 +17,14 @@ import json import os -from mock import (patch, Mock, MagicMock) +from mock import patch, Mock, MagicMock import unittest2 from st2common.runners.parallel_ssh import ParallelSSHClient from st2common.runners.paramiko_ssh import ParamikoSSHClient from st2common.runners.paramiko_ssh import SSHCommandTimeoutError import st2tests.config as tests_config + tests_config.parse_args() MOCK_STDERR_SUDO_PASSWORD_ERROR = """ @@ -35,251 +36,294 @@ class ParallelSSHTests(unittest2.TestCase): - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_password(self): - hosts = ['localhost', '127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - password='ubuntu', - connect=False) + hosts = ["localhost", "127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", password="ubuntu", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'password': 'ubuntu', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "password": "ubuntu", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: - expected_conn['hostname'] = host - client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = host + client._hosts_client[host].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_random_ports(self): - hosts = ['localhost:22', '127.0.0.1:55', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - password='ubuntu', - connect=False) + hosts = ["localhost:22", "127.0.0.1:55", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", password="ubuntu", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'password': 'ubuntu', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "password": "ubuntu", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: hostname, port = client._get_host_port_info(host) - expected_conn['hostname'] = hostname - expected_conn['port'] = port - client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = hostname + expected_conn["port"] = port + client._hosts_client[hostname].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_key(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=False) + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'key_filename': '~/.ssh/id_rsa', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "key_filename": "~/.ssh/id_rsa", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: hostname, port = client._get_host_port_info(host) - expected_conn['hostname'] = hostname - expected_conn['port'] = port - client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = hostname + expected_conn["port"] = port + client._hosts_client[hostname].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_bastion(self): - hosts = ['localhost', '127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - bastion_host='testing_bastion_host', - connect=False) + hosts = ["localhost", "127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, + user="ubuntu", + pkey_file="~/.ssh/id_rsa", + bastion_host="testing_bastion_host", + connect=False, + ) client.connect() for host in hosts: hostname, _ = client._get_host_port_info(host) - self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host') + self.assertEqual( + client._hosts_client[hostname].bastion_host, "testing_bastion_host" + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock(return_value=('/home/ubuntu', '', 0))) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "run", MagicMock(return_value=("/home/ubuntu", "", 0)) + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_run_command(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.run('pwd', timeout=60) - expected_kwargs = { - 'timeout': 60, - 'call_line_handler_func': True - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.run("pwd", timeout=60) + expected_kwargs = {"timeout": 60, "call_line_handler_func": True} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].run.assert_called_with('pwd', **expected_kwargs) + client._hosts_client[hostname].run.assert_called_with( + "pwd", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_run_command_timeout(self): # Make sure stdout and stderr is included on timeout - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10, - stdout='a', - stderr='b', - ssh_connect_timeout=30)) + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + mock_run = Mock( + side_effect=SSHCommandTimeoutError( + cmd="pwd", timeout=10, stdout="a", stderr="b", ssh_connect_timeout=30 + ) + ) for host in hosts: hostname, _ = client._get_host_port_info(host) host_client = client._hosts_client[host] host_client.run = mock_run - results = client.run('pwd') + results = client.run("pwd") for host in hosts: result = results[host] - self.assertEqual(result['failed'], True) - self.assertEqual(result['stdout'], 'a') - self.assertEqual(result['stderr'], 'b') - self.assertEqual(result['return_code'], -9) + self.assertEqual(result["failed"], True) + self.assertEqual(result["stdout"], "a") + self.assertEqual(result["stderr"], "b") + self.assertEqual(result["return_code"], -9) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'put', MagicMock(return_value={})) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "put", MagicMock(return_value={})) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_put(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.put('/local/stuff', '/remote/stuff', mode=0o744) - expected_kwargs = { - 'mode': 0o744, - 'mirror_local_mode': False - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.put("/local/stuff", "/remote/stuff", mode=0o744) + expected_kwargs = {"mode": 0o744, "mirror_local_mode": False} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].put.assert_called_with('/local/stuff', '/remote/stuff', - **expected_kwargs) + client._hosts_client[hostname].put.assert_called_with( + "/local/stuff", "/remote/stuff", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'delete_file', MagicMock(return_value={})) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "delete_file", MagicMock(return_value={})) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_file(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.delete_file('/remote/stuff') + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.delete_file("/remote/stuff") for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff') + client._hosts_client[hostname].delete_file.assert_called_with( + "/remote/stuff" + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'delete_dir', MagicMock(return_value={})) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "delete_dir", MagicMock(return_value={})) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_dir(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.delete_dir('/remote/stuff/', force=True) - expected_kwargs = { - 'force': True, - 'timeout': None - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.delete_dir("/remote/stuff/", force=True) + expected_kwargs = {"force": True, "timeout": None} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].delete_dir.assert_called_with('/remote/stuff/', - **expected_kwargs) + client._hosts_client[hostname].delete_dir.assert_called_with( + "/remote/stuff/", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_host_port_info(self): - client = ParallelSSHClient(hosts=['dummy'], - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) + client = ParallelSSHClient( + hosts=["dummy"], user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) # No port case. Port should be 22. - host_str = '1.2.3.4' + host_str = "1.2.3.4" host, port = client._get_host_port_info(host_str) self.assertEqual(host, host_str) self.assertEqual(port, 22) # IPv6 with square brackets with port specified. - host_str = '[fec2::10]:55' + host_str = "[fec2::10]:55" host, port = client._get_host_port_info(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, 55) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock( - return_value=(json.dumps({'foo': 'bar'}), '', 0)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "run", + MagicMock(return_value=(json.dumps({"foo": "bar"}), "", 0)), + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), ) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) def test_run_command_json_output_transformed_to_object(self): - hosts = ['127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - results = client.run('stuff', timeout=60) - self.assertIn('127.0.0.1', results) - self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'}) + hosts = ["127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + results = client.run("stuff", timeout=60) + self.assertIn("127.0.0.1", results) + self.assertDictEqual(results["127.0.0.1"]["stdout"], {"foo": "bar"}) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock( - return_value=('', MOCK_STDERR_SUDO_PASSWORD_ERROR, 0)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "run", + MagicMock(return_value=("", MOCK_STDERR_SUDO_PASSWORD_ERROR, 0)), + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), ) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) def test_run_sudo_password_user_friendly_error(self): - hosts = ['127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True, - sudo_password=True) - results = client.run('stuff', timeout=60) + hosts = ["127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, + user="ubuntu", + pkey_file="~/.ssh/id_rsa", + connect=True, + sudo_password=True, + ) + results = client.run("stuff", timeout=60) - expected_error = ('Failed executing command "stuff" on host "127.0.0.1" ' - 'Invalid sudo password provided or sudo is not configured for ' - 'this user (bar)') + expected_error = ( + 'Failed executing command "stuff" on host "127.0.0.1" ' + "Invalid sudo password provided or sudo is not configured for " + "this user (bar)" + ) - self.assertIn('127.0.0.1', results) - self.assertEqual(results['127.0.0.1']['succeeded'], False) - self.assertEqual(results['127.0.0.1']['failed'], True) - self.assertIn(expected_error, results['127.0.0.1']['error']) + self.assertIn("127.0.0.1", results) + self.assertEqual(results["127.0.0.1"]["succeeded"], False) + self.assertEqual(results["127.0.0.1"]["failed"], True) + self.assertIn(expected_error, results["127.0.0.1"]["error"]) diff --git a/st2actions/tests/unit/test_paramiko_remote_script_runner.py b/st2actions/tests/unit/test_paramiko_remote_script_runner.py index 1246f1cbe2..1bf67a9503 100644 --- a/st2actions/tests/unit/test_paramiko_remote_script_runner.py +++ b/st2actions/tests/unit/test_paramiko_remote_script_runner.py @@ -21,6 +21,7 @@ # XXX: There is an import dependency. Config needs to setup # before importing remote_script_runner classes. import st2tests.config as tests_config + tests_config.parse_args() from st2common.util import jsonify @@ -35,234 +36,254 @@ from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'ParamikoScriptRunnerTestCase' -] +__all__ = ["ParamikoScriptRunnerTestCase"] -FIXTURES_PACK = 'generic' -TEST_MODELS = { - 'actions': ['a1.yaml'] -} +FIXTURES_PACK = "generic" +TEST_MODELS = {"actions": ["a1.yaml"]} -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] class ParamikoScriptRunnerTestCase(unittest2.TestCase): - @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock) - @patch.object(jsonify, 'json_loads', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={})) + @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock) + @patch.object(jsonify, "json_loads", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "run", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={})) def test_cwd_used_correctly(self): remote_action = ParamikoRemoteScriptAction( - 'foo-script', bson.ObjectId(), - script_local_path_abs='/home/stanley/shiz_storm.py', + "foo-script", + bson.ObjectId(), + script_local_path_abs="/home/stanley/shiz_storm.py", script_local_libs_path_abs=None, - named_args={}, positional_args=['blank space'], env_vars={}, - on_behalf_user='svetlana', user='stanley', - private_key='---SOME RSA KEY---', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args={}, + positional_args=["blank space"], + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + private_key="---SOME RSA KEY---", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", + ) + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") + paramiko_runner._parallel_ssh_client = ParallelSSHClient( + ["127.0.0.1"], "stanley" ) - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') - paramiko_runner._parallel_ssh_client = ParallelSSHClient(['127.0.0.1'], 'stanley') paramiko_runner._run_script_on_remote_host(remote_action) exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'" - ParallelSSHClient.run.assert_called_with(exp_cmd, - timeout=None) + ParallelSSHClient.run.assert_called_with(exp_cmd, timeout=None) def test_username_invalid_private_key(self): - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") paramiko_runner.runner_parameters = { - 'username': 'test_user', - 'hosts': '127.0.0.1', - 'private_key': 'invalid private key', + "username": "test_user", + "hosts": "127.0.0.1", + "private_key": "invalid private key", } paramiko_runner.context = {} self.assertRaises(NoHostsConnectedToException, paramiko_runner.pre_run) - @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock) - @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={})) + @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock) + @patch.object(ParallelSSHClient, "run", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={})) def test_top_level_error_is_correctly_reported(self): # Verify that a top-level error doesn't cause an exception to be thrown. # In a top-level error case, result dict doesn't contain entry per host - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") paramiko_runner.runner_parameters = { - 'username': 'test_user', - 'hosts': '127.0.0.1' + "username": "test_user", + "hosts": "127.0.0.1", } paramiko_runner.action = ACTION_1 - paramiko_runner.liveaction_id = 'foo' - paramiko_runner.entry_point = 'foo' + paramiko_runner.liveaction_id = "foo" + paramiko_runner.entry_point = "foo" paramiko_runner.context = {} - paramiko_runner._cwd = '/tmp' - paramiko_runner._copy_artifacts = Mock(side_effect=Exception('fail!')) + paramiko_runner._cwd = "/tmp" + paramiko_runner._copy_artifacts = Mock(side_effect=Exception("fail!")) status, result, _ = paramiko_runner.run(action_parameters={}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(result['failed'], True) - self.assertEqual(result['succeeded'], False) - self.assertIn('Failed copying content to remote boxes', result['error']) + self.assertEqual(result["failed"], True) + self.assertEqual(result["succeeded"], False) + self.assertIn("Failed copying content to remote boxes", result["error"]) def test_command_construction_correct_default_parameter_values_are_used(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, - }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} - action_db = ActionDB(pack='dummy', name='action') + action_db = ActionDB(pack="dummy", name="action") - runner = ParamikoRemoteScriptRunner('id') + runner = ParamikoRemoteScriptRunner("id") runner.runner_parameters = {} runner.action = action_db # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo' + expected = "cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo" self.assertEqual(command_string, expected) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob' + expected = "cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob" self.assertEqual(command_string, expected) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc' + expected = "cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc" self.assertEqual(command_string, expected) diff --git a/st2actions/tests/unit/test_paramiko_ssh.py b/st2actions/tests/unit/test_paramiko_ssh.py index 7335f11a7e..eadc4a477a 100644 --- a/st2actions/tests/unit/test_paramiko_ssh.py +++ b/st2actions/tests/unit/test_paramiko_ssh.py @@ -28,363 +28,456 @@ from st2common.runners.paramiko_ssh import ParamikoSSHClient from st2tests.fixturesloader import get_resources_base_path import st2tests.config as tests_config + tests_config.parse_args() -__all__ = [ - 'ParamikoSSHClientTestCase' -] +__all__ = ["ParamikoSSHClientTestCase"] class ParamikoSSHClientTestCase(unittest2.TestCase): - - @patch('paramiko.SSHClient', Mock) + @patch("paramiko.SSHClient", Mock) def setUp(self): """ Creates the object patching the actual connection. """ - cfg.CONF.set_override(name='ssh_key_file', override=None, group='system_user') - cfg.CONF.set_override(name='use_ssh_config', override=False, group='ssh_runner') - cfg.CONF.set_override(name='ssh_connect_timeout', override=30, group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org', - 'port': 8822, - 'username': 'ubuntu', - 'key_files': '~/.ssh/ubuntu_ssh', - 'timeout': 30} + cfg.CONF.set_override(name="ssh_key_file", override=None, group="system_user") + cfg.CONF.set_override(name="use_ssh_config", override=False, group="ssh_runner") + cfg.CONF.set_override( + name="ssh_connect_timeout", override=30, group="ssh_runner" + ) + + conn_params = { + "hostname": "dummy.host.org", + "port": 8822, + "username": "ubuntu", + "key_files": "~/.ssh/ubuntu_ssh", + "timeout": 30, + } self.ssh_cli = ParamikoSSHClient(**conn_params) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch('paramiko.ProxyCommand') + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch("paramiko.ProxyCommand") def test_set_proxycommand(self, mock_ProxyCommand): """ Loads proxy commands from ssh config file """ - ssh_config_file_path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_ssh_config') - cfg.CONF.set_override(name='ssh_config_file_path', - override=ssh_config_file_path, - group='ssh_runner') - cfg.CONF.set_override(name='use_ssh_config', override=True, - group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org', 'username': 'ubuntu', 'password': 'foo'} + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "dummy_ssh_config" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") + + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "foo", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - mock_ProxyCommand.assert_called_once_with('ssh -q -W dummy.host.org:22 dummy_bastion') - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch('paramiko.ProxyCommand') + mock_ProxyCommand.assert_called_once_with( + "ssh -q -W dummy.host.org:22 dummy_bastion" + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch("paramiko.ProxyCommand") def test_fail_set_proxycommand(self, mock_ProxyCommand): """ Loads proxy commands from ssh config file """ - ssh_config_file_path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_ssh_config_fail') - cfg.CONF.set_override(name='ssh_config_file_path', - override=ssh_config_file_path, - group='ssh_runner') - cfg.CONF.set_override(name='use_ssh_config', - override=True, group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org'} + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "dummy_ssh_config_fail" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") + + conn_params = {"hostname": "dummy.host.org"} mock = ParamikoSSHClient(**conn_params) self.assertRaises(Exception, mock.connect) mock_ProxyCommand.assert_not_called() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_password(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'password': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "password": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_deprecated_key_argument(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) def test_key_files_and_key_material_arguments_are_mutual_exclusive(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa', - 'key_material': 'key'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + "key_material": "key", + } - expected_msg = ('key_files and key_material arguments are mutually exclusive. ' - 'Supply only one.') + expected_msg = ( + "key_files and key_material arguments are mutually exclusive. " + "Supply only one." + ) client = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(ValueError, expected_msg, - client.connect) + self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_argument(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + } mock = ParamikoSSHClient(**conn_params) mock.connect() pkey = paramiko.RSAKey.from_private_key(StringIO(private_key)) - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'pkey': pkey, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "pkey": pkey, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_argument_invalid_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) - expected_msg = 'Invalid or unsupported key type' - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) + expected_msg = "Invalid or unsupported key type" + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_passphrase_no_key_provided(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "passphrase": "testphrase", + } - expected_msg = 'passphrase should accompany private key material' + expected_msg = "passphrase should accompany private key material" client = ParamikoSSHClient(**conn_params) self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) + @patch("paramiko.SSHClient", Mock) def test_passphrase_not_provided_for_encrypted_key_file(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': path} + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": path, + } mock = ParamikoSSHClient(**conn_params) - self.assertRaises(paramiko.ssh_exception.PasswordRequiredException, mock.connect) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + self.assertRaises( + paramiko.ssh_exception.PasswordRequiredException, mock.connect + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_key_with_passphrase_success(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() # Key material provided - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key, - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + "passphrase": "testphrase", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), 'testphrase') - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'pkey': pkey, - 'timeout': 30, - 'port': 22} + pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), "testphrase") + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "pkey": pkey, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) # Path to private key file provided - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': path, - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": path, + "passphrase": "testphrase", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': path, - 'password': 'testphrase', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": path, + "password": "testphrase", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_passphrase_and_no_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "passphrase": "testphrase", + } - expected_msg = 'passphrase should accompany private key material' + expected_msg = "passphrase should accompany private key material" client = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(ValueError, expected_msg, - client.connect) + self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_incorrect_passphrase(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key, - 'passphrase': 'incorrect'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + "passphrase": "incorrect", + } mock = ParamikoSSHClient(**conn_params) - expected_msg = 'Invalid passphrase or invalid/unsupported key type' - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + expected_msg = "Invalid passphrase or invalid/unsupported key type" + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_contains_path_not_contents(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} - key_materials = [ - '~/.ssh/id_rsa', - '/tmp/id_rsa', - 'C:\\id_rsa' - ] + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} + key_materials = ["~/.ssh/id_rsa", "/tmp/id_rsa", "C:\\id_rsa"] - expected_msg = ('"private_key" parameter needs to contain private key data / content and ' - 'not a path') + expected_msg = ( + '"private_key" parameter needs to contain private key data / content and ' + "not a path" + ) for key_material in key_materials: conn_params = conn_params.copy() - conn_params['key_material'] = key_material + conn_params["key_material"] = key_material mock = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_key_via_bastion(self): - conn_params = {'hostname': 'dummy.host.org', - 'bastion_host': 'bastion.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "bastion_host": "bastion.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_bastion_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'bastion.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_bastion_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "bastion.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.bastion_client.connect.assert_called_once_with(**expected_bastion_conn) - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22, - 'sock': mock.bastion_socket} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + "sock": mock.bastion_socket, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_password_and_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'password': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "password": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_without_credentials(self): """ Initialize object with no credentials. @@ -394,44 +487,54 @@ def test_create_without_credentials(self): the final parameters at the last moment when we explicitly try to connect, all the credentials should be set to None. """ - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) self.assertEqual(mock.password, None) self.assertEqual(mock.key_material, None) self.assertEqual(mock.key_files, None) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_without_credentials_use_default_key(self): # No credentials are provided by default stanley ssh key exists so it should use that - cfg.CONF.set_override(name='ssh_key_file', override='stanley_rsa', group='system_user') + cfg.CONF.set_override( + name="ssh_key_file", override="stanley_rsa", group="system_user" + ) - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'hostname': 'dummy.host.org', - 'key_filename': 'stanley_rsa', - 'allow_agent': False, - 'look_for_keys': False, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "hostname": "dummy.host.org", + "key_filename": "stanley_rsa", + "allow_agent": False, + "look_for_keys": False, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_consume_stdout', - MagicMock(return_value=StringIO(''))) - @patch.object(ParamikoSSHClient, '_consume_stderr', - MagicMock(return_value=StringIO(''))) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO("")) + ) + @patch.object( + ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO("")) + ) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_basic_usage_absolute_path(self): """ Basic execution. @@ -443,13 +546,15 @@ def test_basic_usage_absolute_path(self): # Connect behavior mock.connect() mock_cli = mock.client # The actual mocked object: SSHClient - expected_conn = {'username': 'ubuntu', - 'key_filename': '~/.ssh/ubuntu_ssh', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'timeout': 28, - 'port': 8822} + expected_conn = { + "username": "ubuntu", + "key_filename": "~/.ssh/ubuntu_ssh", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "timeout": 28, + "port": 8822, + } mock_cli.connect.assert_called_once_with(**expected_conn) mock.put(sd, sd, mirror_local_mode=False) @@ -458,21 +563,23 @@ def test_basic_usage_absolute_path(self): mock.run(sd) # Make assertions over 'run' method - mock_cli.get_transport().open_session().exec_command \ - .assert_called_once_with(sd) + mock_cli.get_transport().open_session().exec_command.assert_called_once_with(sd) mock.close() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_script(self): """ Provide a basic test with 'delete' action. """ mock = self.ssh_cli # script to execute - sd = '/root/random_script.sh' + sd = "/root/random_script.sh" mock.connect() @@ -482,91 +589,110 @@ def test_delete_script(self): mock.close() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch.object(ParamikoSSHClient, 'exists', return_value=False) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch.object(ParamikoSSHClient, "exists", return_value=False) def test_put_dir(self, *args): mock = self.ssh_cli mock.connect() - local_dir = os.path.join(get_resources_base_path(), 'packs') - mock.put_dir(local_path=local_dir, remote_path='/tmp') + local_dir = os.path.join(get_resources_base_path(), "packs") + mock.put_dir(local_path=local_dir, remote_path="/tmp") mock_cli = mock.client # The actual mocked object: SSHClient # Assert that expected dirs are created on remote box. - calls = [call('/tmp/packs/pythonactions'), call('/tmp/packs/pythonactions/actions')] + calls = [ + call("/tmp/packs/pythonactions"), + call("/tmp/packs/pythonactions/actions"), + ] mock_cli.open_sftp().mkdir.assert_has_calls(calls, any_order=True) # Assert that expected files are copied to remote box. - local_file = os.path.join(get_resources_base_path(), - 'packs/pythonactions/actions/pascal_row.py') - remote_file = os.path.join('/tmp', 'packs/pythonactions/actions/pascal_row.py') + local_file = os.path.join( + get_resources_base_path(), "packs/pythonactions/actions/pascal_row.py" + ) + remote_file = os.path.join("/tmp", "packs/pythonactions/actions/pascal_row.py") calls = [call(local_file, remote_file)] mock_cli.open_sftp().put.assert_has_calls(calls, any_order=True) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_consume_stdout(self): # Test utf-8 decoding of ``stdout`` still works fine when reading CHUNK_SIZE splits a # multi-byte utf-8 character in the middle. We should wait to collect all bytes # and finally decode. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.CHUNK_SIZE = 1 chan = Mock() chan.recv_ready.side_effect = [True, True, True, True, False] - chan.recv.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88'] + chan.recv.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"] try: - b'\xF0'.decode('utf-8') - self.fail('Test fixture is not right.') + b"\xF0".decode("utf-8") + self.fail("Test fixture is not right.") except UnicodeDecodeError: pass stdout = mock._consume_stdout(chan) - self.assertEqual(u'\U00010348', stdout.getvalue()) + self.assertEqual("\U00010348", stdout.getvalue()) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_consume_stderr(self): # Test utf-8 decoding of ``stderr`` still works fine when reading CHUNK_SIZE splits a # multi-byte utf-8 character in the middle. We should wait to collect all bytes # and finally decode. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.CHUNK_SIZE = 1 chan = Mock() chan.recv_stderr_ready.side_effect = [True, True, True, True, False] - chan.recv_stderr.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88'] + chan.recv_stderr.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"] try: - b'\xF0'.decode('utf-8') - self.fail('Test fixture is not right.') + b"\xF0".decode("utf-8") + self.fail("Test fixture is not right.") except UnicodeDecodeError: pass stderr = mock._consume_stderr(chan) - self.assertEqual(u'\U00010348', stderr.getvalue()) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_consume_stdout', - MagicMock(return_value=StringIO(''))) - @patch.object(ParamikoSSHClient, '_consume_stderr', - MagicMock(return_value=StringIO(''))) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + self.assertEqual("\U00010348", stderr.getvalue()) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO("")) + ) + @patch.object( + ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO("")) + ) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_sftp_connection_is_only_established_if_required(self): # Verify that SFTP connection is lazily established only if and when needed. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', 'password': 'ubuntu'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + } # Verify sftp connection and client hasn't been established yet client = ParamikoSSHClient(**conn_params) @@ -577,7 +703,7 @@ def test_sftp_connection_is_only_established_if_required(self): # run method doesn't require sftp access so it shouldn't establish connection client = ParamikoSSHClient(**conn_params) client.connect() - client.run(cmd='whoami') + client.run(cmd="whoami") self.assertIsNone(client.sftp_client) @@ -585,7 +711,7 @@ def test_sftp_connection_is_only_established_if_required(self): # put client = ParamikoSSHClient(**conn_params) client.connect() - path = '/root/random_script.sh' + path = "/root/random_script.sh" client.put(path, path, mirror_local_mode=False) self.assertIsNotNone(client.sftp_client) @@ -593,14 +719,14 @@ def test_sftp_connection_is_only_established_if_required(self): # exists client = ParamikoSSHClient(**conn_params) client.connect() - client.exists('/root/somepath.txt') + client.exists("/root/somepath.txt") self.assertIsNotNone(client.sftp_client) # mkdir client = ParamikoSSHClient(**conn_params) client.connect() - client.mkdir('/root/somedirfoo') + client.mkdir("/root/somedirfoo") self.assertIsNotNone(client.sftp_client) @@ -614,26 +740,26 @@ def test_sftp_connection_is_only_established_if_required(self): # Verify SFTP connection is closed if it's opened client = ParamikoSSHClient(**conn_params) client.connect() - client.mkdir('/root/somedirfoo') + client.mkdir("/root/somedirfoo") self.assertIsNotNone(client.sftp_client) client.close() self.assertEqual(client.sftp_client.close.call_count, 1) - @patch('paramiko.SSHClient', Mock) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) + @patch("paramiko.SSHClient", Mock) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) def test_handle_stdout_and_stderr_line_funcs(self): mock_handle_stdout_line_func = mock.Mock() mock_handle_stderr_line_func = mock.Mock() conn_params = { - 'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu', - 'handle_stdout_line_func': mock_handle_stdout_line_func, - 'handle_stderr_line_func': mock_handle_stderr_line_func + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + "handle_stdout_line_func": mock_handle_stdout_line_func, + "handle_stderr_line_func": mock_handle_stderr_line_func, } client = ParamikoSSHClient(**conn_params) client.connect() @@ -654,6 +780,7 @@ def mock_recv_ready(): return True return False + return mock_recv_ready def mock_recv_stderr_ready_factory(chan): @@ -665,12 +792,13 @@ def mock_recv_stderr_ready(): return True return False + return mock_recv_stderr_ready mock_chan.recv_ready = mock_recv_ready_factory(mock_chan) mock_chan.recv_stderr_ready = mock_recv_stderr_ready_factory(mock_chan) - mock_chan.recv.return_value = 'stdout 1\nstdout 2\nstdout 3' - mock_chan.recv_stderr.return_value = 'stderr 1\nstderr 2\nstderr 3' + mock_chan.recv.return_value = "stdout 1\nstdout 2\nstdout 3" + mock_chan.recv_stderr.return_value = "stderr 1\nstderr 2\nstderr 3" # call_line_handler_func is False so handler functions shouldn't be called client.run(cmd='echo "test"', call_line_handler_func=False) @@ -686,132 +814,176 @@ def mock_recv_stderr_ready(): client.run(cmd='echo "test"', call_line_handler_func=True) self.assertEqual(mock_handle_stdout_line_func.call_count, 3) - self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n') + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n" + ) self.assertEqual(mock_handle_stderr_line_func.call_count, 3) - self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n') - - @patch('paramiko.SSHClient') + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n" + ) + + @patch("paramiko.SSHClient") def test_use_ssh_config_port_value_provided_in_the_config(self, mock_sshclient): - cfg.CONF.set_override(name='use_ssh_config', override=True, group='ssh_runner') + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', 'empty_config') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "empty_config" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) # 1. Default port is used (not explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], 22) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': None, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": None, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], 22) # 2. Default port is used (explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': DEFAULT_SSH_PORT, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": DEFAULT_SSH_PORT, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], DEFAULT_SSH_PORT) - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], DEFAULT_SSH_PORT) + self.assertEqual(call_kwargs["port"], 22) # 3. Custom port is used (explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': 5555, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": 5555, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 5555) + self.assertEqual(call_kwargs["port"], 5555) # 4. Custom port is specified in the ssh config (it has precedence over default port) - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', - 'ssh_config_custom_port') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "ssh_config_custom_port" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 6677) + self.assertEqual(call_kwargs["port"], 6677) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': DEFAULT_SSH_PORT} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": DEFAULT_SSH_PORT, + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 6677) + self.assertEqual(call_kwargs["port"], 6677) # 5. Custom port is specified in ssh config, but one is also provided via runner parameter # (runner parameter one has precedence) - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', - 'ssh_config_custom_port') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "ssh_config_custom_port" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': 9999} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": 9999, + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 9999) + self.assertEqual(call_kwargs["port"], 9999) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_socket_closed(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) # Make sure .close() doesn't actually call anything real @@ -840,13 +1012,18 @@ def test_socket_closed(self): self.assertEqual(ssh_client.bastion_socket.close.call_count, 1) self.assertEqual(ssh_client.bastion_client.close.call_count, 1) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_socket_not_closed_if_none(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) # Make sure .close() doesn't actually call anything real diff --git a/st2actions/tests/unit/test_paramiko_ssh_runner.py b/st2actions/tests/unit/test_paramiko_ssh_runner.py index 8467264c5b..f42746d602 100644 --- a/st2actions/tests/unit/test_paramiko_ssh_runner.py +++ b/st2actions/tests/unit/test_paramiko_ssh_runner.py @@ -29,6 +29,7 @@ import st2tests.config as tests_config from st2tests.fixturesloader import get_resources_base_path + tests_config.parse_args() @@ -38,195 +39,192 @@ def run(self): class ParamikoSSHRunnerTestCase(unittest2.TestCase): - @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient') + @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient") def test_pre_run(self, mock_client): # Test case which verifies that ParamikoSSHClient is instantiated with the correct arguments - private_key_path = os.path.join(get_resources_base_path(), 'ssh', 'dummy_rsa') + private_key_path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa") - with open(private_key_path, 'r') as fp: + with open(private_key_path, "r") as fp: private_key = fp.read() # Username and password provided - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser1', - RUNNER_PASSWORD: 'somepassword' + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser1", + RUNNER_PASSWORD: "somepassword", } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser1', - 'password': 'somepassword', - 'port': None, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser1", + "password": "somepassword", + "port": None, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as raw key material - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser2', + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser2", RUNNER_PRIVATE_KEY: private_key, - RUNNER_SSH_PORT: 22 + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser2', - 'pkey_material': private_key, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser2", + "pkey_material": private_key, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as raw key material + passphrase - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost21', - RUNNER_USERNAME: 'someuser21', + RUNNER_HOSTS: "localhost21", + RUNNER_USERNAME: "someuser21", RUNNER_PRIVATE_KEY: private_key, - RUNNER_PASSPHRASE: 'passphrase21', - RUNNER_SSH_PORT: 22 + RUNNER_PASSPHRASE: "passphrase21", + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost21'], - 'user': 'someuser21', - 'pkey_material': private_key, - 'passphrase': 'passphrase21', - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost21"], + "user": "someuser21", + "pkey_material": private_key, + "passphrase": "passphrase21", + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as path to the private key file - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser3', + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser3", RUNNER_PRIVATE_KEY: private_key_path, - RUNNER_SSH_PORT: 22 + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser3', - 'pkey_file': private_key_path, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser3", + "pkey_file": private_key_path, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as path to the private key file + passphrase - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost31', - RUNNER_USERNAME: 'someuser31', + RUNNER_HOSTS: "localhost31", + RUNNER_USERNAME: "someuser31", RUNNER_PRIVATE_KEY: private_key_path, - RUNNER_PASSPHRASE: 'passphrase31', - RUNNER_SSH_PORT: 22 + RUNNER_PASSPHRASE: "passphrase31", + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost31'], - 'user': 'someuser31', - 'pkey_file': private_key_path, - 'passphrase': 'passphrase31', - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost31"], + "user": "someuser31", + "pkey_file": private_key_path, + "passphrase": "passphrase31", + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # No password or private key provided, should default to system user private key - runner = Runner('id') + runner = Runner("id") runner.context = {} - runner_parameters = { - RUNNER_HOSTS: 'localhost4', - RUNNER_SSH_PORT: 22 - } + runner_parameters = {RUNNER_HOSTS: "localhost4", RUNNER_SSH_PORT: 22} runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost4'], - 'user': None, - 'pkey_file': None, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost4"], + "user": None, + "pkey_file": None, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) - @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient') + @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient") def test_post_run(self, mock_client): # Verify that the SSH connections are closed on post_run - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser1', - RUNNER_PASSWORD: 'somepassword' + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser1", + RUNNER_PASSWORD: "somepassword", } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser1', - 'password': 'somepassword', - 'port': None, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser1", + "password": "somepassword", + "port": None, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) self.assertEqual(runner._parallel_ssh_client.close.call_count, 0) diff --git a/st2actions/tests/unit/test_policies.py b/st2actions/tests/unit/test_policies.py index 4be7af59e5..f16ffbcb0b 100644 --- a/st2actions/tests/unit/test_policies.py +++ b/st2actions/tests/unit/test_policies.py @@ -37,37 +37,34 @@ TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ], - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "actions": ["action1.yaml"], + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) -@mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), +) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) class SchedulingPolicyTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(SchedulingPolicyTest, cls).setUpClass() @@ -75,15 +72,15 @@ def setUpClass(cls): # Register runners runners_registrar.register_runners() - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) Action.add_or_update(ActionAPI.to_model(instance)) - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) - for _, fixture in six.iteritems(FIXTURES['policies']): + for _, fixture in six.iteritems(FIXTURES["policies"]): instance = PolicyAPI(**fixture) Policy.add_or_update(PolicyAPI.to_model(instance)) @@ -91,35 +88,54 @@ def tearDown(self): # Ensure all liveactions are canceled at end of each test. for liveaction in LiveAction.get_all(): action_service.update_status( - liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @mock.patch.object( - FakeConcurrencyApplicator, 'apply_before', + FakeConcurrencyApplicator, + "apply_before", mock.MagicMock( - side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before)) + side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before + ), + ) @mock.patch.object( - RaiseExceptionApplicator, 'apply_before', - mock.MagicMock( - side_effect=RaiseExceptionApplicator(None, None).apply_before)) + RaiseExceptionApplicator, + "apply_before", + mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_before), + ) @mock.patch.object( - FakeConcurrencyApplicator, 'apply_after', + FakeConcurrencyApplicator, + "apply_after", mock.MagicMock( - side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after)) + side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after + ), + ) @mock.patch.object( - RaiseExceptionApplicator, 'apply_after', - mock.MagicMock( - side_effect=RaiseExceptionApplicator(None, None).apply_after)) + RaiseExceptionApplicator, + "apply_after", + mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_after), + ) def test_apply(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) FakeConcurrencyApplicator.apply_before.assert_called_once_with(liveaction) RaiseExceptionApplicator.apply_before.assert_called_once_with(liveaction) FakeConcurrencyApplicator.apply_after.assert_called_once_with(liveaction) RaiseExceptionApplicator.apply_after.assert_called_once_with(liveaction) - @mock.patch.object(FakeConcurrencyApplicator, 'get_threshold', mock.MagicMock(return_value=0)) + @mock.patch.object( + FakeConcurrencyApplicator, "get_threshold", mock.MagicMock(return_value=0) + ) def test_enforce(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) diff --git a/st2actions/tests/unit/test_polling_async_runner.py b/st2actions/tests/unit/test_polling_async_runner.py index 435f7eb9b6..c48bb9aa67 100644 --- a/st2actions/tests/unit/test_polling_async_runner.py +++ b/st2actions/tests/unit/test_polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2actions/tests/unit/test_queue_consumers.py b/st2actions/tests/unit/test_queue_consumers.py index 1550a82e22..80d3a09c26 100644 --- a/st2actions/tests/unit/test_queue_consumers.py +++ b/st2actions/tests/unit/test_queue_consumers.py @@ -18,6 +18,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() import mock @@ -39,16 +40,13 @@ from st2tests.base import ExecutionDbTestCase -PACKS = [ - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core"] -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -@mock.patch.object(executions, 'update_execution', mock.MagicMock()) -@mock.patch.object(Message, 'ack', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +@mock.patch.object(executions, "update_execution", mock.MagicMock()) +@mock.patch.object(Message, "ack", mock.MagicMock()) class QueueConsumerTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(QueueConsumerTest, cls).setUpClass() @@ -58,8 +56,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -71,14 +68,16 @@ def __init__(self, *args, **kwargs): self.scheduling_queue = scheduling_queue.get_handler() self.dispatcher = worker.get_worker() - def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED): - action_db = action_utils.get_action_by_ref('core.noop') + def _create_liveaction_db( + self, status=action_constants.LIVEACTION_STATUS_REQUESTED + ): + action_db = action_utils.get_action_by_ref("core.noop") liveaction_db = LiveActionDB( action=action_db.ref, parameters=None, start_timestamp=date_utils.get_datetime_utc_now(), - status=status + status=status, ) liveaction_db = action.LiveAction.add_or_update(liveaction_db, publish=False) @@ -91,15 +90,16 @@ def _process_request(self, liveaction_db): queued_request = self.scheduling_queue._get_next_execution() self.scheduling_queue._handle_execution(queued_request) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value={'key': 'value'})) + @mock.patch.object( + RunnerContainer, "dispatch", mock.MagicMock(return_value={"key": "value"}) + ) def test_execute(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.assertDictEqual(scheduled_liveaction_db.runner_info, {}) @@ -107,54 +107,56 @@ def test_execute(self): dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) self.assertGreater(len(list(dispatched_liveaction_db.runner_info.keys())), 0) self.assertEqual( - dispatched_liveaction_db.status, - action_constants.LIVEACTION_STATUS_RUNNING + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_RUNNING ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(side_effect=Exception('Boom!'))) + @mock.patch.object( + RunnerContainer, "dispatch", mock.MagicMock(side_effect=Exception("Boom!")) + ) def test_execute_failure(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db) dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) - self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None)) + @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None)) def test_execute_no_result(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db) dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) - self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None)) + @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None)) def test_execute_cancelation(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_CANCELED, - liveaction_id=liveaction_db.id + liveaction_id=liveaction_db.id, ) canceled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) @@ -162,11 +164,10 @@ def test_execute_cancelation(self): dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) self.assertEqual( - dispatched_liveaction_db.status, - action_constants.LIVEACTION_STATUS_CANCELED + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_CANCELED ) self.assertDictEqual( dispatched_liveaction_db.result, - {'message': 'Action execution canceled by user.'} + {"message": "Action execution canceled by user."}, ) diff --git a/st2actions/tests/unit/test_remote_runners.py b/st2actions/tests/unit/test_remote_runners.py index 26d75cb5dc..7f84165dbb 100644 --- a/st2actions/tests/unit/test_remote_runners.py +++ b/st2actions/tests/unit/test_remote_runners.py @@ -16,6 +16,7 @@ # XXX: FabricRunner import depends on config being setup. from __future__ import absolute_import import st2tests.config as tests_config + tests_config.parse_args() from unittest2 import TestCase @@ -26,12 +27,20 @@ class RemoteScriptActionTestCase(TestCase): def test_parameter_formatting(self): # Only named args - named_args = {'--foo1': 'bar1', '--foo2': 'bar2', '--foo3': True, - '--foo4': False} + named_args = { + "--foo1": "bar1", + "--foo2": "bar2", + "--foo3": True, + "--foo4": False, + } - action = RemoteScriptAction(name='foo', action_exec_id='dummy', - script_local_path_abs='test.py', - script_local_libs_path_abs='/', - remote_dir='/tmp', - named_args=named_args, positional_args=None) - self.assertEqual(action.command, '/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3') + action = RemoteScriptAction( + name="foo", + action_exec_id="dummy", + script_local_path_abs="test.py", + script_local_libs_path_abs="/", + remote_dir="/tmp", + named_args=named_args, + positional_args=None, + ) + self.assertEqual(action.command, "/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3") diff --git a/st2actions/tests/unit/test_runner_container.py b/st2actions/tests/unit/test_runner_container.py index 3ccfb7a4ea..f17eeceb71 100644 --- a/st2actions/tests/unit/test_runner_container.py +++ b/st2actions/tests/unit/test_runner_container.py @@ -21,7 +21,10 @@ from st2common.constants import action as action_constants from st2common.runners.base import get_runner -from st2common.exceptions.actionrunner import ActionRunnerCreateError, ActionRunnerDispatchError +from st2common.exceptions.actionrunner import ( + ActionRunnerCreateError, + ActionRunnerDispatchError, +) from st2common.models.system.common import ResourceReference from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.runner import RunnerTypeDB @@ -34,6 +37,7 @@ from st2tests.base import DbTestCase import st2tests.config as tests_config + tests_config.parse_args() from st2tests.fixturesloader import FixturesLoader @@ -44,39 +48,43 @@ from st2actions.container.base import get_runner_container TEST_FIXTURES = { - 'runners': [ - 'run-local.yaml', - 'testrunner1.yaml', - 'testfailingrunner1.yaml', - 'testasyncrunner1.yaml', - 'testasyncrunner2.yaml' + "runners": [ + "run-local.yaml", + "testrunner1.yaml", + "testfailingrunner1.yaml", + "testasyncrunner1.yaml", + "testasyncrunner2.yaml", + ], + "actions": [ + "local.yaml", + "action1.yaml", + "async_action1.yaml", + "async_action2.yaml", + "action-invalid-runner.yaml", ], - 'actions': [ - 'local.yaml', - 'action1.yaml', - 'async_action1.yaml', - 'async_action2.yaml', - 'action-invalid-runner.yaml' - ] } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" NON_UTF8_RESULT = { - 'stderr': '', - 'stdout': '\x82\n', - 'succeeded': True, - 'failed': False, - 'return_code': 0 + "stderr": "", + "stdout": "\x82\n", + "succeeded": True, + "failed": False, + "return_code": 0, } from st2tests.mocks.runners import runner from st2tests.mocks.runners import polling_async_runner -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RunnerContainerTest(DbTestCase): action_db = None async_action_db = None @@ -88,30 +96,38 @@ class RunnerContainerTest(DbTestCase): def setUpClass(cls): super(RunnerContainerTest, cls).setUpClass() - cfg.CONF.set_override(name='validate_output_schema', override=False, group='system') + cfg.CONF.set_override( + name="validate_output_schema", override=False, group="system" + ) models = RunnerContainerTest.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RunnerContainerTest.runnertype_db = models['runners']['testrunner1.yaml'] - RunnerContainerTest.action_db = models['actions']['action1.yaml'] - RunnerContainerTest.local_action_db = models['actions']['local.yaml'] - RunnerContainerTest.async_action_db = models['actions']['async_action1.yaml'] - RunnerContainerTest.polling_async_action_db = models['actions']['async_action2.yaml'] - RunnerContainerTest.failingaction_db = models['actions']['action-invalid-runner.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RunnerContainerTest.runnertype_db = models["runners"]["testrunner1.yaml"] + RunnerContainerTest.action_db = models["actions"]["action1.yaml"] + RunnerContainerTest.local_action_db = models["actions"]["local.yaml"] + RunnerContainerTest.async_action_db = models["actions"]["async_action1.yaml"] + RunnerContainerTest.polling_async_action_db = models["actions"][ + "async_action2.yaml" + ] + RunnerContainerTest.failingaction_db = models["actions"][ + "action-invalid-runner.yaml" + ] @classmethod def tearDownClass(cls): RunnerContainerTest.fixtures_loader.delete_fixtures_from_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) super(RunnerContainerTest, cls).tearDownClass() def test_get_runner_module(self): - runner = get_runner(name='local-shell-script') - self.assertIsNotNone(runner, 'TestRunner must be valid.') + runner = get_runner(name="local-shell-script") + self.assertIsNotNone(runner, "TestRunner must be valid.") def test_pre_run_runner_is_disabled(self): runnertype_db = RunnerContainerTest.runnertype_db - runner = get_runner(name='local-shell-cmd') + runner = get_runner(name="local-shell-cmd") runner.runner_type = runnertype_db runner.runner_type.enabled = False @@ -119,10 +135,12 @@ def test_pre_run_runner_is_disabled(self): expected_msg = 'Runner "test-runner-1" has been disabled by the administrator' self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action(self): + def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action( + self, + ): params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } global global_runner @@ -141,15 +159,17 @@ def mock_get_runner(*args, **kwargs): liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) - liveaction_db.context = {'user': 'user_joe_1'} + liveaction_db.context = {"user": "user_joe_1"} executions.create_execution_object(liveaction_db) runner_container._get_runner = mock_get_runner - self.assertEqual(getattr(global_runner, 'auth_token', None), None) + self.assertEqual(getattr(global_runner, "auth_token", None), None) runner_container.dispatch(liveaction_db) - self.assertEqual(global_runner.auth_token.user, 'user_joe_1') - self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container') + self.assertEqual(global_runner.auth_token.user, "user_joe_1") + self.assertEqual( + global_runner.auth_token.metadata["service"], "actions_container" + ) runner_container._get_runner = original_get_runner @@ -160,23 +180,25 @@ def mock_get_runner(*args, **kwargs): liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) - liveaction_db.context = {'user': 'user_mark_2'} + liveaction_db.context = {"user": "user_mark_2"} executions.create_execution_object(liveaction_db) original_get_runner = runner_container._get_runner runner_container._get_runner = mock_get_runner - self.assertEqual(getattr(global_runner, 'auth_token', None), None) + self.assertEqual(getattr(global_runner, "auth_token", None), None) runner_container.dispatch(liveaction_db) - self.assertEqual(global_runner.auth_token.user, 'user_mark_2') - self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container') + self.assertEqual(global_runner.auth_token.user, "user_mark_2") + self.assertEqual( + global_runner.auth_token.metadata["service"], "actions_container" + ) def test_post_run_is_always_called_after_run(self): # 1. post_run should be called on success, failure, etc. runner_container = get_runner_container() params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -191,6 +213,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -200,10 +223,7 @@ def mock_get_runner(*args, **kwargs): # 2. Verify post_run is called if run() throws runner_container = get_runner_container() - params = { - 'actionstr': 'bar', - 'raise': True - } + params = {"actionstr": "bar", "raise": True} liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) @@ -216,6 +236,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -225,10 +246,10 @@ def mock_get_runner(*args, **kwargs): # 2. Verify post_run is also called if _delete_auth_token throws runner_container = get_runner_container() - runner_container._delete_auth_token = mock.Mock(side_effect=ValueError('throw')) + runner_container._delete_auth_token = mock.Mock(side_effect=ValueError("throw")) params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -242,6 +263,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -250,43 +272,42 @@ def mock_get_runner(*args, **kwargs): self.assertTrue(global_runner.post_run_called) def test_get_runner_module_fail(self): - runnertype_db = RunnerTypeDB(name='dummy', runner_module='absent.module') + runnertype_db = RunnerTypeDB(name="dummy", runner_module="absent.module") runner = None try: - runner = get_runner(runnertype_db.runner_module, runnertype_db.runner_module) + runner = get_runner( + runnertype_db.runner_module, runnertype_db.runner_module + ) except ActionRunnerCreateError: pass - self.assertFalse(runner, 'TestRunner must be valid.') + self.assertFalse(runner, "TestRunner must be valid.") def test_dispatch(self): runner_container = get_runner_container() - params = { - 'actionstr': 'bar' - } - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "bar"} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran successfully. runner_container.dispatch(liveaction_db) liveaction_db = LiveAction.get_by_id(liveaction_db.id) result = liveaction_db.result - self.assertTrue(result.get('action_params').get('actionint') == 10) - self.assertTrue(result.get('action_params').get('actionstr') == 'bar') + self.assertTrue(result.get("action_params").get("actionint") == 10) + self.assertTrue(result.get("action_params").get("actionstr") == "bar") # Assert that context is written correctly. - context = { - 'user': 'stanley', - 'third_party_system': { - 'ref_id': '1234' - } - } + context = {"user": "stanley", "third_party_system": {"ref_id": "1234"}} self.assertDictEqual(liveaction_db.context, context) def test_dispatch_unsupported_status(self): runner_container = get_runner_container() - params = {'actionstr': 'bar'} - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "bar"} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) @@ -295,86 +316,74 @@ def test_dispatch_unsupported_status(self): # Assert exception is raised on dispatch. self.assertRaises( - ActionRunnerDispatchError, - runner_container.dispatch, - liveaction_db + ActionRunnerDispatchError, runner_container.dispatch, liveaction_db ) def test_dispatch_runner_failure(self): runner_container = get_runner_container() - params = { - 'actionstr': 'bar' - } + params = {"actionstr": "bar"} liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) runner_container.dispatch(liveaction_db) # pickup updated liveaction_db liveaction_db = LiveAction.get_by_id(liveaction_db.id) - self.assertIn('error', liveaction_db.result) - self.assertIn('traceback', liveaction_db.result) + self.assertIn("error", liveaction_db.result) + self.assertIn("traceback", liveaction_db.result) def test_dispatch_override_default_action_params(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20 - } - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "foo", "actionint": 20} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran successfully. runner_container.dispatch(liveaction_db) liveaction_db = LiveAction.get_by_id(liveaction_db.id) result = liveaction_db.result - self.assertTrue(result.get('action_params').get('actionint') == 20) - self.assertTrue(result.get('action_params').get('actionstr') == 'foo') + self.assertTrue(result.get("action_params").get("actionint") == 20) + self.assertTrue(result.get("action_params").get("actionstr") == "foo") def test_state_db_created_for_polling_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.polling_async_action_db, - params + RunnerContainerTest.polling_async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran without exceptions. - with mock.patch('st2actions.container.base.get_runner', - mock.Mock(return_value=polling_async_runner.get_runner())): + with mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=polling_async_runner.get_runner()), + ): runner_container.dispatch(liveaction_db) states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) > 0, 'There should be a state db object.') - self.assertTrue(len(found) == 1, 'There should only be one state db object.') + self.assertTrue(len(found) > 0, "There should be a state db object.") + self.assertTrue(len(found) == 1, "There should only be one state db object.") self.assertIsNotNone(found[0].query_context) self.assertIsNotNone(found[0].query_module) @mock.patch.object( PollingAsyncActionRunner, - 'is_polling_enabled', - mock.MagicMock(return_value=False)) + "is_polling_enabled", + mock.MagicMock(return_value=False), + ) def test_state_db_not_created_for_disabled_polling_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.polling_async_action_db, - params + RunnerContainerTest.polling_async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -385,20 +394,15 @@ def test_state_db_not_created_for_disabled_polling_async_actions(self): states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) == 0, 'There should not be a state db object.') + self.assertTrue(len(found) == 0, "There should not be a state db object.") def test_state_db_not_created_for_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.async_action_db, - params + RunnerContainerTest.async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -409,17 +413,21 @@ def test_state_db_not_created_for_async_actions(self): states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) == 0, 'There should not be a state db object.') + self.assertTrue(len(found) == 0, "There should not be a state db object.") def _get_liveaction_model(self, action_db, params): status = action_constants.LIVEACTION_STATUS_REQUESTED start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db def _get_failingaction_exec_db_model(self, params): @@ -427,12 +435,17 @@ def _get_failingaction_exec_db_model(self, params): start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference( name=RunnerContainerTest.failingaction_db.name, - pack=RunnerContainerTest.failingaction_db.pack).ref + pack=RunnerContainerTest.failingaction_db.pack, + ).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db def _get_output_schema_exec_db_model(self, params): @@ -440,10 +453,15 @@ def _get_output_schema_exec_db_model(self, params): start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference( name=RunnerContainerTest.schema_output_action_db.name, - pack=RunnerContainerTest.schema_output_action_db.pack).ref + pack=RunnerContainerTest.schema_output_action_db.pack, + ).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db diff --git a/st2actions/tests/unit/test_scheduler.py b/st2actions/tests/unit/test_scheduler.py index c23568eea1..1a7d4b9beb 100644 --- a/st2actions/tests/unit/test_scheduler.py +++ b/st2actions/tests/unit/test_scheduler.py @@ -20,6 +20,7 @@ import eventlet from st2tests import config as test_config + test_config.parse_args() import st2common @@ -45,31 +46,28 @@ LIVE_ACTION = { - 'parameters': { - 'cmd': 'echo ":dat_face:"', + "parameters": { + "cmd": 'echo ":dat_face:"', }, - 'action': 'core.local', - 'status': 'requested' + "action": "core.local", + "status": "requested", } -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_3.yaml', - 'policy_7.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_3.yaml", "policy_7.yaml"], } @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), +) class ActionExecutionSchedulingQueueItemDBTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): ExecutionDbTestCase.setUpClass() @@ -81,18 +79,21 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def setUp(self): super(ActionExecutionSchedulingQueueItemDBTest, self).setUp() self.scheduler = scheduling.get_scheduler_entrypoint() self.scheduling_queue = scheduling_queue.get_handler() - def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED): - action_ref = 'wolfpack.action-1' - parameters = {'actionstr': 'fu'} - liveaction_db = LiveActionDB(action=action_ref, parameters=parameters, status=status) + def _create_liveaction_db( + self, status=action_constants.LIVEACTION_STATUS_REQUESTED + ): + action_ref = "wolfpack.action-1" + parameters = {"actionstr": "fu"} + liveaction_db = LiveActionDB( + action=action_ref, parameters=parameters, status=status + ) liveaction_db = LiveAction.add_or_update(liveaction_db) execution_service.create_execution_object(liveaction_db, publish=False) @@ -108,7 +109,9 @@ def test_create_from_liveaction(self): delay, ) - delay_date = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay) + delay_date = date.append_milliseconds_to_time( + liveaction_db.start_timestamp, delay + ) self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB) self.assertEqual(schedule_q_db.scheduled_start_timestamp, delay_date) @@ -125,12 +128,14 @@ def test_next_execution(self): for delay in delays: liveaction_db = self._create_liveaction_db() - delayed_start = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay) + delayed_start = date.append_milliseconds_to_time( + liveaction_db.start_timestamp, delay + ) test_case = { - 'liveaction': liveaction_db, - 'delay': delay, - 'delayed_start': delayed_start + "liveaction": liveaction_db, + "delay": delay, + "delayed_start": delayed_start, } test_cases.append(test_case) @@ -139,8 +144,8 @@ def test_next_execution(self): schedule_q_dbs.append( ActionExecutionSchedulingQueue.add_or_update( self.scheduler._create_execution_queue_item_db_from_liveaction( - test_case['liveaction'], - test_case['delay'], + test_case["liveaction"], + test_case["delay"], ) ) ) @@ -152,22 +157,24 @@ def test_next_execution(self): test_case = test_cases[index] date_mock = mock.MagicMock() - date_mock.get_datetime_utc_now.return_value = test_case['delayed_start'] + date_mock.get_datetime_utc_now.return_value = test_case["delayed_start"] date_mock.append_milliseconds_to_time = date.append_milliseconds_to_time - with mock.patch('st2actions.scheduler.handler.date', date_mock): + with mock.patch("st2actions.scheduler.handler.date", date_mock): schedule_q_db = self.scheduling_queue._get_next_execution() ActionExecutionSchedulingQueue.delete(schedule_q_db) self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB) - self.assertEqual(schedule_q_db.delay, test_case['delay']) - self.assertEqual(schedule_q_db.liveaction_id, str(test_case['liveaction'].id)) + self.assertEqual(schedule_q_db.delay, test_case["delay"]) + self.assertEqual( + schedule_q_db.liveaction_id, str(test_case["liveaction"].id) + ) # NOTE: We can't directly assert on the timestamp due to the delays on the code and # timing variance scheduled_start_timestamp = schedule_q_db.scheduled_start_timestamp - test_case_start_timestamp = test_case['delayed_start'] - start_timestamp_diff = (scheduled_start_timestamp - test_case_start_timestamp) + test_case_start_timestamp = test_case["delayed_start"] + start_timestamp_diff = scheduled_start_timestamp - test_case_start_timestamp self.assertTrue(start_timestamp_diff <= datetime.timedelta(seconds=1)) def test_next_executions_empty(self): @@ -227,9 +234,11 @@ def test_garbage_collection(self): schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(schedule_q_db) - @mock.patch('st2actions.scheduler.handler.action_service') - @mock.patch('st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete') - def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_action_service): + @mock.patch("st2actions.scheduler.handler.action_service") + @mock.patch("st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete") + def test_processing_when_task_completed( + self, mock_execution_queue_delete, mock_action_service + ): self.reset() liveaction_db = self._create_liveaction_db() @@ -245,7 +254,7 @@ def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_ mock_execution_queue_delete.assert_called_once() ActionExecutionSchedulingQueue.delete(schedule_q_db) - @mock.patch('st2actions.scheduler.handler.LOG') + @mock.patch("st2actions.scheduler.handler.LOG") def test_failed_next_item(self, mocked_logger): self.reset() @@ -258,15 +267,17 @@ def test_failed_next_item(self, mocked_logger): schedule_q_db = ActionExecutionSchedulingQueue.add_or_update(schedule_q_db) with mock.patch( - 'st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update', - side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db) + "st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update", + side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db), ): schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNone(schedule_q_db) self.assertEqual(mocked_logger.info.call_count, 2) call_args = mocked_logger.info.call_args_list[1][0] - self.assertEqual(r'[%s] Item "%s" is already handled by another scheduler.', call_args[0]) + self.assertEqual( + r'[%s] Item "%s" is already handled by another scheduler.', call_args[0] + ) schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(schedule_q_db) @@ -288,33 +299,39 @@ def test_cleanup_policy_delayed(self): # Manually update the liveaction to policy-delayed status. # Using action_service.update_status will throw an exception on the # deprecated action_constants.LIVEACTION_STATUS_POLICY_DELAYED. - liveaction_db.status = 'policy-delayed' + liveaction_db.status = "policy-delayed" liveaction_db = LiveAction.add_or_update(liveaction_db) execution_db = execution_service.update_execution(liveaction_db) # Check that the execution status is set to policy-delayed. liveaction_db = LiveAction.get_by_id(str(liveaction_db.id)) - self.assertEqual(liveaction_db.status, 'policy-delayed') + self.assertEqual(liveaction_db.status, "policy-delayed") execution_db = ActionExecution.get_by_id(str(execution_db.id)) - self.assertEqual(execution_db.status, 'policy-delayed') + self.assertEqual(execution_db.status, "policy-delayed") # Run the clean up logic. self.scheduling_queue._cleanup_policy_delayed() # Check that the execution status is reset to requested. liveaction_db = LiveAction.get_by_id(str(liveaction_db.id)) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) execution_db = ActionExecution.get_by_id(str(execution_db.id)) - self.assertEqual(execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # The old entry should have been deleted. Since the execution is # reset to requested, there should be a new scheduling entry. new_schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(new_schedule_q_db) self.assertNotEqual(str(schedule_q_db.id), str(new_schedule_q_db.id)) - self.assertEqual(schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id) + self.assertEqual( + schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id + ) self.assertEqual(schedule_q_db.liveaction_id, new_schedule_q_db.liveaction_id) # TODO: Remove this test case for populating action_execution_id in v3.2. diff --git a/st2actions/tests/unit/test_scheduler_entrypoint.py b/st2actions/tests/unit/test_scheduler_entrypoint.py index ddcba287e7..2bc535d99d 100644 --- a/st2actions/tests/unit/test_scheduler_entrypoint.py +++ b/st2actions/tests/unit/test_scheduler_entrypoint.py @@ -17,6 +17,7 @@ import mock from st2tests import config as test_config + test_config.parse_args() from st2actions.cmd.scheduler import _run_scheduler @@ -25,32 +26,30 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'SchedulerServiceEntryPointTestCase' -] +__all__ = ["SchedulerServiceEntryPointTestCase"] def mock_handler_run(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('handler run exception') + raise Exception("handler run exception") def mock_handler_cleanup(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('handler clean exception') + raise Exception("handler clean exception") def mock_entrypoint_start(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('entrypoint start exception') + raise Exception("entrypoint start exception") class SchedulerServiceEntryPointTestCase(CleanDbTestCase): - @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'run', mock_handler_run) - @mock.patch('st2actions.cmd.scheduler.LOG') + @mock.patch.object(ActionExecutionSchedulingQueueHandler, "run", mock_handler_run) + @mock.patch("st2actions.cmd.scheduler.LOG") def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_log): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() @@ -58,26 +57,32 @@ def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_lo self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) - - @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'cleanup', mock_handler_cleanup) - @mock.patch('st2actions.cmd.scheduler.LOG') - def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup(self, mock_log): + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) + + @mock.patch.object( + ActionExecutionSchedulingQueueHandler, "cleanup", mock_handler_cleanup + ) + @mock.patch("st2actions.cmd.scheduler.LOG") + def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup( + self, mock_log + ): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) - @mock.patch.object(SchedulerEntrypoint, 'start', mock_entrypoint_start) - @mock.patch('st2actions.cmd.scheduler.LOG') - def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start(self, mock_log): + @mock.patch.object(SchedulerEntrypoint, "start", mock_entrypoint_start) + @mock.patch("st2actions.cmd.scheduler.LOG") + def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start( + self, mock_log + ): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) diff --git a/st2actions/tests/unit/test_scheduler_retry.py b/st2actions/tests/unit/test_scheduler_retry.py index e47a2ad3eb..ad1f221df1 100644 --- a/st2actions/tests/unit/test_scheduler_retry.py +++ b/st2actions/tests/unit/test_scheduler_retry.py @@ -19,6 +19,7 @@ import uuid from st2tests import config as test_config + test_config.parse_args() from st2actions.scheduler import handler @@ -27,22 +28,23 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'SchedulerHandlerRetryTestCase' -] +__all__ = ["SchedulerHandlerRetryTestCase"] -MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB(liveaction_id=uuid.uuid4().hex) +MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB( + liveaction_id=uuid.uuid4().hex +) class SchedulerHandlerRetryTestCase(CleanDbTestCase): - - @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM])) @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock( + side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM] + ), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retry_connection_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() scheduling_queue_handler.process() @@ -52,69 +54,88 @@ def test_handler_retry_connection_error(self): eventlet.GreenPool.spawn.assert_has_calls(calls) @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3)) - @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retries_exhausted(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() - self.assertRaises(pymongo.errors.ConnectionFailure, scheduling_queue_handler.process) + self.assertRaises( + pymongo.errors.ConnectionFailure, scheduling_queue_handler.process + ) self.assertEqual(eventlet.GreenPool.spawn.call_count, 0) @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=KeyError())) - @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock(side_effect=KeyError()), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retry_unexpected_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() self.assertRaises(KeyError, scheduling_queue_handler.process) self.assertEqual(eventlet.GreenPool.spawn.call_count, 0) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]])) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock( + side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]] + ), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_retry_connection_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() scheduling_queue_handler._handle_garbage_collection() # Make sure retry occurs and that _handle_execution in process is called. calls = [mock.call(MOCK_QUEUE_ITEM, publish=False)] - ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls(calls) + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls( + calls + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_retries_exhausted(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() self.assertRaises( pymongo.errors.ConnectionFailure, - scheduling_queue_handler._handle_garbage_collection + scheduling_queue_handler._handle_garbage_collection, ) - self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0) + self.assertEqual( + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0 + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=KeyError())) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock(side_effect=KeyError()), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_unexpected_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() - self.assertRaises( - KeyError, - scheduling_queue_handler._handle_garbage_collection - ) + self.assertRaises(KeyError, scheduling_queue_handler._handle_garbage_collection) - self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0) + self.assertEqual( + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0 + ) diff --git a/st2actions/tests/unit/test_worker.py b/st2actions/tests/unit/test_worker.py index 19ffd69553..d8637b9ac7 100644 --- a/st2actions/tests/unit/test_worker.py +++ b/st2actions/tests/unit/test_worker.py @@ -36,16 +36,20 @@ from st2tests.fixturesloader import FixturesLoader import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() -TEST_FIXTURES = { - 'actions': ['local.yaml'] -} +TEST_FIXTURES = {"actions": ["local.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -NON_UTF8_RESULT = {'stderr': '', 'stdout': '\x82\n', 'succeeded': True, 'failed': False, - 'return_code': 0} +NON_UTF8_RESULT = { + "stderr": "", + "stdout": "\x82\n", + "succeeded": True, + "failed": False, + "return_code": 0, +} class WorkerTestCase(DbTestCase): @@ -58,28 +62,42 @@ def setUpClass(cls): runners_registrar.register_runners() models = WorkerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - WorkerTestCase.local_action_db = models['actions']['local.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + WorkerTestCase.local_action_db = models["actions"]["local.yaml"] def _get_liveaction_model(self, action_db, params): status = action_constants.LIVEACTION_STATUS_REQUESTED start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db - @mock.patch.object(LocalShellCommandRunner, 'run', mock.MagicMock( - return_value=(action_constants.LIVEACTION_STATUS_SUCCEEDED, NON_UTF8_RESULT, None))) + @mock.patch.object( + LocalShellCommandRunner, + "run", + mock.MagicMock( + return_value=( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + NON_UTF8_RESULT, + None, + ) + ), + ) def test_non_utf8_action_result_string(self): action_worker = actions_worker.get_worker() - params = { - 'cmd': "python -c 'print \"\\x82\"'" - } - liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params) + params = {"cmd": "python -c 'print \"\\x82\"'"} + liveaction_db = self._get_liveaction_model( + WorkerTestCase.local_action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) execution_db = executions.create_execution_object(liveaction_db) @@ -87,11 +105,15 @@ def test_non_utf8_action_result_string(self): action_worker._run_action(liveaction_db) except InvalidStringData: liveaction_db = LiveAction.get_by_id(liveaction_db.id) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertIn('error', liveaction_db.result) - self.assertIn('traceback', liveaction_db.result) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) + self.assertIn("error", liveaction_db.result) + self.assertIn("traceback", liveaction_db.result) execution_db = ActionExecution.get_by_id(execution_db.id) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) def test_worker_shutdown(self): action_worker = actions_worker.get_worker() @@ -107,8 +129,10 @@ def test_worker_shutdown(self): self.assertTrue(os.path.isfile(temp_file)) # Launch the action execution in a separate thread. - params = {'cmd': 'while [ -e \'%s\' ]; do sleep 0.1; done' % temp_file} - liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params) + params = {"cmd": "while [ -e '%s' ]; do sleep 0.1; done" % temp_file} + liveaction_db = self._get_liveaction_model( + WorkerTestCase.local_action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) runner_thread = eventlet.spawn(action_worker._run_action, liveaction_db) @@ -127,8 +151,11 @@ def test_worker_shutdown(self): # Verify that _running_liveactions is empty and the liveaction is abandoned. self.assertEqual(len(action_worker._running_liveactions), 0) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_ABANDONED, - str(liveaction_db)) + self.assertEqual( + liveaction_db.status, + action_constants.LIVEACTION_STATUS_ABANDONED, + str(liveaction_db), + ) # Make sure the temporary file has been deleted. self.assertFalse(os.path.isfile(temp_file)) diff --git a/st2actions/tests/unit/test_workflow_engine.py b/st2actions/tests/unit/test_workflow_engine.py index 916682d569..b8e4fae83f 100644 --- a/st2actions/tests/unit/test_workflow_engine.py +++ b/st2actions/tests/unit/test_workflow_engine.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2actions.workflows import workflows @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class WorkflowExecutionHandlerTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionHandlerTest, cls).setUpClass() @@ -86,50 +95,57 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_process(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) # Assert the workflow has completed successfully with expected output. - expected_output = {'msg': 'Stanley, All your base are belong to us!'} + expected_output = {"msg": "Stanley, All your base are belong to us!"} wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) self.assertDictEqual(wf_ex_db.output, expected_output) @@ -137,37 +153,43 @@ def test_process(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @mock.patch.object( - coordination_service.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) def test_process_error_handling(self): expected_errors = [ { - 'message': 'Execution failed. See result for details.', - 'type': 'error', - 'task_id': 'task1' + "message": "Execution failed. See result for details.", + "type": "error", + "task_id": "task1", }, { - 'type': 'error', - 'message': 'ToozConnectionError: foobar', - 'task_id': 'task1', - 'route': 0 - } + "type": "error", + "message": "ToozConnectionError: foobar", + "task_id": "task1", + "route": 0, + }, ] - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] workflows.get_engine().process(t1_ac_ex_db) # Assert the task is marked as failed. @@ -182,36 +204,42 @@ def test_process_error_handling(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) @mock.patch.object( - coordination_service.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) @mock.patch.object( workflows.WorkflowExecutionHandler, - 'fail_workflow_execution', - mock.MagicMock(side_effect=Exception('Unexpected error.'))) + "fail_workflow_execution", + mock.MagicMock(side_effect=Exception("Unexpected error.")), + ) def test_process_error_handling_has_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] self.assertRaisesRegexp( - Exception, - 'Unexpected error.', - workflows.get_engine().process, - t1_ac_ex_db + Exception, "Unexpected error.", workflows.get_engine().process, t1_ac_ex_db ) - self.assertTrue(workflows.WorkflowExecutionHandler.fail_workflow_execution.called) + self.assertTrue( + workflows.WorkflowExecutionHandler.fail_workflow_execution.called + ) # Since error handling failed, the workflow will have status of running. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) diff --git a/st2api/dist_utils.py b/st2api/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2api/dist_utils.py +++ b/st2api/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2api/setup.py b/st2api/setup.py index 932f2e90f4..b0cfa24067 100644 --- a/st2api/setup.py +++ b/st2api/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2api import __version__ -ST2_COMPONENT = 'st2api' +ST2_COMPONENT = "st2api" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2api' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2api"], ) diff --git a/st2api/st2api/__init__.py b/st2api/st2api/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2api/st2api/__init__.py +++ b/st2api/st2api/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2api/st2api/app.py b/st2api/st2api/app.py index 5b10e58c3f..2483d0ef9e 100644 --- a/st2api/st2api/app.py +++ b/st2api/st2api/app.py @@ -36,55 +36,60 @@ def setup_app(config=None): config = config or {} - LOG.info('Creating st2api: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2api: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # NOTE: We only want to perform this logic in the WSGI worker st2api_config.register_opts() capabilities = { - 'name': 'api', - 'listen_host': cfg.CONF.api.host, - 'listen_port': cfg.CONF.api.port, - 'type': 'active' + "name": "api", + "listen_host": cfg.CONF.api.host, + "listen_port": cfg.CONF.api.port, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='api', config=st2api_config, setup_db=True, - register_mq_exchanges=True, - register_signal_handlers=True, - register_internal_trigger_types=True, - run_migrations=True, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="api", + config=st2api_config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + run_migrations=True, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) # Additional pre-run time checks validate_rbac_is_correctly_configured() - router = Router(debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable, - is_gunicorn=is_gunicorn) + router = Router( + debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn + ) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") transforms = { - '^/api/v1/$': ['/v1'], - '^/api/v1/': ['/', '/v1/'], - '^/api/v1/executions': ['/actionexecutions', '/v1/actionexecutions'], - '^/api/exp/': ['/exp/'] + "^/api/v1/$": ["/v1"], + "^/api/v1/": ["/", "/v1/"], + "^/api/v1/executions": ["/actionexecutions", "/v1/actionexecutions"], + "^/api/exp/": ["/exp/"], } router.add_spec(spec, transforms=transforms) app = router.as_wsgi # Order is important. Check middleware for detailed explanation. - app = StreamingMiddleware(app, path_whitelist=['/v1/executions/*/output*']) + app = StreamingMiddleware(app, path_whitelist=["/v1/executions/*/output*"]) app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='api') + app = ResponseInstrumentationMiddleware(app, router, service_name="api") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='api') + app = RequestInstrumentationMiddleware(app, router, service_name="api") return app diff --git a/st2api/st2api/cmd/__init__.py b/st2api/st2api/cmd/__init__.py index 4e28bca433..0b9307922a 100644 --- a/st2api/st2api/cmd/__init__.py +++ b/st2api/st2api/cmd/__init__.py @@ -15,4 +15,4 @@ from st2api.cmd import api -__all__ = ['api'] +__all__ = ["api"] diff --git a/st2api/st2api/cmd/api.py b/st2api/st2api/cmd/api.py index 73d3520444..1cf01d0544 100644 --- a/st2api/st2api/cmd/api.py +++ b/st2api/st2api/cmd/api.py @@ -21,6 +21,7 @@ # See https://github.com/StackStorm/st2/issues/4832 and https://github.com/gevent/gevent/issues/1016 # for details. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import eventlet @@ -31,14 +32,13 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown from st2api import config + config.register_opts() from st2api import app from st2api.validation import validate_rbac_is_correctly_configured -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -48,15 +48,22 @@ def _setup(): capabilities = { - 'name': 'api', - 'listen_host': cfg.CONF.api.host, - 'listen_port': cfg.CONF.api.port, - 'type': 'active' + "name": "api", + "listen_host": cfg.CONF.api.host, + "listen_port": cfg.CONF.api.port, + "type": "active", } - common_setup(service='api', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=True, - service_registry=True, capabilities=capabilities) + common_setup( + service="api", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + service_registry=True, + capabilities=capabilities, + ) # Additional pre-run time checks validate_rbac_is_correctly_configured() @@ -66,13 +73,15 @@ def _run_server(): host = cfg.CONF.api.host port = cfg.CONF.api.port - LOG.info('(PID=%s) ST2 API is serving on http://%s:%s.', os.getpid(), host, port) + LOG.info("(PID=%s) ST2 API is serving on http://%s:%s.", os.getpid(), host, port) max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS worker_pool = eventlet.GreenPool(max_pool_size) sock = eventlet.listen((host, port)) - wsgi.server(sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False) + wsgi.server( + sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False + ) return 0 @@ -87,7 +96,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) ST2 API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2api/st2api/config.py b/st2api/st2api/config.py index 71378da9ad..35a21d87d5 100644 --- a/st2api/st2api/config.py +++ b/st2api/st2api/config.py @@ -32,8 +32,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -52,32 +55,38 @@ def get_logging_config_path(): def _register_app_opts(): # Note "host", "port", "allow_origin", "mask_secrets" options are registered as part of # st2common config since they are also used outside st2api - static_root = os.path.join(cfg.CONF.system.base_path, 'static') - template_path = os.path.join(BASE_DIR, 'templates/') + static_root = os.path.join(cfg.CONF.system.base_path, "static") + template_path = os.path.join(BASE_DIR, "templates/") pecan_opts = [ cfg.StrOpt( - 'root', default='st2api.controllers.root.RootController', - help='Action root controller'), - cfg.StrOpt('static_root', default=static_root), - cfg.StrOpt('template_path', default=template_path), - cfg.ListOpt('modules', default=['st2api']), - cfg.BoolOpt('debug', default=False), - cfg.BoolOpt('auth_enable', default=True), - cfg.DictOpt('errors', default={'__force_dict__': True}) + "root", + default="st2api.controllers.root.RootController", + help="Action root controller", + ), + cfg.StrOpt("static_root", default=static_root), + cfg.StrOpt("template_path", default=template_path), + cfg.ListOpt("modules", default=["st2api"]), + cfg.BoolOpt("debug", default=False), + cfg.BoolOpt("auth_enable", default=True), + cfg.DictOpt("errors", default={"__force_dict__": True}), ] - CONF.register_opts(pecan_opts, group='api_pecan') + CONF.register_opts(pecan_opts, group="api_pecan") logging_opts = [ - cfg.BoolOpt('debug', default=False), + cfg.BoolOpt("debug", default=False), cfg.StrOpt( - 'logging', default='/etc/st2/logging.api.conf', - help='location of the logging.conf file'), + "logging", + default="/etc/st2/logging.api.conf", + help="location of the logging.conf file", + ), cfg.IntOpt( - 'max_page_size', default=100, - help='Maximum limit (page size) argument which can be ' - 'specified by the user in a query string.') + "max_page_size", + default=100, + help="Maximum limit (page size) argument which can be " + "specified by the user in a query string.", + ), ] - CONF.register_opts(logging_opts, group='api') + CONF.register_opts(logging_opts, group="api") diff --git a/st2api/st2api/controllers/base.py b/st2api/st2api/controllers/base.py index a3f24e2f0f..e4f13d8f1e 100644 --- a/st2api/st2api/controllers/base.py +++ b/st2api/st2api/controllers/base.py @@ -20,9 +20,7 @@ from st2api.controllers.controller_transforms import transform_to_bool from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'BaseRestControllerMixin' -] +__all__ = ["BaseRestControllerMixin"] class BaseRestControllerMixin(object): @@ -41,7 +39,9 @@ def _parse_query_params(self, request): return query_params - def _get_query_param_value(self, request, param_name, param_type, default_value=None): + def _get_query_param_value( + self, request, param_name, param_type, default_value=None + ): """ Return a value for the provided query param and optionally cast it for boolean types. @@ -61,7 +61,7 @@ def _get_query_param_value(self, request, param_name, param_type, default_value= query_params = self._parse_query_params(request=request) value = query_params.get(param_name, default_value) - if param_type == 'bool' and isinstance(value, six.string_types): + if param_type == "bool" and isinstance(value, six.string_types): value = transform_to_bool(value) return value diff --git a/st2api/st2api/controllers/controller_transforms.py b/st2api/st2api/controllers/controller_transforms.py index 8afff88da6..0ca51a0a75 100644 --- a/st2api/st2api/controllers/controller_transforms.py +++ b/st2api/st2api/controllers/controller_transforms.py @@ -14,9 +14,7 @@ # limitations under the License. -__all__ = [ - 'transform_to_bool' -] +__all__ = ["transform_to_bool"] def transform_to_bool(value): @@ -27,8 +25,8 @@ def transform_to_bool(value): Any other representation will be rejected. """ - if value in ['1', 'true', 'True', True]: + if value in ["1", "true", "True", True]: return True - elif value in ['0', 'false', 'False', False]: + elif value in ["0", "false", "False", False]: return False raise ValueError('Invalid bool representation "%s" provided.' % value) diff --git a/st2api/st2api/controllers/resource.py b/st2api/st2api/controllers/resource.py index 72611a90dc..a2391ff9aa 100644 --- a/st2api/st2api/controllers/resource.py +++ b/st2api/st2api/controllers/resource.py @@ -35,21 +35,19 @@ LOG = logging.getLogger(__name__) -RESERVED_QUERY_PARAMS = { - 'id': 'id', - 'name': 'name', - 'sort': 'order_by' -} +RESERVED_QUERY_PARAMS = {"id": "id", "name": "name", "sort": "order_by"} def split_id_value(value): if not value or isinstance(value, (list, tuple)): return value - split = value.split(',') + split = value.split(",") if len(split) > 100: - raise ValueError('Maximum of 100 items can be provided for a query parameter value') + raise ValueError( + "Maximum of 100 items can be provided for a query parameter value" + ) return split @@ -57,7 +55,7 @@ def split_id_value(value): DEFAULT_FILTER_TRANSFORM_FUNCTIONS = { # Support for filtering on multiple ids when a commona delimited string is provided # (e.g. ?id=1,2,3) - 'id': split_id_value + "id": split_id_value } @@ -65,14 +63,14 @@ def parameter_validation(validator, properties, instance, schema): parameter_specific_schema = { "description": "Input parameters for the action.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_parameters_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_parameters_schema()}, + "additionalProperties": False, + "default": {}, } - parameter_specific_validator = util_schema.CustomValidator(parameter_specific_schema) + parameter_specific_validator = util_schema.CustomValidator( + parameter_specific_schema + ) for error in parameter_specific_validator.iter_errors(instance=instance): yield error @@ -91,18 +89,16 @@ class ResourceController(object): # ?include_attributes filter. Those attributes need to be included because a lot of code # depends on compound references and primary keys. In addition to that, it's needed for secrets # masking to work, etc. - mandatory_include_fields_retrieve = ['id'] + mandatory_include_fields_retrieve = ["id"] # A list of fields which are always included in the response when ?include_attributes filter is # used. Those are things such as primary keys and similar. - mandatory_include_fields_response = ['id'] + mandatory_include_fields_response = ["id"] # Default number of items returned per page if no limit is explicitly provided default_limit = 100 - query_options = { - 'sort': [] - } + query_options = {"sort": []} # A list of optional transformation functions for user provided filter values filter_transform_functions = {} @@ -120,7 +116,9 @@ def __init__(self): self.supported_filters = copy.deepcopy(self.__class__.supported_filters) self.supported_filters.update(RESERVED_QUERY_PARAMS) - self.filter_transform_functions = copy.deepcopy(self.__class__.filter_transform_functions) + self.filter_transform_functions = copy.deepcopy( + self.__class__.filter_transform_functions + ) self.filter_transform_functions.update(DEFAULT_FILTER_TRANSFORM_FUNCTIONS) self.get_one_db_method = self._get_by_name_or_id @@ -130,9 +128,19 @@ def __init__(self): def max_limit(self): return cfg.CONF.api.max_page_size - def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=None, - sort=None, offset=0, limit=None, query_options=None, - from_model_kwargs=None, raw_filters=None, requester_user=None): + def _get_all( + self, + exclude_fields=None, + include_fields=None, + advanced_filters=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -144,8 +152,10 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No query_options = query_options if query_options else self.query_options if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_fields) @@ -153,18 +163,18 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No # TODO: Why do we use comma delimited string, user can just specify # multiple values using ?sort=foo&sort=bar and we get a list back - sort = sort.split(',') if sort else [] + sort = sort.split(",") if sort else [] db_sort_values = [] for sort_key in sort: - if sort_key.startswith('-'): - direction = '-' + if sort_key.startswith("-"): + direction = "-" sort_key = sort_key[1:] - elif sort_key.startswith('+'): - direction = '+' + elif sort_key.startswith("+"): + direction = "+" sort_key = sort_key[1:] else: - direction = '' + direction = "" if sort_key not in self.supported_filters: # Skip unsupported sort key @@ -173,12 +183,12 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No sort_value = direction + self.supported_filters[sort_key] db_sort_values.append(sort_value) - default_sort_values = copy.copy(query_options.get('sort')) - raw_filters['sort'] = db_sort_values if db_sort_values else default_sort_values + default_sort_values = copy.copy(query_options.get("sort")) + raw_filters["sort"] = db_sort_values if db_sort_values else default_sort_values # TODO: To protect us from DoS, we need to make max_limit mandatory offset = int(offset) - if offset >= 2**31: + if offset >= 2 ** 31: raise ValueError('Offset "%s" specified is more than 32-bit int' % (offset)) limit = validate_limit_query_param(limit=limit, requester_user=requester_user) @@ -195,32 +205,35 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No value_transform_function = value_transform_function or (lambda value: value) filter_value = value_transform_function(value=filter_value) - if k in ['id', 'name'] and isinstance(filter_value, list): - filters[k + '__in'] = filter_value + if k in ["id", "name"] and isinstance(filter_value, list): + filters[k + "__in"] = filter_value else: - field_name_split = v.split('.') + field_name_split = v.split(".") # Make sure filter value is a list when using "in" filter - if field_name_split[-1] == 'in' and not isinstance(filter_value, (list, tuple)): + if field_name_split[-1] == "in" and not isinstance( + filter_value, (list, tuple) + ): filter_value = [filter_value] - filters['__'.join(field_name_split)] = filter_value + filters["__".join(field_name_split)] = filter_value if advanced_filters: - for token in advanced_filters.split(' '): + for token in advanced_filters.split(" "): try: - [k, v] = token.split(':', 1) + [k, v] = token.split(":", 1) except ValueError: raise ValueError('invalid format for filter "%s"' % token) - path = k.split('.') + path = k.split(".") try: self.model.model._lookup_field(path) - filters['__'.join(path)] = v + filters["__".join(path)] = v except LookUpError as e: raise ValueError(six.text_type(e)) - instances = self.access.query(exclude_fields=exclude_fields, only_fields=include_fields, - **filters) + instances = self.access.query( + exclude_fields=exclude_fields, only_fields=include_fields, **filters + ) if limit == 1: # Perform the filtering on the DB side instances = instances.limit(limit) @@ -228,44 +241,65 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resources_model_filter(model=self.model, - instances=instances, - offset=offset, - eop=eop, - requester_user=requester_user, - **from_model_kwargs) + result = self.resources_model_filter( + model=self.model, + instances=instances, + offset=offset, + eop=eop, + requester_user=requester_user, + **from_model_kwargs, + ) resp = Response(json=result) - resp.headers['X-Total-Count'] = str(instances.count()) + resp.headers["X-Total-Count"] = str(instances.count()) if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp - def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0, - **from_model_kwargs): + def resources_model_filter( + self, + model, + instances, + requester_user=None, + offset=0, + eop=0, + **from_model_kwargs, + ): """ Method which converts DB objects to API objects and performs any additional filtering. """ result = [] for instance in instances[offset:eop]: - item = self.resource_model_filter(model=model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + item = self.resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) result.append(item) return result - def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs): + def resource_model_filter( + self, model, instance, requester_user=None, **from_model_kwargs + ): """ Method which converts DB object to API object and performs any additional filtering. """ item = model.from_model(instance, **from_model_kwargs) return item - def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=None, - include_fields=None, from_model_kwargs=None): + def _get_one_by_id( + self, + id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -273,14 +307,17 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non :type include_fields: ``list`` """ - instance = self._get_by_id(resource_id=id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_id( + resource_id=id, exclude_fields=exclude_fields, include_fields=include_fields + ) if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) if not instance: msg = 'Unable to identify resource with id "%s".' % id @@ -289,21 +326,35 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resource_model_filter(model=self.model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + result = self.resource_model_filter( + model=self.model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not result: - LOG.debug('Not returning the result because RBAC resource isolation is enabled and ' - 'current user doesn\'t match the resource user') - raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user, - resource_api_or_db=instance, - permission_type=permission_type) + LOG.debug( + "Not returning the result because RBAC resource isolation is enabled and " + "current user doesn't match the resource user" + ) + raise ResourceAccessDeniedPermissionIsolationError( + user_db=requester_user, + resource_api_or_db=instance, + permission_type=permission_type, + ) return result - def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, - exclude_fields=None, include_fields=None, from_model_kwargs=None): + def _get_one_by_name_or_id( + self, + name_or_id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -311,14 +362,19 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, :type include_fields: ``list`` """ - instance = self._get_by_name_or_id(name_or_id=name_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_name_or_id( + name_or_id=name_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) if not instance: msg = 'Unable to identify resource with name_or_id "%s".' % (name_or_id) @@ -330,10 +386,14 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, return result - def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None, - from_model_kwargs=None): - instance = self._get_by_pack_ref(pack_ref=pack_ref, exclude_fields=exclude_fields, - include_fields=include_fields) + def _get_one_by_pack_ref( + self, pack_ref, exclude_fields=None, include_fields=None, from_model_kwargs=None + ): + instance = self._get_by_pack_ref( + pack_ref=pack_ref, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not instance: msg = 'Unable to identify resource with pack_ref "%s".' % (pack_ref) @@ -347,8 +407,11 @@ def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=Non def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(id=resource_id, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + id=resource_id, + exclude_fields=exclude_fields, + only_fields=include_fields, + ) except ValidationError: resource_db = None @@ -356,8 +419,11 @@ def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None): def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(name=resource_name, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + name=resource_name, + exclude_fields=exclude_fields, + only_fields=include_fields, + ) except Exception: resource_db = None @@ -365,8 +431,9 @@ def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None): def _get_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(pack=pack_ref, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + pack=pack_ref, exclude_fields=exclude_fields, only_fields=include_fields + ) except Exception: resource_db = None @@ -376,13 +443,17 @@ def _get_by_name_or_id(self, name_or_id, exclude_fields=None, include_fields=Non """ Retrieve resource object by an id of a name. """ - resource_db = self._get_by_id(resource_id=name_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_id( + resource_id=name_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not resource_db: # Try name - resource_db = self._get_by_name(resource_name=name_or_id, - exclude_fields=exclude_fields) + resource_db = self._get_by_name( + resource_name=name_or_id, exclude_fields=exclude_fields + ) if not resource_db: msg = 'Resource with a name or id "%s" not found' % (name_or_id) @@ -402,11 +473,16 @@ def _get_one_by_scope_and_name(self, scope, name, from_model_kwargs=None): """ instance = self.access.get_by_scope_and_name(scope=scope, name=name) if not instance: - msg = 'KeyValuePair with name: %s and scope: %s not found in db.' % (name, scope) + msg = "KeyValuePair with name: %s and scope: %s not found in db." % ( + name, + scope, + ) raise StackStormDBObjectNotFoundError(msg) from_model_kwargs = from_model_kwargs or {} result = self.model.from_model(instance, **from_model_kwargs) - LOG.debug('GET with scope=%s and name=%s, client_result=%s', scope, name, result) + LOG.debug( + "GET with scope=%s and name=%s, client_result=%s", scope, name, result + ) return result @@ -422,7 +498,7 @@ def _validate_exclude_fields(self, exclude_fields): for field in exclude_fields: if field not in self.valid_exclude_attributes: - msg = ('Invalid or unsupported exclude attribute specified: %s' % (field)) + msg = "Invalid or unsupported exclude attribute specified: %s" % (field) raise ValueError(msg) return exclude_fields @@ -438,7 +514,7 @@ def _validate_include_fields(self, include_fields): for field in self.mandatory_include_fields_retrieve: # Don't add mandatory field if user already requested the whole dict object (e.g. user # requests action and action.parameters is a mandatory field) - partial_field = field.split('.')[0] + partial_field = field.split(".")[0] if partial_field in include_fields: continue @@ -456,20 +532,38 @@ class BaseResourceIsolationControllerMixin(object): users). """ - def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0, - **from_model_kwargs): + def resources_model_filter( + self, + model, + instances, + requester_user=None, + offset=0, + eop=0, + **from_model_kwargs, + ): # RBAC or permission isolation is disabled, bail out if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation): - result = super(BaseResourceIsolationControllerMixin, self).resources_model_filter( - model=model, instances=instances, requester_user=requester_user, - offset=offset, eop=eop, **from_model_kwargs) + result = super( + BaseResourceIsolationControllerMixin, self + ).resources_model_filter( + model=model, + instances=instances, + requester_user=requester_user, + offset=offset, + eop=eop, + **from_model_kwargs, + ) return result result = [] for instance in instances[offset:eop]: - item = self.resource_model_filter(model=model, instance=instance, - requester_user=requester_user, **from_model_kwargs) + item = self.resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not item: continue @@ -478,18 +572,25 @@ def resources_model_filter(self, model, instances, requester_user=None, offset=0 return result - def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs): + def resource_model_filter( + self, model, instance, requester_user=None, **from_model_kwargs + ): # RBAC or permission isolation is disabled, bail out if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation): - result = super(BaseResourceIsolationControllerMixin, self).resource_model_filter( - model=model, instance=instance, requester_user=requester_user, - **from_model_kwargs) + result = super( + BaseResourceIsolationControllerMixin, self + ).resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) return result rbac_utils = get_rbac_backend().get_utils_class() user_is_admin = rbac_utils.user_is_admin(user_db=requester_user) - user_is_system_user = (requester_user.name == cfg.CONF.system_user.user) + user_is_system_user = requester_user.name == cfg.CONF.system_user.user item = model.from_model(instance, **from_model_kwargs) @@ -497,7 +598,7 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod if user_is_admin or user_is_system_user: return item - user = item.context.get('user', None) + user = item.context.get("user", None) if user and (user == requester_user.name): return item @@ -506,21 +607,31 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod class ContentPackResourceController(ResourceController): # name and pack are mandatory because they compromise primary key - reference (.) - mandatory_include_fields_retrieve = ['pack', 'name'] + mandatory_include_fields_retrieve = ["pack", "name"] # A list of fields which are always included in the response. Those are things such as primary # keys and similar - mandatory_include_fields_response = ['id', 'ref'] + mandatory_include_fields_response = ["id", "ref"] def __init__(self): super(ContentPackResourceController, self).__init__() self.get_one_db_method = self._get_by_ref_or_id - def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=None, - include_fields=None, from_model_kwargs=None): + def _get_one( + self, + ref_or_id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): try: - instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_ref_or_id( + ref_or_id=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) except Exception as e: LOG.exception(six.text_type(e)) abort(http_client.NOT_FOUND, six.text_type(e)) @@ -528,40 +639,59 @@ def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=No if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) # Perform resource isolation check (if supported) from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resource_model_filter(model=self.model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + result = self.resource_model_filter( + model=self.model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not result: - LOG.debug('Not returning the result because RBAC resource isolation is enabled and ' - 'current user doesn\'t match the resource user') - raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user, - resource_api_or_db=instance, - permission_type=permission_type) + LOG.debug( + "Not returning the result because RBAC resource isolation is enabled and " + "current user doesn't match the resource user" + ) + raise ResourceAccessDeniedPermissionIsolationError( + user_db=requester_user, + resource_api_or_db=instance, + permission_type=permission_type, + ) return Response(json=result) - def _get_all(self, exclude_fields=None, include_fields=None, - sort=None, offset=0, limit=None, query_options=None, - from_model_kwargs=None, raw_filters=None, requester_user=None): - resp = super(ContentPackResourceController, - self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + def _get_all( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): + resp = super(ContentPackResourceController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) return resp @@ -574,8 +704,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) """ if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) if ResourceReference.is_resource_reference(ref_or_id): @@ -585,11 +717,17 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) is_reference = False if is_reference: - resource_db = self._get_by_ref(resource_ref=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_ref( + resource_ref=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) else: - resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_id( + resource_id=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not resource_db: msg = 'Resource with a reference or id "%s" not found' % (ref_or_id) @@ -599,8 +737,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None): if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) try: @@ -608,9 +748,12 @@ def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None): except Exception: return None - resource_db = self.access.query(name=ref.name, pack=ref.pack, - exclude_fields=exclude_fields, - only_fields=include_fields).first() + resource_db = self.access.query( + name=ref.name, + pack=ref.pack, + exclude_fields=exclude_fields, + only_fields=include_fields, + ).first() return resource_db @@ -629,25 +772,29 @@ def validate_limit_query_param(limit, requester_user=None): if int(limit) == -1: if not user_is_admin: # Only admins can specify limit -1 - message = ('Administrator access required to be able to specify limit=-1 and ' - 'retrieve all the records') - raise AccessDeniedError(message=message, - user_db=requester_user) + message = ( + "Administrator access required to be able to specify limit=-1 and " + "retrieve all the records" + ) + raise AccessDeniedError(message=message, user_db=requester_user) return 0 elif int(limit) <= -2: msg = 'Limit, "%s" specified, must be a positive number.' % (limit) raise ValueError(msg) elif int(limit) > cfg.CONF.api.max_page_size and not user_is_admin: - msg = ('Limit "%s" specified, maximum value is "%s"' % (limit, - cfg.CONF.api.max_page_size)) + msg = 'Limit "%s" specified, maximum value is "%s"' % ( + limit, + cfg.CONF.api.max_page_size, + ) - raise AccessDeniedError(message=msg, - user_db=requester_user) + raise AccessDeniedError(message=msg, user_db=requester_user) # Disable n = 0 elif limit == 0: - msg = ('Limit, "%s" specified, must be a positive number or -1 for full result set.' % - (limit)) + msg = ( + 'Limit, "%s" specified, must be a positive number or -1 for full result set.' + % (limit) + ) raise ValueError(msg) return limit diff --git a/st2api/st2api/controllers/root.py b/st2api/st2api/controllers/root.py index c2db487b02..2d5d953afa 100644 --- a/st2api/st2api/controllers/root.py +++ b/st2api/st2api/controllers/root.py @@ -15,23 +15,21 @@ from st2common import __version__ -__all__ = [ - 'RootController' -] +__all__ = ["RootController"] class RootController(object): def index(self): data = {} - if 'dev' in __version__: - docs_url = 'http://docs.stackstorm.com/latest' + if "dev" in __version__: + docs_url = "http://docs.stackstorm.com/latest" else: - docs_version = '.'.join(__version__.split('.')[:2]) - docs_url = 'http://docs.stackstorm.com/%s' % docs_version + docs_version = ".".join(__version__.split(".")[:2]) + docs_url = "http://docs.stackstorm.com/%s" % docs_version - data['version'] = __version__ - data['docs_url'] = docs_url + data["version"] = __version__ + data["docs_url"] = docs_url return data diff --git a/st2api/st2api/controllers/v1/action_views.py b/st2api/st2api/controllers/v1/action_views.py index d1701ebfbf..2e528b5b13 100644 --- a/st2api/st2api/controllers/v1/action_views.py +++ b/st2api/st2api/controllers/v1/action_views.py @@ -33,11 +33,7 @@ from st2common.router import abort from st2common.router import Response -__all__ = [ - 'OverviewController', - 'ParametersViewController', - 'EntryPointController' -] +__all__ = ["OverviewController", "ParametersViewController", "EntryPointController"] http_client = six.moves.http_client @@ -45,7 +41,6 @@ class LookupUtils(object): - @staticmethod def _get_action_by_id(id): try: @@ -75,31 +70,33 @@ def _get_runner_by_name(name): class ParametersViewController(object): - def get_one(self, action_id, requester_user): return self._get_one(action_id, requester_user=requester_user) @staticmethod def _get_one(action_id, requester_user): """ - List merged action & runner parameters by action id. + List merged action & runner parameters by action id. - Handle: - GET /actions/views/parameters/1 + Handle: + GET /actions/views/parameters/1 """ action_db = LookupUtils._get_action_by_id(action_id) permission_type = PermissionType.ACTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) - runner_db = LookupUtils._get_runner_by_name(action_db.runner_type['name']) + runner_db = LookupUtils._get_runner_by_name(action_db.runner_type["name"]) all_params = action_param_utils.get_params_view( - action_db=action_db, runner_db=runner_db, merged_only=True) + action_db=action_db, runner_db=runner_db, merged_only=True + ) - return {'parameters': all_params} + return {"parameters": all_params} class OverviewController(resource.ContentPackResourceController): @@ -107,47 +104,54 @@ class OverviewController(resource.ContentPackResourceController): access = Action supported_filters = {} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - mandatory_include_fields_retrieve = [ - 'pack', - 'name', - 'parameters', - 'runner_type' - ] + mandatory_include_fields_retrieve = ["pack", "name", "parameters", "runner_type"] def get_one(self, ref_or_id, requester_user): """ - List action by id. + List action by id. - Handle: - GET /actions/views/overview/1 + Handle: + GET /actions/views/overview/1 """ - resp = super(OverviewController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=PermissionType.ACTION_VIEW) + resp = super(OverviewController, self)._get_one( + ref_or_id, + requester_user=requester_user, + permission_type=PermissionType.ACTION_VIEW, + ) action_api = ActionAPI(**resp.json) - result = self._transform_action_api(action_api=action_api, requester_user=requester_user) + result = self._transform_action_api( + action_api=action_api, requester_user=requester_user + ) resp.json = result return resp - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): """ - List all actions. + List all actions. - Handles requests: - GET /actions/views/overview + Handles requests: + GET /actions/views/overview """ - resp = super(OverviewController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + resp = super(OverviewController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) runner_type_names = set([]) action_ids = [] @@ -164,9 +168,12 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o # N * 2 additional queries # 1. Retrieve all the respective runner objects - we only need parameters - runner_type_dbs = RunnerType.query(name__in=runner_type_names, - only_fields=['name', 'runner_parameters']) - runner_type_dbs = dict([(runner_db.name, runner_db) for runner_db in runner_type_dbs]) + runner_type_dbs = RunnerType.query( + name__in=runner_type_names, only_fields=["name", "runner_parameters"] + ) + runner_type_dbs = dict( + [(runner_db.name, runner_db) for runner_db in runner_type_dbs] + ) # 2. Retrieve all the respective action objects - we only need parameters action_dbs = dict([(action_db.id, action_db) for action_db in result]) @@ -174,9 +181,9 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o for action_api in result: action_db = action_dbs.get(action_api.id, None) runner_db = runner_type_dbs.get(action_api.runner_type, None) - all_params = action_param_utils.get_params_view(action_db=action_db, - runner_db=runner_db, - merged_only=True) + all_params = action_param_utils.get_params_view( + action_db=action_db, runner_db=runner_db, merged_only=True + ) action_api.parameters = all_params resp.json = result @@ -185,9 +192,10 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o @staticmethod def _transform_action_api(action_api, requester_user): action_id = action_api.id - result = ParametersViewController._get_one(action_id=action_id, - requester_user=requester_user) - action_api.parameters = result.get('parameters', {}) + result = ParametersViewController._get_one( + action_id=action_id, requester_user=requester_user + ) + action_api.parameters = result.get("parameters", {}) return action_api @@ -202,35 +210,38 @@ def get_all(self): def get_one(self, ref_or_id, requester_user): """ - Outputs the file associated with action entry_point + Outputs the file associated with action entry_point - Handles requests: - GET /actions/views/entry_point/1 + Handles requests: + GET /actions/views/entry_point/1 """ - LOG.info('GET /actions/views/entry_point with ref_or_id=%s', ref_or_id) + LOG.info("GET /actions/views/entry_point with ref_or_id=%s", ref_or_id) action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) permission_type = PermissionType.ACTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) - pack = getattr(action_db, 'pack', None) - entry_point = getattr(action_db, 'entry_point', None) + pack = getattr(action_db, "pack", None) + entry_point = getattr(action_db, "entry_point", None) abs_path = utils.get_entry_point_abs_path(pack, entry_point) if not abs_path: - raise StackStormDBObjectNotFoundError('Action ref_or_id=%s has no entry_point to output' - % ref_or_id) + raise StackStormDBObjectNotFoundError( + "Action ref_or_id=%s has no entry_point to output" % ref_or_id + ) - with codecs.open(abs_path, 'r') as fp: + with codecs.open(abs_path, "r") as fp: content = fp.read() # Ensure content is utf-8 if isinstance(content, six.binary_type): - content = content.decode('utf-8') + content = content.decode("utf-8") try: content_type = mimetypes.guess_type(abs_path)[0] @@ -240,15 +251,15 @@ def get_one(self, ref_or_id, requester_user): # Special case if /etc/mime.types doesn't contain entry for yaml, py if not content_type: _, extension = os.path.splitext(abs_path) - if extension in ['.yaml', '.yml']: - content_type = 'application/x-yaml' - elif extension in ['.py']: - content_type = 'application/x-python' + if extension in [".yaml", ".yml"]: + content_type = "application/x-yaml" + elif extension in [".py"]: + content_type = "application/x-python" else: - content_type = 'text/plain' + content_type = "text/plain" response = Response() - response.headers['Content-Type'] = content_type + response.headers["Content-Type"] = content_type response.text = content return response diff --git a/st2api/st2api/controllers/v1/actionalias.py b/st2api/st2api/controllers/v1/actionalias.py index 00e58675f9..5488300d6e 100644 --- a/st2api/st2api/controllers/v1/actionalias.py +++ b/st2api/st2api/controllers/v1/actionalias.py @@ -37,175 +37,219 @@ class ActionAliasController(resource.ContentPackResourceController): """ - Implements the RESTful interface for ActionAliases. + Implements the RESTful interface for ActionAliases. """ + model = ActionAliasAPI access = ActionAlias - supported_filters = { - 'name': 'name', - 'pack': 'pack' - } - - query_options = { - 'sort': ['pack', 'name'] - } - - _custom_actions = { - 'match': ['POST'], - 'help': ['POST'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, - sort=None, offset=0, limit=None, requester_user=None, **raw_filters): - return super(ActionAliasController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack"} + + query_options = {"sort": ["pack", "name"]} + + _custom_actions = {"match": ["POST"], "help": ["POST"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(ActionAliasController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.ACTION_ALIAS_VIEW - return super(ActionAliasController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=permission_type) + return super(ActionAliasController, self)._get_one( + ref_or_id, requester_user=requester_user, permission_type=permission_type + ) def match(self, action_alias_match_api): """ - Find a matching action alias. + Find a matching action alias. - Handles requests: - POST /actionalias/match + Handles requests: + POST /actionalias/match """ command = action_alias_match_api.command try: format_ = get_matching_alias(command=command) except ActionAliasAmbiguityException as e: - LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches)) + LOG.exception( + 'Command "%s" matched (%s) patterns.', e.command, len(e.matches) + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) # Convert ActionAliasDB to API - action_alias_api = ActionAliasAPI.from_model(format_['alias']) + action_alias_api = ActionAliasAPI.from_model(format_["alias"]) return { - 'actionalias': action_alias_api, - 'display': format_['display'], - 'representation': format_['representation'], + "actionalias": action_alias_api, + "display": format_["display"], + "representation": format_["representation"], } def help(self, filter, pack, limit, offset, **kwargs): """ - Get available help strings for action aliases. + Get available help strings for action aliases. - Handles requests: - GET /actionalias/help + Handles requests: + GET /actionalias/help """ try: aliases_resp = super(ActionAliasController, self)._get_all(**kwargs) aliases = [ActionAliasAPI(**alias) for alias in aliases_resp.json] - return generate_helpstring_result(aliases, filter, pack, int(limit), int(offset)) + return generate_helpstring_result( + aliases, filter, pack, int(limit), int(offset) + ) except (TypeError) as e: - LOG.exception('Helpstring request contains an invalid data type: %s.', six.text_type(e)) + LOG.exception( + "Helpstring request contains an invalid data type: %s.", + six.text_type(e), + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) def post(self, action_alias, requester_user): """ - Create a new ActionAlias. + Create a new ActionAlias. - Handles requests: - POST /actionalias/ + Handles requests: + POST /actionalias/ """ permission_type = PermissionType.ACTION_ALIAS_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=action_alias, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=action_alias, + permission_type=permission_type, + ) try: action_alias_db = ActionAliasAPI.to_model(action_alias) - LOG.debug('/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s', - action_alias_db) + LOG.debug( + "/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s", + action_alias_db, + ) action_alias_db = ActionAlias.add_or_update(action_alias_db) except (ValidationError, ValueError, ValueValidationException) as e: - LOG.exception('Validation failed for action alias data=%s.', action_alias) + LOG.exception("Validation failed for action alias data=%s.", action_alias) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias created. ActionAlias.id=%s' % (action_alias_db.id), extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias created. ActionAlias.id=%s" % (action_alias_db.id), + extra=extra, + ) action_alias_api = ActionAliasAPI.from_model(action_alias_db) return Response(json=action_alias_api, status=http_client.CREATED) def put(self, action_alias, ref_or_id, requester_user): """ - Update an action alias. + Update an action alias. - Handles requests: - PUT /actionalias/1 + Handles requests: + PUT /actionalias/1 """ action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('PUT /actionalias/ lookup with id=%s found object: %s', ref_or_id, - action_alias_db) + LOG.debug( + "PUT /actionalias/ lookup with id=%s found object: %s", + ref_or_id, + action_alias_db, + ) permission_type = PermissionType.ACTION_ALIAS_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_alias_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_alias_db, + permission_type=permission_type, + ) - if not hasattr(action_alias, 'id'): + if not hasattr(action_alias, "id"): action_alias.id = None try: - if action_alias.id is not None and action_alias.id != '' and \ - action_alias.id != ref_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - action_alias.id, ref_or_id) + if ( + action_alias.id is not None + and action_alias.id != "" + and action_alias.id != ref_or_id + ): + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + action_alias.id, + ref_or_id, + ) old_action_alias_db = action_alias_db action_alias_db = ActionAliasAPI.to_model(action_alias) action_alias_db.id = ref_or_id action_alias_db = ActionAlias.add_or_update(action_alias_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for action alias data=%s', action_alias) + LOG.exception("Validation failed for action alias data=%s", action_alias) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_action_alias_db': old_action_alias_db, 'new_action_alias_db': action_alias_db} - LOG.audit('Action alias updated. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra) + extra = { + "old_action_alias_db": old_action_alias_db, + "new_action_alias_db": action_alias_db, + } + LOG.audit( + "Action alias updated. ActionAlias.id=%s." % (action_alias_db.id), + extra=extra, + ) action_alias_api = ActionAliasAPI.from_model(action_alias_db) return action_alias_api def delete(self, ref_or_id, requester_user): """ - Delete an action alias. + Delete an action alias. - Handles requests: - DELETE /actionalias/1 + Handles requests: + DELETE /actionalias/1 """ action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('DELETE /actionalias/ lookup with id=%s found object: %s', ref_or_id, - action_alias_db) + LOG.debug( + "DELETE /actionalias/ lookup with id=%s found object: %s", + ref_or_id, + action_alias_db, + ) permission_type = PermissionType.ACTION_ALIAS_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_alias_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_alias_db, + permission_type=permission_type, + ) try: ActionAlias.delete(action_alias_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s".', - ref_or_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s".', + ref_or_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias deleted. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias deleted. ActionAlias.id=%s." % (action_alias_db.id), + extra=extra, + ) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/actionexecutions.py b/st2api/st2api/controllers/v1/actionexecutions.py index b0aa4e9e1d..3cc7741b2d 100644 --- a/st2api/st2api/controllers/v1/actionexecutions.py +++ b/st2api/st2api/controllers/v1/actionexecutions.py @@ -54,18 +54,15 @@ from st2common.rbac.types import PermissionType from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'ActionExecutionsController' -] +__all__ = ["ActionExecutionsController"] LOG = logging.getLogger(__name__) # Note: We initialize filters here and not in the constructor SUPPORTED_EXECUTIONS_FILTERS = copy.deepcopy(SUPPORTED_FILTERS) -SUPPORTED_EXECUTIONS_FILTERS.update({ - 'timestamp_gt': 'start_timestamp.gt', - 'timestamp_lt': 'start_timestamp.lt' -}) +SUPPORTED_EXECUTIONS_FILTERS.update( + {"timestamp_gt": "start_timestamp.gt", "timestamp_lt": "start_timestamp.lt"} +) MONITOR_THREAD_EMPTY_Q_SLEEP_TIME = 5 MONITOR_THREAD_NO_WORKERS_SLEEP_TIME = 1 @@ -82,29 +79,24 @@ class ActionExecutionsControllerMixin(BaseRestControllerMixin): # Those two attributes are mandatory so we can correctly determine and mask secret execution # parameters mandatory_include_fields_retrieve = [ - 'action.parameters', - 'runner.runner_parameters', - 'parameters', - + "action.parameters", + "runner.runner_parameters", + "parameters", # Attributes below are mandatory for RBAC installations - 'action.pack', - 'action.uid', - + "action.pack", + "action.uid", # Required when rbac.permission_isolation is enabled - 'context' + "context", ] # A list of attributes which can be specified using ?exclude_attributes filter # NOTE: Allowing user to exclude attribute such as action and runner would break secrets # masking - valid_exclude_attributes = [ - 'result', - 'trigger_instance', - 'status' - ] + valid_exclude_attributes = ["result", "trigger_instance", "status"] - def _handle_schedule_execution(self, liveaction_api, requester_user, context_string=None, - show_secrets=False): + def _handle_schedule_execution( + self, liveaction_api, requester_user, context_string=None, show_secrets=False + ): """ :param liveaction: LiveActionAPI object. :type liveaction: :class:`LiveActionAPI` @@ -124,101 +116,129 @@ def _handle_schedule_execution(self, liveaction_api, requester_user, context_str # Assert the permissions permission_type = PermissionType.ACTION_EXECUTE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) # Validate that the authenticated user is admin if user query param is provided user = liveaction_api.user or requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) try: - return self._schedule_execution(liveaction=liveaction_api, - requester_user=requester_user, - user=user, - context_string=context_string, - show_secrets=show_secrets, - action_db=action_db) + return self._schedule_execution( + liveaction=liveaction_api, + requester_user=requester_user, + user=user, + context_string=context_string, + show_secrets=show_secrets, + action_db=action_db, + ) except ValueError as e: - LOG.exception('Unable to execute action.') + LOG.exception("Unable to execute action.") abort(http_client.BAD_REQUEST, six.text_type(e)) except jsonschema.ValidationError as e: - LOG.exception('Unable to execute action. Parameter validation failed.') - abort(http_client.BAD_REQUEST, re.sub("u'([^']*)'", r"'\1'", - getattr(e, 'message', six.text_type(e)))) + LOG.exception("Unable to execute action. Parameter validation failed.") + abort( + http_client.BAD_REQUEST, + re.sub("u'([^']*)'", r"'\1'", getattr(e, "message", six.text_type(e))), + ) except trace_exc.TraceNotFoundException as e: abort(http_client.BAD_REQUEST, six.text_type(e)) except validation_exc.ValueValidationException as e: raise e except Exception as e: - LOG.exception('Unable to execute action. Unexpected error encountered.') + LOG.exception("Unable to execute action. Unexpected error encountered.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) - def _schedule_execution(self, liveaction, requester_user, action_db, user=None, - context_string=None, show_secrets=False): + def _schedule_execution( + self, + liveaction, + requester_user, + action_db, + user=None, + context_string=None, + show_secrets=False, + ): # Initialize execution context if it does not exist. - if not hasattr(liveaction, 'context'): + if not hasattr(liveaction, "context"): liveaction.context = dict() - liveaction.context['user'] = user - liveaction.context['pack'] = action_db.pack + liveaction.context["user"] = user + liveaction.context["pack"] = action_db.pack - LOG.debug('User is: %s' % liveaction.context['user']) + LOG.debug("User is: %s" % liveaction.context["user"]) # Retrieve other st2 context from request header. if context_string: context = try_loads(context_string) if not isinstance(context, dict): - raise ValueError('Unable to convert st2-context from the headers into JSON.') + raise ValueError( + "Unable to convert st2-context from the headers into JSON." + ) liveaction.context.update(context) # Include RBAC context (if RBAC is available and enabled) if cfg.CONF.rbac.enable: user_db = UserDB(name=user) rbac_service = get_rbac_backend().get_service_class() - role_dbs = rbac_service.get_roles_for_user(user_db=user_db, include_remote=True) + role_dbs = rbac_service.get_roles_for_user( + user_db=user_db, include_remote=True + ) roles = [role_db.name for role_db in role_dbs] - liveaction.context['rbac'] = { - 'user': user, - 'roles': roles - } + liveaction.context["rbac"] = {"user": user, "roles": roles} # Schedule the action execution. liveaction_db = LiveActionAPI.to_model(liveaction) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) try: liveaction_db.parameters = param_utils.render_live_params( - runnertype_db.runner_parameters, action_db.parameters, liveaction_db.parameters, - liveaction_db.context) + runnertype_db.runner_parameters, + action_db.parameters, + liveaction_db.parameters, + liveaction_db.context, + ) except param_exc.ParamException: # We still need to create a request, so liveaction_db is assigned an ID liveaction_db, actionexecution_db = action_service.create_request( liveaction=liveaction_db, action_db=action_db, - runnertype_db=runnertype_db) + runnertype_db=runnertype_db, + ) # By this point the execution is already in the DB therefore need to mark it failed. _, e, tb = sys.exc_info() action_service.update_status( liveaction=liveaction_db, new_status=action_constants.LIVEACTION_STATUS_FAILED, - result={'error': six.text_type(e), - 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": six.text_type(e), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) # Might be a good idea to return the actual ActionExecution rather than bubble up # the exception. raise validation_exc.ValueValidationException(six.text_type(e)) # The request should be created after the above call to render_live_params # so any templates in live parameters have a chance to render. - liveaction_db, actionexecution_db = action_service.create_request(liveaction=liveaction_db, - action_db=action_db, - runnertype_db=runnertype_db) + liveaction_db, actionexecution_db = action_service.create_request( + liveaction=liveaction_db, action_db=action_db, runnertype_db=runnertype_db + ) - _, actionexecution_db = action_service.publish_request(liveaction_db, actionexecution_db) + _, actionexecution_db = action_service.publish_request( + liveaction_db, actionexecution_db + ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets) + execution_api = ActionExecutionAPI.from_model( + actionexecution_db, mask_secrets=mask_secrets + ) return Response(json=execution_api, status=http_client.CREATED) @@ -231,25 +251,33 @@ def _get_result_object(self, id): :rtype: ``dict`` """ - fields = ['result'] - action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get() + fields = ["result"] + action_exec_db = ( + self.access.impl.model.objects.filter(id=id).only(*fields).get() + ) return action_exec_db.result - def _get_children(self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False): + def _get_children( + self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False + ): # make sure depth is int. Url encoding will make it a string and needs to # be converted back in that case. depth = int(depth) - LOG.debug('retrieving children for id: %s with depth: %s', id_, depth) - descendants = execution_service.get_descendants(actionexecution_id=id_, - descendant_depth=depth, - result_fmt=result_fmt) + LOG.debug("retrieving children for id: %s with depth: %s", id_, depth) + descendants = execution_service.get_descendants( + actionexecution_id=id_, descendant_depth=depth, result_fmt=result_fmt + ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - return [self.model.from_model(descendant, mask_secrets=mask_secrets) for - descendant in descendants] + return [ + self.model.from_model(descendant, mask_secrets=mask_secrets) + for descendant in descendants + ] -class BaseActionExecutionNestedController(ActionExecutionsControllerMixin, ResourceController): +class BaseActionExecutionNestedController( + ActionExecutionsControllerMixin, ResourceController +): # Note: We need to override "get_one" and "get_all" to return 404 since nested controller # don't implement thos methods @@ -265,24 +293,36 @@ def get_one(self, id): class ActionExecutionChildrenController(BaseActionExecutionNestedController): - def get_one(self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False): + def get_one( + self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False + ): """ Retrieve children for the provided action execution. :rtype: ``list`` """ - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) id = str(execution_db.id) - return self._get_children(id_=id, depth=depth, result_fmt=result_fmt, - requester_user=requester_user, show_secrets=show_secrets) + return self._get_children( + id_=id, + depth=depth, + result_fmt=result_fmt, + requester_user=requester_user, + show_secrets=show_secrets, + ) class ActionExecutionAttributeController(BaseActionExecutionNestedController): - valid_exclude_attributes = ['action__pack', 'action__uid'] + \ - ActionExecutionsControllerMixin.valid_exclude_attributes + valid_exclude_attributes = [ + "action__pack", + "action__uid", + ] + ActionExecutionsControllerMixin.valid_exclude_attributes def get(self, id, attribute, requester_user): """ @@ -294,76 +334,94 @@ def get(self, id, attribute, requester_user): :rtype: ``dict`` """ - fields = [attribute, 'action__pack', 'action__uid'] + fields = [attribute, "action__pack", "action__uid"] try: fields = self._validate_exclude_fields(fields) except ValueError: - valid_attributes = ', '.join(ActionExecutionsControllerMixin.valid_exclude_attributes) - msg = ('Invalid attribute "%s" specified. Valid attributes are: %s' % - (attribute, valid_attributes)) + valid_attributes = ", ".join( + ActionExecutionsControllerMixin.valid_exclude_attributes + ) + msg = 'Invalid attribute "%s" specified. Valid attributes are: %s' % ( + attribute, + valid_attributes, + ) raise ValueError(msg) - action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get() + action_exec_db = ( + self.access.impl.model.objects.filter(id=id).only(*fields).get() + ) permission_type = PermissionType.EXECUTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_exec_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_exec_db, + permission_type=permission_type, + ) result = getattr(action_exec_db, attribute, None) return Response(json=result, status=http_client.OK) -class ActionExecutionOutputController(ActionExecutionsControllerMixin, ResourceController): - supported_filters = { - 'output_type': 'output_type' - } +class ActionExecutionOutputController( + ActionExecutionsControllerMixin, ResourceController +): + supported_filters = {"output_type": "output_type"} exclude_fields = [] - def get_one(self, id, output_type='all', output_format='raw', existing_only=False, - requester_user=None): + def get_one( + self, + id, + output_type="all", + output_format="raw", + existing_only=False, + requester_user=None, + ): # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).first() + if id == "last": + execution_db = ActionExecution.query().order_by("-id").limit(1).first() if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) execution_id = str(execution_db.id) query_filters = {} - if output_type and output_type != 'all': - query_filters['output_type'] = output_type + if output_type and output_type != "all": + query_filters["output_type"] = output_type def existing_output_iter(): # Consume and return all of the existing lines # pylint: disable=no-member - output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters) + output_dbs = ActionExecutionOutput.query( + execution_id=execution_id, **query_filters + ) - output = ''.join([output_db.data for output_db in output_dbs]) - yield six.binary_type(output.encode('utf-8')) + output = "".join([output_db.data for output_db in output_dbs]) + yield six.binary_type(output.encode("utf-8")) def make_response(): app_iter = existing_output_iter() - res = Response(content_type='text/plain', app_iter=app_iter) + res = Response(content_type="text/plain", app_iter=app_iter) return res res = make_response() return res -class ActionExecutionReRunController(ActionExecutionsControllerMixin, ResourceController): +class ActionExecutionReRunController( + ActionExecutionsControllerMixin, ResourceController +): supported_filters = {} - exclude_fields = [ - 'result', - 'trigger_instance' - ] + exclude_fields = ["result", "trigger_instance"] class ExecutionSpecificationAPI(object): def __init__(self, parameters=None, tasks=None, reset=None, user=None): @@ -374,8 +432,10 @@ def __init__(self, parameters=None, tasks=None, reset=None, user=None): def validate(self): if (self.tasks or self.reset) and self.parameters: - raise ValueError('Parameters override is not supported when ' - 're-running task(s) for a workflow.') + raise ValueError( + "Parameters override is not supported when " + "re-running task(s) for a workflow." + ) if self.parameters: assert isinstance(self.parameters, dict) @@ -387,7 +447,9 @@ def validate(self): assert isinstance(self.reset, list) if list(set(self.reset) - set(self.tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) return self @@ -401,8 +463,10 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) """ if (spec_api.tasks or spec_api.reset) and spec_api.parameters: - raise ValueError('Parameters override is not supported when ' - 're-running task(s) for a workflow.') + raise ValueError( + "Parameters override is not supported when " + "re-running task(s) for a workflow." + ) if spec_api.parameters: assert isinstance(spec_api.parameters, dict) @@ -414,7 +478,9 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) assert isinstance(spec_api.reset, list) if list(set(spec_api.reset) - set(spec_api.tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) delay = None @@ -422,59 +488,69 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) delay = spec_api.delay no_merge = cast_argument_value(value_type=bool, value=no_merge) - existing_execution = self._get_one_by_id(id=id, exclude_fields=self.exclude_fields, - requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + existing_execution = self._get_one_by_id( + id=id, + exclude_fields=self.exclude_fields, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) - if spec_api.tasks and \ - existing_execution.runner['name'] != 'orquesta': - raise ValueError('Task option is only supported for Orquesta workflows.') + if spec_api.tasks and existing_execution.runner["name"] != "orquesta": + raise ValueError("Task option is only supported for Orquesta workflows.") # Merge in any parameters provided by the user new_parameters = {} if not no_merge: - new_parameters.update(getattr(existing_execution, 'parameters', {})) + new_parameters.update(getattr(existing_execution, "parameters", {})) new_parameters.update(spec_api.parameters) # Create object for the new execution - action_ref = existing_execution.action['ref'] + action_ref = existing_execution.action["ref"] # Include additional option(s) for the execution context = { - 're-run': { - 'ref': id, + "re-run": { + "ref": id, } } if spec_api.tasks: - context['re-run']['tasks'] = spec_api.tasks + context["re-run"]["tasks"] = spec_api.tasks if spec_api.reset: - context['re-run']['reset'] = spec_api.reset + context["re-run"]["reset"] = spec_api.reset # Add trace to the new execution trace = trace_service.get_trace_db_by_action_execution( - action_execution_id=existing_execution.id) + action_execution_id=existing_execution.id + ) if trace: - context['trace_context'] = {'id_': str(trace.id)} - - new_liveaction_api = LiveActionCreateAPI(action=action_ref, - context=context, - parameters=new_parameters, - user=spec_api.user, - delay=delay) - - return self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user, - show_secrets=show_secrets) - - -class ActionExecutionsController(BaseResourceIsolationControllerMixin, - ActionExecutionsControllerMixin, ResourceController): + context["trace_context"] = {"id_": str(trace.id)} + + new_liveaction_api = LiveActionCreateAPI( + action=action_ref, + context=context, + parameters=new_parameters, + user=spec_api.user, + delay=delay, + ) + + return self._handle_schedule_execution( + liveaction_api=new_liveaction_api, + requester_user=requester_user, + show_secrets=show_secrets, + ) + + +class ActionExecutionsController( + BaseResourceIsolationControllerMixin, + ActionExecutionsControllerMixin, + ResourceController, +): """ - Implements the RESTful web endpoint that handles - the lifecycle of ActionExecutions in the system. + Implements the RESTful web endpoint that handles + the lifecycle of ActionExecutions in the system. """ # Nested controllers @@ -485,17 +561,25 @@ class ActionExecutionsController(BaseResourceIsolationControllerMixin, re_run = ActionExecutionReRunController() # ResourceController attributes - query_options = { - 'sort': ['-start_timestamp', 'action.ref'] - } + query_options = {"sort": ["-start_timestamp", "action.ref"]} supported_filters = SUPPORTED_EXECUTIONS_FILTERS filter_transform_functions = { - 'timestamp_gt': lambda value: isotime.parse(value=value), - 'timestamp_lt': lambda value: isotime.parse(value=value) + "timestamp_gt": lambda value: isotime.parse(value=value), + "timestamp_lt": lambda value: isotime.parse(value=value), } - def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0, limit=None, - show_secrets=False, include_attributes=None, advanced_filters=None, **raw_filters): + def get_all( + self, + requester_user, + exclude_attributes=None, + sort=None, + offset=0, + limit=None, + show_secrets=False, + include_attributes=None, + advanced_filters=None, + **raw_filters, + ): """ List all executions. @@ -508,27 +592,37 @@ def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0, # Use a custom sort order when filtering on a timestamp so we return a correct result as # expected by the user query_options = None - if raw_filters.get('timestamp_lt', None) or raw_filters.get('sort_desc', None): - query_options = {'sort': ['-start_timestamp', 'action.ref']} - elif raw_filters.get('timestamp_gt', None) or raw_filters.get('sort_asc', None): - query_options = {'sort': ['+start_timestamp', 'action.ref']} + if raw_filters.get("timestamp_lt", None) or raw_filters.get("sort_desc", None): + query_options = {"sort": ["-start_timestamp", "action.ref"]} + elif raw_filters.get("timestamp_gt", None) or raw_filters.get("sort_asc", None): + query_options = {"sort": ["+start_timestamp", "action.ref"]} from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - return self._get_action_executions(exclude_fields=exclude_attributes, - include_fields=include_attributes, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - advanced_filters=advanced_filters, - requester_user=requester_user) - - def get_one(self, id, requester_user, exclude_attributes=None, include_attributes=None, - show_secrets=False): + return self._get_action_executions( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + advanced_filters=advanced_filters, + requester_user=requester_user, + ) + + def get_one( + self, + id, + requester_user, + exclude_attributes=None, + include_attributes=None, + show_secrets=False, + ): """ Retrieve a single execution. @@ -538,33 +632,48 @@ def get_one(self, id, requester_user, exclude_attributes=None, include_attribute :param exclude_attributes: List of attributes to exclude from the object. :type exclude_attributes: ``list`` """ - exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_attributes) - include_fields = self._validate_include_fields(include_fields=include_attributes) + exclude_fields = self._validate_exclude_fields( + exclude_fields=exclude_attributes + ) + include_fields = self._validate_include_fields( + include_fields=include_attributes + ) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).only('id').first() + if id == "last": + execution_db = ( + ActionExecution.query().order_by("-id").limit(1).only("id").first() + ) if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - return self._get_one_by_id(id=id, exclude_fields=exclude_fields, - include_fields=include_fields, - requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_VIEW) - - def post(self, liveaction_api, requester_user, context_string=None, show_secrets=False): - return self._handle_schedule_execution(liveaction_api=liveaction_api, - requester_user=requester_user, - context_string=context_string, - show_secrets=show_secrets) + return self._get_one_by_id( + id=id, + exclude_fields=exclude_fields, + include_fields=include_fields, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_VIEW, + ) + + def post( + self, liveaction_api, requester_user, context_string=None, show_secrets=False + ): + return self._handle_schedule_execution( + liveaction_api=liveaction_api, + requester_user=requester_user, + context_string=context_string, + show_secrets=show_secrets, + ) def put(self, id, liveaction_api, requester_user, show_secrets=False): """ @@ -578,76 +687,118 @@ def put(self, id, liveaction_api, requester_user, show_secrets=False): requester_user = UserDB(cfg.CONF.system_user.user) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - execution_api = self._get_one_by_id(id=id, requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_STOP) + execution_api = self._get_one_by_id( + id=id, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_STOP, + ) if not execution_api: - abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id) + abort(http_client.NOT_FOUND, "Execution with id %s not found." % id) - liveaction_id = execution_api.liveaction['id'] + liveaction_id = execution_api.liveaction["id"] if not liveaction_id: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) try: liveaction_db = LiveAction.get_by_id(liveaction_id) except: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES: - abort(http_client.BAD_REQUEST, 'Execution is already in completed state.') + abort(http_client.BAD_REQUEST, "Execution is already in completed state.") def update_status(liveaction_api, liveaction_db): status = liveaction_api.status - result = getattr(liveaction_api, 'result', None) + result = getattr(liveaction_api, "result", None) liveaction_db = action_service.update_status(liveaction_db, status, result) - actionexecution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) + actionexecution_db = ActionExecution.get( + liveaction__id=str(liveaction_db.id) + ) return (liveaction_db, actionexecution_db) try: - if (liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING and - liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED): + if ( + liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING + and liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED + ): if action_service.is_children_active(liveaction_id): liveaction_api.status = action_constants.LIVEACTION_STATUS_CANCELING - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) - elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING or - liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED): + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) + elif ( + liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING + or liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED + ): liveaction_db, actionexecution_db = action_service.request_cancellation( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) - elif (liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING and - liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED): + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) + elif ( + liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING + and liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED + ): if action_service.is_children_active(liveaction_id): liveaction_api.status = action_constants.LIVEACTION_STATUS_PAUSING - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) - elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING or - liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED): + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) + elif ( + liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING + or liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED + ): liveaction_db, actionexecution_db = action_service.request_pause( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) elif liveaction_api.status == action_constants.LIVEACTION_STATUS_RESUMING: liveaction_db, actionexecution_db = action_service.request_resume( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) else: - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) except runner_exc.InvalidActionRunnerOperationError as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) - abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) + abort( + http_client.BAD_REQUEST, + "Failed updating execution. %s" % six.text_type(e), + ) except runner_exc.UnexpectedActionExecutionStatusError as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) - abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) + abort( + http_client.BAD_REQUEST, + "Failed updating execution. %s" % six.text_type(e), + ) except Exception as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) abort( http_client.INTERNAL_SERVER_ERROR, - 'Failed updating execution due to unexpected error.' + "Failed updating execution due to unexpected error.", ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets) + execution_api = ActionExecutionAPI.from_model( + actionexecution_db, mask_secrets=mask_secrets + ) return execution_api @@ -663,50 +814,76 @@ def delete(self, id, requester_user, show_secrets=False): requester_user = UserDB(cfg.CONF.system_user.user) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - execution_api = self._get_one_by_id(id=id, requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_STOP) + execution_api = self._get_one_by_id( + id=id, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_STOP, + ) if not execution_api: - abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id) + abort(http_client.NOT_FOUND, "Execution with id %s not found." % id) - liveaction_id = execution_api.liveaction['id'] + liveaction_id = execution_api.liveaction["id"] if not liveaction_id: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) try: liveaction_db = LiveAction.get_by_id(liveaction_id) except: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) if liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELED: LOG.info( 'Action %s already in "canceled" state; \ - returning execution object.' % liveaction_db.id + returning execution object.' + % liveaction_db.id ) return execution_api if liveaction_db.status not in action_constants.LIVEACTION_CANCELABLE_STATES: - abort(http_client.OK, 'Action cannot be canceled. State = %s.' % liveaction_db.status) + abort( + http_client.OK, + "Action cannot be canceled. State = %s." % liveaction_db.status, + ) try: (liveaction_db, execution_db) = action_service.request_cancellation( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) except: - LOG.exception('Failed requesting cancellation for liveaction %s.', liveaction_db.id) - abort(http_client.INTERNAL_SERVER_ERROR, 'Failed canceling execution.') - - return ActionExecutionAPI.from_model(execution_db, - mask_secrets=from_model_kwargs['mask_secrets']) - - def _get_action_executions(self, exclude_fields=None, include_fields=None, - sort=None, offset=0, limit=None, advanced_filters=None, - query_options=None, raw_filters=None, from_model_kwargs=None, - requester_user=None): + LOG.exception( + "Failed requesting cancellation for liveaction %s.", liveaction_db.id + ) + abort(http_client.INTERNAL_SERVER_ERROR, "Failed canceling execution.") + + return ActionExecutionAPI.from_model( + execution_db, mask_secrets=from_model_kwargs["mask_secrets"] + ) + + def _get_action_executions( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + advanced_filters=None, + query_options=None, + raw_filters=None, + from_model_kwargs=None, + requester_user=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -717,18 +894,25 @@ def _get_action_executions(self, exclude_fields=None, include_fields=None, limit = int(limit) - LOG.debug('Retrieving all action executions with filters=%s,exclude_fields=%s,' - 'include_fields=%s', raw_filters, exclude_fields, include_fields) - return super(ActionExecutionsController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - advanced_filters=advanced_filters, - requester_user=requester_user) + LOG.debug( + "Retrieving all action executions with filters=%s,exclude_fields=%s," + "include_fields=%s", + raw_filters, + exclude_fields, + include_fields, + ) + return super(ActionExecutionsController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + advanced_filters=advanced_filters, + requester_user=requester_user, + ) action_executions_controller = ActionExecutionsController() diff --git a/st2api/st2api/controllers/v1/actions.py b/st2api/st2api/controllers/v1/actions.py index 1746e84b83..c78667076d 100644 --- a/st2api/st2api/controllers/v1/actions.py +++ b/st2api/st2api/controllers/v1/actions.py @@ -53,91 +53,102 @@ class ActionsController(resource.ContentPackResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of Actions in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Actions in the system. """ + views = ActionViewsController() model = ActionAPI access = Action - supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'tags': 'tags.name' - } + supported_filters = {"name": "name", "pack": "pack", "tags": "tags.name"} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - valid_exclude_attributes = [ - 'parameters', - 'notify' - ] + valid_exclude_attributes = ["parameters", "notify"] def __init__(self, *args, **kwargs): super(ActionsController, self).__init__(*args, **kwargs) self._trigger_dispatcher = TriggerDispatcher(LOG) - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(ActionsController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(ActionsController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): - return super(ActionsController, self)._get_one(ref_or_id, requester_user=requester_user, - permission_type=PermissionType.ACTION_VIEW) + return super(ActionsController, self)._get_one( + ref_or_id, + requester_user=requester_user, + permission_type=PermissionType.ACTION_VIEW, + ) def post(self, action, requester_user): """ - Create a new action. + Create a new action. - Handles requests: - POST /actions/ + Handles requests: + POST /actions/ """ permission_type = PermissionType.ACTION_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=action, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, resource_api=action, permission_type=permission_type + ) try: # Perform validation validate_not_part_of_system_pack(action) action_validator.validate_action(action) - except (ValidationError, ValueError, - ValueValidationException, InvalidActionParameterException) as e: - LOG.exception('Unable to create action data=%s', action) + except ( + ValidationError, + ValueError, + ValueValidationException, + InvalidActionParameterException, + ) as e: + LOG.exception("Unable to create action data=%s", action) abort(http_client.BAD_REQUEST, six.text_type(e)) return # Write pack data files to disk (if any are provided) - data_files = getattr(action, 'data_files', []) + data_files = getattr(action, "data_files", []) written_data_files = [] if data_files: - written_data_files = self._handle_data_files(pack_ref=action.pack, - data_files=data_files) + written_data_files = self._handle_data_files( + pack_ref=action.pack, data_files=data_files + ) action_model = ActionAPI.to_model(action) - LOG.debug('/actions/ POST verified ActionAPI object=%s', action) + LOG.debug("/actions/ POST verified ActionAPI object=%s", action) action_db = Action.add_or_update(action_model) - LOG.debug('/actions/ POST saved ActionDB object=%s', action_db) + LOG.debug("/actions/ POST saved ActionDB object=%s", action_db) # Dispatch an internal trigger for each written data file. This way user # automate comitting this files to git using StackStorm rule if written_data_files: - self._dispatch_trigger_for_written_data_files(action_db=action_db, - written_data_files=written_data_files) + self._dispatch_trigger_for_written_data_files( + action_db=action_db, written_data_files=written_data_files + ) - extra = {'acion_db': action_db} - LOG.audit('Action created. Action.id=%s' % (action_db.id), extra=extra) + extra = {"acion_db": action_db} + LOG.audit("Action created. Action.id=%s" % (action_db.id), extra=extra) action_api = ActionAPI.from_model(action_db) return Response(json=action_api, status=http_client.CREATED) @@ -148,13 +159,15 @@ def put(self, action, ref_or_id, requester_user): # Assert permissions permission_type = PermissionType.ACTION_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) action_id = action_db.id - if not getattr(action, 'pack', None): + if not getattr(action, "pack", None): action.pack = action_db.pack # Perform validation @@ -162,70 +175,81 @@ def put(self, action, ref_or_id, requester_user): action_validator.validate_action(action) # Write pack data files to disk (if any are provided) - data_files = getattr(action, 'data_files', []) + data_files = getattr(action, "data_files", []) written_data_files = [] if data_files: - written_data_files = self._handle_data_files(pack_ref=action.pack, - data_files=data_files) + written_data_files = self._handle_data_files( + pack_ref=action.pack, data_files=data_files + ) try: action_db = ActionAPI.to_model(action) - LOG.debug('/actions/ PUT incoming action: %s', action_db) + LOG.debug("/actions/ PUT incoming action: %s", action_db) action_db.id = action_id action_db = Action.add_or_update(action_db) - LOG.debug('/actions/ PUT after add_or_update: %s', action_db) + LOG.debug("/actions/ PUT after add_or_update: %s", action_db) except (ValidationError, ValueError) as e: - LOG.exception('Unable to update action data=%s', action) + LOG.exception("Unable to update action data=%s", action) abort(http_client.BAD_REQUEST, six.text_type(e)) return # Dispatch an internal trigger for each written data file. This way user # automate committing this files to git using StackStorm rule if written_data_files: - self._dispatch_trigger_for_written_data_files(action_db=action_db, - written_data_files=written_data_files) + self._dispatch_trigger_for_written_data_files( + action_db=action_db, written_data_files=written_data_files + ) action_api = ActionAPI.from_model(action_db) - LOG.debug('PUT /actions/ client_result=%s', action_api) + LOG.debug("PUT /actions/ client_result=%s", action_api) return action_api def delete(self, ref_or_id, requester_user): """ - Delete an action. + Delete an action. - Handles requests: - POST /actions/1?_method=delete - DELETE /actions/1 - DELETE /actions/mypack.myaction + Handles requests: + POST /actions/1?_method=delete + DELETE /actions/1 + DELETE /actions/mypack.myaction """ action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) action_id = action_db.id permission_type = PermissionType.ACTION_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) try: validate_not_part_of_system_pack(action_db) except ValueValidationException as e: abort(http_client.BAD_REQUEST, six.text_type(e)) - LOG.debug('DELETE /actions/ lookup with ref_or_id=%s found object: %s', - ref_or_id, action_db) + LOG.debug( + "DELETE /actions/ lookup with ref_or_id=%s found object: %s", + ref_or_id, + action_db, + ) try: Action.delete(action_db) except Exception as e: - LOG.error('Database delete encountered exception during delete of id="%s". ' - 'Exception was %s', action_id, e) + LOG.error( + 'Database delete encountered exception during delete of id="%s". ' + "Exception was %s", + action_id, + e, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'action_db': action_db} - LOG.audit('Action deleted. Action.id=%s' % (action_db.id), extra=extra) + extra = {"action_db": action_db} + LOG.audit("Action deleted. Action.id=%s" % (action_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) def _handle_data_files(self, pack_ref, data_files): @@ -238,13 +262,17 @@ def _handle_data_files(self, pack_ref, data_files): 2. Updates affected PackDB model """ # Write files to disk - written_file_paths = self._write_data_files_to_disk(pack_ref=pack_ref, - data_files=data_files) + written_file_paths = self._write_data_files_to_disk( + pack_ref=pack_ref, data_files=data_files + ) # Update affected PackDB model (update a list of files) # Update PackDB - self._update_pack_model(pack_ref=pack_ref, data_files=data_files, - written_file_paths=written_file_paths) + self._update_pack_model( + pack_ref=pack_ref, + data_files=data_files, + written_file_paths=written_file_paths, + ) return written_file_paths @@ -255,23 +283,27 @@ def _write_data_files_to_disk(self, pack_ref, data_files): written_file_paths = [] for data_file in data_files: - file_path = data_file['file_path'] - content = data_file['content'] + file_path = data_file["file_path"] + content = data_file["content"] - file_path = get_pack_resource_file_abs_path(pack_ref=pack_ref, - resource_type='action', - file_path=file_path) + file_path = get_pack_resource_file_abs_path( + pack_ref=pack_ref, resource_type="action", file_path=file_path + ) LOG.debug('Writing data file "%s" to "%s"' % (str(data_file), file_path)) try: - self._write_data_file(pack_ref=pack_ref, file_path=file_path, content=content) + self._write_data_file( + pack_ref=pack_ref, file_path=file_path, content=content + ) except (OSError, IOError) as e: # Throw a more user-friendly exception on Permission denied error if e.errno == errno.EACCES: - msg = ('Unable to write data to "%s" (permission denied). Make sure ' - 'permissions for that pack directory are configured correctly so ' - 'st2api can write to it.' % (file_path)) + msg = ( + 'Unable to write data to "%s" (permission denied). Make sure ' + "permissions for that pack directory are configured correctly so " + "st2api can write to it." % (file_path) + ) raise ValueError(msg) raise e @@ -285,7 +317,9 @@ def _update_pack_model(self, pack_ref, data_files, written_file_paths): """ file_paths = [] # A list of paths relative to the pack directory for new files for file_path in written_file_paths: - file_path = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) + file_path = get_relative_path_to_pack_file( + pack_ref=pack_ref, file_path=file_path + ) file_paths.append(file_path) pack_db = Pack.get_by_ref(pack_ref) @@ -314,18 +348,18 @@ def _write_data_file(self, pack_ref, file_path, content): mode = stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH os.makedirs(directory, mode) - with open(file_path, 'w') as fp: + with open(file_path, "w") as fp: fp.write(content) def _dispatch_trigger_for_written_data_files(self, action_db, written_data_files): - trigger = ACTION_FILE_WRITTEN_TRIGGER['name'] + trigger = ACTION_FILE_WRITTEN_TRIGGER["name"] host_info = get_host_info() for file_path in written_data_files: payload = { - 'ref': action_db.ref, - 'file_path': file_path, - 'host_info': host_info + "ref": action_db.ref, + "file_path": file_path, + "host_info": host_info, } self._trigger_dispatcher.dispatch(trigger=trigger, payload=payload) diff --git a/st2api/st2api/controllers/v1/aliasexecution.py b/st2api/st2api/controllers/v1/aliasexecution.py index 7ecc14d62e..ecbea2028e 100644 --- a/st2api/st2api/controllers/v1/aliasexecution.py +++ b/st2api/st2api/controllers/v1/aliasexecution.py @@ -30,7 +30,9 @@ from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.notification import NotificationSchema, NotificationSubSchema from st2common.models.utils import action_param_utils -from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db +from st2common.models.utils.action_alias_utils import ( + extract_parameters_for_action_alias_db, +) from st2common.models.utils.action_alias_utils import inject_immutable_parameters from st2common.persistence.actionalias import ActionAlias from st2common.services import action as action_service @@ -53,57 +55,60 @@ def cast_array(value): # Already a list, no casting needed nor wanted. return value - return [v.strip() for v in value.split(',')] + return [v.strip() for v in value.split(",")] CAST_OVERRIDES = { - 'array': cast_array, + "array": cast_array, } class ActionAliasExecutionController(BaseRestControllerMixin): def match_and_execute(self, input_api, requester_user, show_secrets=False): """ - Try to find a matching alias and if one is found, schedule a new - execution by parsing parameters from the provided command against - the matched alias. + Try to find a matching alias and if one is found, schedule a new + execution by parsing parameters from the provided command against + the matched alias. - Handles requests: - POST /aliasexecution/match_and_execute + Handles requests: + POST /aliasexecution/match_and_execute """ command = input_api.command try: format_ = get_matching_alias(command=command) except ActionAliasAmbiguityException as e: - LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches)) + LOG.exception( + 'Command "%s" matched (%s) patterns.', e.command, len(e.matches) + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) - action_alias_db = format_['alias'] - representation = format_['representation'] + action_alias_db = format_["alias"] + representation = format_["representation"] params = { - 'name': action_alias_db.name, - 'format': representation, - 'command': command, - 'user': input_api.user, - 'source_channel': input_api.source_channel, + "name": action_alias_db.name, + "format": representation, + "command": command, + "user": input_api.user, + "source_channel": input_api.source_channel, } # Add in any additional parameters provided by the user if input_api.notification_channel: - params['notification_channel'] = input_api.notification_channel + params["notification_channel"] = input_api.notification_channel if input_api.notification_route: - params['notification_route'] = input_api.notification_route + params["notification_route"] = input_api.notification_route alias_execution_api = AliasMatchAndExecuteInputAPI(**params) results = self._post( payload=alias_execution_api, requester_user=requester_user, show_secrets=show_secrets, - match_multiple=format_['match_multiple']) - return Response(json={'results': results}, status=http_client.CREATED) + match_multiple=format_["match_multiple"], + ) + return Response(json={"results": results}, status=http_client.CREATED) def _post(self, payload, requester_user, show_secrets=False, match_multiple=False): action_alias_name = payload.name if payload else None @@ -115,8 +120,8 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - format_str = payload.format or '' - command = payload.command or '' + format_str = payload.format or "" + command = payload.command or "" try: action_alias_db = ActionAlias.get_by_name(action_alias_name) @@ -124,7 +129,9 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals action_alias_db = None if not action_alias_db: - msg = 'Unable to identify action alias with name "%s".' % (action_alias_name) + msg = 'Unable to identify action alias with name "%s".' % ( + action_alias_name + ) abort(http_client.NOT_FOUND, msg) return @@ -138,132 +145,163 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals action_alias_db=action_alias_db, format_str=format_str, param_stream=command, - match_multiple=match_multiple) + match_multiple=match_multiple, + ) else: multiple_execution_parameters = [ extract_parameters_for_action_alias_db( action_alias_db=action_alias_db, format_str=format_str, param_stream=command, - match_multiple=match_multiple) + match_multiple=match_multiple, + ) ] notify = self._get_notify_field(payload) context = { - 'action_alias_ref': reference.get_ref_from_model(action_alias_db), - 'api_user': payload.user, - 'user': requester_user.name, - 'source_channel': payload.source_channel, + "action_alias_ref": reference.get_ref_from_model(action_alias_db), + "api_user": payload.user, + "user": requester_user.name, + "source_channel": payload.source_channel, } inject_immutable_parameters( action_alias_db=action_alias_db, multiple_execution_parameters=multiple_execution_parameters, - action_context=context) + action_context=context, + ) results = [] for execution_parameters in multiple_execution_parameters: - execution = self._schedule_execution(action_alias_db=action_alias_db, - params=execution_parameters, - notify=notify, - context=context, - show_secrets=show_secrets, - requester_user=requester_user) + execution = self._schedule_execution( + action_alias_db=action_alias_db, + params=execution_parameters, + notify=notify, + context=context, + show_secrets=show_secrets, + requester_user=requester_user, + ) result = { - 'execution': execution, - 'actionalias': ActionAliasAPI.from_model(action_alias_db) + "execution": execution, + "actionalias": ActionAliasAPI.from_model(action_alias_db), } if action_alias_db.ack: try: - if 'format' in action_alias_db.ack: - message = render({'alias': action_alias_db.ack['format']}, result)['alias'] + if "format" in action_alias_db.ack: + message = render( + {"alias": action_alias_db.ack["format"]}, result + )["alias"] - result.update({ - 'message': message - }) + result.update({"message": message}) except UndefinedError as e: - result.update({ - 'message': ('Cannot render "format" in field "ack" for alias. ' + - six.text_type(e)) - }) + result.update( + { + "message": ( + 'Cannot render "format" in field "ack" for alias. ' + + six.text_type(e) + ) + } + ) try: - if 'extra' in action_alias_db.ack: - result.update({ - 'extra': render(action_alias_db.ack['extra'], result) - }) + if "extra" in action_alias_db.ack: + result.update( + {"extra": render(action_alias_db.ack["extra"], result)} + ) except UndefinedError as e: - result.update({ - 'extra': ('Cannot render "extra" in field "ack" for alias. ' + - six.text_type(e)) - }) + result.update( + { + "extra": ( + 'Cannot render "extra" in field "ack" for alias. ' + + six.text_type(e) + ) + } + ) results.append(result) return results def post(self, payload, requester_user, show_secrets=False): - results = self._post(payload, requester_user, show_secrets, match_multiple=False) + results = self._post( + payload, requester_user, show_secrets, match_multiple=False + ) return Response(json=results[0], status=http_client.CREATED) def _tokenize_alias_execution(self, alias_execution): - tokens = alias_execution.strip().split(' ', 1) + tokens = alias_execution.strip().split(" ", 1) return (tokens[0], tokens[1] if len(tokens) > 1 else None) def _get_notify_field(self, payload): on_complete = NotificationSubSchema() - route = (getattr(payload, 'notification_route', None) or - getattr(payload, 'notification_channel', None)) + route = getattr(payload, "notification_route", None) or getattr( + payload, "notification_channel", None + ) on_complete.routes = [route] on_complete.data = { - 'user': payload.user, - 'source_channel': payload.source_channel, - 'source_context': getattr(payload, 'source_context', None), + "user": payload.user, + "source_channel": payload.source_channel, + "source_context": getattr(payload, "source_context", None), } notify = NotificationSchema() notify.on_complete = on_complete return notify - def _schedule_execution(self, action_alias_db, params, notify, context, requester_user, - show_secrets): + def _schedule_execution( + self, action_alias_db, params, notify, context, requester_user, show_secrets + ): action_ref = action_alias_db.action_ref action_db = action_utils.get_action_by_ref(action_ref) if not action_db: - raise StackStormDBObjectNotFoundError('Action with ref "%s" not found ' % (action_ref)) + raise StackStormDBObjectNotFoundError( + 'Action with ref "%s" not found ' % (action_ref) + ) rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.ACTION_EXECUTE - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) try: # prior to shipping off the params cast them to the right type. - params = action_param_utils.cast_params(action_ref=action_alias_db.action_ref, - params=params, - cast_overrides=CAST_OVERRIDES) + params = action_param_utils.cast_params( + action_ref=action_alias_db.action_ref, + params=params, + cast_overrides=CAST_OVERRIDES, + ) if not context: context = { - 'action_alias_ref': reference.get_ref_from_model(action_alias_db), - 'user': get_system_username() + "action_alias_ref": reference.get_ref_from_model(action_alias_db), + "user": get_system_username(), } - liveaction = LiveActionDB(action=action_alias_db.action_ref, context=context, - parameters=params, notify=notify) + liveaction = LiveActionDB( + action=action_alias_db.action_ref, + context=context, + parameters=params, + notify=notify, + ) _, action_execution_db = action_service.request(liveaction) - mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - return ActionExecutionAPI.from_model(action_execution_db, mask_secrets=mask_secrets) + mask_secrets = self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) + return ActionExecutionAPI.from_model( + action_execution_db, mask_secrets=mask_secrets + ) except ValueError as e: - LOG.exception('Unable to execute action.') + LOG.exception("Unable to execute action.") abort(http_client.BAD_REQUEST, six.text_type(e)) except jsonschema.ValidationError as e: - LOG.exception('Unable to execute action. Parameter validation failed.') + LOG.exception("Unable to execute action. Parameter validation failed.") abort(http_client.BAD_REQUEST, six.text_type(e)) except Exception as e: - LOG.exception('Unable to execute action. Unexpected error encountered.') + LOG.exception("Unable to execute action. Unexpected error encountered.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) diff --git a/st2api/st2api/controllers/v1/auth.py b/st2api/st2api/controllers/v1/auth.py index 909c8ff4fe..d4741c4bf1 100644 --- a/st2api/st2api/controllers/v1/auth.py +++ b/st2api/st2api/controllers/v1/auth.py @@ -37,9 +37,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'ApiKeyController' -] +__all__ = ["ApiKeyController"] # See st2common.rbac.resolvers.ApiKeyPermissionResolver#user_has_resource_db_permission for resaon @@ -49,13 +47,9 @@ class ApiKeyController(BaseRestControllerMixin): Implements the REST endpoint for managing the key value store. """ - supported_filters = { - 'user': 'user' - } + supported_filters = {"user": "user"} - query_options = { - 'sort': ['user'] - } + query_options = {"sort": ["user"]} def __init__(self): super(ApiKeyController, self).__init__() @@ -63,31 +57,36 @@ def __init__(self): def get_one(self, api_key_id_or_key, requester_user, show_secrets=None): """ - List api keys. + List api keys. - Handle: - GET /apikeys/1 + Handle: + GET /apikeys/1 """ api_key_db = None try: api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key) except ApiKeyNotFoundError: - msg = ('ApiKey matching %s for reference and id not found.' % (api_key_id_or_key)) + msg = "ApiKey matching %s for reference and id not found." % ( + api_key_id_or_key + ) LOG.exception(msg) abort(http_client.NOT_FOUND, msg) permission_type = PermissionType.API_KEY_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) try: - mask_secrets = self._get_mask_secrets(show_secrets=show_secrets, - requester_user=requester_user) + mask_secrets = self._get_mask_secrets( + show_secrets=show_secrets, requester_user=requester_user + ) return ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) except (ValidationError, ValueError) as e: - LOG.exception('Failed to serialize API key.') + LOG.exception("Failed to serialize API key.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) @property @@ -96,29 +95,34 @@ def max_limit(self): def get_all(self, requester_user, show_secrets=None, limit=None, offset=0): """ - List all keys. + List all keys. - Handles requests: - GET /apikeys/ + Handles requests: + GET /apikeys/ """ - mask_secrets = self._get_mask_secrets(show_secrets=show_secrets, - requester_user=requester_user) + mask_secrets = self._get_mask_secrets( + show_secrets=show_secrets, requester_user=requester_user + ) - limit = resource.validate_limit_query_param(limit, requester_user=requester_user) + limit = resource.validate_limit_query_param( + limit, requester_user=requester_user + ) try: api_key_dbs = ApiKey.get_all(limit=limit, offset=offset) - api_keys = [ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) - for api_key_db in api_key_dbs] + api_keys = [ + ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) + for api_key_db in api_key_dbs + ] except OverflowError: msg = 'Offset "%s" specified is more than 32 bit int' % (offset) raise ValueError(msg) resp = Response(json=api_keys) - resp.headers['X-Total-Count'] = str(api_key_dbs.count()) + resp.headers["X-Total-Count"] = str(api_key_dbs.count()) if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp @@ -129,14 +133,16 @@ def post(self, api_key_api, requester_user): permission_type = PermissionType.API_KEY_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=api_key_api, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=api_key_api, + permission_type=permission_type, + ) api_key_db = None api_key = None try: - if not getattr(api_key_api, 'user', None): + if not getattr(api_key_api, "user", None): if requester_user: api_key_api.user = requester_user.name else: @@ -148,22 +154,22 @@ def post(self, api_key_api, requester_user): user_db = UserDB(name=api_key_api.user) User.add_or_update(user_db) - extra = {'username': api_key_api.user, 'user': user_db} + extra = {"username": api_key_api.user, "user": user_db} LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra) # If key_hash is provided use that and do not create a new key. The assumption # is user already has the original api-key - if not getattr(api_key_api, 'key_hash', None): + if not getattr(api_key_api, "key_hash", None): api_key, api_key_hash = auth_util.generate_api_key_and_hash() # store key_hash in DB api_key_api.key_hash = api_key_hash api_key_db = ApiKey.add_or_update(ApiKeyAPI.to_model(api_key_api)) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for api_key data=%s.', api_key_api) + LOG.exception("Validation failed for api_key data=%s.", api_key_api) abort(http_client.BAD_REQUEST, six.text_type(e)) - extra = {'api_key_db': api_key_db} - LOG.audit('ApiKey created. ApiKey.id=%s' % (api_key_db.id), extra=extra) + extra = {"api_key_db": api_key_db} + LOG.audit("ApiKey created. ApiKey.id=%s" % (api_key_db.id), extra=extra) api_key_create_response_api = ApiKeyCreateResponseAPI.from_model(api_key_db) # Return real api_key back to user. A one-way hash of the api_key is stored in the DB @@ -178,9 +184,11 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): permission_type = PermissionType.API_KEY_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) old_api_key_db = api_key_db api_key_db = ApiKeyAPI.to_model(api_key_api) @@ -191,7 +199,7 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): user_db = UserDB(name=api_key_api.user) User.add_or_update(user_db) - extra = {'username': api_key_api.user, 'user': user_db} + extra = {"username": api_key_api.user, "user": user_db} LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra) # Passing in key_hash as MASKED_ATTRIBUTE_VALUE is expected since we do not @@ -203,36 +211,38 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): # Rather than silently ignore any update to key_hash it is better to explicitly # disallow and notify user. if old_api_key_db.key_hash != api_key_db.key_hash: - raise ValueError('Update of key_hash is not allowed.') + raise ValueError("Update of key_hash is not allowed.") api_key_db.id = old_api_key_db.id api_key_db = ApiKey.add_or_update(api_key_db) - extra = {'old_api_key_db': old_api_key_db, 'new_api_key_db': api_key_db} - LOG.audit('API Key updated. ApiKey.id=%s.' % (api_key_db.id), extra=extra) + extra = {"old_api_key_db": old_api_key_db, "new_api_key_db": api_key_db} + LOG.audit("API Key updated. ApiKey.id=%s." % (api_key_db.id), extra=extra) api_key_api = ApiKeyAPI.from_model(api_key_db) return api_key_api def delete(self, api_key_id_or_key, requester_user): """ - Delete the key value pair. + Delete the key value pair. - Handles requests: - DELETE /apikeys/1 + Handles requests: + DELETE /apikeys/1 """ api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key) permission_type = PermissionType.API_KEY_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) ApiKey.delete(api_key_db) - extra = {'api_key_db': api_key_db} - LOG.audit('ApiKey deleted. ApiKey.id=%s' % (api_key_db.id), extra=extra) + extra = {"api_key_db": api_key_db} + LOG.audit("ApiKey deleted. ApiKey.id=%s" % (api_key_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/execution_views.py b/st2api/st2api/controllers/v1/execution_views.py index 9a61bdc321..f4240b94ab 100644 --- a/st2api/st2api/controllers/v1/execution_views.py +++ b/st2api/st2api/controllers/v1/execution_views.py @@ -29,51 +29,51 @@ # response. Failure to do so will eventually result in Chrome hanging out while opening History # tab of st2web. SUPPORTED_FILTERS = { - 'action': 'action.ref', - 'status': 'status', - 'liveaction': 'liveaction.id', - 'parent': 'parent', - 'rule': 'rule.name', - 'runner': 'runner.name', - 'timestamp': 'start_timestamp', - 'trigger': 'trigger.name', - 'trigger_type': 'trigger_type.name', - 'trigger_instance': 'trigger_instance.id', - 'user': 'context.user' + "action": "action.ref", + "status": "status", + "liveaction": "liveaction.id", + "parent": "parent", + "rule": "rule.name", + "runner": "runner.name", + "timestamp": "start_timestamp", + "trigger": "trigger.name", + "trigger_type": "trigger_type.name", + "trigger_instance": "trigger_instance.id", + "user": "context.user", } # A list of fields for which null (None) is a valid value which we include in the list of valid # filters. FILTERS_WITH_VALID_NULL_VALUES = [ - 'parent', - 'rule', - 'trigger', - 'trigger_type', - 'trigger_instance' + "parent", + "rule", + "trigger", + "trigger_type", + "trigger_instance", ] # List of filters that are too broad to distinct by them and are very likely to represent 1 to 1 # relation between filter and particular history record. -IGNORE_FILTERS = ['parent', 'timestamp', 'liveaction', 'trigger_instance'] +IGNORE_FILTERS = ["parent", "timestamp", "liveaction", "trigger_instance"] class FiltersController(object): def get_all(self, types=None): """ - List all distinct filters. + List all distinct filters. - Handles requests: - GET /executions/views/filters[?types=action,rule] + Handles requests: + GET /executions/views/filters[?types=action,rule] - :param types: Comma delimited string of filter types to output. - :type types: ``str`` + :param types: Comma delimited string of filter types to output. + :type types: ``str`` """ filters = {} for name, field in six.iteritems(SUPPORTED_FILTERS): if name not in IGNORE_FILTERS and (not types or name in types): if name not in FILTERS_WITH_VALID_NULL_VALUES: - query = {field.replace('.', '__'): {'$ne': None}} + query = {field.replace(".", "__"): {"$ne": None}} else: query = {} diff --git a/st2api/st2api/controllers/v1/inquiries.py b/st2api/st2api/controllers/v1/inquiries.py index a892076917..fb3bf2e3f0 100644 --- a/st2api/st2api/controllers/v1/inquiries.py +++ b/st2api/st2api/controllers/v1/inquiries.py @@ -34,13 +34,11 @@ from st2common.services import inquiry as inquiry_service -__all__ = [ - 'InquiriesController' -] +__all__ = ["InquiriesController"] LOG = logging.getLogger(__name__) -INQUIRY_RUNNER = 'inquirer' +INQUIRY_RUNNER = "inquirer" class InquiriesController(ResourceController): @@ -55,12 +53,18 @@ class InquiriesController(ResourceController): model = inqy_api_models.InquiryAPI access = ex_db_access.ActionExecution - def get_all(self, exclude_attributes=None, include_attributes=None, requester_user=None, - limit=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + requester_user=None, + limit=None, + **raw_filters, + ): """Retrieve multiple Inquiries - Handles requests: - GET /inquiries/ + Handles requests: + GET /inquiries/ """ # NOTE: This controller retrieves execution objects and returns a new model composed of @@ -70,13 +74,13 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us # filtering before returning the response. raw_inquiries = super(InquiriesController, self)._get_all( exclude_fields=[], - include_fields=['id', 'result'], + include_fields=["id", "result"], limit=limit, raw_filters={ - 'status': action_constants.LIVEACTION_STATUS_PENDING, - 'runner': INQUIRY_RUNNER + "status": action_constants.LIVEACTION_STATUS_PENDING, + "runner": INQUIRY_RUNNER, }, - requester_user=requester_user + requester_user=requester_user, ) # Since "model" is set to InquiryAPI (for good reasons), _get_all returns a list of @@ -90,18 +94,18 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us # Repackage into Response with correct headers resp = api_router.Response(json=inquiries) - resp.headers['X-Total-Count'] = raw_inquiries.headers['X-Total-Count'] + resp.headers["X-Total-Count"] = raw_inquiries.headers["X-Total-Count"] if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp def get_one(self, inquiry_id, requester_user=None): """Retrieve a single Inquiry - Handles requests: - GET /inquiries/ + Handles requests: + GET /inquiries/ """ # Retrieve the inquiry by id. @@ -110,7 +114,7 @@ def get_one(self, inquiry_id, requester_user=None): inquiry = self._get_one_by_id( id=inquiry_id, requester_user=requester_user, - permission_type=rbac_types.PermissionType.INQUIRY_VIEW + permission_type=rbac_types.PermissionType.INQUIRY_VIEW, ) except db_exceptions.StackStormDBObjectNotFoundError as e: LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id) @@ -132,15 +136,18 @@ def get_one(self, inquiry_id, requester_user=None): def put(self, inquiry_id, response_data, requester_user): """Provide response data to an Inquiry - In general, provided the response data validates against the provided - schema, and the user has the appropriate permissions to respond, - this will set the Inquiry execution to a successful status, and resume - the parent workflow. + In general, provided the response data validates against the provided + schema, and the user has the appropriate permissions to respond, + this will set the Inquiry execution to a successful status, and resume + the parent workflow. - Handles requests: - PUT /inquiries/ + Handles requests: + PUT /inquiries/ """ - LOG.debug("Inquiry %s received response payload: %s" % (inquiry_id, response_data.response)) + LOG.debug( + "Inquiry %s received response payload: %s" + % (inquiry_id, response_data.response) + ) # Set requester to system user if not provided. if not requester_user: @@ -151,7 +158,7 @@ def put(self, inquiry_id, response_data, requester_user): inquiry = self._get_one_by_id( id=inquiry_id, requester_user=requester_user, - permission_type=rbac_types.PermissionType.INQUIRY_RESPOND + permission_type=rbac_types.PermissionType.INQUIRY_RESPOND, ) except db_exceptions.StackStormDBObjectNotFoundError as e: LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id) @@ -186,18 +193,23 @@ def put(self, inquiry_id, response_data, requester_user): # Respond to inquiry and update if there is a partial response. try: - inquiry_service.respond(inquiry, response_data.response, requester=requester_user) + inquiry_service.respond( + inquiry, response_data.response, requester=requester_user + ) except Exception as e: LOG.exception('Fail to update response for inquiry "%s".' % inquiry_id) api_router.abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) - return { - 'id': inquiry_id, - 'response': response_data.response - } + return {"id": inquiry_id, "response": response_data.response} - def _get_one_by_id(self, id, requester_user, permission_type, - exclude_fields=None, from_model_kwargs=None): + def _get_one_by_id( + self, + id, + requester_user, + permission_type, + exclude_fields=None, + from_model_kwargs=None, + ): """Override ResourceController._get_one_by_id to contain scope of Inquiries UID hack :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -215,8 +227,11 @@ def _get_one_by_id(self, id, requester_user, permission_type, # "inquiry:". # # TODO (mierdin): All of this should be removed once Inquiries get their own DB model. - if (execution_db and getattr(execution_db, 'runner', None) and - execution_db.runner.get('runner_module') == INQUIRY_RUNNER): + if ( + execution_db + and getattr(execution_db, "runner", None) + and execution_db.runner.get("runner_module") == INQUIRY_RUNNER + ): execution_db.get_uid = get_uid LOG.debug('Checking permission on inquiry "%s".' % id) @@ -226,7 +241,7 @@ def _get_one_by_id(self, id, requester_user, permission_type, rbac_utils.assert_user_has_resource_db_permission( user_db=requester_user, resource_db=execution_db, - permission_type=permission_type + permission_type=permission_type, ) from_model_kwargs = from_model_kwargs or {} @@ -237,9 +252,8 @@ def _get_one_by_id(self, id, requester_user, permission_type, def get_uid(): - """Inquiry UID hack for RBAC - """ - return 'inquiry' + """Inquiry UID hack for RBAC""" + return "inquiry" inquiries_controller = InquiriesController() diff --git a/st2api/st2api/controllers/v1/keyvalue.py b/st2api/st2api/controllers/v1/keyvalue.py index eab8cb025a..2bd8449e24 100644 --- a/st2api/st2api/controllers/v1/keyvalue.py +++ b/st2api/st2api/controllers/v1/keyvalue.py @@ -24,7 +24,10 @@ from st2common.constants.keyvalue import ALL_SCOPE, FULL_SYSTEM_SCOPE, SYSTEM_SCOPE from st2common.constants.keyvalue import FULL_USER_SCOPE, USER_SCOPE, ALLOWED_SCOPES from st2common.exceptions.db import StackStormDBObjectNotFoundError -from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException +from st2common.exceptions.keyvalue import ( + CryptoKeyNotSetupException, + InvalidScopeException, +) from st2common.models.api.keyvalue import KeyValuePairAPI from st2common.models.db.auth import UserDB from st2common.persistence.keyvalue import KeyValuePair @@ -40,9 +43,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'KeyValuePairController' -] +__all__ = ["KeyValuePairController"] class KeyValuePairController(ResourceController): @@ -52,22 +53,21 @@ class KeyValuePairController(ResourceController): model = KeyValuePairAPI access = KeyValuePair - supported_filters = { - 'prefix': 'name__startswith', - 'scope': 'scope' - } + supported_filters = {"prefix": "name__startswith", "scope": "scope"} def __init__(self): super(KeyValuePairController, self).__init__() self._coordinator = coordination.get_coordinator() self.get_one_db_method = self._get_by_name - def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False): + def get_one( + self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False + ): """ - List key by name. + List key by name. - Handle: - GET /keys/key1 + Handle: + GET /keys/key1 """ if not scope: # Default to system scope @@ -84,8 +84,9 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr self._validate_scope(scope=scope) # User needs to be either admin or requesting item for itself - self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, - requester_user=requester_user) + self._validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, requester_user=requester_user + ) user_query_param_filter = bool(user) @@ -95,45 +96,56 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr rbac_utils = get_rbac_backend().get_utils_class() # Validate that the authenticated user is admin if user query param is provided - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) # Additional guard to ensure there is no information leakage across users is_admin = rbac_utils.user_is_admin(user_db=requester_user) if is_admin and user_query_param_filter: # Retrieve values scoped to the provided user - user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE, user=user) + user_scope_prefix = get_key_reference( + name=name, scope=USER_SCOPE, user=user + ) else: # RBAC not enabled or user is not an admin, retrieve user scoped values for the # current user - user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE, - user=current_user) + user_scope_prefix = get_key_reference( + name=name, scope=USER_SCOPE, user=current_user + ) if scope == FULL_USER_SCOPE: key_ref = user_scope_prefix elif scope == FULL_SYSTEM_SCOPE: key_ref = get_key_reference(scope=FULL_SYSTEM_SCOPE, name=name, user=user) else: - raise ValueError('Invalid scope: %s' % (scope)) + raise ValueError("Invalid scope: %s" % (scope)) - from_model_kwargs = {'mask_secrets': not decrypt} + from_model_kwargs = {"mask_secrets": not decrypt} kvp_api = self._get_one_by_scope_and_name( - name=key_ref, - scope=scope, - from_model_kwargs=from_model_kwargs + name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs ) return kvp_api - def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=None, - decrypt=False, sort=None, offset=0, limit=None, **raw_filters): + def get_all( + self, + requester_user, + prefix=None, + scope=FULL_SYSTEM_SCOPE, + user=None, + decrypt=False, + sort=None, + offset=0, + limit=None, + **raw_filters, + ): """ - List all keys. + List all keys. - Handles requests: - GET /keys/ + Handles requests: + GET /keys/ """ if not scope: # Default to system scope @@ -152,8 +164,9 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non self._validate_all_scope(scope=scope, requester_user=requester_user) # User needs to be either admin or requesting items for themselves - self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, - requester_user=requester_user) + self._validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, requester_user=requester_user + ) user_query_param_filter = bool(user) @@ -163,15 +176,15 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non rbac_utils = get_rbac_backend().get_utils_class() # Validate that the authenticated user is admin if user query param is provided - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) - from_model_kwargs = {'mask_secrets': not decrypt} + from_model_kwargs = {"mask_secrets": not decrypt} if scope and scope not in ALL_SCOPE: self._validate_scope(scope=scope) - raw_filters['scope'] = scope + raw_filters["scope"] = scope # Set prefix which will be used for user-scoped items. # NOTE: It's very important raw_filters['prefix'] is set when requesting user scoped items @@ -180,47 +193,52 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non if is_admin and user_query_param_filter: # Retrieve values scoped to the provided user - user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE, user=user) + user_scope_prefix = get_key_reference( + name=prefix or "", scope=USER_SCOPE, user=user + ) else: # RBAC not enabled or user is not an admin, retrieve user scoped values for the # current user - user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE, - user=current_user) + user_scope_prefix = get_key_reference( + name=prefix or "", scope=USER_SCOPE, user=current_user + ) if scope == ALL_SCOPE: # Special case for ALL_SCOPE # 1. Retrieve system scoped values - raw_filters['scope'] = FULL_SYSTEM_SCOPE - raw_filters['prefix'] = prefix + raw_filters["scope"] = FULL_SYSTEM_SCOPE + raw_filters["prefix"] = prefix - assert 'scope' in raw_filters + assert "scope" in raw_filters kvp_apis_system = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) # 2. Retrieve user scoped items for current user or for all the users (depending if the # authenticated user is admin and if ?user is provided) - raw_filters['scope'] = FULL_USER_SCOPE + raw_filters["scope"] = FULL_USER_SCOPE if cfg.CONF.rbac.enable and is_admin and not user_query_param_filter: # Admin user retrieving user-scoped items for all the users - raw_filters['prefix'] = prefix or '' + raw_filters["prefix"] = prefix or "" else: - raw_filters['prefix'] = user_scope_prefix + raw_filters["prefix"] = user_scope_prefix - assert 'scope' in raw_filters - assert 'prefix' in raw_filters + assert "scope" in raw_filters + assert "prefix" in raw_filters kvp_apis_user = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) # Combine the result kvp_apis = [] @@ -228,31 +246,33 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non kvp_apis.extend(kvp_apis_user.json or []) elif scope in [USER_SCOPE, FULL_USER_SCOPE]: # Make sure we only returned values scoped to current user - prefix = get_key_reference(name=prefix or '', scope=scope, user=user) - raw_filters['prefix'] = user_scope_prefix + prefix = get_key_reference(name=prefix or "", scope=scope, user=user) + raw_filters["prefix"] = user_scope_prefix - assert 'scope' in raw_filters - assert 'prefix' in raw_filters + assert "scope" in raw_filters + assert "prefix" in raw_filters kvp_apis = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) elif scope in [SYSTEM_SCOPE, FULL_SYSTEM_SCOPE]: - raw_filters['prefix'] = prefix + raw_filters["prefix"] = prefix - assert 'scope' in raw_filters + assert "scope" in raw_filters kvp_apis = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) else: - raise ValueError('Invalid scope: %s' % (scope)) + raise ValueError("Invalid scope: %s" % (scope)) return kvp_apis @@ -266,42 +286,42 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - scope = getattr(kvp, 'scope', scope) + scope = getattr(kvp, "scope", scope) scope = get_datastore_full_scope(scope) self._validate_scope(scope=scope) - user = getattr(kvp, 'user', requester_user.name) or requester_user.name + user = getattr(kvp, "user", requester_user.name) or requester_user.name # Validate that the authenticated user is admin if user query param is provided rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) # Validate that encrypted option can only be used by admins - encrypted = getattr(kvp, 'encrypted', False) - self._validate_encrypted_query_parameter(encrypted=encrypted, scope=scope, - requester_user=requester_user) + encrypted = getattr(kvp, "encrypted", False) + self._validate_encrypted_query_parameter( + encrypted=encrypted, scope=scope, requester_user=requester_user + ) key_ref = get_key_reference(scope=scope, name=name, user=user) lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope) - LOG.debug('PUT scope: %s, name: %s', scope, name) + LOG.debug("PUT scope: %s, name: %s", scope, name) # TODO: Custom permission check since the key doesn't need to exist here # Note: We use lock to avoid a race with self._coordinator.get_lock(lock_name): try: existing_kvp_api = self._get_one_by_scope_and_name( - scope=scope, - name=key_ref + scope=scope, name=key_ref ) except StackStormDBObjectNotFoundError: existing_kvp_api = None # st2client sends invalid id when initially setting a key so we ignore those - id_ = kvp.__dict__.get('id', None) + id_ = kvp.__dict__.get("id", None) if not existing_kvp_api and id_ and not bson.ObjectId.is_valid(id_): - del kvp.__dict__['id'] + del kvp.__dict__["id"] kvp.name = key_ref kvp.scope = scope @@ -314,7 +334,7 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): kvp_db = KeyValuePair.add_or_update(kvp_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for key value data=%s', kvp) + LOG.exception("Validation failed for key value data=%s", kvp) abort(http_client.BAD_REQUEST, six.text_type(e)) return except CryptoKeyNotSetupException as e: @@ -325,18 +345,18 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): LOG.exception(six.text_type(e)) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'kvp_db': kvp_db} - LOG.audit('KeyValuePair updated. KeyValuePair.id=%s' % (kvp_db.id), extra=extra) + extra = {"kvp_db": kvp_db} + LOG.audit("KeyValuePair updated. KeyValuePair.id=%s" % (kvp_db.id), extra=extra) kvp_api = KeyValuePairAPI.from_model(kvp_db) return kvp_api def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None): """ - Delete the key value pair. + Delete the key value pair. - Handles requests: - DELETE /keys/1 + Handles requests: + DELETE /keys/1 """ if not scope: scope = FULL_SYSTEM_SCOPE @@ -351,37 +371,42 @@ def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None): # Validate that the authenticated user is admin if user query param is provided rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) key_ref = get_key_reference(scope=scope, name=name, user=user) lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope) # Note: We use lock to avoid a race with self._coordinator.get_lock(lock_name): - from_model_kwargs = {'mask_secrets': True} + from_model_kwargs = {"mask_secrets": True} kvp_api = self._get_one_by_scope_and_name( - name=key_ref, - scope=scope, - from_model_kwargs=from_model_kwargs + name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs ) kvp_db = KeyValuePairAPI.to_model(kvp_api) - LOG.debug('DELETE /keys/ lookup with scope=%s name=%s found object: %s', - scope, name, kvp_db) + LOG.debug( + "DELETE /keys/ lookup with scope=%s name=%s found object: %s", + scope, + name, + kvp_db, + ) try: KeyValuePair.delete(kvp_db) except Exception as e: - LOG.exception('Database delete encountered exception during ' - 'delete of name="%s". ', name) + LOG.exception( + "Database delete encountered exception during " + 'delete of name="%s". ', + name, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'kvp_db': kvp_db} - LOG.audit('KeyValuePair deleted. KeyValuePair.id=%s' % (kvp_db.id), extra=extra) + extra = {"kvp_db": kvp_db} + LOG.audit("KeyValuePair deleted. KeyValuePair.id=%s" % (kvp_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) @@ -392,7 +417,7 @@ def _get_lock_name_for_key(self, name, scope=FULL_SYSTEM_SCOPE): :param name: Datastore item name (PK). :type name: ``str`` """ - lock_name = six.b('kvp-crud-%s.%s' % (scope, name)) + lock_name = six.b("kvp-crud-%s.%s" % (scope, name)) return lock_name def _validate_all_scope(self, scope, requester_user): @@ -400,7 +425,7 @@ def _validate_all_scope(self, scope, requester_user): Validate that "all" scope can only be provided by admins on RBAC installations. """ scope = get_datastore_full_scope(scope) - is_all_scope = (scope == ALL_SCOPE) + is_all_scope = scope == ALL_SCOPE rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) @@ -415,22 +440,25 @@ def _validate_decrypt_query_parameter(self, decrypt, scope, requester_user): """ rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) - is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE) + is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE if decrypt and (not is_user_scope and not is_admin): - msg = 'Decrypt option requires administrator access' + msg = "Decrypt option requires administrator access" raise AccessDeniedError(message=msg, user_db=requester_user) def _validate_encrypted_query_parameter(self, encrypted, scope, requester_user): rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) if encrypted and not is_admin: - msg = 'Pre-encrypted option requires administrator access' + msg = "Pre-encrypted option requires administrator access" raise AccessDeniedError(message=msg, user_db=requester_user) def _validate_scope(self, scope): if scope not in ALLOWED_SCOPES: - msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES) + msg = "Scope %s is not in allowed scopes list: %s." % ( + scope, + ALLOWED_SCOPES, + ) raise ValueError(msg) diff --git a/st2api/st2api/controllers/v1/pack_config_schemas.py b/st2api/st2api/controllers/v1/pack_config_schemas.py index 551573e12e..933a7ab500 100644 --- a/st2api/st2api/controllers/v1/pack_config_schemas.py +++ b/st2api/st2api/controllers/v1/pack_config_schemas.py @@ -23,9 +23,7 @@ http_client = six.moves.http_client -__all__ = [ - 'PackConfigSchemasController' -] +__all__ = ["PackConfigSchemasController"] class PackConfigSchemasController(ResourceController): @@ -40,7 +38,9 @@ def __init__(self): # this case, RBAC is checked on the parent PackDB object self.get_one_db_method = packs_service.get_pack_by_ref - def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters): + def get_all( + self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters + ): """ Retrieve config schema for all the packs. @@ -48,11 +48,13 @@ def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_fi GET /config_schema/ """ - return super(PackConfigSchemasController, self)._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return super(PackConfigSchemasController, self)._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, pack_ref, requester_user): """ @@ -61,7 +63,9 @@ def get_one(self, pack_ref, requester_user): Handles requests: GET /config_schema/ """ - packs_controller._get_one_by_ref_or_id(ref_or_id=pack_ref, requester_user=requester_user) + packs_controller._get_one_by_ref_or_id( + ref_or_id=pack_ref, requester_user=requester_user + ) return self._get_one_by_pack_ref(pack_ref=pack_ref) diff --git a/st2api/st2api/controllers/v1/pack_configs.py b/st2api/st2api/controllers/v1/pack_configs.py index 6eb18c7a34..4123a3cb22 100644 --- a/st2api/st2api/controllers/v1/pack_configs.py +++ b/st2api/st2api/controllers/v1/pack_configs.py @@ -35,9 +35,7 @@ http_client = six.moves.http_client -__all__ = [ - 'PackConfigsController' -] +__all__ = ["PackConfigsController"] LOG = logging.getLogger(__name__) @@ -54,8 +52,15 @@ def __init__(self): # this case, RBAC is checked on the parent PackDB object self.get_one_db_method = packs_service.get_pack_by_ref - def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets=False, - **raw_filters): + def get_all( + self, + requester_user, + sort=None, + offset=0, + limit=None, + show_secrets=False, + **raw_filters, + ): """ Retrieve configs for all the packs. @@ -63,14 +68,18 @@ def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets= GET /configs/ """ from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - return super(PackConfigsController, self)._get_all(sort=sort, - offset=offset, - limit=limit, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + return super(PackConfigsController, self)._get_all( + sort=sort, + offset=offset, + limit=limit, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, pack_ref, requester_user, show_secrets=False): """ @@ -80,7 +89,9 @@ def get_one(self, pack_ref, requester_user, show_secrets=False): GET /configs/ """ from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } try: instance = packs_service.get_pack_by_ref(pack_ref=pack_ref) @@ -89,18 +100,22 @@ def get_one(self, pack_ref, requester_user, show_secrets=False): abort(http_client.NOT_FOUND, msg) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=PermissionType.PACK_VIEW, + ) - return self._get_one_by_pack_ref(pack_ref=pack_ref, from_model_kwargs=from_model_kwargs) + return self._get_one_by_pack_ref( + pack_ref=pack_ref, from_model_kwargs=from_model_kwargs + ) def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False): """ - Create a new config for a pack. + Create a new config for a pack. - Handles requests: - POST /configs/ + Handles requests: + POST /configs/ """ try: @@ -121,9 +136,9 @@ def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False) def _dump_config_to_disk(self, config_api): config_content = yaml.safe_dump(config_api.values, default_flow_style=False) - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % config_api.pack) - with open(config_path, 'w') as f: + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % config_api.pack) + with open(config_path, "w") as f: f.write(config_content) diff --git a/st2api/st2api/controllers/v1/pack_views.py b/st2api/st2api/controllers/v1/pack_views.py index 5e8b310c33..4fd6f9dd3a 100644 --- a/st2api/st2api/controllers/v1/pack_views.py +++ b/st2api/st2api/controllers/v1/pack_views.py @@ -33,10 +33,7 @@ http_client = six.moves.http_client -__all__ = [ - 'FilesController', - 'FileController' -] +__all__ = ["FilesController", "FileController"] http_client = six.moves.http_client @@ -46,12 +43,10 @@ # Maximum file size in bytes. If the file on disk is larger then this value, we don't include it # in the response. This prevents DDoS / exhaustion attacks. -MAX_FILE_SIZE = (500 * 1000) +MAX_FILE_SIZE = 500 * 1000 # File paths in the file controller for which RBAC checks are not performed -WHITELISTED_FILE_PATHS = [ - 'icon.png' -] +WHITELISTED_FILE_PATHS = ["icon.png"] class BaseFileController(BasePacksController): @@ -76,7 +71,7 @@ def _get_file_stats(self, file_path): return file_stats.st_size, file_stats.st_mtime def _get_file_content(self, file_path): - with codecs.open(file_path, 'rb') as fp: + with codecs.open(file_path, "rb") as fp: content = fp.read() return content @@ -105,17 +100,19 @@ def __init__(self): def get_one(self, ref_or_id, requester_user): """ - Outputs the content of all the files inside the pack. + Outputs the content of all the files inside the pack. - Handles requests: - GET /packs/views/files/ + Handles requests: + GET /packs/views/files/ """ pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=pack_db, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=pack_db, + permission_type=PermissionType.PACK_VIEW, + ) if not pack_db: msg = 'Pack with ref_or_id "%s" does not exist' % (ref_or_id) @@ -126,15 +123,19 @@ def get_one(self, ref_or_id, requester_user): result = [] for file_path in pack_files: - normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path) + normalized_file_path = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path + ) if not normalized_file_path or not os.path.isfile(normalized_file_path): # Ignore references to files which don't exist on disk continue file_size = self._get_file_size(file_path=normalized_file_path) if file_size is not None and file_size > MAX_FILE_SIZE: - LOG.debug('Skipping file "%s" which size exceeds max file size (%s bytes)' % - (normalized_file_path, MAX_FILE_SIZE)) + LOG.debug( + 'Skipping file "%s" which size exceeds max file size (%s bytes)' + % (normalized_file_path, MAX_FILE_SIZE) + ) continue content = self._get_file_content(file_path=normalized_file_path) @@ -144,10 +145,7 @@ def get_one(self, ref_or_id, requester_user): LOG.debug('Skipping binary file "%s"' % (normalized_file_path)) continue - item = { - 'file_path': file_path, - 'content': content - } + item = {"file_path": file_path, "content": content} result.append(item) return result @@ -173,13 +171,19 @@ class FileController(BaseFileController): Controller which allows user to retrieve content of a specific file in a pack. """ - def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, - if_modified_since=None): + def get_one( + self, + ref_or_id, + file_path, + requester_user, + if_none_match=None, + if_modified_since=None, + ): """ - Outputs the content of a specific file in a pack. + Outputs the content of a specific file in a pack. - Handles requests: - GET /packs/views/file// + Handles requests: + GET /packs/views/file// """ pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) @@ -188,7 +192,7 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, raise StackStormDBObjectNotFoundError(msg) if not file_path: - raise ValueError('Missing file path') + raise ValueError("Missing file path") pack_ref = pack_db.ref @@ -196,11 +200,15 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, permission_type = PermissionType.PACK_VIEW if file_path not in WHITELISTED_FILE_PATHS: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=pack_db, - permission_type=permission_type) - - normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=pack_db, + permission_type=permission_type, + ) + + normalized_file_path = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path + ) if not normalized_file_path or not os.path.isfile(normalized_file_path): # Ignore references to files which don't exist on disk raise StackStormDBObjectNotFoundError('File "%s" not found' % (file_path)) @@ -209,24 +217,28 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, response = Response() - if not self._is_file_changed(file_mtime, - if_none_match=if_none_match, - if_modified_since=if_modified_since): + if not self._is_file_changed( + file_mtime, if_none_match=if_none_match, if_modified_since=if_modified_since + ): response.status = http_client.NOT_MODIFIED else: if file_size is not None and file_size > MAX_FILE_SIZE: - msg = ('File %s exceeds maximum allowed file size (%s bytes)' % - (file_path, MAX_FILE_SIZE)) + msg = "File %s exceeds maximum allowed file size (%s bytes)" % ( + file_path, + MAX_FILE_SIZE, + ) raise ValueError(msg) - content_type = mimetypes.guess_type(normalized_file_path)[0] or \ - 'application/octet-stream' + content_type = ( + mimetypes.guess_type(normalized_file_path)[0] + or "application/octet-stream" + ) - response.headers['Content-Type'] = content_type + response.headers["Content-Type"] = content_type response.body = self._get_file_content(file_path=normalized_file_path) - response.headers['Last-Modified'] = format_date_time(file_mtime) - response.headers['ETag'] = repr(file_mtime) + response.headers["Last-Modified"] = format_date_time(file_mtime) + response.headers["ETag"] = repr(file_mtime) return response diff --git a/st2api/st2api/controllers/v1/packs.py b/st2api/st2api/controllers/v1/packs.py index 6193a3f01f..75da16e5e5 100644 --- a/st2api/st2api/controllers/v1/packs.py +++ b/st2api/st2api/controllers/v1/packs.py @@ -52,115 +52,119 @@ http_client = six.moves.http_client -__all__ = [ - 'PacksController', - 'BasePacksController', - 'ENTITIES' -] +__all__ = ["PacksController", "BasePacksController", "ENTITIES"] LOG = logging.getLogger(__name__) # Note: The order those are defined it's important so they are registered in # the same order as they are in st2-register-content. # We also need to use list of tuples to preserve the order. -ENTITIES = OrderedDict([ - ('trigger', (TriggersRegistrar, 'triggers')), - ('sensor', (SensorsRegistrar, 'sensors')), - ('action', (ActionsRegistrar, 'actions')), - ('rule', (RulesRegistrar, 'rules')), - ('alias', (AliasesRegistrar, 'aliases')), - ('policy', (PolicyRegistrar, 'policies')), - ('config', (ConfigsRegistrar, 'configs')) -]) +ENTITIES = OrderedDict( + [ + ("trigger", (TriggersRegistrar, "triggers")), + ("sensor", (SensorsRegistrar, "sensors")), + ("action", (ActionsRegistrar, "actions")), + ("rule", (RulesRegistrar, "rules")), + ("alias", (AliasesRegistrar, "aliases")), + ("policy", (PolicyRegistrar, "policies")), + ("config", (ConfigsRegistrar, "configs")), + ] +) def _get_proxy_config(): - LOG.debug('Loading proxy configuration from env variables %s.', os.environ) - http_proxy = os.environ.get('http_proxy', None) - https_proxy = os.environ.get('https_proxy', None) - no_proxy = os.environ.get('no_proxy', None) - proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None) + LOG.debug("Loading proxy configuration from env variables %s.", os.environ) + http_proxy = os.environ.get("http_proxy", None) + https_proxy = os.environ.get("https_proxy", None) + no_proxy = os.environ.get("no_proxy", None) + proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None) proxy_config = { - 'http_proxy': http_proxy, - 'https_proxy': https_proxy, - 'proxy_ca_bundle_path': proxy_ca_bundle_path, - 'no_proxy': no_proxy + "http_proxy": http_proxy, + "https_proxy": https_proxy, + "proxy_ca_bundle_path": proxy_ca_bundle_path, + "no_proxy": no_proxy, } - LOG.debug('Proxy configuration: %s', proxy_config) + LOG.debug("Proxy configuration: %s", proxy_config) return proxy_config class PackInstallController(ActionExecutionsControllerMixin): - def post(self, pack_install_request, requester_user=None): parameters = { - 'packs': pack_install_request.packs, + "packs": pack_install_request.packs, } if pack_install_request.force: - parameters['force'] = True + parameters["force"] = True if pack_install_request.skip_dependencies: - parameters['skip_dependencies'] = True + parameters["skip_dependencies"] = True if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - new_liveaction_api = LiveActionCreateAPI(action='packs.install', - parameters=parameters, - user=requester_user.name) + new_liveaction_api = LiveActionCreateAPI( + action="packs.install", parameters=parameters, user=requester_user.name + ) - execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user) + execution_resp = self._handle_schedule_execution( + liveaction_api=new_liveaction_api, requester_user=requester_user + ) - exec_id = PackAsyncAPI(execution_id=execution_resp.json['id']) + exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"]) return Response(json=exec_id, status=http_client.ACCEPTED) class PackUninstallController(ActionExecutionsControllerMixin): - def post(self, pack_uninstall_request, ref_or_id=None, requester_user=None): if ref_or_id: - parameters = { - 'packs': [ref_or_id] - } + parameters = {"packs": [ref_or_id]} else: - parameters = { - 'packs': pack_uninstall_request.packs - } + parameters = {"packs": pack_uninstall_request.packs} if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - new_liveaction_api = LiveActionCreateAPI(action='packs.uninstall', - parameters=parameters, - user=requester_user.name) + new_liveaction_api = LiveActionCreateAPI( + action="packs.uninstall", parameters=parameters, user=requester_user.name + ) - execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user) + execution_resp = self._handle_schedule_execution( + liveaction_api=new_liveaction_api, requester_user=requester_user + ) - exec_id = PackAsyncAPI(execution_id=execution_resp.json['id']) + exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"]) return Response(json=exec_id, status=http_client.ACCEPTED) class PackRegisterController(object): - CONTENT_TYPES = ['runner', 'action', 'trigger', 'sensor', 'rule', - 'rule_type', 'alias', 'policy_type', 'policy', 'config'] + CONTENT_TYPES = [ + "runner", + "action", + "trigger", + "sensor", + "rule", + "rule_type", + "alias", + "policy_type", + "policy", + "config", + ] def post(self, pack_register_request): - if pack_register_request and hasattr(pack_register_request, 'types'): + if pack_register_request and hasattr(pack_register_request, "types"): types = pack_register_request.types - if 'all' in types: + if "all" in types: types = PackRegisterController.CONTENT_TYPES else: types = PackRegisterController.CONTENT_TYPES - if pack_register_request and hasattr(pack_register_request, 'packs'): + if pack_register_request and hasattr(pack_register_request, "packs"): packs = list(set(pack_register_request.packs)) else: packs = None @@ -168,64 +172,80 @@ def post(self, pack_register_request): result = defaultdict(int) # Register depended resources (actions depend on runners, rules depend on rule types, etc) - if ('runner' in types or 'runners' in types) or ('action' in types or 'actions' in types): - result['runners'] = runners_registrar.register_runners(experimental=True) - if ('rule_type' in types or 'rule_types' in types) or \ - ('rule' in types or 'rules' in types): - result['rule_types'] = rule_types_registrar.register_rule_types() - if ('policy_type' in types or 'policy_types' in types) or \ - ('policy' in types or 'policies' in types): - result['policy_types'] = policies_registrar.register_policy_types(st2common) + if ("runner" in types or "runners" in types) or ( + "action" in types or "actions" in types + ): + result["runners"] = runners_registrar.register_runners(experimental=True) + if ("rule_type" in types or "rule_types" in types) or ( + "rule" in types or "rules" in types + ): + result["rule_types"] = rule_types_registrar.register_rule_types() + if ("policy_type" in types or "policy_types" in types) or ( + "policy" in types or "policies" in types + ): + result["policy_types"] = policies_registrar.register_policy_types(st2common) use_pack_cache = False - fail_on_failure = getattr(pack_register_request, 'fail_on_failure', True) + fail_on_failure = getattr(pack_register_request, "fail_on_failure", True) for type, (Registrar, name) in six.iteritems(ENTITIES): if type in types or name in types: - registrar = Registrar(use_pack_cache=use_pack_cache, - use_runners_cache=True, - fail_on_failure=fail_on_failure) + registrar = Registrar( + use_pack_cache=use_pack_cache, + use_runners_cache=True, + fail_on_failure=fail_on_failure, + ) if packs: for pack in packs: pack_path = content_utils.get_pack_base_path(pack) try: - registered_count = registrar.register_from_pack(pack_dir=pack_path) + registered_count = registrar.register_from_pack( + pack_dir=pack_path + ) result[name] += registered_count except ValueError as e: # Throw more user-friendly exception if requsted pack doesn't exist - if re.match('Directory ".*?" doesn\'t exist', six.text_type(e)): - msg = 'Pack "%s" not found on disk: %s' % (pack, six.text_type(e)) + if re.match( + 'Directory ".*?" doesn\'t exist', six.text_type(e) + ): + msg = 'Pack "%s" not found on disk: %s' % ( + pack, + six.text_type(e), + ) raise ValueError(msg) raise e else: packs_base_paths = content_utils.get_packs_base_paths() - registered_count = registrar.register_from_packs(base_dirs=packs_base_paths) + registered_count = registrar.register_from_packs( + base_dirs=packs_base_paths + ) result[name] += registered_count return result class PackSearchController(object): - def post(self, pack_search_request): proxy_config = _get_proxy_config() - if hasattr(pack_search_request, 'query'): - packs = packs_service.search_pack_index(pack_search_request.query, - case_sensitive=False, - proxy_config=proxy_config) + if hasattr(pack_search_request, "query"): + packs = packs_service.search_pack_index( + pack_search_request.query, + case_sensitive=False, + proxy_config=proxy_config, + ) return [PackAPI(**pack) for pack in packs] else: - pack = packs_service.get_pack_from_index(pack_search_request.pack, - proxy_config=proxy_config) + pack = packs_service.get_pack_from_index( + pack_search_request.pack, proxy_config=proxy_config + ) return PackAPI(**pack) if pack else [] class IndexHealthController(object): - def get(self): """ Check if all listed indexes are healthy: they should be reachable, @@ -233,7 +253,9 @@ def get(self): """ proxy_config = _get_proxy_config() - _, status = packs_service.fetch_pack_index(allow_empty=True, proxy_config=proxy_config) + _, status = packs_service.fetch_pack_index( + allow_empty=True, proxy_config=proxy_config + ) health = { "indexes": { @@ -249,13 +271,13 @@ def get(self): } for index in status: - if index['error']: - error_count = health['indexes']['errors'].get(index['error'], 0) + 1 - health['indexes']['invalid'] += 1 - health['indexes']['errors'][index['error']] = error_count + if index["error"]: + error_count = health["indexes"]["errors"].get(index["error"], 0) + 1 + health["indexes"]["invalid"] += 1 + health["indexes"]["errors"][index["error"]] = error_count else: - health['indexes']['valid'] += 1 - health['packs']['count'] += index['packs'] + health["indexes"]["valid"] += 1 + health["packs"]["count"] += index["packs"] return health @@ -265,12 +287,16 @@ class BasePacksController(ResourceController): access = Pack def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None): - instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields) + instance = self._get_by_ref_or_id( + ref_or_id=ref_or_id, exclude_fields=exclude_fields + ) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=PermissionType.PACK_VIEW, + ) if not instance: msg = 'Unable to identify resource with ref_or_id "%s".' % (ref_or_id) @@ -282,7 +308,9 @@ def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None): return result def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None): - resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields) + resource_db = self._get_by_id( + resource_id=ref_or_id, exclude_fields=exclude_fields + ) if not resource_db: # Try ref @@ -302,7 +330,7 @@ def _get_by_ref(self, ref, exclude_fields=None): return resource_db -class PacksIndexController(): +class PacksIndexController: search = PackSearchController() health = IndexHealthController() @@ -311,10 +339,7 @@ def get_all(self): index, status = packs_service.fetch_pack_index(proxy_config=proxy_config) - return { - 'status': status, - 'index': index - } + return {"status": status, "index": index} class PacksController(BasePacksController): @@ -322,14 +347,9 @@ class PacksController(BasePacksController): model = PackAPI access = Pack - supported_filters = { - 'name': 'name', - 'ref': 'ref' - } + supported_filters = {"name": "name", "ref": "ref"} - query_options = { - 'sort': ['ref'] - } + query_options = {"sort": ["ref"]} # Nested controllers install = PackInstallController() @@ -342,18 +362,30 @@ def __init__(self): super(PacksController, self).__init__() self.get_one_db_method = self._get_by_ref_or_id - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(PacksController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(PacksController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): - return self._get_one_by_ref_or_id(ref_or_id=ref_or_id, requester_user=requester_user) + return self._get_one_by_ref_or_id( + ref_or_id=ref_or_id, requester_user=requester_user + ) packs_controller = PacksController() diff --git a/st2api/st2api/controllers/v1/policies.py b/st2api/st2api/controllers/v1/policies.py index 3fc488708b..aa57b7cf3d 100644 --- a/st2api/st2api/controllers/v1/policies.py +++ b/st2api/st2api/controllers/v1/policies.py @@ -37,54 +37,73 @@ class PolicyTypeController(resource.ResourceController): model = PolicyTypeAPI access = PolicyType - mandatory_include_fields_retrieve = ['id', 'name', 'resource_type'] + mandatory_include_fields_retrieve = ["id", "name", "resource_type"] - supported_filters = { - 'resource_type': 'resource_type' - } + supported_filters = {"resource_type": "resource_type"} - query_options = { - 'sort': ['resource_type', 'name'] - } + query_options = {"sort": ["resource_type", "name"]} def get_one(self, ref_or_id, requester_user): return self._get_one(ref_or_id, requester_user=requester_user) - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def _get_one(self, ref_or_id, requester_user): instance = self._get_by_ref_or_id(ref_or_id=ref_or_id) permission_type = PermissionType.POLICY_TYPE_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) result = self.model.from_model(instance) return result - def _get_all(self, exclude_fields=None, include_fields=None, sort=None, offset=0, limit=None, - query_options=None, from_model_kwargs=None, raw_filters=None, - requester_user=None): - - resp = super(PolicyTypeController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + def _get_all( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): + + resp = super(PolicyTypeController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) return resp @@ -114,7 +133,9 @@ def _get_by_ref(self, resource_ref): except Exception: return None - resource_db = self.access.query(name=ref.name, resource_type=ref.resource_type).first() + resource_db = self.access.query( + name=ref.name, resource_type=ref.resource_type + ).first() return resource_db @@ -123,77 +144,93 @@ class PolicyController(resource.ContentPackResourceController): access = Policy supported_filters = { - 'pack': 'pack', - 'resource_ref': 'resource_ref', - 'policy_type': 'policy_type' - } - - query_options = { - 'sort': ['pack', 'name'] + "pack": "pack", + "resource_ref": "resource_ref", + "policy_type": "policy_type", } - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + query_options = {"sort": ["pack", "name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.POLICY_VIEW - return self._get_one(ref_or_id, permission_type=permission_type, - requester_user=requester_user) + return self._get_one( + ref_or_id, permission_type=permission_type, requester_user=requester_user + ) def post(self, instance, requester_user): """ - Create a new policy. - Handles requests: - POST /policies/ + Create a new policy. + Handles requests: + POST /policies/ """ permission_type = PermissionType.POLICY_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=instance, + permission_type=permission_type, + ) - op = 'POST /policies/' + op = "POST /policies/" db_model = self.model.to_model(instance) - LOG.debug('%s verified object: %s', op, db_model) + LOG.debug("%s verified object: %s", op, db_model) db_model = self.access.add_or_update(db_model) - LOG.debug('%s created object: %s', op, db_model) - LOG.audit('Policy created. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s created object: %s", op, db_model) + LOG.audit( + "Policy created. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) exec_result = self.model.from_model(db_model) return Response(json=exec_result, status=http_client.CREATED) def put(self, instance, ref_or_id, requester_user): - op = 'PUT /policies/%s/' % ref_or_id + op = "PUT /policies/%s/" % ref_or_id db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('%s found object: %s', op, db_model) + LOG.debug("%s found object: %s", op, db_model) permission_type = PermissionType.POLICY_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=db_model, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=db_model, + permission_type=permission_type, + ) db_model_id = db_model.id try: validate_not_part_of_system_pack(db_model) except ValueValidationException as e: - LOG.exception('%s unable to update object from system pack.', op) + LOG.exception("%s unable to update object from system pack.", op) abort(http_client.BAD_REQUEST, six.text_type(e)) - if not getattr(instance, 'pack', None): + if not getattr(instance, "pack", None): instance.pack = db_model.pack try: @@ -201,12 +238,15 @@ def put(self, instance, ref_or_id, requester_user): db_model.id = db_model_id db_model = self.access.add_or_update(db_model) except (ValidationError, ValueError) as e: - LOG.exception('%s unable to update object: %s', op, db_model) + LOG.exception("%s unable to update object: %s", op, db_model) abort(http_client.BAD_REQUEST, six.text_type(e)) return - LOG.debug('%s updated object: %s', op, db_model) - LOG.audit('Policy updated. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s updated object: %s", op, db_model) + LOG.audit( + "Policy updated. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) exec_result = self.model.from_model(db_model) @@ -214,38 +254,43 @@ def put(self, instance, ref_or_id, requester_user): def delete(self, ref_or_id, requester_user): """ - Delete a policy. - Handles requests: - POST /policies/1?_method=delete - DELETE /policies/1 - DELETE /policies/mypack.mypolicy + Delete a policy. + Handles requests: + POST /policies/1?_method=delete + DELETE /policies/1 + DELETE /policies/mypack.mypolicy """ - op = 'DELETE /policies/%s/' % ref_or_id + op = "DELETE /policies/%s/" % ref_or_id db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('%s found object: %s', op, db_model) + LOG.debug("%s found object: %s", op, db_model) permission_type = PermissionType.POLICY_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=db_model, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=db_model, + permission_type=permission_type, + ) try: validate_not_part_of_system_pack(db_model) except ValueValidationException as e: - LOG.exception('%s unable to delete object from system pack.', op) + LOG.exception("%s unable to delete object from system pack.", op) abort(http_client.BAD_REQUEST, six.text_type(e)) try: self.access.delete(db_model) except Exception as e: - LOG.exception('%s unable to delete object: %s', op, db_model) + LOG.exception("%s unable to delete object: %s", op, db_model) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - LOG.debug('%s deleted object: %s', op, db_model) - LOG.audit('Policy deleted. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s deleted object: %s", op, db_model) + LOG.audit( + "Policy deleted. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) # return None return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/rbac.py b/st2api/st2api/controllers/v1/rbac.py index 0e8c1d4179..49a552f7dc 100644 --- a/st2api/st2api/controllers/v1/rbac.py +++ b/st2api/st2api/controllers/v1/rbac.py @@ -23,78 +23,76 @@ from st2common.rbac.backends import get_rbac_backend from st2common.router import exc -__all__ = [ - 'RolesController', - 'RoleAssignmentsController', - 'PermissionTypesController' -] +__all__ = ["RolesController", "RoleAssignmentsController", "PermissionTypesController"] class RolesController(ResourceController): model = RoleAPI access = Role - supported_filters = { - 'name': 'name', - 'system': 'system' - } + supported_filters = {"name": "name", "system": "system"} - query_options = { - 'sort': ['name'] - } + query_options = {"sort": ["name"]} def get_one(self, name_or_id, requester_user): rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) - return self._get_one_by_name_or_id(name_or_id=name_or_id, - permission_type=None, - requester_user=requester_user) + return self._get_one_by_name_or_id( + name_or_id=name_or_id, permission_type=None, requester_user=requester_user + ) def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters): rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) - return self._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return self._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) class RoleAssignmentsController(ResourceController): """ Meta controller for listing role assignments. """ + model = UserRoleAssignmentAPI access = UserRoleAssignment supported_filters = { - 'user': 'user', - 'role': 'role', - 'source': 'source', - 'remote': 'is_remote' + "user": "user", + "role": "role", + "source": "source", + "remote": "is_remote", } def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters): - user = raw_filters.get('user', None) + user = raw_filters.get("user", None) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user, - user=user) - - return self._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + rbac_utils.assert_user_is_admin_or_operating_on_own_resource( + user_db=requester_user, user=user + ) + + return self._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - result = self._get_one_by_id(id, - requester_user=requester_user, - permission_type=None) - user = getattr(result, 'user', None) + result = self._get_one_by_id( + id, requester_user=requester_user, permission_type=None + ) + user = getattr(result, "user", None) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_or_operating_on_own_resource( + user_db=requester_user, user=user + ) return result @@ -106,10 +104,10 @@ class PermissionTypesController(object): def get_all(self, requester_user): """ - List all the available permission types. + List all the available permission types. - Handles requests: - GET /rbac/permission_types + Handles requests: + GET /rbac/permission_types """ rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) @@ -119,10 +117,10 @@ def get_all(self, requester_user): def get_one(self, resource_type, requester_user): """ - List all the available permission types for a particular resource type. + List all the available permission types for a particular resource type. - Handles requests: - GET /rbac/permission_types/ + Handles requests: + GET /rbac/permission_types/ """ rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) @@ -131,7 +129,7 @@ def get_one(self, resource_type, requester_user): permission_types = all_permission_types.get(resource_type, None) if permission_types is None: - raise exc.HTTPNotFound('Invalid resource type: %s' % (resource_type)) + raise exc.HTTPNotFound("Invalid resource type: %s" % (resource_type)) return permission_types diff --git a/st2api/st2api/controllers/v1/rule_enforcement_views.py b/st2api/st2api/controllers/v1/rule_enforcement_views.py index 3d23d027a9..75831a917b 100644 --- a/st2api/st2api/controllers/v1/rule_enforcement_views.py +++ b/st2api/st2api/controllers/v1/rule_enforcement_views.py @@ -26,9 +26,7 @@ from st2api.controllers.resource import ResourceController -__all__ = [ - 'RuleEnforcementViewController' -] +__all__ = ["RuleEnforcementViewController"] class RuleEnforcementViewController(ResourceController): @@ -50,8 +48,16 @@ class RuleEnforcementViewController(ResourceController): supported_filters = SUPPORTED_FILTERS filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): rule_enforcement_apis = super(RuleEnforcementViewController, self)._get_all( exclude_fields=exclude_attributes, include_fields=include_attributes, @@ -59,16 +65,25 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) - rule_enforcement_apis.json = self._append_view_properties(rule_enforcement_apis.json) + rule_enforcement_apis.json = self._append_view_properties( + rule_enforcement_apis.json + ) return rule_enforcement_apis def get_one(self, id, requester_user): - rule_enforcement_api = super(RuleEnforcementViewController, - self)._get_one_by_id(id, requester_user=requester_user, - permission_type=PermissionType.RULE_ENFORCEMENT_VIEW) - rule_enforcement_api = self._append_view_properties([rule_enforcement_api.__json__()])[0] + rule_enforcement_api = super( + RuleEnforcementViewController, self + )._get_one_by_id( + id, + requester_user=requester_user, + permission_type=PermissionType.RULE_ENFORCEMENT_VIEW, + ) + rule_enforcement_api = self._append_view_properties( + [rule_enforcement_api.__json__()] + )[0] return rule_enforcement_api def _append_view_properties(self, rule_enforcement_apis): @@ -80,29 +95,29 @@ def _append_view_properties(self, rule_enforcement_apis): execution_ids = [] for rule_enforcement_api in rule_enforcement_apis: - if rule_enforcement_api.get('trigger_instance_id', None): - trigger_instance_ids.add(str(rule_enforcement_api['trigger_instance_id'])) + if rule_enforcement_api.get("trigger_instance_id", None): + trigger_instance_ids.add( + str(rule_enforcement_api["trigger_instance_id"]) + ) - if rule_enforcement_api.get('execution_id', None): - execution_ids.append(rule_enforcement_api['execution_id']) + if rule_enforcement_api.get("execution_id", None): + execution_ids.append(rule_enforcement_api["execution_id"]) # 1. Retrieve corresponding execution objects # NOTE: Executions contain a lot of field and could contain a lot of data so we only # retrieve fields we need only_fields = [ - 'id', - - 'action.ref', - 'action.parameters', - - 'runner.name', - 'runner.runner_parameters', - - 'parameters', - 'status' + "id", + "action.ref", + "action.parameters", + "runner.name", + "runner.runner_parameters", + "parameters", + "status", ] - execution_dbs = ActionExecution.query(id__in=execution_ids, - only_fields=only_fields) + execution_dbs = ActionExecution.query( + id__in=execution_ids, only_fields=only_fields + ) execution_dbs_by_id = {} for execution_db in execution_dbs: @@ -114,26 +129,32 @@ def _append_view_properties(self, rule_enforcement_apis): trigger_instance_dbs_by_id = {} for trigger_instance_db in trigger_instance_dbs: - trigger_instance_dbs_by_id[str(trigger_instance_db.id)] = trigger_instance_db + trigger_instance_dbs_by_id[ + str(trigger_instance_db.id) + ] = trigger_instance_db # Ammend rule enforcement objects with additional data for rule_enforcement_api in rule_enforcement_apis: - rule_enforcement_api['trigger_instance'] = {} - rule_enforcement_api['execution'] = {} + rule_enforcement_api["trigger_instance"] = {} + rule_enforcement_api["execution"] = {} - trigger_instance_id = rule_enforcement_api.get('trigger_instance_id', None) - execution_id = rule_enforcement_api.get('execution_id', None) + trigger_instance_id = rule_enforcement_api.get("trigger_instance_id", None) + execution_id = rule_enforcement_api.get("execution_id", None) - trigger_instance_db = trigger_instance_dbs_by_id.get(trigger_instance_id, None) + trigger_instance_db = trigger_instance_dbs_by_id.get( + trigger_instance_id, None + ) execution_db = execution_dbs_by_id.get(execution_id, None) if trigger_instance_db: - trigger_instance_api = TriggerInstanceAPI.from_model(trigger_instance_db) - rule_enforcement_api['trigger_instance'] = trigger_instance_api + trigger_instance_api = TriggerInstanceAPI.from_model( + trigger_instance_db + ) + rule_enforcement_api["trigger_instance"] = trigger_instance_api if execution_db: execution_api = ActionExecutionAPI.from_model(execution_db) - rule_enforcement_api['execution'] = execution_api + rule_enforcement_api["execution"] = execution_api return rule_enforcement_apis diff --git a/st2api/st2api/controllers/v1/rule_enforcements.py b/st2api/st2api/controllers/v1/rule_enforcements.py index 1c117558ca..f1c1f4c5b7 100644 --- a/st2api/st2api/controllers/v1/rule_enforcements.py +++ b/st2api/st2api/controllers/v1/rule_enforcements.py @@ -24,11 +24,10 @@ from st2api.controllers.resource import ResourceController __all__ = [ - 'RuleEnforcementController', - - 'SUPPORTED_FILTERS', - 'QUERY_OPTIONS', - 'FILTER_TRANSFORM_FUNCTIONS' + "RuleEnforcementController", + "SUPPORTED_FILTERS", + "QUERY_OPTIONS", + "FILTER_TRANSFORM_FUNCTIONS", ] @@ -38,23 +37,21 @@ SUPPORTED_FILTERS = { - 'rule_ref': 'rule.ref', - 'rule_id': 'rule.id', - 'execution': 'execution_id', - 'trigger_instance': 'trigger_instance_id', - 'enforced_at': 'enforced_at', - 'enforced_at_gt': 'enforced_at.gt', - 'enforced_at_lt': 'enforced_at.lt' + "rule_ref": "rule.ref", + "rule_id": "rule.id", + "execution": "execution_id", + "trigger_instance": "trigger_instance_id", + "enforced_at": "enforced_at", + "enforced_at_gt": "enforced_at.gt", + "enforced_at_lt": "enforced_at.lt", } -QUERY_OPTIONS = { - 'sort': ['-enforced_at', 'rule.ref'] -} +QUERY_OPTIONS = {"sort": ["-enforced_at", "rule.ref"]} FILTER_TRANSFORM_FUNCTIONS = { - 'enforced_at': lambda value: isotime.parse(value=value), - 'enforced_at_gt': lambda value: isotime.parse(value=value), - 'enforced_at_lt': lambda value: isotime.parse(value=value) + "enforced_at": lambda value: isotime.parse(value=value), + "enforced_at_gt": lambda value: isotime.parse(value=value), + "enforced_at_lt": lambda value: isotime.parse(value=value), } @@ -69,20 +66,32 @@ class RuleEnforcementController(ResourceController): supported_filters = SUPPORTED_FILTERS filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(RuleEnforcementController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(RuleEnforcementController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - return super(RuleEnforcementController, - self)._get_one_by_id(id, requester_user=requester_user, - permission_type=PermissionType.RULE_ENFORCEMENT_VIEW) + return super(RuleEnforcementController, self)._get_one_by_id( + id, + requester_user=requester_user, + permission_type=PermissionType.RULE_ENFORCEMENT_VIEW, + ) rule_enforcements_controller = RuleEnforcementController() diff --git a/st2api/st2api/controllers/v1/rule_views.py b/st2api/st2api/controllers/v1/rule_views.py index 70555149b7..39b4682c52 100644 --- a/st2api/st2api/controllers/v1/rule_views.py +++ b/st2api/st2api/controllers/v1/rule_views.py @@ -32,10 +32,12 @@ LOG = logging.getLogger(__name__) -__all__ = ['RuleViewController'] +__all__ = ["RuleViewController"] -class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResourceController): +class RuleViewController( + BaseResourceIsolationControllerMixin, ContentPackResourceController +): """ Add some extras to a Rule object to make it easier for UI to render a rule. The additions do not necessarily belong in the Rule itself but are still valuable augmentations. @@ -74,64 +76,78 @@ class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResour model = RuleViewAPI access = Rule - supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'user': 'context.user' - } - - query_options = { - 'sort': ['pack', 'name'] - } - - mandatory_include_fields_retrieve = ['pack', 'name', 'trigger'] - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - rules = super(RuleViewController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack", "user": "context.user"} + + query_options = {"sort": ["pack", "name"]} + + mandatory_include_fields_retrieve = ["pack", "name", "trigger"] + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + rules = super(RuleViewController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) result = self._append_view_properties(rules.json) rules.json = result return rules def get_one(self, ref_or_id, requester_user): - from_model_kwargs = {'mask_secrets': True} - rule = self._get_one(ref_or_id, permission_type=PermissionType.RULE_VIEW, - requester_user=requester_user, from_model_kwargs=from_model_kwargs) + from_model_kwargs = {"mask_secrets": True} + rule = self._get_one( + ref_or_id, + permission_type=PermissionType.RULE_VIEW, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + ) result = self._append_view_properties([rule.json])[0] rule.json = result return rule def _append_view_properties(self, rules): - action_by_refs, trigger_by_refs, trigger_type_by_refs = self._get_referenced_models(rules) + ( + action_by_refs, + trigger_by_refs, + trigger_type_by_refs, + ) = self._get_referenced_models(rules) for rule in rules: - action_ref = rule.get('action', {}).get('ref', None) - trigger_ref = rule.get('trigger', {}).get('ref', None) - trigger_type_ref = rule.get('trigger', {}).get('type', None) + action_ref = rule.get("action", {}).get("ref", None) + trigger_ref = rule.get("trigger", {}).get("ref", None) + trigger_type_ref = rule.get("trigger", {}).get("type", None) action_db = action_by_refs.get(action_ref, None) - if 'action' in rule: - rule['action']['description'] = action_db.description if action_db else '' + if "action" in rule: + rule["action"]["description"] = ( + action_db.description if action_db else "" + ) - if 'trigger' in rule: - rule['trigger']['description'] = '' + if "trigger" in rule: + rule["trigger"]["description"] = "" trigger_db = trigger_by_refs.get(trigger_ref, None) if trigger_db: - rule['trigger']['description'] = trigger_db.description + rule["trigger"]["description"] = trigger_db.description # If description is not found in trigger get description from TriggerType - if 'trigger' in rule and not rule['trigger']['description']: + if "trigger" in rule and not rule["trigger"]["description"]: trigger_type_db = trigger_type_by_refs.get(trigger_type_ref, None) if trigger_type_db: - rule['trigger']['description'] = trigger_type_db.description + rule["trigger"]["description"] = trigger_type_db.description return rules @@ -145,9 +161,9 @@ def _get_referenced_models(self, rules): trigger_type_refs = set() for rule in rules: - action_ref = rule.get('action', {}).get('ref', None) - trigger_ref = rule.get('trigger', {}).get('ref', None) - trigger_type_ref = rule.get('trigger', {}).get('type', None) + action_ref = rule.get("action", {}).get("ref", None) + trigger_ref = rule.get("trigger", {}).get("ref", None) + trigger_type_ref = rule.get("trigger", {}).get("type", None) if action_ref: action_refs.add(action_ref) @@ -164,27 +180,31 @@ def _get_referenced_models(self, rules): # The functions that will return args that can used to query. def ref_query_args(ref): - return {'ref': ref} + return {"ref": ref} def name_pack_query_args(ref): resource_ref = ResourceReference.from_string_reference(ref=ref) - return {'name': resource_ref.name, 'pack': resource_ref.pack} + return {"name": resource_ref.name, "pack": resource_ref.pack} - action_dbs = self._get_entities(model_persistence=Action, - refs=action_refs, - query_args=ref_query_args) + action_dbs = self._get_entities( + model_persistence=Action, refs=action_refs, query_args=ref_query_args + ) for action_db in action_dbs: action_by_refs[action_db.ref] = action_db - trigger_dbs = self._get_entities(model_persistence=Trigger, - refs=trigger_refs, - query_args=name_pack_query_args) + trigger_dbs = self._get_entities( + model_persistence=Trigger, + refs=trigger_refs, + query_args=name_pack_query_args, + ) for trigger_db in trigger_dbs: trigger_by_refs[trigger_db.get_reference().ref] = trigger_db - trigger_type_dbs = self._get_entities(model_persistence=TriggerType, - refs=trigger_type_refs, - query_args=name_pack_query_args) + trigger_type_dbs = self._get_entities( + model_persistence=TriggerType, + refs=trigger_type_refs, + query_args=name_pack_query_args, + ) for trigger_type_db in trigger_type_dbs: trigger_type_by_refs[trigger_type_db.get_reference().ref] = trigger_type_db diff --git a/st2api/st2api/controllers/v1/rules.py b/st2api/st2api/controllers/v1/rules.py index 5904f9140e..89f9e63531 100644 --- a/st2api/st2api/controllers/v1/rules.py +++ b/st2api/st2api/controllers/v1/rules.py @@ -34,124 +34,149 @@ from st2common.router import exc from st2common.router import abort from st2common.router import Response -from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count +from st2common.services.triggers import ( + cleanup_trigger_db_for_rule, + increment_trigger_ref_count, +) http_client = six.moves.http_client LOG = logging.getLogger(__name__) -class RuleController(BaseRestControllerMixin, BaseResourceIsolationControllerMixin, - ContentPackResourceController): +class RuleController( + BaseRestControllerMixin, + BaseResourceIsolationControllerMixin, + ContentPackResourceController, +): """ - Implements the RESTful web endpoint that handles - the lifecycle of Rules in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Rules in the system. """ + views = RuleViewController() model = RuleAPI access = Rule supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'action': 'action.ref', - 'trigger': 'trigger', - 'enabled': 'enabled', - 'user': 'context.user' + "name": "name", + "pack": "pack", + "action": "action.ref", + "trigger": "trigger", + "enabled": "enabled", + "user": "context.user", } - filter_transform_functions = { - 'enabled': transform_to_bool - } + filter_transform_functions = {"enabled": transform_to_bool} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - mandatory_include_fields_retrieve = ['pack', 'name', 'trigger'] + mandatory_include_fields_retrieve = ["pack", "name", "trigger"] - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, show_secrets=False, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + show_secrets=False, + requester_user=None, + **raw_filters, + ): from_model_kwargs = { - 'ignore_missing_trigger': True, - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "ignore_missing_trigger": True, + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ), } - return super(RuleController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return super(RuleController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user, show_secrets=False): from_model_kwargs = { - 'ignore_missing_trigger': True, - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "ignore_missing_trigger": True, + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ), } - return super(RuleController, self)._get_one(ref_or_id, from_model_kwargs=from_model_kwargs, - requester_user=requester_user, - permission_type=PermissionType.RULE_VIEW) + return super(RuleController, self)._get_one( + ref_or_id, + from_model_kwargs=from_model_kwargs, + requester_user=requester_user, + permission_type=PermissionType.RULE_VIEW, + ) def post(self, rule, requester_user): """ - Create a new rule. + Create a new rule. - Handles requests: - POST /rules/ + Handles requests: + POST /rules/ """ rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.RULE_CREATE - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=rule, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, resource_api=rule, permission_type=permission_type + ) if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) # Validate that the authenticated user is admin if user query param is provided user = requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) - if not hasattr(rule, 'context'): + if not hasattr(rule, "context"): rule.context = dict() - rule.context['user'] = user + rule.context["user"] = user try: rule_db = RuleAPI.to_model(rule) - LOG.debug('/rules/ POST verified RuleAPI and formulated RuleDB=%s', rule_db) + LOG.debug("/rules/ POST verified RuleAPI and formulated RuleDB=%s", rule_db) # Check referenced trigger and action permissions # Note: This needs to happen after "to_model" call since to_model performs some # validation (trigger exists, etc.) - rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user, - rule_api=rule) + rbac_utils.assert_user_has_rule_trigger_and_action_permission( + user_db=requester_user, rule_api=rule + ) rule_db = Rule.add_or_update(rule_db) # After the rule has been added modify the ref_count. This way a failure to add # the rule due to violated constraints will have no impact on ref_count. increment_trigger_ref_count(rule_api=rule) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for rule data=%s.', rule) + LOG.exception("Validation failed for rule data=%s.", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return except (ValueValidationException, jsonschema.ValidationError) as e: - LOG.exception('Validation failed for rule data=%s.', rule) + LOG.exception("Validation failed for rule data=%s.", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return except TriggerDoesNotExistException: - msg = ('Trigger "%s" defined in the rule does not exist in system or it\'s missing ' - 'required "parameters" attribute' % (rule.trigger['type'])) + msg = ( + 'Trigger "%s" defined in the rule does not exist in system or it\'s missing ' + 'required "parameters" attribute' % (rule.trigger["type"]) + ) LOG.exception(msg) abort(http_client.BAD_REQUEST, msg) return - extra = {'rule_db': rule_db} - LOG.audit('Rule created. Rule.id=%s' % (rule_db.id), extra=extra) + extra = {"rule_db": rule_db} + LOG.audit("Rule created. Rule.id=%s" % (rule_db.id), extra=extra) rule_api = RuleAPI.from_model(rule_db) return Response(json=rule_api, status=exc.HTTPCreated.code) @@ -161,27 +186,33 @@ def put(self, rule, rule_ref_or_id, requester_user): rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.RULE_MODIFY - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=rule, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, resource_db=rule, permission_type=permission_type + ) - LOG.debug('PUT /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db) + LOG.debug( + "PUT /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db + ) if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) # Validate that the authenticated user is admin if user query param is provided user = requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) - if not hasattr(rule, 'context'): + if not hasattr(rule, "context"): rule.context = dict() - rule.context['user'] = user + rule.context["user"] = user try: - if rule.id is not None and rule.id != '' and rule.id != rule_ref_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - rule.id, rule_ref_or_id) + if rule.id is not None and rule.id != "" and rule.id != rule_ref_or_id: + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + rule.id, + rule_ref_or_id, + ) old_rule_db = rule_db try: @@ -193,8 +224,9 @@ def put(self, rule, rule_ref_or_id, requester_user): # Check referenced trigger and action permissions # Note: This needs to happen after "to_model" call since to_model performs some # validation (trigger exists, etc.) - rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user, - rule_api=rule) + rbac_utils.assert_user_has_rule_trigger_and_action_permission( + user_db=requester_user, rule_api=rule + ) rule_db.id = rule_ref_or_id rule_db = Rule.add_or_update(rule_db) @@ -202,48 +234,52 @@ def put(self, rule, rule_ref_or_id, requester_user): # the rule due to violated constraints will have no impact on ref_count. increment_trigger_ref_count(rule_api=rule) except (ValueValidationException, jsonschema.ValidationError, ValueError) as e: - LOG.exception('Validation failed for rule data=%s', rule) + LOG.exception("Validation failed for rule data=%s", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return # use old_rule_db for cleanup. cleanup_trigger_db_for_rule(old_rule_db) - extra = {'old_rule_db': old_rule_db, 'new_rule_db': rule_db} - LOG.audit('Rule updated. Rule.id=%s.' % (rule_db.id), extra=extra) + extra = {"old_rule_db": old_rule_db, "new_rule_db": rule_db} + LOG.audit("Rule updated. Rule.id=%s." % (rule_db.id), extra=extra) rule_api = RuleAPI.from_model(rule_db) return rule_api def delete(self, rule_ref_or_id, requester_user): """ - Delete a rule. + Delete a rule. - Handles requests: - DELETE /rules/1 + Handles requests: + DELETE /rules/1 """ rule_db = self._get_by_ref_or_id(ref_or_id=rule_ref_or_id) permission_type = PermissionType.RULE_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=rule_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, resource_db=rule_db, permission_type=permission_type + ) - LOG.debug('DELETE /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db) + LOG.debug( + "DELETE /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db + ) try: Rule.delete(rule_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s".', - rule_ref_or_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s".', + rule_ref_or_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return # use old_rule_db for cleanup. cleanup_trigger_db_for_rule(rule_db) - extra = {'rule_db': rule_db} - LOG.audit('Rule deleted. Rule.id=%s.' % (rule_db.id), extra=extra) + extra = {"rule_db": rule_db} + LOG.audit("Rule deleted. Rule.id=%s." % (rule_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/ruletypes.py b/st2api/st2api/controllers/v1/ruletypes.py index dcf62069ef..267c192534 100644 --- a/st2api/st2api/controllers/v1/ruletypes.py +++ b/st2api/st2api/controllers/v1/ruletypes.py @@ -28,8 +28,8 @@ class RuleTypesController(object): """ - Implements the RESTful web endpoint that handles - the lifecycle of a RuleType in the system. + Implements the RESTful web endpoint that handles + the lifecycle of a RuleType in the system. """ @staticmethod @@ -46,15 +46,17 @@ def __get_by_name(name): try: return [RuleType.get_by_name(name)] except ValueError as e: - LOG.debug('Database lookup for name="%s" resulted in exception : %s.', name, e) + LOG.debug( + 'Database lookup for name="%s" resulted in exception : %s.', name, e + ) return [] def get_one(self, id): """ - List RuleType objects by id. + List RuleType objects by id. - Handle: - GET /ruletypes/1 + Handle: + GET /ruletypes/1 """ ruletype_db = RuleTypesController.__get_by_id(id) ruletype_api = RuleTypeAPI.from_model(ruletype_db) @@ -62,14 +64,15 @@ def get_one(self, id): def get_all(self): """ - List all RuleType objects. + List all RuleType objects. - Handles requests: - GET /ruletypes/ + Handles requests: + GET /ruletypes/ """ ruletype_dbs = RuleType.get_all() - ruletype_apis = [RuleTypeAPI.from_model(runnertype_db) - for runnertype_db in ruletype_dbs] + ruletype_apis = [ + RuleTypeAPI.from_model(runnertype_db) for runnertype_db in ruletype_dbs + ] return ruletype_apis diff --git a/st2api/st2api/controllers/v1/runnertypes.py b/st2api/st2api/controllers/v1/runnertypes.py index b947babd94..1c84b4425c 100644 --- a/st2api/st2api/controllers/v1/runnertypes.py +++ b/st2api/st2api/controllers/v1/runnertypes.py @@ -31,34 +31,42 @@ class RunnerTypesController(ResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of an RunnerType in the system. + Implements the RESTful web endpoint that handles + the lifecycle of an RunnerType in the system. """ model = RunnerTypeAPI access = RunnerType - supported_filters = { - 'name': 'name' - } - - query_options = { - 'sort': ['name'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(RunnerTypesController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name"} + + query_options = {"sort": ["name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(RunnerTypesController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, name_or_id, requester_user): - return self._get_one_by_name_or_id(name_or_id, - requester_user=requester_user, - permission_type=PermissionType.RUNNER_VIEW) + return self._get_one_by_name_or_id( + name_or_id, + requester_user=requester_user, + permission_type=PermissionType.RUNNER_VIEW, + ) def put(self, runner_type_api, name_or_id, requester_user): # Note: We only allow "enabled" attribute of the runner to be changed @@ -66,28 +74,41 @@ def put(self, runner_type_api, name_or_id, requester_user): permission_type = PermissionType.RUNNER_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=runner_type_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=runner_type_db, + permission_type=permission_type, + ) old_runner_type_db = runner_type_db - LOG.debug('PUT /runnertypes/ lookup with id=%s found object: %s', name_or_id, - runner_type_db) + LOG.debug( + "PUT /runnertypes/ lookup with id=%s found object: %s", + name_or_id, + runner_type_db, + ) try: if runner_type_api.id and runner_type_api.id != name_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - runner_type_api.id, name_or_id) + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + runner_type_api.id, + name_or_id, + ) runner_type_db.enabled = runner_type_api.enabled runner_type_db = RunnerType.add_or_update(runner_type_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for runner type data=%s', runner_type_api) + LOG.exception("Validation failed for runner type data=%s", runner_type_api) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_runner_type_db': old_runner_type_db, 'new_runner_type_db': runner_type_db} - LOG.audit('Runner Type updated. RunnerType.id=%s.' % (runner_type_db.id), extra=extra) + extra = { + "old_runner_type_db": old_runner_type_db, + "new_runner_type_db": runner_type_db, + } + LOG.audit( + "Runner Type updated. RunnerType.id=%s." % (runner_type_db.id), extra=extra + ) runner_type_api = RunnerTypeAPI.from_model(runner_type_db) return runner_type_api diff --git a/st2api/st2api/controllers/v1/sensors.py b/st2api/st2api/controllers/v1/sensors.py index a3a71853d8..b62b56c92d 100644 --- a/st2api/st2api/controllers/v1/sensors.py +++ b/st2api/st2api/controllers/v1/sensors.py @@ -36,35 +36,41 @@ class SensorTypeController(resource.ContentPackResourceController): model = SensorTypeAPI access = SensorType supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'enabled': 'enabled', - 'trigger': 'trigger_types' + "name": "name", + "pack": "pack", + "enabled": "enabled", + "trigger": "trigger_types", } - filter_transform_functions = { - 'enabled': transform_to_bool - } - - options = { - 'sort': ['pack', 'name'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(SensorTypeController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + filter_transform_functions = {"enabled": transform_to_bool} + + options = {"sort": ["pack", "name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(SensorTypeController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.SENSOR_VIEW - return super(SensorTypeController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=permission_type) + return super(SensorTypeController, self)._get_one( + ref_or_id, requester_user=requester_user, permission_type=permission_type + ) def put(self, sensor_type, ref_or_id, requester_user): # Note: Right now this function only supports updating of "enabled" @@ -76,9 +82,11 @@ def put(self, sensor_type, ref_or_id, requester_user): permission_type = PermissionType.SENSOR_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=sensor_type_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=sensor_type_db, + permission_type=permission_type, + ) sensor_type_id = sensor_type_db.id @@ -88,23 +96,23 @@ def put(self, sensor_type, ref_or_id, requester_user): abort(http_client.BAD_REQUEST, six.text_type(e)) return - if not getattr(sensor_type, 'pack', None): + if not getattr(sensor_type, "pack", None): sensor_type.pack = sensor_type_db.pack try: old_sensor_type_db = sensor_type_db sensor_type_db.id = sensor_type_id - sensor_type_db.enabled = getattr(sensor_type, 'enabled', False) + sensor_type_db.enabled = getattr(sensor_type, "enabled", False) sensor_type_db = SensorType.add_or_update(sensor_type_db) except (ValidationError, ValueError) as e: - LOG.exception('Unable to update sensor_type data=%s', sensor_type) + LOG.exception("Unable to update sensor_type data=%s", sensor_type) abort(http_client.BAD_REQUEST, six.text_type(e)) return extra = { - 'old_sensor_type_db': old_sensor_type_db, - 'new_sensor_type_db': sensor_type_db + "old_sensor_type_db": old_sensor_type_db, + "new_sensor_type_db": sensor_type_db, } - LOG.audit('Sensor updated. Sensor.id=%s.' % (sensor_type_db.id), extra=extra) + LOG.audit("Sensor updated. Sensor.id=%s." % (sensor_type_db.id), extra=extra) sensor_type_api = SensorTypeAPI.from_model(sensor_type_db) return sensor_type_api diff --git a/st2api/st2api/controllers/v1/service_registry.py b/st2api/st2api/controllers/v1/service_registry.py index d9ee9d542b..3a54563b25 100644 --- a/st2api/st2api/controllers/v1/service_registry.py +++ b/st2api/st2api/controllers/v1/service_registry.py @@ -22,8 +22,8 @@ from st2common.rbac.backends import get_rbac_backend __all__ = [ - 'ServiceRegistryGroupsController', - 'ServiceRegistryGroupMembersController', + "ServiceRegistryGroupsController", + "ServiceRegistryGroupMembersController", ] @@ -35,11 +35,9 @@ def get_all(self, requester_user): coordinator = coordination.get_coordinator() group_ids = list(coordinator.get_groups().get()) - group_ids = [item.decode('utf-8') for item in group_ids] + group_ids = [item.decode("utf-8") for item in group_ids] - result = { - 'groups': group_ids - } + result = {"groups": group_ids} return result @@ -51,26 +49,26 @@ def get_one(self, group_id, requester_user): coordinator = coordination.get_coordinator() if not isinstance(group_id, six.binary_type): - group_id = group_id.encode('utf-8') + group_id = group_id.encode("utf-8") try: member_ids = list(coordinator.get_members(group_id).get()) except GroupNotCreated: - msg = ('Group with ID "%s" not found.' % (group_id.decode('utf-8'))) + msg = 'Group with ID "%s" not found.' % (group_id.decode("utf-8")) raise StackStormDBObjectNotFoundError(msg) - result = { - 'members': [] - } + result = {"members": []} for member_id in member_ids: - capabilities = coordinator.get_member_capabilities(group_id, member_id).get() + capabilities = coordinator.get_member_capabilities( + group_id, member_id + ).get() item = { - 'group_id': group_id.decode('utf-8'), - 'member_id': member_id.decode('utf-8'), - 'capabilities': capabilities + "group_id": group_id.decode("utf-8"), + "member_id": member_id.decode("utf-8"), + "capabilities": capabilities, } - result['members'].append(item) + result["members"].append(item) return result diff --git a/st2api/st2api/controllers/v1/timers.py b/st2api/st2api/controllers/v1/timers.py index c91b80fec1..541957a099 100644 --- a/st2api/st2api/controllers/v1/timers.py +++ b/st2api/st2api/controllers/v1/timers.py @@ -30,17 +30,13 @@ from st2common.services.triggerwatcher import TriggerWatcher from st2common.router import abort -__all__ = [ - 'TimersController', - 'TimersHolder' -] +__all__ = ["TimersController", "TimersHolder"] LOG = logging.getLogger(__name__) class TimersHolder(object): - def __init__(self): self._timers = {} @@ -54,7 +50,7 @@ def get_all(self, timer_type=None): timer_triggers = [] for _, timer in iteritems(self._timers): - if not timer_type or timer['type'] == timer_type: + if not timer_type or timer["type"] == timer_type: timer_triggers.append(timer) return timer_triggers @@ -65,35 +61,37 @@ class TimersController(resource.ContentPackResourceController): access = Trigger supported_filters = { - 'type': 'type', + "type": "type", } - query_options = { - 'sort': ['type'] - } + query_options = {"sort": ["type"]} def __init__(self): self._timers = TimersHolder() self._trigger_types = TIMER_TRIGGER_TYPES.keys() queue_suffix = self.__class__.__name__ - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=queue_suffix, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=queue_suffix, + exclusive=True, + ) self._trigger_watcher.start() self._register_timer_trigger_types() self._allowed_timer_types = TIMER_TRIGGER_TYPES.keys() def get_all(self, timer_type=None): if timer_type and timer_type not in self._allowed_timer_types: - msg = 'Timer type %s not in supported types - %s.' % (timer_type, - self._allowed_timer_types) + msg = "Timer type %s not in supported types - %s." % ( + timer_type, + self._allowed_timer_types, + ) abort(http_client.BAD_REQUEST, msg) t_all = self._timers.get_all(timer_type=timer_type) - LOG.debug('Got timers: %s', t_all) + LOG.debug("Got timers: %s", t_all) return t_all def get_one(self, ref_or_id, requester_user): @@ -108,9 +106,11 @@ def get_one(self, ref_or_id, requester_user): resource_db = TimerDB(pack=trigger_db.pack, name=trigger_db.name) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=resource_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=resource_db, + permission_type=permission_type, + ) result = self.model.from_model(trigger_db) return result @@ -119,7 +119,7 @@ def add_trigger(self, trigger): # Note: Permission checking for creating and deleting a timer is done during rule # creation ref = self._get_timer_ref(trigger) - LOG.info('Started timer %s with parameters %s', ref, trigger['parameters']) + LOG.info("Started timer %s with parameters %s", ref, trigger["parameters"]) self._timers.add_trigger(ref, trigger) def update_trigger(self, trigger): @@ -130,14 +130,16 @@ def remove_trigger(self, trigger): # creation ref = self._get_timer_ref(trigger) self._timers.remove_trigger(ref, trigger) - LOG.info('Stopped timer %s with parameters %s.', ref, trigger['parameters']) + LOG.info("Stopped timer %s with parameters %s.", ref, trigger["parameters"]) def _register_timer_trigger_types(self): for trigger_type in TIMER_TRIGGER_TYPES.values(): trigger_service.create_trigger_type_db(trigger_type) def _get_timer_ref(self, trigger): - return ResourceReference.to_string_reference(pack=trigger['pack'], name=trigger['name']) + return ResourceReference.to_string_reference( + pack=trigger["pack"], name=trigger["name"] + ) ############################################## # Event handler methods for the trigger events diff --git a/st2api/st2api/controllers/v1/traces.py b/st2api/st2api/controllers/v1/traces.py index 91c6e95e4f..4ab1d02aa5 100644 --- a/st2api/st2api/controllers/v1/traces.py +++ b/st2api/st2api/controllers/v1/traces.py @@ -18,47 +18,53 @@ from st2common.persistence.trace import Trace from st2common.rbac.types import PermissionType -__all__ = [ - 'TracesController' -] +__all__ = ["TracesController"] class TracesController(ResourceController): model = TraceAPI access = Trace supported_filters = { - 'trace_tag': 'trace_tag', - 'execution': 'action_executions.object_id', - 'rule': 'rules.object_id', - 'trigger_instance': 'trigger_instances.object_id', + "trace_tag": "trace_tag", + "execution": "action_executions.object_id", + "rule": "rules.object_id", + "trigger_instance": "trigger_instances.object_id", } - query_options = { - 'sort': ['-start_timestamp', 'trace_tag'] - } + query_options = {"sort": ["-start_timestamp", "trace_tag"]} - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): # Use a custom sort order when filtering on a timestamp so we return a correct result as # expected by the user query_options = None - if 'sort_desc' in raw_filters and raw_filters['sort_desc'] == 'True': - query_options = {'sort': ['-start_timestamp', 'trace_tag']} - elif 'sort_asc' in raw_filters and raw_filters['sort_asc'] == 'True': - query_options = {'sort': ['+start_timestamp', 'trace_tag']} - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - requester_user=requester_user) + if "sort_desc" in raw_filters and raw_filters["sort_desc"] == "True": + query_options = {"sort": ["-start_timestamp", "trace_tag"]} + elif "sort_asc" in raw_filters and raw_filters["sort_asc"] == "True": + query_options = {"sort": ["+start_timestamp", "trace_tag"]} + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - return self._get_one_by_id(id, - requester_user=requester_user, - permission_type=PermissionType.TRACE_VIEW) + return self._get_one_by_id( + id, requester_user=requester_user, permission_type=PermissionType.TRACE_VIEW + ) traces_controller = TracesController() diff --git a/st2api/st2api/controllers/v1/triggers.py b/st2api/st2api/controllers/v1/triggers.py index 12c3f133ec..cbdc5ca66b 100644 --- a/st2api/st2api/controllers/v1/triggers.py +++ b/st2api/st2api/controllers/v1/triggers.py @@ -39,55 +39,64 @@ class TriggerTypeController(resource.ContentPackResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of TriggerTypes in the system. + Implements the RESTful web endpoint that handles + the lifecycle of TriggerTypes in the system. """ + model = TriggerTypeAPI access = TriggerType - supported_filters = { - 'name': 'name', - 'pack': 'pack' - } - - options = { - 'sort': ['pack', 'name'] - } - - query_options = { - 'sort': ['ref'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack"} + + options = {"sort": ["pack", "name"]} + + query_options = {"sort": ["ref"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, triggertype_ref_or_id): - return self._get_one(triggertype_ref_or_id, permission_type=None, requester_user=None) + return self._get_one( + triggertype_ref_or_id, permission_type=None, requester_user=None + ) def post(self, triggertype): """ - Create a new triggertype. + Create a new triggertype. - Handles requests: - POST /triggertypes/ + Handles requests: + POST /triggertypes/ """ try: triggertype_db = TriggerTypeAPI.to_model(triggertype) triggertype_db = TriggerType.add_or_update(triggertype_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for triggertype data=%s.', triggertype) + LOG.exception("Validation failed for triggertype data=%s.", triggertype) abort(http_client.BAD_REQUEST, six.text_type(e)) return else: - extra = {'triggertype_db': triggertype_db} - LOG.audit('TriggerType created. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = {"triggertype_db": triggertype_db} + LOG.audit( + "TriggerType created. TriggerType.id=%s" % (triggertype_db.id), + extra=extra, + ) if not triggertype_db.parameters_schema: TriggerTypeController._create_shadow_trigger(triggertype_db) @@ -106,34 +115,44 @@ def put(self, triggertype, triggertype_ref_or_id): try: triggertype_db = TriggerTypeAPI.to_model(triggertype) - if triggertype.id is not None and len(triggertype.id) > 0 and \ - triggertype.id != triggertype_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - triggertype.id, triggertype_id) + if ( + triggertype.id is not None + and len(triggertype.id) > 0 + and triggertype.id != triggertype_id + ): + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + triggertype.id, + triggertype_id, + ) triggertype_db.id = triggertype_id old_triggertype_db = triggertype_db triggertype_db = TriggerType.add_or_update(triggertype_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for triggertype data=%s', triggertype) + LOG.exception("Validation failed for triggertype data=%s", triggertype) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_triggertype_db': old_triggertype_db, 'new_triggertype_db': triggertype_db} - LOG.audit('TriggerType updated. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = { + "old_triggertype_db": old_triggertype_db, + "new_triggertype_db": triggertype_db, + } + LOG.audit( + "TriggerType updated. TriggerType.id=%s" % (triggertype_db.id), extra=extra + ) triggertype_api = TriggerTypeAPI.from_model(triggertype_db) return triggertype_api def delete(self, triggertype_ref_or_id): """ - Delete a triggertype. + Delete a triggertype. - Handles requests: - DELETE /triggertypes/1 - DELETE /triggertypes/pack.name + Handles requests: + DELETE /triggertypes/1 + DELETE /triggertypes/pack.name """ - LOG.info('DELETE /triggertypes/ with ref_or_id=%s', - triggertype_ref_or_id) + LOG.info("DELETE /triggertypes/ with ref_or_id=%s", triggertype_ref_or_id) triggertype_db = self._get_by_ref_or_id(ref_or_id=triggertype_ref_or_id) triggertype_id = triggertype_db.id @@ -146,13 +165,18 @@ def delete(self, triggertype_ref_or_id): try: TriggerType.delete(triggertype_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - triggertype_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + triggertype_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return else: - extra = {'triggertype': triggertype_db} - LOG.audit('TriggerType deleted. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = {"triggertype": triggertype_db} + LOG.audit( + "TriggerType deleted. TriggerType.id=%s" % (triggertype_db.id), + extra=extra, + ) if not triggertype_db.parameters_schema: TriggerTypeController._delete_shadow_trigger(triggertype_db) @@ -162,55 +186,70 @@ def delete(self, triggertype_ref_or_id): def _create_shadow_trigger(triggertype_db): try: trigger_type_ref = triggertype_db.get_reference().ref - trigger = {'name': triggertype_db.name, - 'pack': triggertype_db.pack, - 'type': trigger_type_ref, - 'parameters': {}} + trigger = { + "name": triggertype_db.name, + "pack": triggertype_db.pack, + "type": trigger_type_ref, + "parameters": {}, + } trigger_db = TriggerService.create_or_update_trigger_db(trigger) - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger created for parameter-less TriggerType. Trigger.id=%s' % - (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit( + "Trigger created for parameter-less TriggerType. Trigger.id=%s" + % (trigger_db.id), + extra=extra, + ) except (ValidationError, ValueError): - LOG.exception('Validation failed for trigger data=%s.', trigger) + LOG.exception("Validation failed for trigger data=%s.", trigger) # Not aborting as this is convenience. return except StackStormDBObjectConflictError as e: - LOG.warn('Trigger creation of "%s" failed with uniqueness conflict. Exception: %s', - trigger, six.text_type(e)) + LOG.warn( + 'Trigger creation of "%s" failed with uniqueness conflict. Exception: %s', + trigger, + six.text_type(e), + ) # Not aborting as this is convenience. return @staticmethod def _delete_shadow_trigger(triggertype_db): # shadow Trigger's have the same name as the shadowed TriggerType. - triggertype_ref = ResourceReference(name=triggertype_db.name, pack=triggertype_db.pack) + triggertype_ref = ResourceReference( + name=triggertype_db.name, pack=triggertype_db.pack + ) trigger_db = TriggerService.get_trigger_db_by_ref(triggertype_ref.ref) if not trigger_db: - LOG.warn('No shadow trigger found for %s. Will skip delete.', triggertype_db) + LOG.warn( + "No shadow trigger found for %s. Will skip delete.", triggertype_db + ) return try: Trigger.delete(trigger_db) except Exception: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - trigger_db.id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + trigger_db.id, + ) - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra) class TriggerController(object): """ - Implements the RESTful web endpoint that handles - the lifecycle of Triggers in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Triggers in the system. """ + def get_one(self, trigger_id): """ - List trigger by id. + List trigger by id. - Handle: - GET /triggers/1 + Handle: + GET /triggers/1 """ trigger_db = TriggerController.__get_by_id(trigger_id) trigger_api = TriggerAPI.from_model(trigger_db) @@ -218,10 +257,10 @@ def get_one(self, trigger_id): def get_all(self, requester_user=None): """ - List all triggers. + List all triggers. - Handles requests: - GET /triggers/ + Handles requests: + GET /triggers/ """ trigger_dbs = Trigger.get_all() trigger_apis = [TriggerAPI.from_model(trigger_db) for trigger_db in trigger_dbs] @@ -229,20 +268,20 @@ def get_all(self, requester_user=None): def post(self, trigger): """ - Create a new trigger. + Create a new trigger. - Handles requests: - POST /triggers/ + Handles requests: + POST /triggers/ """ try: trigger_db = TriggerService.create_trigger_db(trigger) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for trigger data=%s.', trigger) + LOG.exception("Validation failed for trigger data=%s.", trigger) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'trigger': trigger_db} - LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger": trigger_db} + LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra) trigger_api = TriggerAPI.from_model(trigger_db) return Response(json=trigger_api, status=http_client.CREATED) @@ -250,42 +289,47 @@ def post(self, trigger): def put(self, trigger, trigger_id): trigger_db = TriggerController.__get_by_id(trigger_id) try: - if trigger.id is not None and trigger.id != '' and trigger.id != trigger_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - trigger.id, trigger_id) + if trigger.id is not None and trigger.id != "" and trigger.id != trigger_id: + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + trigger.id, + trigger_id, + ) trigger_db = TriggerAPI.to_model(trigger) trigger_db.id = trigger_id trigger_db = Trigger.add_or_update(trigger_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for trigger data=%s', trigger) + LOG.exception("Validation failed for trigger data=%s", trigger) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_trigger_db': trigger, 'new_trigger_db': trigger_db} - LOG.audit('Trigger updated. Trigger.id=%s' % (trigger.id), extra=extra) + extra = {"old_trigger_db": trigger, "new_trigger_db": trigger_db} + LOG.audit("Trigger updated. Trigger.id=%s" % (trigger.id), extra=extra) trigger_api = TriggerAPI.from_model(trigger_db) return trigger_api def delete(self, trigger_id): """ - Delete a trigger. + Delete a trigger. - Handles requests: - DELETE /triggers/1 + Handles requests: + DELETE /triggers/1 """ - LOG.info('DELETE /triggers/ with id=%s', trigger_id) + LOG.info("DELETE /triggers/ with id=%s", trigger_id) trigger_db = TriggerController.__get_by_id(trigger_id) try: Trigger.delete(trigger_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - trigger_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + trigger_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) @@ -294,7 +338,9 @@ def __get_by_id(trigger_id): try: return Trigger.get_by_id(trigger_id) except (ValueError, ValidationError): - LOG.exception('Database lookup for id="%s" resulted in exception.', trigger_id) + LOG.exception( + 'Database lookup for id="%s" resulted in exception.', trigger_id + ) abort(http_client.NOT_FOUND) @staticmethod @@ -302,7 +348,11 @@ def __get_by_name(trigger_name): try: return [Trigger.get_by_name(trigger_name)] except ValueError as e: - LOG.debug('Database lookup for name="%s" resulted in exception : %s.', trigger_name, e) + LOG.debug( + 'Database lookup for name="%s" resulted in exception : %s.', + trigger_name, + e, + ) return [] @@ -311,7 +361,9 @@ class TriggerInstanceControllerMixin(object): access = TriggerInstance -class TriggerInstanceResendController(TriggerInstanceControllerMixin, resource.ResourceController): +class TriggerInstanceResendController( + TriggerInstanceControllerMixin, resource.ResourceController +): supported_filters = {} def __init__(self, *args, **kwargs): @@ -338,106 +390,130 @@ def post(self, trigger_instance_id): POST /triggerinstance//re_send """ # Note: We only really need parameters here - existing_trigger_instance = self._get_one_by_id(id=trigger_instance_id, - permission_type=None, - requester_user=None) + existing_trigger_instance = self._get_one_by_id( + id=trigger_instance_id, permission_type=None, requester_user=None + ) new_payload = copy.deepcopy(existing_trigger_instance.payload) - new_payload['__context'] = { - 'original_id': trigger_instance_id - } + new_payload["__context"] = {"original_id": trigger_instance_id} try: - self.trigger_dispatcher.dispatch(existing_trigger_instance.trigger, - new_payload) + self.trigger_dispatcher.dispatch( + existing_trigger_instance.trigger, new_payload + ) return { - 'message': 'Trigger instance %s succesfully re-sent.' % trigger_instance_id, - 'payload': new_payload + "message": "Trigger instance %s succesfully re-sent." + % trigger_instance_id, + "payload": new_payload, } except Exception as e: abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) -class TriggerInstanceController(TriggerInstanceControllerMixin, resource.ResourceController): +class TriggerInstanceController( + TriggerInstanceControllerMixin, resource.ResourceController +): """ - Implements the RESTful web endpoint that handles - the lifecycle of TriggerInstances in the system. + Implements the RESTful web endpoint that handles + the lifecycle of TriggerInstances in the system. """ + supported_filters = { - 'timestamp_gt': 'occurrence_time.gt', - 'timestamp_lt': 'occurrence_time.lt', - 'status': 'status', - 'trigger': 'trigger.in' + "timestamp_gt": "occurrence_time.gt", + "timestamp_lt": "occurrence_time.lt", + "status": "status", + "trigger": "trigger.in", } filter_transform_functions = { - 'timestamp_gt': lambda value: isotime.parse(value=value), - 'timestamp_lt': lambda value: isotime.parse(value=value) + "timestamp_gt": lambda value: isotime.parse(value=value), + "timestamp_lt": lambda value: isotime.parse(value=value), } - query_options = { - 'sort': ['-occurrence_time', 'trigger'] - } + query_options = {"sort": ["-occurrence_time", "trigger"]} def __init__(self): super(TriggerInstanceController, self).__init__() def get_one(self, instance_id): """ - List triggerinstance by instance_id. + List triggerinstance by instance_id. - Handle: - GET /triggerinstances/1 + Handle: + GET /triggerinstances/1 """ - return self._get_one_by_id(instance_id, permission_type=None, requester_user=None) - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + return self._get_one_by_id( + instance_id, permission_type=None, requester_user=None + ) + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): """ - List all triggerinstances. + List all triggerinstances. - Handles requests: - GET /triggerinstances/ + Handles requests: + GET /triggerinstances/ """ # If trigger_type filter is provided, filter based on the TriggerType via Trigger object - trigger_type_ref = raw_filters.get('trigger_type', None) + trigger_type_ref = raw_filters.get("trigger_type", None) if trigger_type_ref: # 1. Retrieve TriggerType object id which match this trigger_type ref - trigger_dbs = Trigger.query(type=trigger_type_ref, - only_fields=['ref', 'name', 'pack', 'type']) + trigger_dbs = Trigger.query( + type=trigger_type_ref, only_fields=["ref", "name", "pack", "type"] + ) trigger_refs = [trigger_db.ref for trigger_db in trigger_dbs] - raw_filters['trigger'] = trigger_refs + raw_filters["trigger"] = trigger_refs - if trigger_type_ref and len(raw_filters.get('trigger', [])) == 0: + if trigger_type_ref and len(raw_filters.get("trigger", [])) == 0: # Empty list means trigger_type_ref filter was provided, but we matched no Triggers so # we should return back empty result return [] - trigger_instances = self._get_trigger_instances(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + trigger_instances = self._get_trigger_instances( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) return trigger_instances - def _get_trigger_instances(self, exclude_fields=None, include_fields=None, sort=None, offset=0, - limit=None, raw_filters=None, requester_user=None): + def _get_trigger_instances( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + raw_filters=None, + requester_user=None, + ): if limit is None: limit = self.default_limit limit = int(limit) - LOG.debug('Retrieving all trigger instances with filters=%s', raw_filters) - return super(TriggerInstanceController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + LOG.debug("Retrieving all trigger instances with filters=%s", raw_filters) + return super(TriggerInstanceController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) triggertype_controller = TriggerTypeController() diff --git a/st2api/st2api/controllers/v1/user.py b/st2api/st2api/controllers/v1/user.py index e3de60b978..0593a13384 100644 --- a/st2api/st2api/controllers/v1/user.py +++ b/st2api/st2api/controllers/v1/user.py @@ -17,9 +17,7 @@ from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'UserController' -] +__all__ = ["UserController"] class UserController(object): @@ -43,21 +41,21 @@ def get(self, requester_user, auth_info): roles = [] data = { - 'username': requester_user.name, - 'authentication': { - 'method': auth_info['method'], - 'location': auth_info['location'] + "username": requester_user.name, + "authentication": { + "method": auth_info["method"], + "location": auth_info["location"], + }, + "rbac": { + "enabled": cfg.CONF.rbac.enable, + "roles": roles, + "is_admin": rbac_utils.user_is_admin(user_db=requester_user), }, - 'rbac': { - 'enabled': cfg.CONF.rbac.enable, - 'roles': roles, - 'is_admin': rbac_utils.user_is_admin(user_db=requester_user) - } } - if auth_info.get('token_expire', None): - token_expire = auth_info['token_expire'].strftime('%Y-%m-%dT%H:%M:%SZ') - data['authentication']['token_expire'] = token_expire + if auth_info.get("token_expire", None): + token_expire = auth_info["token_expire"].strftime("%Y-%m-%dT%H:%M:%SZ") + data["authentication"]["token_expire"] = token_expire return data diff --git a/st2api/st2api/controllers/v1/webhooks.py b/st2api/st2api/controllers/v1/webhooks.py index 35af0c8337..1985bb4dad 100644 --- a/st2api/st2api/controllers/v1/webhooks.py +++ b/st2api/st2api/controllers/v1/webhooks.py @@ -19,7 +19,10 @@ from six.moves import http_client from st2common import log as logging -from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME +from st2common.constants.auth import ( + HEADER_API_KEY_ATTRIBUTE_NAME, + HEADER_ATTRIBUTE_NAME, +) from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES from st2common.models.api.trace import TraceContext from st2common.models.api.trigger import TriggerAPI @@ -35,13 +38,14 @@ LOG = logging.getLogger(__name__) -TRACE_TAG_HEADER = 'St2-Trace-Tag' +TRACE_TAG_HEADER = "St2-Trace-Tag" class HooksHolder(object): """ Maintains a hook to TriggerDB mapping. """ + def __init__(self): self._triggers_by_hook = {} @@ -58,7 +62,7 @@ def remove_hook(self, hook, trigger): return False remove_index = -1 for idx, item in enumerate(self._triggers_by_hook[hook]): - if item['id'] == trigger['id']: + if item["id"] == trigger["id"]: remove_index = idx break if remove_index < 0: @@ -81,17 +85,19 @@ def get_all(self): class WebhooksController(object): def __init__(self, *args, **kwargs): self._hooks = HooksHolder() - self._base_url = '/webhooks/' + self._base_url = "/webhooks/" self._trigger_types = list(WEBHOOK_TRIGGER_TYPES.keys()) self._trigger_dispatcher_service = TriggerDispatcherService(LOG) queue_suffix = self.__class__.__name__ - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=queue_suffix, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=queue_suffix, + exclusive=True, + ) self._trigger_watcher.start() self._register_webhook_trigger_types() @@ -108,9 +114,11 @@ def get_one(self, url, requester_user): permission_type = PermissionType.WEBHOOK_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=WebhookDB(name=url), - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=WebhookDB(name=url), + permission_type=permission_type, + ) # For demonstration purpose return 1st return triggers[0] @@ -120,55 +128,62 @@ def post(self, hook, webhook_body_api, headers, requester_user): permission_type = PermissionType.WEBHOOK_SEND rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=WebhookDB(name=hook), - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=WebhookDB(name=hook), + permission_type=permission_type, + ) headers = self._get_headers_as_dict(headers) headers = self._filter_authentication_headers(headers) # If webhook contains a trace-tag use that else create create a unique trace-tag. - trace_context = self._create_trace_context(trace_tag=headers.pop(TRACE_TAG_HEADER, None), - hook=hook) + trace_context = self._create_trace_context( + trace_tag=headers.pop(TRACE_TAG_HEADER, None), hook=hook + ) - if hook == 'st2' or hook == 'st2/': + if hook == "st2" or hook == "st2/": # When using st2 or system webhook, body needs to always be a dict if not isinstance(body, dict): type_string = get_json_type_for_python_value(body) - msg = ('Webhook body needs to be an object, got: %s' % (type_string)) + msg = "Webhook body needs to be an object, got: %s" % (type_string) raise ValueError(msg) - trigger = body.get('trigger', None) - payload = body.get('payload', None) + trigger = body.get("trigger", None) + payload = body.get("payload", None) if not trigger: - msg = 'Trigger not specified.' + msg = "Trigger not specified." return abort(http_client.BAD_REQUEST, msg) - self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger, - payload=payload, - trace_context=trace_context, - throw_on_validation_error=True) + self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=True, + ) else: if not self._is_valid_hook(hook): - self._log_request('Invalid hook.', headers, body) - msg = 'Webhook %s not registered with st2' % hook + self._log_request("Invalid hook.", headers, body) + msg = "Webhook %s not registered with st2" % hook return abort(http_client.NOT_FOUND, msg) triggers = self._hooks.get_triggers_for_hook(hook) payload = {} - payload['headers'] = headers - payload['body'] = body + payload["headers"] = headers + payload["body"] = body # Dispatch trigger instance for each of the trigger found for trigger_dict in triggers: # TODO: Instead of dispatching the whole dict we should just # dispatch TriggerDB.ref or similar - self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger_dict, - payload=payload, - trace_context=trace_context, - throw_on_validation_error=True) + self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger_dict, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=True, + ) return Response(json=body, status=http_client.ACCEPTED) @@ -183,7 +198,7 @@ def _register_webhook_trigger_types(self): def _create_trace_context(self, trace_tag, hook): # if no trace_tag then create a unique one if not trace_tag: - trace_tag = 'webhook-%s-%s' % (hook, uuid.uuid4().hex) + trace_tag = "webhook-%s-%s" % (hook, uuid.uuid4().hex) return TraceContext(trace_tag=trace_tag) def add_trigger(self, trigger): @@ -191,7 +206,7 @@ def add_trigger(self, trigger): # Note: Permission checking for creating and deleting a webhook is done during rule # creation url = self._get_normalized_url(trigger) - LOG.info('Listening to endpoint: %s', urlparse.urljoin(self._base_url, url)) + LOG.info("Listening to endpoint: %s", urlparse.urljoin(self._base_url, url)) self._hooks.add_hook(url, trigger) def update_trigger(self, trigger): @@ -204,14 +219,16 @@ def remove_trigger(self, trigger): removed = self._hooks.remove_hook(url, trigger) if removed: - LOG.info('Stop listening to endpoint: %s', urlparse.urljoin(self._base_url, url)) + LOG.info( + "Stop listening to endpoint: %s", urlparse.urljoin(self._base_url, url) + ) def _get_normalized_url(self, trigger): """ remove the trailing and leading / so that the hook url and those coming from trigger parameters end up being the same. """ - return trigger['parameters']['url'].strip('/') + return trigger["parameters"]["url"].strip("/") def _get_headers_as_dict(self, headers): headers_dict = {} @@ -220,13 +237,13 @@ def _get_headers_as_dict(self, headers): return headers_dict def _filter_authentication_headers(self, headers): - auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, 'Cookie'] + auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, "Cookie"] return {key: value for key, value in headers.items() if key not in auth_headers} def _log_request(self, msg, headers, body, log_method=LOG.debug): headers = self._get_headers_as_dict(headers) body = str(body) - log_method('%s\n\trequest.header: %s.\n\trequest.body: %s.', msg, headers, body) + log_method("%s\n\trequest.header: %s.\n\trequest.body: %s.", msg, headers, body) ############################################## # Event handler methods for the trigger events diff --git a/st2api/st2api/controllers/v1/workflow_inspection.py b/st2api/st2api/controllers/v1/workflow_inspection.py index 1e5ee53d85..04d60dd2b1 100644 --- a/st2api/st2api/controllers/v1/workflow_inspection.py +++ b/st2api/st2api/controllers/v1/workflow_inspection.py @@ -30,13 +30,12 @@ class WorkflowInspectionController(object): - def mock_st2_ctx(self): st2_ctx = { - 'st2': { - 'api_url': api_utils.get_full_public_api_url(), - 'action_execution_id': uuid.uuid4().hex, - 'user': cfg.CONF.system_user.user + "st2": { + "api_url": api_utils.get_full_public_api_url(), + "action_execution_id": uuid.uuid4().hex, + "user": cfg.CONF.system_user.user, } } @@ -44,7 +43,7 @@ def mock_st2_ctx(self): def post(self, wf_def): # Load workflow definition into workflow spec model. - spec_module = specs_loader.get_spec_module('native') + spec_module = specs_loader.get_spec_module("native") wf_spec = spec_module.instantiate(wf_def) # Mock the st2 context that is typically passed to the workflow engine. diff --git a/st2api/st2api/validation.py b/st2api/st2api/validation.py index ae92d1d9cb..42120c57bf 100644 --- a/st2api/st2api/validation.py +++ b/st2api/st2api/validation.py @@ -15,9 +15,7 @@ from oslo_config import cfg -__all__ = [ - 'validate_rbac_is_correctly_configured' -] +__all__ = ["validate_rbac_is_correctly_configured"] def validate_rbac_is_correctly_configured(): @@ -28,24 +26,29 @@ def validate_rbac_is_correctly_configured(): return True from st2common.rbac.backends import get_available_backends + available_rbac_backends = get_available_backends() # 1. Verify auth is enabled if not cfg.CONF.auth.enable: - msg = ('Authentication is not enabled. RBAC only works when authentication is enabled. ' - 'You can either enable authentication or disable RBAC.') + msg = ( + "Authentication is not enabled. RBAC only works when authentication is enabled. " + "You can either enable authentication or disable RBAC." + ) raise ValueError(msg) # 2. Verify default backend is set - if cfg.CONF.rbac.backend != 'default': - msg = ('You have enabled RBAC, but RBAC backend is not set to "default". ' - 'For RBAC to work, you need to set ' - '"rbac.backend" config option to "default" and restart st2api service.') + if cfg.CONF.rbac.backend != "default": + msg = ( + 'You have enabled RBAC, but RBAC backend is not set to "default". ' + "For RBAC to work, you need to set " + '"rbac.backend" config option to "default" and restart st2api service.' + ) raise ValueError(msg) # 3. Verify default RBAC backend is available - if 'default' not in available_rbac_backends: - msg = ('"default" RBAC backend is not available.') + if "default" not in available_rbac_backends: + msg = '"default" RBAC backend is not available.' raise ValueError(msg) return True diff --git a/st2api/st2api/wsgi.py b/st2api/st2api/wsgi.py index b9c92b7bf4..79baf0f110 100644 --- a/st2api/st2api/wsgi.py +++ b/st2api/st2api/wsgi.py @@ -20,6 +20,7 @@ import os from st2common.util.monkey_patch import monkey_patch + # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things # including shutdown @@ -32,8 +33,11 @@ from st2api import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2api/tests/integration/test_gunicorn_configs.py b/st2api/tests/integration/test_gunicorn_configs.py index 65950bfa7c..9375cf3b85 100644 --- a/st2api/tests/integration/test_gunicorn_configs.py +++ b/st2api/tests/integration/test_gunicorn_configs.py @@ -28,38 +28,44 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") class GunicornWSGIEntryPointTestCase(IntegrationTestCase): - @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled') + @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled") def test_st2api_wsgi_entry_point(self): port = random.randint(10000, 30000) - cmd = ('gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port) + cmd = ( + 'gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' + % port + ) env = os.environ.copy() - env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH + env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid) try: self.add_process(process=process) eventlet.sleep(8) self.assertProcessIsRunning(process=process) - response = requests.get('http://127.0.0.1:%s/v1/actions' % (port)) + response = requests.get("http://127.0.0.1:%s/v1/actions" % (port)) self.assertEqual(response.status_code, http_client.OK) finally: kill_process(process) - @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled') + @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled") def test_st2auth(self): port = random.randint(10000, 30000) - cmd = ('gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port) + cmd = ( + 'gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' + % port + ) env = os.environ.copy() - env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH + env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid) try: self.add_process(process=process) eventlet.sleep(8) self.assertProcessIsRunning(process=process) - response = requests.post('http://127.0.0.1:%s/tokens' % (port)) + response = requests.post("http://127.0.0.1:%s/tokens" % (port)) self.assertEqual(response.status_code, http_client.UNAUTHORIZED) finally: kill_process(process) diff --git a/st2api/tests/unit/controllers/test_root.py b/st2api/tests/unit/controllers/test_root.py index d4172ce155..db4ea01713 100644 --- a/st2api/tests/unit/controllers/test_root.py +++ b/st2api/tests/unit/controllers/test_root.py @@ -15,15 +15,13 @@ from st2tests.api import FunctionalTest -__all__ = [ - 'RootControllerTestCase' -] +__all__ = ["RootControllerTestCase"] class RootControllerTestCase(FunctionalTest): def test_get_index(self): - paths = ['/', '/v1/', '/v1'] + paths = ["/", "/v1/", "/v1"] for path in paths: resp = self.app.get(path) - self.assertIn('version', resp.json) - self.assertIn('docs_url', resp.json) + self.assertIn("version", resp.json) + self.assertIn("docs_url", resp.json) diff --git a/st2api/tests/unit/controllers/v1/test_action_alias.py b/st2api/tests/unit/controllers/v1/test_action_alias.py index 299ce530e3..208ed082be 100644 --- a/st2api/tests/unit/controllers/v1/test_action_alias.py +++ b/st2api/tests/unit/controllers/v1/test_action_alias.py @@ -21,31 +21,33 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -FIXTURES_PACK = 'aliases' +FIXTURES_PACK = "aliases" TEST_MODELS = { - 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml'], - 'actions': ['action3.yaml', 'action4.yaml'] + "aliases": [ + "alias1.yaml", + "alias2.yaml", + "alias_with_undefined_jinja_in_ack_format.yaml", + ], + "actions": ["action3.yaml", "action4.yaml"], } TEST_LOAD_MODELS = { - 'aliases': ['alias3.yaml'], + "aliases": ["alias3.yaml"], } -GENERIC_FIXTURES_PACK = 'generic' +GENERIC_FIXTURES_PACK = "generic" -TEST_LOAD_MODELS_GENERIC = { - 'aliases': ['alias3.yaml'], - 'runners': ['testrunner1.yaml'] -} +TEST_LOAD_MODELS_GENERIC = {"aliases": ["alias3.yaml"], "runners": ["testrunner1.yaml"]} -class ActionAliasControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/actionalias' +class ActionAliasControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/actionalias" controller_cls = ActionAliasController - include_attribute_field_name = 'formats' - exclude_attribute_field_name = 'result' + include_attribute_field_name = "formats" + exclude_attribute_field_name = "result" models = None alias1 = None @@ -56,153 +58,186 @@ class ActionAliasControllerTestCase(FunctionalTest, @classmethod def setUpClass(cls): super(ActionAliasControllerTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.alias1 = cls.models['aliases']['alias1.yaml'] - cls.alias2 = cls.models['aliases']['alias2.yaml'] - - loaded_models = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_LOAD_MODELS) - cls.alias3 = loaded_models['aliases']['alias3.yaml'] - - FixturesLoader().save_fixtures_to_db(fixtures_pack=GENERIC_FIXTURES_PACK, - fixtures_dict={'aliases': ['alias7.yaml']}) - - loaded_models = FixturesLoader().load_models(fixtures_pack=GENERIC_FIXTURES_PACK, - fixtures_dict=TEST_LOAD_MODELS_GENERIC) - cls.alias3_generic = loaded_models['aliases']['alias3.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.alias1 = cls.models["aliases"]["alias1.yaml"] + cls.alias2 = cls.models["aliases"]["alias2.yaml"] + + loaded_models = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS + ) + cls.alias3 = loaded_models["aliases"]["alias3.yaml"] + + FixturesLoader().save_fixtures_to_db( + fixtures_pack=GENERIC_FIXTURES_PACK, + fixtures_dict={"aliases": ["alias7.yaml"]}, + ) + + loaded_models = FixturesLoader().load_models( + fixtures_pack=GENERIC_FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS_GENERIC + ) + cls.alias3_generic = loaded_models["aliases"]["alias3.yaml"] def test_get_all(self): - resp = self.app.get('/v1/actionalias') + resp = self.app.get("/v1/actionalias") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 4, '/v1/actionalias did not return all aliases.') - - retrieved_names = [alias['name'] for alias in resp.json] - - self.assertEqual(retrieved_names, [self.alias1.name, self.alias2.name, - 'alias_with_undefined_jinja_in_ack_format', - 'alias7'], - 'Incorrect aliases retrieved.') + self.assertEqual( + len(resp.json), 4, "/v1/actionalias did not return all aliases." + ) + + retrieved_names = [alias["name"] for alias in resp.json] + + self.assertEqual( + retrieved_names, + [ + self.alias1.name, + self.alias2.name, + "alias_with_undefined_jinja_in_ack_format", + "alias7", + ], + "Incorrect aliases retrieved.", + ) def test_get_all_query_param_filters(self): - resp = self.app.get('/v1/actionalias?pack=doesntexist') + resp = self.app.get("/v1/actionalias?pack=doesntexist") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/actionalias?pack=aliases') + resp = self.app.get("/v1/actionalias?pack=aliases") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 3) for alias_api in resp.json: - self.assertEqual(alias_api['pack'], 'aliases') + self.assertEqual(alias_api["pack"], "aliases") - resp = self.app.get('/v1/actionalias?pack=generic') + resp = self.app.get("/v1/actionalias?pack=generic") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) for alias_api in resp.json: - self.assertEqual(alias_api['pack'], 'generic') + self.assertEqual(alias_api["pack"], "generic") - resp = self.app.get('/v1/actionalias?name=doesntexist') + resp = self.app.get("/v1/actionalias?name=doesntexist") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/actionalias?name=alias2') + resp = self.app.get("/v1/actionalias?name=alias2") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'alias2') + self.assertEqual(resp.json[0]["name"], "alias2") def test_get_one(self): - resp = self.app.get('/v1/actionalias/%s' % self.alias1.id) + resp = self.app.get("/v1/actionalias/%s" % self.alias1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['name'], self.alias1.name, - 'Incorrect aliases retrieved.') + self.assertEqual( + resp.json["name"], self.alias1.name, "Incorrect aliases retrieved." + ) def test_post_delete(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id']) + get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"]) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['name'], self.alias3.name, - 'Incorrect aliases retrieved.') + self.assertEqual( + get_resp.json["name"], self.alias3.name, "Incorrect aliases retrieved." + ) - del_resp = self.__do_delete(post_resp.json['id']) + del_resp = self.__do_delete(post_resp.json["id"]) self.assertEqual(del_resp.status_int, 204) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id'], expect_errors=True) + get_resp = self.app.get( + "/v1/actionalias/%s" % post_resp.json["id"], expect_errors=True + ) self.assertEqual(get_resp.status_int, 404) def test_update_existing_alias(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['name'], self.alias3['name']) + self.assertEqual(post_resp.json["name"], self.alias3["name"]) data = vars(ActionAliasAPI.from_model(self.alias3)) - data['name'] = 'updated-alias-name' + data["name"] = "updated-alias-name" - put_resp = self.app.put_json('/v1/actionalias/%s' % post_resp.json['id'], data) - self.assertEqual(put_resp.json['name'], data['name']) + put_resp = self.app.put_json("/v1/actionalias/%s" % post_resp.json["id"], data) + self.assertEqual(put_resp.json["name"], data["name"]) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id']) - self.assertEqual(get_resp.json['name'], data['name']) + get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"]) + self.assertEqual(get_resp.json["name"], data["name"]) - del_resp = self.__do_delete(post_resp.json['id']) + del_resp = self.__do_delete(post_resp.json["id"]) self.assertEqual(del_resp.status_int, 204) def test_post_dup_name(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - post_resp_dup_name = self._do_post(vars(ActionAliasAPI.from_model(self.alias3_generic))) + post_resp_dup_name = self._do_post( + vars(ActionAliasAPI.from_model(self.alias3_generic)) + ) self.assertEqual(post_resp_dup_name.status_int, 201) - self.__do_delete(post_resp.json['id']) - self.__do_delete(post_resp_dup_name.json['id']) + self.__do_delete(post_resp.json["id"]) + self.__do_delete(post_resp_dup_name.json["id"]) def test_match(self): # No matching patterns - data = {'command': 'hello donny'} + data = {"command": "hello donny"} resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'hello donny' matched no patterns") + self.assertEqual( + str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns" + ) # More than one matching pattern - data = {'command': 'Lorem ipsum banana dolor sit pineapple amet.'} + data = {"command": "Lorem ipsum banana dolor sit pineapple amet."} resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'Lorem ipsum banana dolor sit pineapple amet.' " - "matched more than 1 pattern") + self.assertEqual( + str(resp.json["faultstring"]), + "Command 'Lorem ipsum banana dolor sit pineapple amet.' " + "matched more than 1 pattern", + ) # Single matching pattern - success - data = {'command': 'run whoami on localhost1'} + data = {"command": "run whoami on localhost1"} resp = self.app.post_json("/v1/actionalias/match", data) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['actionalias']['name'], - 'alias_with_undefined_jinja_in_ack_format') + self.assertEqual( + resp.json["actionalias"]["name"], "alias_with_undefined_jinja_in_ack_format" + ) def test_help(self): resp = self.app.get("/v1/actionalias/help") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json.get('available'), 5) + self.assertEqual(resp.json.get("available"), 5) def test_help_args(self): - resp = self.app.get("/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0") + resp = self.app.get( + "/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0" + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json.get('available'), 3) - self.assertEqual(len(resp.json.get('helpstrings')), 1) + self.assertEqual(resp.json.get("available"), 3) + self.assertEqual(len(resp.json.get("helpstrings")), 1) def _insert_mock_models(self): - alias_ids = [self.alias1['id'], self.alias2['id'], self.alias3['id'], - self.alias3_generic['id']] + alias_ids = [ + self.alias1["id"], + self.alias2["id"], + self.alias3["id"], + self.alias3_generic["id"], + ] return alias_ids def _delete_mock_models(self, object_ids): return None def _do_post(self, actionalias, expect_errors=False): - return self.app.post_json('/v1/actionalias', actionalias, expect_errors=expect_errors) + return self.app.post_json( + "/v1/actionalias", actionalias, expect_errors=expect_errors + ) def __do_delete(self, actionalias_id, expect_errors=False): - return self.app.delete('/v1/actionalias/%s' % actionalias_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actionalias/%s" % actionalias_id, expect_errors=expect_errors + ) diff --git a/st2api/tests/unit/controllers/v1/test_action_views.py b/st2api/tests/unit/controllers/v1/test_action_views.py index dbb9346662..a28219c04d 100644 --- a/st2api/tests/unit/controllers/v1/test_action_views.py +++ b/st2api/tests/unit/controllers/v1/test_action_views.py @@ -25,42 +25,44 @@ # ACTION_1: Good action definition. ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': 'test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_2: Good action definition. No content pack. ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': 'test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } -class ActionViewsOverviewControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/actions/views/overview' +class ActionViewsOverviewControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/actions/views/overview" controller_cls = OverviewController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "parameters" - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one(self): post_resp = self._do_post(ACTION_1) action_id = self._get_action_id(post_resp) @@ -71,8 +73,9 @@ def test_get_one(self): finally: self._do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_ref(self): post_resp = self._do_post(ACTION_1) action_id = self._get_action_id(post_resp) @@ -80,66 +83,85 @@ def test_get_one_ref(self): try: get_resp = self._do_get_one(action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['ref'], action_ref) + self.assertEqual(get_resp.json["ref"], action_ref) finally: self._do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_and_limit_minus_one(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview') + resp = self.app.get("/v1/actions/views/overview") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, - '/v1/actions/views/overview did not return all actions.') - resp = self.app.get('/v1/actions/views/overview/?limit=-1') + self.assertEqual( + len(resp.json), + 2, + "/v1/actions/views/overview did not return all actions.", + ) + resp = self.app.get("/v1/actions/views/overview/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, - '/v1/actions/views/overview did not return all actions.') + self.assertEqual( + len(resp.json), + 2, + "/v1/actions/views/overview did not return all actions.", + ) finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_negative_limit(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview/?limit=-22', expect_errors=True) + resp = self.app.get( + "/v1/actions/views/overview/?limit=-22", expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_filter_by_name(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview?name=%s' % str('st2.dummy.action2')) + resp = self.app.get( + "/v1/actions/views/overview?name=%s" % str("st2.dummy.action2") + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json[0]['id'], action_2_id, 'Filtering failed') + self.assertEqual(resp.json[0]["id"], action_2_id, "Filtering failed") finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_include_attributes_filter(self): - return super(ActionViewsOverviewControllerTestCase, self) \ - .test_get_all_include_attributes_filter() + return super( + ActionViewsOverviewControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_exclude_attributes_filter(self): - return super(ActionViewsOverviewControllerTestCase, self) \ - .test_get_all_include_attributes_filter() + return super( + ActionViewsOverviewControllerTestCase, self + ).test_get_all_include_attributes_filter() def _insert_mock_models(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) @@ -149,115 +171,141 @@ def _insert_mock_models(self): @staticmethod def _get_action_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def _get_action_ref(resp): - return '.'.join((resp.json['pack'], resp.json['name'])) + return ".".join((resp.json["pack"], resp.json["name"])) @staticmethod def _get_action_name(resp): - return resp.json['name'] + return resp.json["name"] def _do_get_one(self, action_id, expect_errors=False): - return self.app.get('/v1/actions/views/overview/%s' % action_id, - expect_errors=expect_errors) + return self.app.get( + "/v1/actions/views/overview/%s" % action_id, expect_errors=expect_errors + ) def _do_post(self, action, expect_errors=False): - return self.app.post_json('/v1/actions', action, expect_errors=expect_errors) + return self.app.post_json("/v1/actions", action, expect_errors=expect_errors) def _do_delete(self, action_id, expect_errors=False): - return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actions/%s" % action_id, expect_errors=expect_errors + ) class ActionViewsParametersControllerTestCase(FunctionalTest): - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] try: - get_resp = self.app.get('/v1/actions/views/parameters/%s' % action_id) + get_resp = self.app.get("/v1/actions/views/parameters/%s" % action_id) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) class ActionEntryPointViewControllerTestCase(FunctionalTest): - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_id) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_id) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file.yaml')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file.yaml"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_yaml_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.headers['Content-Type'], 'application/x-yaml') + self.assertEqual(get_resp.headers["Content-Type"], "application/x-yaml") finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value=__file__.replace('.pyc', '.py'))) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value=__file__.replace(".pyc", ".py")), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_python_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertIn(get_resp.headers['Content-Type'], ['application/x-python', - 'text/x-python']) + self.assertIn( + get_resp.headers["Content-Type"], + ["application/x-python", "text/x-python"], + ) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/file/does/not/exist')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/file/does/not/exist"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_text_plain_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.headers['Content-Type'], 'text/plain') + self.assertEqual(get_resp.headers["Content-Type"], "text/plain") finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) diff --git a/st2api/tests/unit/controllers/v1/test_actions.py b/st2api/tests/unit/controllers/v1/test_actions.py index c189803d2c..40e973c0a1 100644 --- a/st2api/tests/unit/controllers/v1/test_actions.py +++ b/st2api/tests/unit/controllers/v1/test_actions.py @@ -41,257 +41,259 @@ # ACTION_1: Good action definition. ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, }, - 'tags': [ - {'name': 'tag1', 'value': 'dont-care'}, - {'name': 'tag2', 'value': 'dont-care'} - ] + "tags": [ + {"name": "tag1", "value": "dont-care"}, + {"name": "tag2", "value": "dont-care"}, + ], } # ACTION_2: Good action definition. No content pack. ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } # ACTION_3: No enabled field ACTION_3 = { - 'name': 'st2.dummy.action3', - 'description': 'test description', - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action3", + "description": "test description", + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_4: Enabled field is False ACTION_4 = { - 'name': 'st2.dummy.action4', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action4", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_5: Invalid runner_type ACTION_5 = { - 'name': 'st2.dummy.action5', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'xyzxyz', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action5", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "xyzxyz", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_6: No description field. ACTION_6 = { - 'name': 'st2.dummy.action6', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action6", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_7: id field provided ACTION_7 = { - 'id': 'foobar', - 'name': 'st2.dummy.action7', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "id": "foobar", + "name": "st2.dummy.action7", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_8: id field provided ACTION_8 = { - 'name': 'st2.dummy.action8', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'cmd': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action8", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "cmd": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_9: Parameter dict has fields not part of JSONSchema spec. ACTION_9 = { - 'name': 'st2.dummy.action9', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1', 'dummyfield': True}, # dummyfield is invalid. - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action9", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": { + "type": "string", + "default": "A1", + "dummyfield": True, + }, # dummyfield is invalid. + "b": {"type": "string", "default": "B1"}, + }, } # Same name as ACTION_1. Different pack though. # Ensure that this remains the only action with pack == wolfpack1, # otherwise take care of the test test_get_one_using_pack_parameter ACTION_10 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # Good action with a system pack ACTION_11 = { - 'name': 'st2.dummy.action11', - 'pack': SYSTEM_PACK_NAME, - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action11", + "pack": SYSTEM_PACK_NAME, + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } # Good action inside dummy pack ACTION_12 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, }, - 'tags': [ - {'name': 'tag1', 'value': 'dont-care'}, - {'name': 'tag2', 'value': 'dont-care'} - ] + "tags": [ + {"name": "tag1", "value": "dont-care"}, + {"name": "tag2", "value": "dont-care"}, + ], } # Action with invalid parameter type attribute ACTION_13 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': ['string', 'object'], 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": ["string", "object"], "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } ACTION_14 = { - 'name': 'st2.dummy.action14', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'type': 'string'} - } + "name": "st2.dummy.action14", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"type": "string"}, + }, } ACTION_15 = { - 'name': 'st2.dummy.action15', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'default': True, 'immutable': True} - } + "name": "st2.dummy.action15", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"default": True, "immutable": True}, + }, } ACTION_WITH_NOTIFY = { - 'name': 'st2.dummy.action_notify_test', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'default': True, 'immutable': True} + "name": "st2.dummy.action_notify_test", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"default": True, "immutable": True}, }, - 'notify': { - 'on-complete': { - 'message': 'Woohoo! I completed!!!' - } - } + "notify": {"on-complete": {"message": "Woohoo! I completed!!!"}}, } -class ActionsControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase, - CleanFilesTestCase): - get_all_path = '/v1/actions' +class ActionsControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase, CleanFilesTestCase +): + get_all_path = "/v1/actions" controller_cls = ActionsController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "parameters" register_packs = True to_delete_files = [ - os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1/actions/filea.txt') + os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1/actions/filea.txt") ] - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_id(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) @@ -300,146 +302,169 @@ def test_get_one_using_id(self): self.assertEqual(self.__get_action_id(get_resp), action_id) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_ref(self): - ref = '.'.join([ACTION_1['pack'], ACTION_1['name']]) + ref = ".".join([ACTION_1["pack"], ACTION_1["name"]]) action_id = self.__get_action_id(self.__do_post(ACTION_1)) get_resp = self.__do_get_one(ref) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertEqual(get_resp.json['ref'], ref) + self.assertEqual(get_resp.json["ref"], ref) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_validate_params(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - expected_args = ACTION_1['parameters'] - self.assertEqual(get_resp.json['parameters'], expected_args) + expected_args = ACTION_1["parameters"] + self.assertEqual(get_resp.json["parameters"], expected_args) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_and_with_minus_one(self): - action_1_ref = '.'.join([ACTION_1['pack'], ACTION_1['name']]) + action_1_ref = ".".join([ACTION_1["pack"], ACTION_1["name"]]) action_1_id = self.__get_action_id(self.__do_post(ACTION_1)) action_2_id = self.__get_action_id(self.__do_post(ACTION_2)) - resp = self.app.get('/v1/actions') + resp = self.app.get("/v1/actions") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.") - item = [i for i in resp.json if i['id'] == action_1_id][0] - self.assertEqual(item['ref'], action_1_ref) + item = [i for i in resp.json if i["id"] == action_1_id][0] + self.assertEqual(item["ref"], action_1_ref) - resp = self.app.get('/v1/actions?limit=-1') + resp = self.app.get("/v1/actions?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.") - item = [i for i in resp.json if i['id'] == action_1_id][0] - self.assertEqual(item['ref'], action_1_ref) + item = [i for i in resp.json if i["id"] == action_1_id][0] + self.assertEqual(item["ref"], action_1_ref) self.__do_delete(action_1_id) self.__do_delete(action_2_id) - @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin', - mock.Mock(return_value=False)) + @mock.patch( + "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin", + mock.Mock(return_value=False), + ) def test_get_all_invalid_limit_too_large_none_admin(self): # limit > max_page_size, but user is not admin - resp = self.app.get('/v1/actions?limit=1000', expect_errors=True) + resp = self.app.get("/v1/actions?limit=1000", expect_errors=True) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], 'Limit "1000" specified, maximum value is' - ' "100"') + self.assertEqual( + resp.json["faultstring"], + 'Limit "1000" specified, maximum value is' ' "100"', + ) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/actions?limit=-22', expect_errors=True) + resp = self.app.get("/v1/actions?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') - - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) + + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_include_attributes_filter(self): - return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter() + return super( + ActionsControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_exclude_attributes_filter(self): - return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter() + return super( + ActionsControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_query(self): action_1_id = self.__get_action_id(self.__do_post(ACTION_1)) action_2_id = self.__get_action_id(self.__do_post(ACTION_2)) - resp = self.app.get('/v1/actions?name=%s' % ACTION_1['name']) + resp = self.app.get("/v1/actions?name=%s" % ACTION_1["name"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 1, "/v1/actions did not return all actions.") self.__do_delete(action_1_id) self.__do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_fail(self): - resp = self.app.get('/v1/actions/1', expect_errors=True) + resp = self.app.get("/v1/actions/1", expect_errors=True) self.assertEqual(resp.status_int, 404) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_delete(self): post_resp = self.__do_post(ACTION_1) self.assertEqual(post_resp.status_int, 201) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_action_with_bad_params(self): post_resp = self.__do_post(ACTION_9, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_no_description_field(self): post_resp = self.__do_post(ACTION_6) self.assertEqual(post_resp.status_int, 201) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_no_enable_field(self): post_resp = self.__do_post(ACTION_3) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'enabled', post_resp.body) + self.assertIn(b"enabled", post_resp.body) # If enabled field is not provided it should default to True data = json.loads(post_resp.body) - self.assertDictContainsSubset({'enabled': True}, data) + self.assertDictContainsSubset({"enabled": True}, data) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_false_enable_field(self): post_resp = self.__do_post(ACTION_4) self.assertEqual(post_resp.status_int, 201) data = json.loads(post_resp.body) - self.assertDictContainsSubset({'enabled': False}, data) + self.assertDictContainsSubset({"enabled": False}, data) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_name_unicode_action_already_exists(self): # Verify that exception messages containing unicode characters don't result in internal # server errors action = copy.deepcopy(ACTION_1) # NOTE: We explicitly don't prefix this string value with u"" - action['name'] = 'žactionćšžži💩' + action["name"] = "žactionćšžži💩" # 1. Initial creation post_resp = self.__do_post(action, expect_errors=True) @@ -448,54 +473,64 @@ def test_post_name_unicode_action_already_exists(self): # 2. Action already exists post_resp = self.__do_post(action, expect_errors=True) self.assertEqual(post_resp.status_int, 409) - self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring']) + self.assertIn( + "Tried to save duplicate unique keys", post_resp.json["faultstring"] + ) # 3. Action already exists (this time with unicode type) - action['name'] = u'žactionćšžži💩' + action["name"] = "žactionćšžži💩" post_resp = self.__do_post(action, expect_errors=True) self.assertEqual(post_resp.status_int, 409) - self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring']) + self.assertIn( + "Tried to save duplicate unique keys", post_resp.json["faultstring"] + ) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_parameter_type_is_array_and_invalid(self): post_resp = self.__do_post(ACTION_13, expect_errors=True) self.assertEqual(post_resp.status_int, 400) if six.PY3: - expected_error = b'[\'string\', \'object\'] is not valid under any of the given schemas' + expected_error = ( + b"['string', 'object'] is not valid under any of the given schemas" + ) else: - expected_error = \ - b'[u\'string\', u\'object\'] is not valid under any of the given schemas' + expected_error = ( + b"[u'string', u'object'] is not valid under any of the given schemas" + ) self.assertIn(expected_error, post_resp.body) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_discard_id_field(self): post_resp = self.__do_post(ACTION_7) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'id', post_resp.body) + self.assertIn(b"id", post_resp.body) data = json.loads(post_resp.body) # Verify that user-provided id is discarded. - self.assertNotEquals(data['id'], ACTION_7['id']) + self.assertNotEquals(data["id"], ACTION_7["id"]) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_duplicate(self): action_ids = [] post_resp = self.__do_post(ACTION_1) self.assertEqual(post_resp.status_int, 201) - action_in_db = Action.get_by_name(ACTION_1.get('name')) - self.assertIsNotNone(action_in_db, 'Action must be in db.') + action_in_db = Action.get_by_name(ACTION_1.get("name")) + self.assertIsNotNone(action_in_db, "Action must be in db.") action_ids.append(self.__get_action_id(post_resp)) post_resp = self.__do_post(ACTION_1, expect_errors=True) # Verify name conflict self.assertEqual(post_resp.status_int, 409) - self.assertEqual(post_resp.json['conflict-id'], action_ids[0]) + self.assertEqual(post_resp.json["conflict-id"], action_ids[0]) post_resp = self.__do_post(ACTION_10) action_ids.append(self.__get_action_id(post_resp)) @@ -505,20 +540,16 @@ def test_post_duplicate(self): for i in action_ids: self.__do_delete(i) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_include_files(self): # Verify initial state - pack_db = Pack.get_by_ref(ACTION_12['pack']) - self.assertNotIn('actions/filea.txt', pack_db.files) + pack_db = Pack.get_by_ref(ACTION_12["pack"]) + self.assertNotIn("actions/filea.txt", pack_db.files) action = copy.deepcopy(ACTION_12) - action['data_files'] = [ - { - 'file_path': 'filea.txt', - 'content': 'test content' - } - ] + action["data_files"] = [{"file_path": "filea.txt", "content": "test content"}] post_resp = self.__do_post(action) # Verify file has been written on disk @@ -526,29 +557,30 @@ def test_post_include_files(self): self.assertTrue(os.path.exists(file_path)) # Verify PackDB.files has been updated - pack_db = Pack.get_by_ref(ACTION_12['pack']) - self.assertIn('actions/filea.txt', pack_db.files) + pack_db = Pack.get_by_ref(ACTION_12["pack"]) + self.assertIn("actions/filea.txt", pack_db.files) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_put_delete(self): action = copy.copy(ACTION_1) post_resp = self.__do_post(action) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'id', post_resp.body) + self.assertIn(b"id", post_resp.body) body = json.loads(post_resp.body) - action['id'] = body['id'] - action['description'] = 'some other test description' - pack = action['pack'] - del action['pack'] - self.assertNotIn('pack', action) - put_resp = self.__do_put(action['id'], action) + action["id"] = body["id"] + action["description"] = "some other test description" + pack = action["pack"] + del action["pack"] + self.assertNotIn("pack", action) + put_resp = self.__do_put(action["id"], action) self.assertEqual(put_resp.status_int, 200) - self.assertIn(b'description', put_resp.body) + self.assertIn(b"description", put_resp.body) body = json.loads(put_resp.body) - self.assertEqual(body['description'], action['description']) - self.assertEqual(body['pack'], pack) + self.assertEqual(body["description"], action["description"]) + self.assertEqual(body["pack"], pack) delete_resp = self.__do_delete(self.__get_action_id(post_resp)) self.assertEqual(delete_resp.status_int, 204) @@ -559,94 +591,107 @@ def test_post_invalid_runner_type(self): def test_post_override_runner_param_not_allowed(self): post_resp = self.__do_post(ACTION_14, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - expected = ('The attribute "type" for the runner parameter "sudo" ' - 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.') - self.assertEqual(post_resp.json.get('faultstring'), expected) + expected = ( + 'The attribute "type" for the runner parameter "sudo" ' + 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.' + ) + self.assertEqual(post_resp.json.get("faultstring"), expected) def test_post_override_runner_param_allowed(self): post_resp = self.__do_post(ACTION_15) self.assertEqual(post_resp.status_int, 201) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_delete(self): post_resp = self.__do_post(ACTION_1) del_resp = self.__do_delete(self.__get_action_id(post_resp)) self.assertEqual(del_resp.status_int, 204) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_action_with_tags(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertEqual(get_resp.json['tags'], ACTION_1['tags']) + self.assertEqual(get_resp.json["tags"], ACTION_1["tags"]) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_action_with_notify_update(self): post_resp = self.__do_post(ACTION_WITH_NOTIFY) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertIsNotNone(get_resp.json['notify']['on-complete']) + self.assertIsNotNone(get_resp.json["notify"]["on-complete"]) # Now post the same action with no notify ACTION_WITHOUT_NOTIFY = copy.copy(ACTION_WITH_NOTIFY) - del ACTION_WITHOUT_NOTIFY['notify'] + del ACTION_WITHOUT_NOTIFY["notify"] self.__do_put(action_id, ACTION_WITHOUT_NOTIFY) # Validate that notify section has vanished get_resp = self.__do_get_one(action_id) - self.assertEqual(get_resp.json['notify'], {}) + self.assertEqual(get_resp.json["notify"], {}) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_name_parameter(self): action_id, action_name = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_1), 'name') - get_resp = self.__do_get_actions_by_url_parameter('name', action_name) + self.__do_post(ACTION_1), "name" + ) + get_resp = self.__do_get_actions_by_url_parameter("name", action_name) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['name'], action_name) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["name"], action_name) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_pack_parameter(self): action_id, action_pack = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_10), 'pack') - get_resp = self.__do_get_actions_by_url_parameter('pack', action_pack) + self.__do_post(ACTION_10), "pack" + ) + get_resp = self.__do_get_actions_by_url_parameter("pack", action_pack) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['pack'], action_pack) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["pack"], action_pack) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_tag_parameter(self): action_id, action_tags = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_1), 'tags') - get_resp = self.__do_get_actions_by_url_parameter('tags', action_tags[0]['name']) + self.__do_post(ACTION_1), "tags" + ) + get_resp = self.__do_get_actions_by_url_parameter( + "tags", action_tags[0]["name"] + ) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['tags'], action_tags) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["tags"], action_tags) self.__do_delete(action_id) # TODO: Re-enable those tests after we ensure DB is flushed in setUp # and each test starts in a clean state - @unittest2.skip('Skip because of test polution') + @unittest2.skip("Skip because of test polution") def test_update_action_belonging_to_system_pack(self): post_resp = self.__do_post(ACTION_11) action_id = self.__get_action_id(post_resp) put_resp = self.__do_put(action_id, ACTION_11, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - @unittest2.skip('Skip because of test polution') + @unittest2.skip("Skip because of test polution") def test_delete_action_belonging_to_system_pack(self): post_resp = self.__do_post(ACTION_11) action_id = self.__get_action_id(post_resp) @@ -664,31 +709,37 @@ def _do_delete(self, action_id, expect_errors=False): @staticmethod def __get_action_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def __get_action_name(resp): - return resp.json['name'] + return resp.json["name"] @staticmethod def __get_action_tags(resp): - return resp.json['tags'] + return resp.json["tags"] @staticmethod def __get_action_id_and_additional_attribute(resp, attribute): - return resp.json['id'], resp.json[attribute] + return resp.json["id"], resp.json[attribute] def __do_get_one(self, action_id, expect_errors=False): - return self.app.get('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.get("/v1/actions/%s" % action_id, expect_errors=expect_errors) def __do_get_actions_by_url_parameter(self, filter, value, expect_errors=False): - return self.app.get('/v1/actions?%s=%s' % (filter, value), expect_errors=expect_errors) + return self.app.get( + "/v1/actions?%s=%s" % (filter, value), expect_errors=expect_errors + ) def __do_post(self, action, expect_errors=False): - return self.app.post_json('/v1/actions', action, expect_errors=expect_errors) + return self.app.post_json("/v1/actions", action, expect_errors=expect_errors) def __do_put(self, action_id, action, expect_errors=False): - return self.app.put_json('/v1/actions/%s' % action_id, action, expect_errors=expect_errors) + return self.app.put_json( + "/v1/actions/%s" % action_id, action, expect_errors=expect_errors + ) def __do_delete(self, action_id, expect_errors=False): - return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actions/%s" % action_id, expect_errors=expect_errors + ) diff --git a/st2api/tests/unit/controllers/v1/test_alias_execution.py b/st2api/tests/unit/controllers/v1/test_alias_execution.py index e7b1827f31..9806a46864 100644 --- a/st2api/tests/unit/controllers/v1/test_alias_execution.py +++ b/st2api/tests/unit/controllers/v1/test_alias_execution.py @@ -24,29 +24,32 @@ from st2tests.fixturesloader import FixturesLoader from st2tests.api import FunctionalTest -FIXTURES_PACK = 'aliases' +FIXTURES_PACK = "aliases" TEST_MODELS = { - 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml', - 'alias_with_immutable_list_param.yaml', - 'alias_with_immutable_list_param_str_cast.yaml', - 'alias4.yaml', 'alias5.yaml', 'alias_fixes1.yaml', 'alias_fixes2.yaml', - 'alias_match_multiple.yaml'], - 'actions': ['action1.yaml', 'action2.yaml', 'action3.yaml', 'action4.yaml'], - 'runners': ['runner1.yaml'] + "aliases": [ + "alias1.yaml", + "alias2.yaml", + "alias_with_undefined_jinja_in_ack_format.yaml", + "alias_with_immutable_list_param.yaml", + "alias_with_immutable_list_param_str_cast.yaml", + "alias4.yaml", + "alias5.yaml", + "alias_fixes1.yaml", + "alias_fixes2.yaml", + "alias_match_multiple.yaml", + ], + "actions": ["action1.yaml", "action2.yaml", "action3.yaml", "action4.yaml"], + "runners": ["runner1.yaml"], } -TEST_LOAD_MODELS = { - 'aliases': ['alias3.yaml'] -} +TEST_LOAD_MODELS = {"aliases": ["alias3.yaml"]} -EXECUTION = ActionExecutionDB(id='54e657d60640fd16887d6855', - status=LIVEACTION_STATUS_SUCCEEDED, - result='') +EXECUTION = ActionExecutionDB( + id="54e657d60640fd16887d6855", status=LIVEACTION_STATUS_SUCCEEDED, result="" +) -__all__ = [ - 'AliasExecutionTestCase' -] +__all__ = ["AliasExecutionTestCase"] class AliasExecutionTestCase(FunctionalTest): @@ -59,193 +62,217 @@ class AliasExecutionTestCase(FunctionalTest): @classmethod def setUpClass(cls): super(AliasExecutionTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - - cls.runner1 = cls.models['runners']['runner1.yaml'] - cls.action1 = cls.models['actions']['action1.yaml'] - cls.alias1 = cls.models['aliases']['alias1.yaml'] - cls.alias2 = cls.models['aliases']['alias2.yaml'] - cls.alias4 = cls.models['aliases']['alias4.yaml'] - cls.alias5 = cls.models['aliases']['alias5.yaml'] - cls.alias_with_undefined_jinja_in_ack_format = \ - cls.models['aliases']['alias_with_undefined_jinja_in_ack_format.yaml'] - - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + + cls.runner1 = cls.models["runners"]["runner1.yaml"] + cls.action1 = cls.models["actions"]["action1.yaml"] + cls.alias1 = cls.models["aliases"]["alias1.yaml"] + cls.alias2 = cls.models["aliases"]["alias2.yaml"] + cls.alias4 = cls.models["aliases"]["alias4.yaml"] + cls.alias5 = cls.models["aliases"]["alias5.yaml"] + cls.alias_with_undefined_jinja_in_ack_format = cls.models["aliases"][ + "alias_with_undefined_jinja_in_ack_format.yaml" + ] + + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_basic_execution(self, request): command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.' post_resp = self._do_post(alias_execution=self.alias1, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param2': 'value2 value3'} + expected_parameters = {"param1": "value1", "param2": "value2 value3"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_basic_execution_with_immutable_parameters(self, request): - command = 'lorem ipsum' + command = "lorem ipsum" post_resp = self._do_post(alias_execution=self.alias5, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param2': 'value2'} + expected_parameters = {"param1": "value1", "param2": "value2"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_invalid_format_string_referenced_in_request(self, request): command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.' - format_str = 'some invalid not supported string' - post_resp = self._do_post(alias_execution=self.alias1, command=command, - format_str=format_str, expect_errors=True) + format_str = "some invalid not supported string" + post_resp = self._do_post( + alias_execution=self.alias1, + command=command, + format_str=format_str, + expect_errors=True, + ) self.assertEqual(post_resp.status_int, 400) - expected_msg = ('Format string "some invalid not supported string" is ' - 'not available on the alias "alias1"') - self.assertIn(expected_msg, post_resp.json['faultstring']) + expected_msg = ( + 'Format string "some invalid not supported string" is ' + 'not available on the alias "alias1"' + ) + self.assertIn(expected_msg, post_resp.json["faultstring"]) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_execution_with_array_type_single_value(self, request): - command = 'Lorem ipsum value1 dolor sit value2 amet.' + command = "Lorem ipsum value1 dolor sit value2 amet." self._do_post(alias_execution=self.alias2, command=command) - expected_parameters = {'param1': 'value1', 'param3': ['value2']} + expected_parameters = {"param1": "value1", "param3": ["value2"]} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_execution_with_array_type_multi_value(self, request): command = 'Lorem ipsum value1 dolor sit "value2, value3" amet.' post_resp = self._do_post(alias_execution=self.alias2, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param3': ['value2', 'value3']} + expected_parameters = {"param1": "value1", "param3": ["value2", "value3"]} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_invalid_jinja_var_in_ack_format(self, request): - command = 'run date on localhost' + command = "run date on localhost" # print(self.alias_with_undefined_jinja_in_ack_format) post_resp = self._do_post( alias_execution=self.alias_with_undefined_jinja_in_ack_format, command=command, - expect_errors=False + expect_errors=False, ) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'cmd': 'date', 'hosts': 'localhost'} + expected_parameters = {"cmd": "date", "hosts": "localhost"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) self.assertEqual( - post_resp.json['message'], - 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined' + post_resp.json["message"], + 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined', ) - @mock.patch.object(action_service, 'request') + @mock.patch.object(action_service, "request") def test_execution_secret_parameter(self, request): - execution = ActionExecutionDB(id='54e657d60640fd16887d6855', - status=LIVEACTION_STATUS_SUCCEEDED, - action={'parameters': self.action1.parameters}, - runner={'runner_parameters': self.runner1.runner_parameters}, - parameters={ - 'param4': SUPER_SECRET_PARAMETER - }, - result='') + execution = ActionExecutionDB( + id="54e657d60640fd16887d6855", + status=LIVEACTION_STATUS_SUCCEEDED, + action={"parameters": self.action1.parameters}, + runner={"runner_parameters": self.runner1.runner_parameters}, + parameters={"param4": SUPER_SECRET_PARAMETER}, + result="", + ) request.return_value = (None, execution) - command = 'Lorem ipsum value1 dolor sit ' + SUPER_SECRET_PARAMETER + ' amet.' + command = "Lorem ipsum value1 dolor sit " + SUPER_SECRET_PARAMETER + " amet." post_resp = self._do_post(alias_execution=self.alias4, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param4': SUPER_SECRET_PARAMETER} + expected_parameters = {"param1": "value1", "param4": SUPER_SECRET_PARAMETER} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - post_resp = self._do_post(alias_execution=self.alias4, command=command, show_secrets=True, - expect_errors=True) + post_resp = self._do_post( + alias_execution=self.alias4, + command=command, + show_secrets=True, + expect_errors=True, + ) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['execution']['parameters']['param4'], - SUPER_SECRET_PARAMETER) + self.assertEqual( + post_resp.json["execution"]["parameters"]["param4"], SUPER_SECRET_PARAMETER + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_doesnt_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command doesnt match any patterns data = copy.deepcopy(base_data) - data['command'] = 'hello donny' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "hello donny" + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'hello donny' matched no patterns") + self.assertEqual( + str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns" + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_many(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches more than one pattern data = copy.deepcopy(base_data) - data['command'] = 'Lorem ipsum banana dolor sit pineapple amet.' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "Lorem ipsum banana dolor sit pineapple amet." + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'Lorem ipsum banana dolor sit pineapple amet.' " - "matched more than 1 pattern") + self.assertEqual( + str(resp.json["faultstring"]), + "Command 'Lorem ipsum banana dolor sit pineapple amet.' " + "matched more than 1 pattern", + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_one(self, mock_request): base_data = { - 'source_channel': 'chat-channel', - 'notification_route': 'hubot', - 'user': 'chat-user', + "source_channel": "chat-channel", + "notification_route": "hubot", + "user": "chat-user", } # Command matches - should result in action execution data = copy.deepcopy(base_data) - data['command'] = 'run date on localhost' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data) + data["command"] = "run date on localhost" + resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) - self.assertEqual(len(resp.json['results']), 1) - self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status']) + self.assertEqual(len(resp.json["results"]), 1) + self.assertEqual( + resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][0]["execution"]["status"], EXECUTION["status"] + ) - expected_parameters = {'cmd': 'date', 'hosts': 'localhost'} + expected_parameters = {"cmd": "date", "hosts": "localhost"} self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters) # Also check for source_channel - see # https://github.com/StackStorm/st2/issues/4650 actual_context = mock_request.call_args[0][0].context - self.assertIn('source_channel', mock_request.call_args[0][0].context.keys()) - self.assertEqual(actual_context['source_channel'], 'chat-channel') - self.assertEqual(actual_context['api_user'], 'chat-user') - self.assertEqual(actual_context['user'], 'stanley') + self.assertIn("source_channel", mock_request.call_args[0][0].context.keys()) + self.assertEqual(actual_context["source_channel"], "chat-channel") + self.assertEqual(actual_context["api_user"], "chat-user") + self.assertEqual(actual_context["user"], "stanley") - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_one_multiple_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches multiple times - should result in multiple action execution data = copy.deepcopy(base_data) - data['command'] = ('JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which ' - 'is a duplicate of DRSEUSS-12') - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data) + data["command"] = ( + "JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which " + "is a duplicate of DRSEUSS-12" + ) + resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) - self.assertEqual(len(resp.json['results']), 2) - self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status']) - self.assertEqual(resp.json['results'][1]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][1]['execution']['status'], EXECUTION['status']) + self.assertEqual(len(resp.json["results"]), 2) + self.assertEqual( + resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][0]["execution"]["status"], EXECUTION["status"] + ) + self.assertEqual( + resp.json["results"][1]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][1]["execution"]["status"], EXECUTION["status"] + ) # The mock object only stores the parameters of the _last_ time it was called, so that's # what we assert on. Luckily re.finditer() processes groups in order, so if this was the @@ -255,34 +282,39 @@ def test_match_and_execute_matches_one_multiple_match(self, mock_request): # # We've also already checked the results array # - expected_parameters = {'issue_key': 'DRSEUSS-12'} + expected_parameters = {"issue_key": "DRSEUSS-12"} self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_many_multiple_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches multiple times - should result in multiple action execution data = copy.deepcopy(base_data) - data['command'] = 'JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12" + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command '{command}' " - "matched more than 1 (multi) pattern".format(command=data['command'])) + self.assertEqual( + str(resp.json["faultstring"]), + "Command '{command}' " + "matched more than 1 (multi) pattern".format(command=data["command"]), + ) def test_match_and_execute_list_action_param_str_cast_to_list(self): data = { - 'command': 'test alias list param str cast', - 'source_channel': 'hubot', - 'user': 'foo', + "command": "test alias list param str cast", + "source_channel": "hubot", + "user": "foo", } - resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True) + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) # Param is a comma delimited string - our custom cast function should cast it to a list. # I assume that was done to make specifying complex params in chat easier. @@ -300,15 +332,19 @@ def test_match_and_execute_list_action_param_str_cast_to_list(self): self.assertEqual(live_action["parameters"]["array_param"][1], "two") self.assertEqual(live_action["parameters"]["array_param"][2], "three") self.assertEqual(live_action["parameters"]["array_param"][3], "four") - self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], str)) + self.assertTrue( + isinstance(action_alias["immutable_parameters"]["array_param"], str) + ) def test_match_and_execute_list_action_param_already_a_list(self): data = { - 'command': 'test alias foo', - 'source_channel': 'hubot', - 'user': 'foo', + "command": "test alias foo", + "source_channel": "hubot", + "user": "foo", } - resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True) + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) # immutable_param is already a list - verify no casting is performed self.assertEqual(resp.status_int, 201) @@ -323,37 +359,53 @@ def test_match_and_execute_list_action_param_already_a_list(self): self.assertEqual(live_action["parameters"]["array_param"][0]["key2"], "two") self.assertEqual(live_action["parameters"]["array_param"][1]["key3"], "three") self.assertEqual(live_action["parameters"]["array_param"][1]["key4"], "four") - self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], list)) + self.assertTrue( + isinstance(action_alias["immutable_parameters"]["array_param"], list) + ) def test_match_and_execute_success(self): data = { - 'command': 'run whoami on localhost1', - 'source_channel': 'hubot', - 'user': "user", + "command": "run whoami on localhost1", + "source_channel": "hubot", + "user": "user", } resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) self.assertEqual(len(resp.json["results"]), 1) - self.assertTrue(resp.json["results"][0]["actionalias"]["ref"], - "aliases.alias_with_undefined_jinja_in_ack_format") - - def _do_post(self, alias_execution, command, format_str=None, expect_errors=False, - show_secrets=False): - if (isinstance(alias_execution.formats[0], dict) and - alias_execution.formats[0].get('representation')): - representation = alias_execution.formats[0].get('representation')[0] + self.assertTrue( + resp.json["results"][0]["actionalias"]["ref"], + "aliases.alias_with_undefined_jinja_in_ack_format", + ) + + def _do_post( + self, + alias_execution, + command, + format_str=None, + expect_errors=False, + show_secrets=False, + ): + if isinstance(alias_execution.formats[0], dict) and alias_execution.formats[ + 0 + ].get("representation"): + representation = alias_execution.formats[0].get("representation")[0] else: representation = alias_execution.formats[0] if not format_str: format_str = representation - execution = {'name': alias_execution.name, - 'format': format_str, - 'command': command, - 'user': 'stanley', - 'source_channel': 'test', - 'notification_route': 'test'} - url = show_secrets and '/v1/aliasexecution?show_secrets=true' or '/v1/aliasexecution' - return self.app.post_json(url, execution, - expect_errors=expect_errors) + execution = { + "name": alias_execution.name, + "format": format_str, + "command": command, + "user": "stanley", + "source_channel": "test", + "notification_route": "test", + } + url = ( + show_secrets + and "/v1/aliasexecution?show_secrets=true" + or "/v1/aliasexecution" + ) + return self.app.post_json(url, execution, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_auth.py b/st2api/tests/unit/controllers/v1/test_auth.py index fb5a203929..d6f3602c3c 100644 --- a/st2api/tests/unit/controllers/v1/test_auth.py +++ b/st2api/tests/unit/controllers/v1/test_auth.py @@ -27,7 +27,7 @@ from st2tests.fixturesloader import FixturesLoader OBJ_ID = bson.ObjectId() -USER = 'stanley' +USER = "stanley" USER_DB = UserDB(name=USER) TOKEN = uuid.uuid4().hex NOW = date_utils.get_datetime_utc_now() @@ -40,67 +40,84 @@ class TestTokenBasedAuth(FunctionalTest): enable_auth = True @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_headers(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_query_params(self): - response = self.app.get('/v1/actions?x-auth-token=%s' % (TOKEN), expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions?x-auth-token=%s" % (TOKEN), expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_cookies(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - with mock.patch.object(self.app.cookiejar, 'clear', return_value=None): - response = self.app.get('/v1/actions', expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + with mock.patch.object(self.app.cookiejar, "clear", return_value=None): + response = self.app.get("/v1/actions", expect_errors=False) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST))) + Token, + "get", + mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST)), + ) def test_token_expired(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - @mock.patch.object( - Token, 'get', mock.MagicMock(side_effect=TokenNotFoundError())) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=TokenNotFoundError())) def test_token_not_found(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) def test_token_not_provided(self): - response = self.app.get('/v1/actions', expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get("/v1/actions", expect_errors=True) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -TEST_MODELS = { - 'apikeys': ['apikey1.yaml', 'apikey_disabled.yaml'] -} +TEST_MODELS = {"apikeys": ["apikey1.yaml", "apikey_disabled.yaml"]} # Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures. KEY1_KEY = "1234" @@ -117,62 +134,83 @@ class TestApiKeyBasedAuth(FunctionalTest): @classmethod def setUpClass(cls): super(TestApiKeyBasedAuth, cls).setUpClass() - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.apikey1 = models['apikeys']['apikey1.yaml'] - cls.apikey_disabled = models['apikeys']['apikey_disabled.yaml'] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.apikey1 = models["apikeys"]["apikey1.yaml"] + cls.apikey_disabled = models["apikeys"]["apikey_disabled.yaml"] - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_headers(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_query_params(self): - response = self.app.get('/v1/actions?st2-api-key=%s' % (KEY1_KEY), expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions?st2-api-key=%s" % (KEY1_KEY), expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_cookies(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - with mock.patch.object(self.app.cookiejar, 'clear', return_value=None): - response = self.app.get('/v1/actions', expect_errors=True) + with mock.patch.object(self.app.cookiejar, "clear", return_value=None): + response = self.app.get("/v1/actions", expect_errors=True) self.assertEqual(response.status_int, 401) - self.assertEqual(response.json_body['faultstring'], - 'Unauthorized - One of Token or API key required.') + self.assertEqual( + response.json_body["faultstring"], + "Unauthorized - One of Token or API key required.", + ) def test_apikey_disabled(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': DISABLED_KEY}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": DISABLED_KEY}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - self.assertEqual(response.json_body['faultstring'], 'Unauthorized - API key is disabled.') + self.assertEqual( + response.json_body["faultstring"], "Unauthorized - API key is disabled." + ) def test_apikey_not_found(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': 'UNKNOWN'}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": "UNKNOWN"}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - self.assertRegexpMatches(response.json_body['faultstring'], - '^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$') + self.assertRegexpMatches( + response.json_body["faultstring"], + "^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$", + ) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) @mock.patch.object( - ApiKey, 'get', - mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + ApiKey, + "get", + mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True)), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_multiple_auth_sources(self): - response = self.app.get('/v1/actions', - headers={'X-Auth-Token': TOKEN, 'St2-Api-key': KEY1_KEY}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", + headers={"X-Auth-Token": TOKEN, "St2-Api-key": KEY1_KEY}, + expect_errors=True, + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) diff --git a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py index c172b22445..bf76d41276 100644 --- a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py +++ b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py @@ -22,11 +22,16 @@ from st2tests.fixturesloader import FixturesLoader from st2tests.api import FunctionalTest -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'apikeys': ['apikey1.yaml', 'apikey2.yaml', 'apikey3.yaml', 'apikey_disabled.yaml', - 'apikey_malformed.yaml'] + "apikeys": [ + "apikey1.yaml", + "apikey2.yaml", + "apikey3.yaml", + "apikey_disabled.yaml", + "apikey_malformed.yaml", + ] } # Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures. @@ -45,205 +50,239 @@ class TestApiKeyController(FunctionalTest): def setUpClass(cls): super(TestApiKeyController, cls).setUpClass() - cfg.CONF.set_override(name='mask_secrets', override=True, group='api') - cfg.CONF.set_override(name='mask_secrets', override=True, group='log') + cfg.CONF.set_override(name="mask_secrets", override=True, group="api") + cfg.CONF.set_override(name="mask_secrets", override=True, group="log") - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.apikey1 = models['apikeys']['apikey1.yaml'] - cls.apikey2 = models['apikeys']['apikey2.yaml'] - cls.apikey3 = models['apikeys']['apikey3.yaml'] - cls.apikey4 = models['apikeys']['apikey_disabled.yaml'] - cls.apikey5 = models['apikeys']['apikey_malformed.yaml'] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.apikey1 = models["apikeys"]["apikey1.yaml"] + cls.apikey2 = models["apikeys"]["apikey2.yaml"] + cls.apikey3 = models["apikeys"]["apikey3.yaml"] + cls.apikey4 = models["apikeys"]["apikey_disabled.yaml"] + cls.apikey5 = models["apikeys"]["apikey_malformed.yaml"] def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/apikeys') + resp = self.app.get("/v1/apikeys") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "50") - self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.') - - retrieved_ids = [apikey['id'] for apikey in resp.json] - self.assertEqual(retrieved_ids, - [str(self.apikey1.id), str(self.apikey2.id), str(self.apikey3.id), - str(self.apikey4.id), str(self.apikey5.id)], - 'Incorrect api keys retrieved.') - - resp = self.app.get('/v1/apikeys/?limit=-1') + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "50") + self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.") + + retrieved_ids = [apikey["id"] for apikey in resp.json] + self.assertEqual( + retrieved_ids, + [ + str(self.apikey1.id), + str(self.apikey2.id), + str(self.apikey3.id), + str(self.apikey4.id), + str(self.apikey5.id), + ], + "Incorrect api keys retrieved.", + ) + + resp = self.app.get("/v1/apikeys/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.') + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.") def test_get_all_with_pagnination_with_offset_and_limit(self): - resp = self.app.get('/v1/apikeys?offset=2&limit=1') + resp = self.app.get("/v1/apikeys?offset=2&limit=1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "1") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "1") self.assertEqual(len(resp.json), 1) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey3.id)]) def test_get_all_with_pagnination_with_only_offset(self): - resp = self.app.get('/v1/apikeys?offset=3') + resp = self.app.get("/v1/apikeys?offset=3") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "50") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "50") self.assertEqual(len(resp.json), 2) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey4.id), str(self.apikey5.id)]) def test_get_all_with_pagnination_with_only_limit(self): - resp = self.app.get('/v1/apikeys?limit=2') + resp = self.app.get("/v1/apikeys?limit=2") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "2") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "2") self.assertEqual(len(resp.json), 2) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey1.id), str(self.apikey2.id)]) - @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin', - mock.Mock(return_value=False)) + @mock.patch( + "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin", + mock.Mock(return_value=False), + ) def test_get_all_invalid_limit_too_large_none_admin(self): # limit > max_page_size, but user is not admin - resp = self.app.get('/v1/apikeys?offset=2&limit=1000', expect_errors=True) + resp = self.app.get("/v1/apikeys?offset=2&limit=1000", expect_errors=True) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], - 'Limit "1000" specified, maximum value is "100"') + self.assertEqual( + resp.json["faultstring"], 'Limit "1000" specified, maximum value is "100"' + ) def test_get_all_invalid_limit_negative_integer(self): - resp = self.app.get('/v1/apikeys?offset=2&limit=-22', expect_errors=True) + resp = self.app.get("/v1/apikeys?offset=2&limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_invalid_offset_too_large(self): - offset = '2141564789454123457895412237483648' - resp = self.app.get('/v1/apikeys?offset=%s&limit=1' % (offset), expect_errors=True) + offset = "2141564789454123457895412237483648" + resp = self.app.get( + "/v1/apikeys?offset=%s&limit=1" % (offset), expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Offset "%s" specified is more than 32 bit int' % (offset)) + self.assertEqual( + resp.json["faultstring"], + 'Offset "%s" specified is more than 32 bit int' % (offset), + ) def test_get_one_by_id(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey1.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) def test_get_one_by_key(self): # key1 - resp = self.app.get('/v1/apikeys/%s' % KEY1_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY1_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey1.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) # key2 - resp = self.app.get('/v1/apikeys/%s' % KEY2_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY2_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey2.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey2.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) # key3 - resp = self.app.get('/v1/apikeys/%s' % KEY3_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY3_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey3.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey3.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) def test_get_show_secrets(self): - resp = self.app.get('/v1/apikeys?show_secrets=True', expect_errors=True) + resp = self.app.get("/v1/apikeys?show_secrets=True", expect_errors=True) self.assertEqual(resp.status_int, 200) for key in resp.json: - self.assertNotEqual(key['key_hash'], MASKED_ATTRIBUTE_VALUE) - self.assertNotEqual(key['uid'], MASKED_ATTRIBUTE_VALUE) + self.assertNotEqual(key["key_hash"], MASKED_ATTRIBUTE_VALUE) + self.assertNotEqual(key["uid"], MASKED_ATTRIBUTE_VALUE) def test_post_delete_key(self): - api_key = { - 'user': 'herge' - } - resp1 = self.app.post_json('/v1/apikeys', api_key) + api_key = {"user": "herge"} + resp1 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp1.status_int, 201) - self.assertTrue(resp1.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp1.json['key'], MASKED_ATTRIBUTE_VALUE, - 'Key should not be masked.') + self.assertTrue(resp1.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp1.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) # should lead to creation of another key - resp2 = self.app.post_json('/v1/apikeys', api_key) + resp2 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp2.status_int, 201) - self.assertTrue(resp2.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp2.json['key'], MASKED_ATTRIBUTE_VALUE, 'Key should not be masked.') - self.assertNotEqual(resp1.json['key'], resp2.json['key'], 'Should be different') + self.assertTrue(resp2.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp2.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) + self.assertNotEqual(resp1.json["key"], resp2.json["key"], "Should be different") - resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id']) + resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"]) self.assertEqual(resp.status_int, 204) - resp = self.app.delete('/v1/apikeys/%s' % resp2.json['key']) + resp = self.app.delete("/v1/apikeys/%s" % resp2.json["key"]) self.assertEqual(resp.status_int, 204) # With auth disabled, use system_user - resp3 = self.app.post_json('/v1/apikeys', {}) + resp3 = self.app.post_json("/v1/apikeys", {}) self.assertEqual(resp3.status_int, 201) - self.assertTrue(resp3.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp3.json['key'], MASKED_ATTRIBUTE_VALUE, - 'Key should not be masked.') - self.assertTrue(resp3.json['user'], cfg.CONF.system_user.user) + self.assertTrue(resp3.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp3.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) + self.assertTrue(resp3.json["user"], cfg.CONF.system_user.user) def test_post_delete_same_key_hash(self): api_key = { - 'id': '5c5dbb576cb8de06a2d79a4d', - 'user': 'herge', - 'key_hash': 'ABCDE' + "id": "5c5dbb576cb8de06a2d79a4d", + "user": "herge", + "key_hash": "ABCDE", } - resp1 = self.app.post_json('/v1/apikeys', api_key) + resp1 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp1.status_int, 201) - self.assertEqual(resp1.json['key'], None, 'Key should be None.') + self.assertEqual(resp1.json["key"], None, "Key should be None.") # drop into the DB since API will be masking this value. - api_key_db = ApiKey.get_by_id(resp1.json['id']) + api_key_db = ApiKey.get_by_id(resp1.json["id"]) - self.assertEqual(resp1.json['id'], api_key['id'], 'PK ID of created API should match.') - self.assertEqual(api_key_db.key_hash, api_key['key_hash'], 'Key_hash should match.') - self.assertEqual(api_key_db.user, api_key['user'], 'User should match.') + self.assertEqual( + resp1.json["id"], api_key["id"], "PK ID of created API should match." + ) + self.assertEqual( + api_key_db.key_hash, api_key["key_hash"], "Key_hash should match." + ) + self.assertEqual(api_key_db.user, api_key["user"], "User should match.") - resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id']) + resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"]) self.assertEqual(resp.status_int, 204) def test_put_api_key(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) update_input = resp.json - update_input['enabled'] = not update_input['enabled'] - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["enabled"] = not update_input["enabled"] + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['enabled'], not resp.json['enabled']) + self.assertEqual(put_resp.json["enabled"], not resp.json["enabled"]) update_input = put_resp.json - update_input['enabled'] = not update_input['enabled'] - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["enabled"] = not update_input["enabled"] + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['enabled'], resp.json['enabled']) + self.assertEqual(put_resp.json["enabled"], resp.json["enabled"]) def test_put_api_key_fail(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) update_input = resp.json - update_input['key_hash'] = '1' - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["key_hash"] = "1" + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 400) - self.assertTrue(put_resp.json['faultstring']) + self.assertTrue(put_resp.json["faultstring"]) def test_post_no_user_fail(self): - self.app.post_json('/v1/apikeys', {}, expect_errors=True) + self.app.post_json("/v1/apikeys", {}, expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_base.py b/st2api/tests/unit/controllers/v1/test_base.py index fa8b4f1c92..cbfe3e54c2 100644 --- a/st2api/tests/unit/controllers/v1/test_base.py +++ b/st2api/tests/unit/controllers/v1/test_base.py @@ -19,77 +19,79 @@ class TestBase(FunctionalTest): def test_defaults(self): - response = self.app.get('/') + response = self.app.get("/") self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://127.0.0.1:3000') - self.assertEqual(response.headers['Access-Control-Allow-Methods'], - 'GET,POST,PUT,DELETE,OPTIONS') - self.assertEqual(response.headers['Access-Control-Allow-Headers'], - 'Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID') - self.assertEqual(response.headers['Access-Control-Expose-Headers'], - 'Content-Type,X-Limit,X-Total-Count,X-Request-ID') + self.assertEqual( + response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000" + ) + self.assertEqual( + response.headers["Access-Control-Allow-Methods"], + "GET,POST,PUT,DELETE,OPTIONS", + ) + self.assertEqual( + response.headers["Access-Control-Allow-Headers"], + "Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID", + ) + self.assertEqual( + response.headers["Access-Control-Expose-Headers"], + "Content-Type,X-Limit,X-Total-Count,X-Request-ID", + ) def test_origin(self): - response = self.app.get('/', headers={ - 'origin': 'http://127.0.0.1:3000' - }) + response = self.app.get("/", headers={"origin": "http://127.0.0.1:3000"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000" + ) def test_additional_origin(self): - response = self.app.get('/', headers={ - 'origin': 'http://dev' - }) + response = self.app.get("/", headers={"origin": "http://dev"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://dev') + self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://dev") def test_wrong_origin(self): # Invalid origin (not specified in the config), we return first allowed origin specified # in the config - response = self.app.get('/', headers={ - 'origin': 'http://xss' - }) + response = self.app.get("/", headers={"origin": "http://xss"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers.get("Access-Control-Allow-Origin"), "http://127.0.0.1:3000" + ) invalid_origins = [ - 'http://', - 'https://', - 'https://www.example.com', - 'null', - '*' + "http://", + "https://", + "https://www.example.com", + "null", + "*", ] for origin in invalid_origins: - response = self.app.get('/', headers={ - 'origin': origin - }) + response = self.app.get("/", headers={"origin": origin}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers.get("Access-Control-Allow-Origin"), + "http://127.0.0.1:3000", + ) def test_wildcard_origin(self): try: - cfg.CONF.set_override('allow_origin', ['*'], 'api') - response = self.app.get('/', headers={ - 'origin': 'http://xss' - }) + cfg.CONF.set_override("allow_origin", ["*"], "api") + response = self.app.get("/", headers={"origin": "http://xss"}) finally: - cfg.CONF.clear_override('allow_origin', 'api') + cfg.CONF.clear_override("allow_origin", "api") self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://xss') + self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://xss") def test_valid_status_code_is_returned_on_invalid_path(self): # TypeError: get_all() takes exactly 1 argument (2 given) - resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run', expect_errors=True) + resp = self.app.get( + "/v1/executions/577f775b0640fd1451f2030b/re_run", expect_errors=True + ) self.assertEqual(resp.status_int, 404) # get_one() takes exactly 2 arguments (4 given) - resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run/a/b', - expect_errors=True) + resp = self.app.get( + "/v1/executions/577f775b0640fd1451f2030b/re_run/a/b", expect_errors=True + ) self.assertEqual(resp.status_int, 404) diff --git a/st2api/tests/unit/controllers/v1/test_executions.py b/st2api/tests/unit/controllers/v1/test_executions.py index 57dad1f9f3..5a59f6aab5 100644 --- a/st2api/tests/unit/controllers/v1/test_executions.py +++ b/st2api/tests/unit/controllers/v1/test_executions.py @@ -55,324 +55,286 @@ from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase __all__ = [ - 'ActionExecutionControllerTestCase', - 'ActionExecutionOutputControllerTestCase' + "ActionExecutionControllerTestCase", + "ActionExecutionOutputControllerTestCase", ] ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "remote-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.sh', - 'pack': 'familypack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'c': { - 'type': 'object', - 'properties': { - 'c1': { - 'type': 'string' - } - } - }, - 'd': { - 'type': 'boolean', - 'default': False - } - } + "name": "st2.dummy.action2", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/action2.sh", + "pack": "familypack", + "runner_type": "remote-shell-cmd", + "parameters": { + "c": {"type": "object", "properties": {"c1": {"type": "string"}}}, + "d": {"type": "boolean", "default": False}, + }, } ACTION_3 = { - 'name': 'st2.dummy.action3', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/action3.sh', - 'pack': 'wolfpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'e': {}, - 'f': {} - } + "name": "st2.dummy.action3", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/action3.sh", + "pack": "wolfpack", + "runner_type": "remote-shell-cmd", + "parameters": {"e": {}, "f": {}}, } ACTION_4 = { - 'name': 'st2.dummy.action4', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/workflows/action4.yaml', - 'pack': 'starterpack', - 'runner_type': 'orquesta', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - } - } + "name": "st2.dummy.action4", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/workflows/action4.yaml", + "pack": "starterpack", + "runner_type": "orquesta", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + }, } ACTION_INQUIRY = { - 'name': 'st2.dummy.ask', - 'description': 'another test description', - 'enabled': True, - 'pack': 'wolfpack', - 'runner_type': 'inquirer', + "name": "st2.dummy.ask", + "description": "another test description", + "enabled": True, + "pack": "wolfpack", + "runner_type": "inquirer", } ACTION_DEFAULT_TEMPLATE = { - 'name': 'st2.dummy.default_template', - 'description': 'An action that uses a jinja template as a default value for a parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'intparam': { - 'type': 'integer', - 'default': '{{ st2kv.system.test_int | int }}' - } - } + "name": "st2.dummy.default_template", + "description": "An action that uses a jinja template as a default value for a parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "intparam": {"type": "integer", "default": "{{ st2kv.system.test_int | int }}"} + }, } ACTION_DEFAULT_ENCRYPT = { - 'name': 'st2.dummy.default_encrypted_value', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}' + "name": "st2.dummy.default_encrypted_value", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}' - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + }, + }, } ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS = { - 'name': 'st2.dummy.default_encrypted_value_secret_param', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}', - 'secret': True + "name": "st2.dummy.default_encrypted_value_secret_param", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", + "secret": True, }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}', - 'secret': True - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + "secret": True, + }, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } LIVE_ACTION_2 = { - 'action': 'familypack.st2.dummy.action2', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'ls -l' - } + "action": "familypack.st2.dummy.action2", + "parameters": {"hosts": "localhost", "cmd": "ls -l"}, } LIVE_ACTION_3 = { - 'action': 'wolfpack.st2.dummy.action3', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'ls -l', - 'e': 'abcde', - 'f': 12345 - } + "action": "wolfpack.st2.dummy.action3", + "parameters": {"hosts": "localhost", "cmd": "ls -l", "e": "abcde", "f": 12345}, } LIVE_ACTION_4 = { - 'action': 'starterpack.st2.dummy.action4', + "action": "starterpack.st2.dummy.action4", } LIVE_ACTION_DELAY = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, }, - 'delay': 100 + "delay": 100, } LIVE_ACTION_INQUIRY = { - 'parameters': { - 'route': 'developers', - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': u'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "parameters": { + "route": "developers", + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } - }, - 'action': 'wolfpack.st2.dummy.ask', - 'result': { - 'users': [], - 'roles': [], - 'route': 'developers', - 'ttl': 1440, - 'response': { - 'secondfactor': 'supersecretvalue' + }, }, - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': 'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + }, + "action": "wolfpack.st2.dummy.ask", + "result": { + "users": [], + "roles": [], + "route": "developers", + "ttl": 1440, + "response": {"secondfactor": "supersecretvalue"}, + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } - } + }, + }, + }, } LIVE_ACTION_WITH_SECRET_PARAM = { - 'parameters': { + "parameters": { # action params - 'a': 'param a', - 'd': 'secretpassword1', - + "a": "param a", + "d": "secretpassword1", # runner params - 'password': 'secretpassword2', - 'hosts': 'localhost' + "password": "secretpassword2", + "hosts": "localhost", }, - 'action': 'sixpack.st2.dummy.action1' + "action": "sixpack.st2.dummy.action1", } # Do not add parameters to this. There are tests that will test first without params, # then make a copy with params. LIVE_ACTION_DEFAULT_TEMPLATE = { - 'action': 'starterpack.st2.dummy.default_template', + "action": "starterpack.st2.dummy.default_template", } LIVE_ACTION_DEFAULT_ENCRYPT = { - 'action': 'starterpack.st2.dummy.default_encrypted_value', + "action": "starterpack.st2.dummy.default_encrypted_value", } LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM = { - 'action': 'starterpack.st2.dummy.default_encrypted_value_secret_param', + "action": "starterpack.st2.dummy.default_encrypted_value_secret_param", } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml', 'local.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml", "local.yaml"], } -@mock.patch.object(content_utils, 'get_pack_base_path', mock.MagicMock(return_value='/tmp/test')) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class ActionExecutionControllerTestCase(BaseActionExecutionControllerTestCase, FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/executions' +@mock.patch.object( + content_utils, "get_pack_base_path", mock.MagicMock(return_value="/tmp/test") +) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class ActionExecutionControllerTestCase( + BaseActionExecutionControllerTestCase, + FunctionalTest, + APIControllerWithIncludeAndExcludeFilterTestCase, +): + get_all_path = "/v1/executions" controller_cls = ActionExecutionsController - include_attribute_field_name = 'status' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "status" + exclude_attribute_field_name = "status" test_exact_object_count = False @classmethod - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUpClass(cls): super(BaseActionExecutionControllerTestCase, cls).setUpClass() cls.action1 = copy.deepcopy(ACTION_1) - post_resp = cls.app.post_json('/v1/actions', cls.action1) - cls.action1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action1) + cls.action1["id"] = post_resp.json["id"] cls.action2 = copy.deepcopy(ACTION_2) - post_resp = cls.app.post_json('/v1/actions', cls.action2) - cls.action2['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action2) + cls.action2["id"] = post_resp.json["id"] cls.action3 = copy.deepcopy(ACTION_3) - post_resp = cls.app.post_json('/v1/actions', cls.action3) - cls.action3['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action3) + cls.action3["id"] = post_resp.json["id"] cls.action4 = copy.deepcopy(ACTION_4) - post_resp = cls.app.post_json('/v1/actions', cls.action4) - cls.action4['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action4) + cls.action4["id"] = post_resp.json["id"] cls.action_inquiry = copy.deepcopy(ACTION_INQUIRY) - post_resp = cls.app.post_json('/v1/actions', cls.action_inquiry) - cls.action_inquiry['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_inquiry) + cls.action_inquiry["id"] = post_resp.json["id"] cls.action_template = copy.deepcopy(ACTION_DEFAULT_TEMPLATE) - post_resp = cls.app.post_json('/v1/actions', cls.action_template) - cls.action_template['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_template) + cls.action_template["id"] = post_resp.json["id"] cls.action_decrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT) - post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt) - cls.action_decrypt['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt) + cls.action_decrypt["id"] = post_resp.json["id"] - cls.action_decrypt_secret_param = copy.deepcopy(ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS) - post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt_secret_param) - cls.action_decrypt_secret_param['id'] = post_resp.json['id'] + cls.action_decrypt_secret_param = copy.deepcopy( + ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS + ) + post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt_secret_param) + cls.action_decrypt_secret_param["id"] = post_resp.json["id"] @classmethod def tearDownClass(cls): - cls.app.delete('/v1/actions/%s' % cls.action1['id']) - cls.app.delete('/v1/actions/%s' % cls.action2['id']) - cls.app.delete('/v1/actions/%s' % cls.action3['id']) - cls.app.delete('/v1/actions/%s' % cls.action4['id']) - cls.app.delete('/v1/actions/%s' % cls.action_inquiry['id']) - cls.app.delete('/v1/actions/%s' % cls.action_template['id']) - cls.app.delete('/v1/actions/%s' % cls.action_decrypt['id']) + cls.app.delete("/v1/actions/%s" % cls.action1["id"]) + cls.app.delete("/v1/actions/%s" % cls.action2["id"]) + cls.app.delete("/v1/actions/%s" % cls.action3["id"]) + cls.app.delete("/v1/actions/%s" % cls.action4["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_inquiry["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_template["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_decrypt["id"]) super(BaseActionExecutionControllerTestCase, cls).tearDownClass() def test_get_one(self): @@ -381,11 +343,11 @@ def test_get_one(self): get_resp = self._do_get_one(actionexecution_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) - self.assertIn('web_url', get_resp) - if 'end_timestamp' in get_resp: - self.assertIn('elapsed_seconds', get_resp) + self.assertIn("web_url", get_resp) + if "end_timestamp" in get_resp: + self.assertIn("elapsed_seconds", get_resp) - get_resp = self._do_get_one('last') + get_resp = self._do_get_one("last") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) @@ -396,13 +358,15 @@ def test_get_all_id_query_param_filtering_success(self): self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) - resp = self.app.get('/v1/executions?id=%s' % (actionexecution_id), expect_errors=False) + resp = self.app.get( + "/v1/executions?id=%s" % (actionexecution_id), expect_errors=False + ) self.assertEqual(resp.status_int, 200) def test_get_all_id_query_param_filtering_invalid_id(self): - resp = self.app.get('/v1/executions?id=invalidid', expect_errors=True) + resp = self.app.get("/v1/executions?id=invalidid", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertIn('not a valid ObjectId', resp.json['faultstring']) + self.assertIn("not a valid ObjectId", resp.json["faultstring"]) def test_get_all_id_query_param_filtering_multiple_ids_provided(self): post_resp = self._do_post(LIVE_ACTION_1) @@ -413,94 +377,118 @@ def test_get_all_id_query_param_filtering_multiple_ids_provided(self): self.assertEqual(post_resp.status_int, 201) id_2 = self._get_actionexecution_id(post_resp) - resp = self.app.get('/v1/executions?id=%s,%s' % (id_1, id_2), expect_errors=False) + resp = self.app.get( + "/v1/executions?id=%s,%s" % (id_1, id_2), expect_errors=False + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 2) def test_get_all(self): self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) self._get_actionexecution_id(self._do_post(LIVE_ACTION_2)) - resp = self.app.get('/v1/executions') + resp = self.app.get("/v1/executions") body = resp.json self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "2") - self.assertEqual(len(resp.json), 2, - '/v1/executions did not return all ' - 'actionexecutions.') + self.assertEqual(resp.headers["X-Total-Count"], "2") + self.assertEqual( + len(resp.json), 2, "/v1/executions did not return all " "actionexecutions." + ) # Assert liveactions are sorted by timestamp. for i in range(len(body) - 1): - self.assertTrue(isotime.parse(body[i]['start_timestamp']) >= - isotime.parse(body[i + 1]['start_timestamp'])) - self.assertIn('web_url', body[i]) - if 'end_timestamp' in body[i]: - self.assertIn('elapsed_seconds', body[i]) + self.assertTrue( + isotime.parse(body[i]["start_timestamp"]) + >= isotime.parse(body[i + 1]["start_timestamp"]) + ) + self.assertIn("web_url", body[i]) + if "end_timestamp" in body[i]: + self.assertIn("elapsed_seconds", body[i]) def test_get_all_invalid_offset_too_large(self): - offset = '2141564789454123457895412237483648' - resp = self.app.get('/v1/executions?offset=%s&limit=1' % (offset), expect_errors=True) + offset = "2141564789454123457895412237483648" + resp = self.app.get( + "/v1/executions?offset=%s&limit=1" % (offset), expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Offset "%s" specified is more than 32-bit int' % (offset)) + self.assertEqual( + resp.json["faultstring"], + 'Offset "%s" specified is more than 32-bit int' % (offset), + ) def test_get_query(self): - actionexecution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) + actionexecution_1_id = self._get_actionexecution_id( + self._do_post(LIVE_ACTION_1) + ) - resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action']) + resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"]) self.assertEqual(resp.status_int, 200) - matching_execution = filter(lambda ae: ae['id'] == actionexecution_1_id, resp.json) - self.assertEqual(len(list(matching_execution)), 1, - '/v1/executions did not return correct liveaction.') + matching_execution = filter( + lambda ae: ae["id"] == actionexecution_1_id, resp.json + ) + self.assertEqual( + len(list(matching_execution)), + 1, + "/v1/executions did not return correct liveaction.", + ) def test_get_query_with_limit_and_offset(self): self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) - resp = self.app.get('/v1/executions') + resp = self.app.get("/v1/executions") self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 0) - resp = self.app.get('/v1/executions?limit=1') + resp = self.app.get("/v1/executions?limit=1") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - resp = self.app.get('/v1/executions?limit=0', expect_errors=True) + resp = self.app.get("/v1/executions?limit=0", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue(resp.json['faultstring'], - u'Limit, "0" specified, must be a positive number or -1 for full \ - result set.') + self.assertTrue( + resp.json["faultstring"], + 'Limit, "0" specified, must be a positive number or -1 for full \ + result set.', + ) - resp = self.app.get('/v1/executions?limit=-1') + resp = self.app.get("/v1/executions?limit=-1") self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 1) - resp = self.app.get('/v1/executions?limit=-22', expect_errors=True) + resp = self.app.get("/v1/executions?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) - resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action']) + resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"]) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 1) - resp = self.app.get('/v1/executions?action=%s&limit=0' % - LIVE_ACTION_1['action'], expect_errors=True) + resp = self.app.get( + "/v1/executions?action=%s&limit=0" % LIVE_ACTION_1["action"], + expect_errors=True, + ) self.assertEqual(resp.status_int, 400) - self.assertTrue(resp.json['faultstring'], - u'Limit, "0" specified, must be a positive number or -1 for full \ - result set.') - - resp = self.app.get('/v1/executions?action=%s&limit=1' % - LIVE_ACTION_1['action']) + self.assertTrue( + resp.json["faultstring"], + 'Limit, "0" specified, must be a positive number or -1 for full \ + result set.', + ) + + resp = self.app.get( + "/v1/executions?action=%s&limit=1" % LIVE_ACTION_1["action"] + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - total_count = resp.headers['X-Total-Count'] + total_count = resp.headers["X-Total-Count"] - resp = self.app.get('/v1/executions?offset=%s&limit=1' % total_count) + resp = self.app.get("/v1/executions?offset=%s&limit=1" % total_count) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json), 0) def test_get_one_fail(self): - resp = self.app.get('/v1/executions/100', expect_errors=True) + resp = self.app.get("/v1/executions/100", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_post_delete(self): @@ -508,13 +496,13 @@ def test_post_delete(self): self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') - expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'} - self.assertDictEqual(delete_resp.json['result'], expected_result) + self.assertEqual(delete_resp.json["status"], "canceled") + expected_result = {"message": "Action canceled by user.", "user": "stanley"} + self.assertDictEqual(delete_resp.json["result"], expected_result) def test_post_delete_duplicate(self): """Cancels an execution twice, to ensure that a full execution object - is returned instead of an error message + is returned instead of an error message """ post_resp = self._do_post(LIVE_ACTION_1) @@ -524,59 +512,65 @@ def test_post_delete_duplicate(self): for i in range(2): delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') - expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'} - self.assertDictEqual(delete_resp.json['result'], expected_result) + self.assertEqual(delete_resp.json["status"], "canceled") + expected_result = {"message": "Action canceled by user.", "user": "stanley"} + self.assertDictEqual(delete_resp.json["result"], expected_result) def test_post_delete_trace(self): LIVE_ACTION_TRACE = copy.copy(LIVE_ACTION_1) - LIVE_ACTION_TRACE['context'] = {'trace_context': {'trace_tag': 'balleilaka'}} + LIVE_ACTION_TRACE["context"] = {"trace_context": {"trace_tag": "balleilaka"}} post_resp = self._do_post(LIVE_ACTION_TRACE) self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') + self.assertEqual(delete_resp.json["status"], "canceled") trace_id = str(Trace.get_all()[0].id) - LIVE_ACTION_TRACE['context'] = {'trace_context': {'id_': trace_id}} + LIVE_ACTION_TRACE["context"] = {"trace_context": {"id_": trace_id}} post_resp = self._do_post(LIVE_ACTION_TRACE) self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') + self.assertEqual(delete_resp.json["status"], "canceled") def test_post_nonexistent_action(self): live_action = copy.deepcopy(LIVE_ACTION_1) - live_action['action'] = 'mock.foobar' + live_action["action"] = "mock.foobar" post_resp = self._do_post(live_action, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - expected_error = 'Action "%s" cannot be found.' % live_action['action'] - self.assertEqual(expected_error, post_resp.json['faultstring']) + expected_error = 'Action "%s" cannot be found.' % live_action["action"] + self.assertEqual(expected_error, post_resp.json["faultstring"]) def test_post_parameter_validation_failed(self): execution = copy.deepcopy(LIVE_ACTION_1) # Runner type does not expects additional properties. - execution['parameters']['foo'] = 'bar' + execution["parameters"]["foo"] = "bar" post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], - "Additional properties are not allowed ('foo' was unexpected)") + self.assertEqual( + post_resp.json["faultstring"], + "Additional properties are not allowed ('foo' was unexpected)", + ) # Runner type expects parameter "hosts". - execution['parameters'] = {} + execution["parameters"] = {} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], "'hosts' is a required property") + self.assertEqual( + post_resp.json["faultstring"], "'hosts' is a required property" + ) # Runner type expects parameters "cmd" to be str. - execution['parameters'] = {"hosts": "127.0.0.1", "cmd": 1000} + execution["parameters"] = {"hosts": "127.0.0.1", "cmd": 1000} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertIn('Value "1000" must either be a string or None. Got "int"', - post_resp.json['faultstring']) + self.assertIn( + 'Value "1000" must either be a string or None. Got "int"', + post_resp.json["faultstring"], + ) # Runner type expects parameters "cmd" to be str. - execution['parameters'] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1} + execution["parameters"] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) @@ -589,53 +583,55 @@ def test_post_parameter_render_failed(self): execution = copy.deepcopy(LIVE_ACTION_1) # Runner type does not expects additional properties. - execution['parameters']['hosts'] = '{{ABSENT}}' + execution["parameters"]["hosts"] = "{{ABSENT}}" post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], - 'Dependency unsatisfied in variable "ABSENT"') + self.assertEqual( + post_resp.json["faultstring"], 'Dependency unsatisfied in variable "ABSENT"' + ) def test_post_parameter_validation_explicit_none(self): execution = copy.deepcopy(LIVE_ACTION_1) - execution['parameters']['a'] = None + execution["parameters"]["a"] = None post_resp = self._do_post(execution) self.assertEqual(post_resp.status_int, 201) def test_post_with_st2_context_in_headers(self): resp = self._do_post(copy.deepcopy(LIVE_ACTION_1)) self.assertEqual(resp.status_int, 201) - parent_user = resp.json['context']['user'] - parent_exec_id = str(resp.json['id']) + parent_user = resp.json["context"]["user"] + parent_exec_id = str(resp.json["id"]) context = { - 'parent': { - 'execution_id': parent_exec_id, - 'user': parent_user - }, - 'user': None, - 'other': {'k1': 'v1'} + "parent": {"execution_id": parent_exec_id, "user": parent_user}, + "user": None, + "other": {"k1": "v1"}, + } + headers = { + "content-type": "application/json", + "st2-context": json.dumps(context), } - headers = {'content-type': 'application/json', 'st2-context': json.dumps(context)} resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], parent_user, 'Should use parent\'s user.') + self.assertEqual( + resp.json["context"]["user"], parent_user, "Should use parent's user." + ) expected = { - 'parent': { - 'execution_id': parent_exec_id, - 'user': parent_user - }, - 'user': parent_user, - 'pack': 'sixpack', - 'other': {'k1': 'v1'} + "parent": {"execution_id": parent_exec_id, "user": parent_user}, + "user": parent_user, + "pack": "sixpack", + "other": {"k1": "v1"}, } - self.assertDictEqual(resp.json['context'], expected) + self.assertDictEqual(resp.json["context"], expected) def test_post_with_st2_context_in_headers_failed(self): resp = self._do_post(copy.deepcopy(LIVE_ACTION_1)) self.assertEqual(resp.status_int, 201) - headers = {'content-type': 'application/json', 'st2-context': 'foobar'} - resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True) + headers = {"content-type": "application/json", "st2-context": "foobar"} + resp = self._do_post( + copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertIn('Unable to convert st2-context', resp.json['faultstring']) + self.assertIn("Unable to convert st2-context", resp.json["faultstring"]) def test_re_run_success(self): # Create a new execution @@ -645,12 +641,16 @@ def test_re_run_success(self): # Re-run created execution (no parameters overrides) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) # Re-run created execution (with parameters overrides) - data = {'parameters': {'a': 'val1'}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"parameters": {"a": "val1"}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) def test_re_run_with_delay(self): @@ -659,21 +659,24 @@ def test_re_run_with_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 100 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) resp = json.loads(re_run_resp.body) - self.assertEqual(resp['delay'], delay_time) + self.assertEqual(resp["delay"], delay_time) def test_re_run_with_incorrect_delay(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - delay_time = 'sudo apt -y upgrade winson' - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + delay_time = "sudo apt -y upgrade winson" + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) def test_re_run_with_very_large_delay(self): @@ -682,8 +685,10 @@ def test_re_run_with_very_large_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 10 ** 10 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) def test_re_run_delayed_aciton_with_no_delay(self): @@ -692,11 +697,13 @@ def test_re_run_delayed_aciton_with_no_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 0 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) resp = json.loads(re_run_resp.body) - self.assertNotIn('delay', resp.keys()) + self.assertNotIn("delay", resp.keys()) def test_re_run_failure_execution_doesnt_exist(self): # Create a new execution @@ -705,8 +712,9 @@ def test_re_run_failure_execution_doesnt_exist(self): # Re-run created execution (override parameter with an invalid value) data = {} - re_run_resp = self.app.post_json('/v1/executions/doesntexist/re_run', - data, expect_errors=True) + re_run_resp = self.app.post_json( + "/v1/executions/doesntexist/re_run", data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 404) def test_re_run_failure_parameter_override_invalid_type(self): @@ -716,12 +724,15 @@ def test_re_run_failure_parameter_override_invalid_type(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter and task together) - data = {'parameters': {'a': 1000}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": 1000}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('Value "1000" must either be a string or None. Got "int"', - re_run_resp.json['faultstring']) + self.assertIn( + 'Value "1000" must either be a string or None. Got "int"', + re_run_resp.json["faultstring"], + ) def test_template_param(self): @@ -731,31 +742,46 @@ def test_template_param(self): # Assert that the template in the parameter default value # was rendered and st2kv was used - self.assertEqual(post_resp.json['parameters']['intparam'], 0) + self.assertEqual(post_resp.json["parameters"]["intparam"], 0) # Test with live param live_int_param = 3 livaction_with_params = copy.deepcopy(LIVE_ACTION_DEFAULT_TEMPLATE) - livaction_with_params['parameters'] = { - "intparam": live_int_param - } + livaction_with_params["parameters"] = {"intparam": live_int_param} post_resp = self._do_post(livaction_with_params) self.assertEqual(post_resp.status_int, 201) # Assert that the template in the parameter default value # was not rendered, and the provided parameter was used - self.assertEqual(post_resp.json['parameters']['intparam'], live_int_param) + self.assertEqual(post_resp.json["parameters"]["intparam"], live_int_param) def test_template_encrypted_params(self): # register datastore values which are used in this test case KeyValuePairAPI._setup_crypto() register_items = [ - {'name': 'secret', 'secret': True, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')}, - {'name': 'stanley:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')}, - {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'baz')}, + { + "name": "secret", + "secret": True, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "foo" + ), + }, + { + "name": "stanley:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "bar" + ), + }, + { + "name": "user1:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "baz" + ), + }, ] kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items] @@ -763,43 +789,53 @@ def test_template_encrypted_params(self): # 1. parameters are not marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'stanley') - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["context"]["user"], "stanley") + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") # 2. parameters are marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'stanley') - self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(resp.json["context"]["user"], "stanley") + self.assertEqual( + resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual( + resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE + ) # After switching to the 'user1', that value will be read from switched user's scope - self.use_user(UserDB(name='user1')) + self.use_user(UserDB(name="user1")) # 1. parameters are not marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'user1') - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'baz') + self.assertEqual(resp.json["context"]["user"], "user1") + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "baz") # 2. parameters are marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'user1') - self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(resp.json["context"]["user"], "user1") + self.assertEqual( + resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual( + resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE + ) # This switches to the 'user2', there is no value in that user's scope. When a request # that tries to evaluate Jinja expression to decrypt empty value is sent, a HTTP response # which has 4xx status code will be returned. - self.use_user(UserDB(name='user2')) + self.use_user(UserDB(name="user2")) resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Failed to render parameter "encrypted_user_param": Referenced datastore ' - 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string') + self.assertEqual( + resp.json["faultstring"], + 'Failed to render parameter "encrypted_user_param": Referenced datastore ' + 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string', + ) # clean-up values that are registered at first for kvp in kvps: @@ -808,7 +844,9 @@ def test_template_encrypted_params(self): def test_template_encrypted_params_without_registering(self): resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'].index('Failed to render parameter'), 0) + self.assertEqual( + resp.json["faultstring"].index("Failed to render parameter"), 0 + ) def test_re_run_workflow_success(self): # Create a new execution @@ -818,26 +856,25 @@ def test_re_run_workflow_success(self): # Re-run created execution (tasks option for non workflow) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'user': 'stanley', - 'pack': 'starterpack', - 're-run': { - 'ref': execution_id - }, - 'trace_context': { - 'id_': str(trace.id) - } + "user": "stanley", + "pack": "starterpack", + "re-run": {"ref": execution_id}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_task_success(self): # Create a new execution @@ -846,28 +883,26 @@ def test_re_run_workflow_task_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'] - }, - 'trace_context': { - 'id_': str(trace.id) - } + "pack": "starterpack", + "user": "stanley", + "re-run": {"ref": execution_id, "tasks": data["tasks"]}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_tasks_success(self): # Create a new execution @@ -876,28 +911,26 @@ def test_re_run_workflow_tasks_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x', 'y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x", "y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'] - }, - 'trace_context': { - 'id_': str(trace.id) - } + "pack": "starterpack", + "user": "stanley", + "re-run": {"ref": execution_id, "tasks": data["tasks"]}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_tasks_reset_success(self): # Create a new execution @@ -906,29 +939,30 @@ def test_re_run_workflow_tasks_reset_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x', 'y'], 'reset': ['y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x", "y"], "reset": ["y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'], - 'reset': data['reset'] + "pack": "starterpack", + "user": "stanley", + "re-run": { + "ref": execution_id, + "tasks": data["tasks"], + "reset": data["reset"], }, - 'trace_context': { - 'id_': str(trace.id) - } + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_failure_tasks_option_for_non_workflow(self): # Create a new execution @@ -937,14 +971,15 @@ def test_re_run_failure_tasks_option_for_non_workflow(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - expected_substring = 'only supported for Orquesta workflows' - self.assertIn(expected_substring, re_run_resp.json['faultstring']) + expected_substring = "only supported for Orquesta workflows" + self.assertIn(expected_substring, re_run_resp.json["faultstring"]) def test_re_run_workflow_failure_given_both_params_and_tasks(self): # Create a new execution @@ -953,13 +988,16 @@ def test_re_run_workflow_failure_given_both_params_and_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'parameters': {'a': 'xyz'}, 'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": "xyz"}, "tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('not supported when re-running task(s) for a workflow', - re_run_resp.json['faultstring']) + self.assertIn( + "not supported when re-running task(s) for a workflow", + re_run_resp.json["faultstring"], + ) def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self): # Create a new execution @@ -968,13 +1006,16 @@ def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'parameters': {'a': 'xyz'}, 'reset': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": "xyz"}, "reset": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('not supported when re-running task(s) for a workflow', - re_run_resp.json['faultstring']) + self.assertIn( + "not supported when re-running task(s) for a workflow", + re_run_resp.json["faultstring"], + ) def test_re_run_workflow_failure_invalid_reset_tasks(self): # Create a new execution @@ -983,13 +1024,16 @@ def test_re_run_workflow_failure_invalid_reset_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'tasks': ['x'], 'reset': ['y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"], "reset": ["y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('tasks to reset does not match the tasks to rerun', - re_run_resp.json['faultstring']) + self.assertIn( + "tasks to reset does not match the tasks to rerun", + re_run_resp.json["faultstring"], + ) def test_re_run_secret_parameter(self): # Create a new execution @@ -999,96 +1043,100 @@ def test_re_run_secret_parameter(self): # Re-run created execution (no parameters overrides) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) execution_id = self._get_actionexecution_id(re_run_resp) - re_run_result = self._do_get_one(execution_id, - params={'show_secrets': True}, - expect_errors=True) - self.assertEqual(re_run_result.json['parameters'], LIVE_ACTION_1['parameters']) + re_run_result = self._do_get_one( + execution_id, params={"show_secrets": True}, expect_errors=True + ) + self.assertEqual(re_run_result.json["parameters"], LIVE_ACTION_1["parameters"]) # Re-run created execution (with parameters overrides) - data = {'parameters': {'a': 'val1', 'd': ANOTHER_SUPER_SECRET_PARAMETER}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"parameters": {"a": "val1", "d": ANOTHER_SUPER_SECRET_PARAMETER}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) execution_id = self._get_actionexecution_id(re_run_resp) - re_run_result = self._do_get_one(execution_id, - params={'show_secrets': True}, - expect_errors=True) - self.assertEqual(re_run_result.json['parameters']['d'], data['parameters']['d']) + re_run_result = self._do_get_one( + execution_id, params={"show_secrets": True}, expect_errors=True + ) + self.assertEqual(re_run_result.json["parameters"]["d"], data["parameters"]["d"]) def test_put_status_and_result(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'succeeded') - self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(put_resp.json["status"], "succeeded") + self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"}) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'succeeded') - self.assertDictEqual(get_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(get_resp.json["status"], "succeeded") + self.assertDictEqual(get_resp.json["result"], {"stdout": "foobar"}) def test_put_bad_state(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'married'} + updates = {"status": "married"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('\'married\' is not one of', put_resp.json['faultstring']) + self.assertIn("'married' is not one of", put_resp.json["faultstring"]) def test_put_bad_result(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'result': 'foobar'} + updates = {"result": "foobar"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('is not of type \'object\'', put_resp.json['faultstring']) + self.assertIn("is not of type 'object'", put_resp.json["faultstring"]) def test_put_bad_property(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'abandoned', 'foo': 'bar'} + updates = {"status": "abandoned", "foo": "bar"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('Additional properties are not allowed', put_resp.json['faultstring']) + self.assertIn( + "Additional properties are not allowed", put_resp.json["faultstring"] + ) def test_put_status_to_completed_execution(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'succeeded') - self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(put_resp.json["status"], "succeeded") + self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"}) - updates = {'status': 'abandoned'} + updates = {"status": "abandoned"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - @mock.patch.object( - LiveAction, 'get_by_id', - mock.MagicMock(return_value=None)) + @mock.patch.object(LiveAction, "get_by_id", mock.MagicMock(return_value=None)) def test_put_execution_missing_liveaction(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 500) @@ -1098,19 +1146,19 @@ def test_put_pause_unsupported(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) - updates = {'status': 'paused'} + updates = {"status": "paused"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) def test_put_pause(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1118,50 +1166,50 @@ def test_put_pause(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_pause_not_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['status'], 'requested') + self.assertEqual(post_resp.json["status"], "requested") execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('is not in a running state', put_resp.json['faultstring']) + self.assertIn("is not in a running state", put_resp.json["faultstring"]) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'requested') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "requested") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_pause_already_pausing(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1169,44 +1217,46 @@ def test_put_pause_already_pausing(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: - updates = {'status': 'pausing'} + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') + self.assertEqual(put_resp.json["status"], "pausing") mocked.assert_not_called() get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_unsupported(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) def test_put_resume(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1214,44 +1264,46 @@ def test_put_resume(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) # Manually change the status to paused because only the runner pause method should # set the paused status directly to the liveaction and execution database objects. liveaction_id = self._get_liveaction_id(post_resp) liveaction = action_db_util.get_liveaction_by_id(liveaction_id) - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'paused') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "paused") + self.assertIsNone(get_resp.json.get("result")) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'resuming') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "resuming") + self.assertIsNone(put_resp.json.get("result")) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'resuming') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "resuming") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_not_paused(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1259,33 +1311,35 @@ def test_put_resume_not_paused(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - expected_error_message = 'it is in "pausing" state and not in "paused" state' - self.assertIn(expected_error_message, put_resp.json['faultstring']) + expected_error_message = ( + 'it is in "pausing" state and not in "paused" state' + ) + self.assertIn(expected_error_message, put_resp.json["faultstring"]) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_already_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1293,24 +1347,26 @@ def test_put_resume_already_running(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: - updates = {'status': 'resuming'} + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") mocked.assert_not_called() get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'running') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "running") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_get_inquiry_mask(self): """Ensure Inquiry responses are masked when retrieved via ActionExecution GET @@ -1327,194 +1383,213 @@ def test_get_inquiry_mask(self): self.assertEqual(get_resp.status_int, 200) resp = json.loads(get_resp.body) - self.assertEqual(resp['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE + ) post_resp = self._do_post(LIVE_ACTION_INQUIRY) actionexecution_id = self._get_actionexecution_id(post_resp) - get_resp = self._do_get_one(actionexecution_id, params={'show_secrets': True}) + get_resp = self._do_get_one(actionexecution_id, params={"show_secrets": True}) self.assertEqual(get_resp.status_int, 200) resp = json.loads(get_resp.body) - self.assertEqual(resp['result']['response']['secondfactor'], "supersecretvalue") + self.assertEqual(resp["result"]["response"]["secondfactor"], "supersecretvalue") def test_get_include_attributes_and_secret_parameters(self): # Verify that secret parameters are correctly masked when using ?include_attributes filter self._do_post(LIVE_ACTION_WITH_SECRET_PARAM) urls = [ - '/v1/actionexecutions?include_attributes=parameters', - '/v1/actionexecutions?include_attributes=parameters,action', - '/v1/actionexecutions?include_attributes=parameters,runner', - '/v1/actionexecutions?include_attributes=parameters,action,runner' + "/v1/actionexecutions?include_attributes=parameters", + "/v1/actionexecutions?include_attributes=parameters,action", + "/v1/actionexecutions?include_attributes=parameters,runner", + "/v1/actionexecutions?include_attributes=parameters,action,runner", ] for url in urls: - resp = self.app.get(url + '&limit=1') + resp = self.app.get(url + "&limit=1") - self.assertIn('parameters', resp.json[0]) - self.assertEqual(resp.json[0]['parameters']['a'], 'param a') - self.assertEqual(resp.json[0]['parameters']['d'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json[0]['parameters']['password'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json[0]) + self.assertEqual(resp.json[0]["parameters"]["a"], "param a") + self.assertEqual(resp.json[0]["parameters"]["d"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp.json[0]["parameters"]["password"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost") # With ?show_secrets=True urls = [ - ('/v1/actionexecutions?&include_attributes=parameters'), - ('/v1/actionexecutions?include_attributes=parameters,action'), - ('/v1/actionexecutions?include_attributes=parameters,runner'), - ('/v1/actionexecutions?include_attributes=parameters,action,runner') + ("/v1/actionexecutions?&include_attributes=parameters"), + ("/v1/actionexecutions?include_attributes=parameters,action"), + ("/v1/actionexecutions?include_attributes=parameters,runner"), + ("/v1/actionexecutions?include_attributes=parameters,action,runner"), ] for url in urls: - resp = self.app.get(url + '&limit=1&show_secrets=True') + resp = self.app.get(url + "&limit=1&show_secrets=True") - self.assertIn('parameters', resp.json[0]) - self.assertEqual(resp.json[0]['parameters']['a'], 'param a') - self.assertEqual(resp.json[0]['parameters']['d'], 'secretpassword1') - self.assertEqual(resp.json[0]['parameters']['password'], 'secretpassword2') - self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json[0]) + self.assertEqual(resp.json[0]["parameters"]["a"], "param a") + self.assertEqual(resp.json[0]["parameters"]["d"], "secretpassword1") + self.assertEqual(resp.json[0]["parameters"]["password"], "secretpassword2") + self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost") # NOTE: We don't allow exclusion of attributes such as "action" and "runner" because # that would break secrets masking urls = [ - '/v1/actionexecutions?limit=1&exclude_attributes=action', - '/v1/actionexecutions?limit=1&exclude_attributes=runner', - '/v1/actionexecutions?limit=1&exclude_attributes=action,runner', + "/v1/actionexecutions?limit=1&exclude_attributes=action", + "/v1/actionexecutions?limit=1&exclude_attributes=runner", + "/v1/actionexecutions?limit=1&exclude_attributes=action,runner", ] for url in urls: - resp = self.app.get(url + '&limit=1', expect_errors=True) + resp = self.app.get(url + "&limit=1", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid or unsupported exclude attribute specified:' in - resp.json['faultstring']) + self.assertTrue( + "Invalid or unsupported exclude attribute specified:" + in resp.json["faultstring"] + ) def test_get_single_attribute_success(self): - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] - resp = self.app.get('/v1/executions/%s/attribute/status' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/status" % (exec_id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, 'requested') + self.assertEqual(resp.json, "requested") - resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id)) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, None) - resp = self.app.get('/v1/executions/%s/attribute/trigger_instance' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/trigger_instance" % (exec_id)) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, None) data = {} - data['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - data['result'] = {'foo': 'bar'} + data["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + data["result"] = {"foo": "bar"} - resp = self.app.put_json('/v1/executions/%s' % (exec_id), data) + resp = self.app.put_json("/v1/executions/%s" % (exec_id), data) self.assertEqual(resp.status_int, 200) - resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, data['result']) + self.assertEqual(resp.json, data["result"]) def test_get_single_attribute_failure_invalid_attribute(self): - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] - resp = self.app.get('/v1/executions/%s/attribute/start_timestamp' % (exec_id), - expect_errors=True) + resp = self.app.get( + "/v1/executions/%s/attribute/start_timestamp" % (exec_id), + expect_errors=True, + ) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid attribute "start_timestamp" specified.' in - resp.json['faultstring']) + self.assertTrue( + 'Invalid attribute "start_timestamp" specified.' in resp.json["faultstring"] + ) def test_get_single_include_attributes_and_secret_parameters(self): # Verify that secret parameters are correctly masked when using ?include_attributes filter self._do_post(LIVE_ACTION_WITH_SECRET_PARAM) - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] # FYI, the response always contains the 'id' parameter urls = [ { - 'url': '/v1/executions/%s?include_attributes=parameters' % (exec_id), - 'expected_parameters': ['id', 'parameters'], + "url": "/v1/executions/%s?include_attributes=parameters" % (exec_id), + "expected_parameters": ["id", "parameters"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action'], + "url": "/v1/executions/%s?include_attributes=parameters,action" + % (exec_id), + "expected_parameters": ["id", "parameters", "action"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "runner"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action', 'runner'], - } + "url": "/v1/executions/%s?include_attributes=parameters,action,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "action", "runner"], + }, ] for item in urls: - url = item['url'] + url = item["url"] resp = self.app.get(url) - self.assertIn('parameters', resp.json) - self.assertEqual(resp.json['parameters']['a'], 'param a') - self.assertEqual(resp.json['parameters']['d'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['password'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json) + self.assertEqual(resp.json["parameters"]["a"], "param a") + self.assertEqual(resp.json["parameters"]["d"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp.json["parameters"]["password"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(resp.json["parameters"]["hosts"], "localhost") # ensure that the response has only the keys we epect, no more, no less resp_keys = set(resp.json.keys()) - expected_params = set(item['expected_parameters']) + expected_params = set(item["expected_parameters"]) diff = resp_keys.symmetric_difference(expected_params) self.assertEqual(diff, set()) # With ?show_secrets=True urls = [ { - 'url': '/v1/executions/%s?&include_attributes=parameters' % (exec_id), - 'expected_parameters': ['id', 'parameters'], + "url": "/v1/executions/%s?&include_attributes=parameters" % (exec_id), + "expected_parameters": ["id", "parameters"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action'], + "url": "/v1/executions/%s?include_attributes=parameters,action" + % (exec_id), + "expected_parameters": ["id", "parameters", "action"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "runner"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,action,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "action", "runner"], }, ] for item in urls: - url = item['url'] - resp = self.app.get(url + '&show_secrets=True') + url = item["url"] + resp = self.app.get(url + "&show_secrets=True") - self.assertIn('parameters', resp.json) - self.assertEqual(resp.json['parameters']['a'], 'param a') - self.assertEqual(resp.json['parameters']['d'], 'secretpassword1') - self.assertEqual(resp.json['parameters']['password'], 'secretpassword2') - self.assertEqual(resp.json['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json) + self.assertEqual(resp.json["parameters"]["a"], "param a") + self.assertEqual(resp.json["parameters"]["d"], "secretpassword1") + self.assertEqual(resp.json["parameters"]["password"], "secretpassword2") + self.assertEqual(resp.json["parameters"]["hosts"], "localhost") # ensure that the response has only the keys we epect, no more, no less resp_keys = set(resp.json.keys()) - expected_params = set(item['expected_parameters']) + expected_params = set(item["expected_parameters"]) diff = resp_keys.symmetric_difference(expected_params) self.assertEqual(diff, set()) # NOTE: We don't allow exclusion of attributes such as "action" and "runner" because # that would break secrets masking urls = [ - '/v1/executions/%s?limit=1&exclude_attributes=action', - '/v1/executions/%s?limit=1&exclude_attributes=runner', - '/v1/executions/%s?limit=1&exclude_attributes=action,runner', + "/v1/executions/%s?limit=1&exclude_attributes=action", + "/v1/executions/%s?limit=1&exclude_attributes=runner", + "/v1/executions/%s?limit=1&exclude_attributes=action,runner", ] for url in urls: resp = self.app.get(url, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid or unsupported exclude attribute specified:' in - resp.json['faultstring']) + self.assertTrue( + "Invalid or unsupported exclude attribute specified:" + in resp.json["faultstring"] + ) def _insert_mock_models(self): execution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) @@ -1522,37 +1597,44 @@ def _insert_mock_models(self): return [execution_1_id, execution_2_id] -class ActionExecutionOutputControllerTestCase(BaseActionExecutionControllerTestCase, - FunctionalTest): +class ActionExecutionOutputControllerTestCase( + BaseActionExecutionControllerTestCase, FunctionalTest +): def test_get_output_id_last_no_executions_in_the_database(self): ActionExecution.query().delete() - resp = self.app.get('/v1/executions/last/output', expect_errors=True) + resp = self.app.get("/v1/executions/last/output", expect_errors=True) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(resp.json['faultstring'], 'No executions found in the database') + self.assertEqual( + resp.json["faultstring"], "No executions found in the database" + ) def test_get_output_running_execution(self): # Only the output produced so far should be returned # Test the execution output API endpoint for execution which is running (blocking) status = action_constants.LIVEACTION_STATUS_RUNNING timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) - output_params = dict(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout before start\n') + output_params = dict( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout before start\n", + ) def insert_mock_data(data): - output_params['data'] = data + output_params["data"] = data output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -1561,45 +1643,51 @@ def insert_mock_data(data): ActionExecutionOutput.add_or_update(output_db, publish=False) # Retrieve data while execution is running - data produced so far should be retrieved - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 1) - self.assertEqual(lines[0], 'stdout before start') + self.assertEqual(lines[0], "stdout before start") # Insert more data - insert_mock_data('stdout mid 1\n') + insert_mock_data("stdout mid 1\n") # Retrieve data while execution is running - data produced so far should be retrieved - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 2) - self.assertEqual(lines[0], 'stdout before start') - self.assertEqual(lines[1], 'stdout mid 1') + self.assertEqual(lines[0], "stdout before start") + self.assertEqual(lines[1], "stdout mid 1") # Insert more data - insert_mock_data('stdout pre finish 1\n') + insert_mock_data("stdout pre finish 1\n") # Transition execution to completed state action_execution_db.status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_execution_db = ActionExecution.add_or_update(action_execution_db) # Execution has finished - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 3) - self.assertEqual(lines[0], 'stdout before start') - self.assertEqual(lines[1], 'stdout mid 1') - self.assertEqual(lines[2], 'stdout pre finish 1') + self.assertEqual(lines[0], "stdout before start") + self.assertEqual(lines[1], "stdout mid 1") + self.assertEqual(lines[2], "stdout pre finish 1") def test_get_output_finished_execution(self): # Test the execution output API endpoint for execution which has finished @@ -1607,42 +1695,50 @@ def test_get_output_finished_execution(self): # Insert mock execution and output objects status = action_constants.LIVEACTION_STATUS_SUCCEEDED timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) for i in range(1, 6): - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout %s\n' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) for i in range(10, 15): - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr %s\n' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") self.assertEqual(len(lines), 10) - self.assertEqual(lines[0], 'stdout 1') - self.assertEqual(lines[9], 'stderr 14') + self.assertEqual(lines[0], "stdout 1") + self.assertEqual(lines[9], "stderr 14") # Verify "last" short-hand id works - resp = self.app.get('/v1/executions/last/output', expect_errors=False) + resp = self.app.get("/v1/executions/last/output", expect_errors=False) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") self.assertEqual(len(lines), 10) diff --git a/st2api/tests/unit/controllers/v1/test_executions_auth.py b/st2api/tests/unit/controllers/v1/test_executions_auth.py index e408d053dc..f1045a7d54 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_auth.py +++ b/st2api/tests/unit/controllers/v1/test_executions_auth.py @@ -44,61 +44,48 @@ ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "remote-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } ACTION_DEFAULT_ENCRYPT = { - 'name': 'st2.dummy.default_encrypted_value', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}' + "name": "st2.dummy.default_encrypted_value", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}' - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + }, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } LIVE_ACTION_DEFAULT_ENCRYPT = { - 'action': 'starterpack.st2.dummy.default_encrypted_value', + "action": "starterpack.st2.dummy.default_encrypted_value", } # NOTE: We use a longer expiry time because this variable is initialized on module import (aka @@ -107,19 +94,23 @@ # by that time and the tests would fail. NOW = date_utils.get_datetime_utc_now() EXPIRY = NOW + datetime.timedelta(seconds=1000) -SYS_TOKEN = TokenDB(id=bson.ObjectId(), user='system', token=uuid.uuid4().hex, expiry=EXPIRY) -USR_TOKEN = TokenDB(id=bson.ObjectId(), user='tokenuser', token=uuid.uuid4().hex, expiry=EXPIRY) +SYS_TOKEN = TokenDB( + id=bson.ObjectId(), user="system", token=uuid.uuid4().hex, expiry=EXPIRY +) +USR_TOKEN = TokenDB( + id=bson.ObjectId(), user="tokenuser", token=uuid.uuid4().hex, expiry=EXPIRY +) -FIXTURES_PACK = 'generic' -FIXTURES = { - 'users': ['system_user.yaml', 'token_user.yaml'] -} +FIXTURES_PACK = "generic" +FIXTURES = {"users": ["system_user.yaml", "token_user.yaml"]} # These parameters are used for the tests of getting value from datastore and decrypting it at # Jinja expression in a action metadata definition. -TEST_USER = UserDB(name='user1') -TEST_TOKEN = TokenDB(id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY) -TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash='secret_key', enabled=True) +TEST_USER = UserDB(name="user1") +TEST_TOKEN = TokenDB( + id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY +) +TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash="secret_key", enabled=True) def mock_get_token(*args, **kwargs): @@ -128,50 +119,69 @@ def mock_get_token(*args, **kwargs): return USR_TOKEN -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionExecutionControllerTestCaseAuthEnabled(FunctionalTest): enable_auth = True @classmethod + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=UserDB)) @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) - @mock.patch.object(User, 'get_by_name', mock.MagicMock(side_effect=UserDB)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUpClass(cls): super(ActionExecutionControllerTestCaseAuthEnabled, cls).setUpClass() cls.action = copy.deepcopy(ACTION_1) - headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)} - post_resp = cls.app.post_json('/v1/actions', cls.action, headers=headers) - cls.action['id'] = post_resp.json['id'] + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + } + post_resp = cls.app.post_json("/v1/actions", cls.action, headers=headers) + cls.action["id"] = post_resp.json["id"] cls.action_encrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT) - post_resp = cls.app.post_json('/v1/actions', cls.action_encrypt, headers=headers) - cls.action_encrypt['id'] = post_resp.json['id'] + post_resp = cls.app.post_json( + "/v1/actions", cls.action_encrypt, headers=headers + ) + cls.action_encrypt["id"] = post_resp.json["id"] - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=FIXTURES) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=FIXTURES + ) # register datastore values which are used in this tests KeyValuePairAPI._setup_crypto() register_items = [ - {'name': 'secret', 'secret': True, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')}, - {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')}, + { + "name": "secret", + "secret": True, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "foo" + ), + }, + { + "name": "user1:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "bar" + ), + }, + ] + cls.kvps = [ + KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items ] - cls.kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items] @classmethod - @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) def tearDownClass(cls): - headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)} - cls.app.delete('/v1/actions/%s' % cls.action['id'], headers=headers) - cls.app.delete('/v1/actions/%s' % cls.action_encrypt['id'], headers=headers) + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + } + cls.app.delete("/v1/actions/%s" % cls.action["id"], headers=headers) + cls.app.delete("/v1/actions/%s" % cls.action_encrypt["id"], headers=headers) # unregister key-value pairs for tests [KeyValuePair.delete(x) for x in cls.kvps] @@ -179,49 +189,53 @@ def tearDownClass(cls): super(ActionExecutionControllerTestCaseAuthEnabled, cls).tearDownClass() def _do_post(self, liveaction, *args, **kwargs): - return self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + return self.app.post_json("/v1/executions", liveaction, *args, **kwargs) - @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) def test_post_with_st2_context_in_headers(self): - headers = {'content-type': 'application/json', 'X-Auth-Token': str(USR_TOKEN.token)} + headers = { + "content-type": "application/json", + "X-Auth-Token": str(USR_TOKEN.token), + } resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - token_user = resp.json['context']['user'] - self.assertEqual(token_user, 'tokenuser') - context = {'parent': {'execution_id': str(resp.json['id']), 'user': token_user}} - headers = {'content-type': 'application/json', - 'X-Auth-Token': str(SYS_TOKEN.token), - 'st2-context': json.dumps(context)} + token_user = resp.json["context"]["user"] + self.assertEqual(token_user, "tokenuser") + context = {"parent": {"execution_id": str(resp.json["id"]), "user": token_user}} + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + "st2-context": json.dumps(context), + } resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'tokenuser') - self.assertEqual(resp.json['context']['parent'], context['parent']) + self.assertEqual(resp.json["context"]["user"], "tokenuser") + self.assertEqual(resp.json["context"]["parent"], context["parent"]) - @mock.patch.object(ApiKey, 'get', mock.Mock(return_value=TEST_APIKEY)) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER)) + @mock.patch.object(ApiKey, "get", mock.Mock(return_value=TEST_APIKEY)) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER)) def test_template_encrypted_params_with_apikey(self): - resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={ - 'St2-Api-key': 'secret_key' - }) + resp = self._do_post( + LIVE_ACTION_DEFAULT_ENCRYPT, headers={"St2-Api-key": "secret_key"} + ) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") - @mock.patch.object(Token, 'get', mock.Mock(return_value=TEST_TOKEN)) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER)) + @mock.patch.object(Token, "get", mock.Mock(return_value=TEST_TOKEN)) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER)) def test_template_encrypted_params_with_access_token(self): - resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={ - 'X-Auth-Token': str(TEST_TOKEN.token) - }) + resp = self._do_post( + LIVE_ACTION_DEFAULT_ENCRYPT, headers={"X-Auth-Token": str(TEST_TOKEN.token)} + ) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") def test_template_encrypted_params_without_auth(self): resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 401) - self.assertEqual(resp.json['faultstring'], - 'Unauthorized - One of Token or API key required.') + self.assertEqual( + resp.json["faultstring"], "Unauthorized - One of Token or API key required." + ) diff --git a/st2api/tests/unit/controllers/v1/test_executions_descendants.py b/st2api/tests/unit/controllers/v1/test_executions_descendants.py index 1afbcdde2f..945e03feeb 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_descendants.py +++ b/st2api/tests/unit/controllers/v1/test_executions_descendants.py @@ -19,64 +19,85 @@ from st2tests.api import FunctionalTest -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class ActionExecutionControllerTestCaseDescendantsTest(FunctionalTest): - @classmethod def setUpClass(cls): super(ActionExecutionControllerTestCaseDescendantsTest, cls).setUpClass() - cls.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + cls.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_get_all_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get("/v1/executions/%s/children" % str(root_execution.id)) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_all_descendants_depth_neg_1(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children?depth=-1' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get( + "/v1/executions/%s/children?depth=-1" % str(root_execution.id) + ) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_1_level_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children?depth=1' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get( + "/v1/executions/%s/children?depth=1" % str(root_execution.id) + ) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # All children of root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.parent == str(root_execution.id)] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.parent == str(root_execution.id) + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) diff --git a/st2api/tests/unit/controllers/v1/test_executions_filters.py b/st2api/tests/unit/controllers/v1/test_executions_filters.py index e33e8bf87d..af451ca519 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_filters.py +++ b/st2api/tests/unit/controllers/v1/test_executions_filters.py @@ -22,6 +22,7 @@ from six.moves import http_client import st2tests.config as tests_config + tests_config.parse_args() from st2tests.api import FunctionalTest @@ -36,7 +37,6 @@ class TestActionExecutionFilters(FunctionalTest): - @classmethod def testDownClass(cls): pass @@ -52,29 +52,33 @@ def setUpClass(cls): cls.start_timestamps = [] cls.fake_types = [ { - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']), - 'context': copy.deepcopy(fixture.ARTIFACTS['context']), - 'children': [] + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger_instance": copy.deepcopy( + fixture.ARTIFACTS["trigger_instance"] + ), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]), + "liveaction": copy.deepcopy( + fixture.ARTIFACTS["liveactions"]["workflow"] + ), + "context": copy.deepcopy(fixture.ARTIFACTS["context"]), + "children": [], }, { - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1']) - } + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]), + }, ] def assign_parent(child): - candidates = [v for k, v in cls.refs.items() if v.action['name'] == 'chain'] + candidates = [v for k, v in cls.refs.items() if v.action["name"] == "chain"] if candidates: parent = random.choice(candidates) - child['parent'] = str(parent.id) - parent.children.append(child['id']) + child["parent"] = str(parent.id) + parent.children.append(child["id"]) cls.refs[str(parent.id)] = ActionExecution.add_or_update(parent) for i in range(cls.num_records): @@ -82,12 +86,12 @@ def assign_parent(child): timestamp = cls.dt_base + datetime.timedelta(seconds=i) fake_type = random.choice(cls.fake_types) data = copy.deepcopy(fake_type) - data['id'] = obj_id - data['start_timestamp'] = isotime.format(timestamp, offset=False) - data['end_timestamp'] = isotime.format(timestamp, offset=False) - data['status'] = data['liveaction']['status'] - data['result'] = data['liveaction']['result'] - if fake_type['action']['name'] == 'local' and random.choice([True, False]): + data["id"] = obj_id + data["start_timestamp"] = isotime.format(timestamp, offset=False) + data["end_timestamp"] = isotime.format(timestamp, offset=False) + data["status"] = data["liveaction"]["status"] + data["result"] = data["liveaction"]["result"] + if fake_type["action"]["name"] == "local" and random.choice([True, False]): assign_parent(data) wb_obj = ActionExecutionAPI(**data) db_obj = ActionExecutionAPI.to_model(wb_obj) @@ -97,154 +101,185 @@ def assign_parent(child): cls.start_timestamps = sorted(cls.start_timestamps) def test_get_all(self): - response = self.app.get('/v1/executions') + response = self.app.get("/v1/executions") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), self.num_records) - self.assertEqual(response.headers['X-Total-Count'], str(self.num_records)) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(self.num_records)) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(self.refs.keys())) def test_get_all_exclude_attributes(self): # No attributes excluded - response = self.app.get('/v1/executions?action=executions.local&limit=1') + response = self.app.get("/v1/executions?action=executions.local&limit=1") self.assertEqual(response.status_int, 200) - self.assertIn('result', response.json[0]) + self.assertIn("result", response.json[0]) # Exclude "result" attribute - path = '/v1/executions?action=executions.local&limit=1&exclude_attributes=result' + path = ( + "/v1/executions?action=executions.local&limit=1&exclude_attributes=result" + ) response = self.app.get(path) self.assertEqual(response.status_int, 200) - self.assertNotIn('result', response.json[0]) + self.assertNotIn("result", response.json[0]) def test_get_one(self): obj_id = random.choice(list(self.refs.keys())) - response = self.app.get('/v1/executions/%s' % obj_id) + response = self.app.get("/v1/executions/%s" % obj_id) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) record = response.json fake_record = ActionExecutionAPI.from_model(self.refs[obj_id]) - self.assertEqual(record['id'], obj_id) - self.assertDictEqual(record['action'], fake_record.action) - self.assertDictEqual(record['runner'], fake_record.runner) - self.assertDictEqual(record['liveaction'], fake_record.liveaction) + self.assertEqual(record["id"], obj_id) + self.assertDictEqual(record["action"], fake_record.action) + self.assertDictEqual(record["runner"], fake_record.runner) + self.assertDictEqual(record["liveaction"], fake_record.liveaction) def test_get_one_failed(self): - response = self.app.get('/v1/executions/%s' % str(bson.ObjectId()), - expect_errors=True) + response = self.app.get( + "/v1/executions/%s" % str(bson.ObjectId()), expect_errors=True + ) self.assertEqual(response.status_int, http_client.NOT_FOUND) def test_limit(self): limit = 10 - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % - limit) + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), limit) - self.assertEqual(response.headers['X-Limit'], str(limit)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Limit"], str(limit)) + self.assertEqual( + response.headers["X-Total-Count"], str(len(refs)), response.json + ) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(refs)), []) def test_limit_minus_one(self): limit = -1 - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit) + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json) - ids = [item['id'] for item in response.json] + self.assertEqual( + response.headers["X-Total-Count"], str(len(refs)), response.json + ) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(refs)), []) def test_limit_negative(self): limit = -22 - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit, - expect_errors=True) + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit, + expect_errors=True, + ) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + response.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_query(self): - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain') + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get("/v1/executions?action=executions.chain") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(refs))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(refs)) def test_filters(self): - excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt', - 'timestamp_lt', 'status'] + excludes = [ + "parent", + "timestamp", + "action", + "liveaction", + "timestamp_gt", + "timestamp_lt", + "status", + ] for param, field in six.iteritems(ActionExecutionsController.supported_filters): if param in excludes: continue value = self.fake_types[0] - for item in field.split('.'): + for item in field.split("."): value = value[item] - response = self.app.get('/v1/executions?%s=%s' % (param, value)) + response = self.app.get("/v1/executions?%s=%s" % (param, value)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertGreater(len(response.json), 0) - self.assertGreater(int(response.headers['X-Total-Count']), 0) + self.assertGreater(int(response.headers["X-Total-Count"]), 0) def test_advanced_filters(self): - excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt', - 'timestamp_lt', 'status'] + excludes = [ + "parent", + "timestamp", + "action", + "liveaction", + "timestamp_gt", + "timestamp_lt", + "status", + ] for param, field in six.iteritems(ActionExecutionsController.supported_filters): if param in excludes: continue value = self.fake_types[0] - for item in field.split('.'): + for item in field.split("."): value = value[item] - response = self.app.get('/v1/executions?filter=%s:%s' % (field, value)) + response = self.app.get("/v1/executions?filter=%s:%s" % (field, value)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertGreater(len(response.json), 0) - self.assertGreater(int(response.headers['X-Total-Count']), 0) + self.assertGreater(int(response.headers["X-Total-Count"]), 0) def test_advanced_filters_malformed(self): - response = self.app.get('/v1/executions?filter=a:b,c:d', expect_errors=True) + response = self.app.get("/v1/executions?filter=a:b,c:d", expect_errors=True) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json, { - "faultstring": "Cannot resolve field \"a\"" - }) - response = self.app.get('/v1/executions?filter=action.ref', expect_errors=True) + self.assertEqual(response.json, {"faultstring": 'Cannot resolve field "a"'}) + response = self.app.get("/v1/executions?filter=action.ref", expect_errors=True) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json, { - "faultstring": "invalid format for filter \"action.ref\"" - }) + self.assertEqual( + response.json, {"faultstring": 'invalid format for filter "action.ref"'} + ) def test_parent(self): - refs = [v for k, v in six.iteritems(self.refs) - if v.action['name'] == 'chain' and v.children] + refs = [ + v + for k, v in six.iteritems(self.refs) + if v.action["name"] == "chain" and v.children + ] self.assertTrue(refs) ref = random.choice(refs) - response = self.app.get('/v1/executions?parent=%s' % str(ref.id)) + response = self.app.get("/v1/executions?parent=%s" % str(ref.id)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(ref.children)) - self.assertEqual(response.headers['X-Total-Count'], str(len(ref.children))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(ref.children))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(ref.children)) def test_parentless(self): - refs = {k: v for k, v in six.iteritems(self.refs) if not getattr(v, 'parent', None)} + refs = { + k: v for k, v in six.iteritems(self.refs) if not getattr(v, "parent", None) + } self.assertTrue(refs) self.assertNotEqual(len(refs), self.num_records) - response = self.app.get('/v1/executions?parent=null') + response = self.app.get("/v1/executions?parent=null") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(refs))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(refs.keys())) def test_pagination(self): @@ -253,14 +288,15 @@ def test_pagination(self): page_count = int(self.num_records / page_size) for i in range(page_count): offset = i * page_size - response = self.app.get('/v1/executions?offset=%s&limit=%s' % ( - offset, page_size)) + response = self.app.get( + "/v1/executions?offset=%s&limit=%s" % (offset, page_size) + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), page_size) - self.assertEqual(response.headers['X-Limit'], str(page_size)) - self.assertEqual(response.headers['X-Total-Count'], str(self.num_records)) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Limit"], str(page_size)) + self.assertEqual(response.headers["X-Total-Count"], str(self.num_records)) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(self.refs.keys())), []) self.assertListEqual(sorted(list(set(ids) - set(retrieved))), sorted(ids)) retrieved += ids @@ -270,60 +306,62 @@ def test_ui_history_query(self): # In this test we only care about making sure this exact query works. This query is used # by the webui for the history page so it is special and breaking this is bad. limit = 50 - history_query = '/v1/executions?limit={}&parent=null&exclude_attributes=' \ - 'result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&' \ - 'offset=0'.format(limit) + history_query = ( + "/v1/executions?limit={}&parent=null&exclude_attributes=" + "result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&" + "offset=0".format(limit) + ) response = self.app.get(history_query) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), limit) - self.assertTrue(int(response.headers['X-Total-Count']) > limit) + self.assertTrue(int(response.headers["X-Total-Count"]) > limit) def test_datetime_range(self): - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' - response = self.app.get('/v1/executions?timestamp=%s' % dt_range) + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" + response = self.app.get("/v1/executions?timestamp=%s" % dt_range) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), 10) - self.assertEqual(response.headers['X-Total-Count'], '10') + self.assertEqual(response.headers["X-Total-Count"], "10") - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[9]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[9]["start_timestamp"] self.assertLess(isotime.parse(dt1), isotime.parse(dt2)) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' - response = self.app.get('/v1/executions?timestamp=%s' % dt_range) + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" + response = self.app.get("/v1/executions?timestamp=%s" % dt_range) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), 10) - self.assertEqual(response.headers['X-Total-Count'], '10') - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[9]['start_timestamp'] + self.assertEqual(response.headers["X-Total-Count"], "10") + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[9]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_default_sort(self): - response = self.app.get('/v1/executions') + response = self.app.get("/v1/executions") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_ascending_sort(self): - response = self.app.get('/v1/executions?sort_asc=True') + response = self.app.get("/v1/executions?sort_asc=True") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt1), isotime.parse(dt2)) def test_descending_sort(self): - response = self.app.get('/v1/executions?sort_desc=True') + response = self.app.get("/v1/executions?sort_desc=True") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_timestamp_lt_and_gt_filter(self): @@ -335,57 +373,81 @@ def isoformat(timestamp): # Last (largest) timestamp, there are no executions with a greater timestamp timestamp = self.start_timestamps[-1] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 0) # First (smallest) timestamp, there are no executions with a smaller timestamp timestamp = self.start_timestamps[0] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 0) # Second last, there should be one timestamp greater than it timestamp = self.start_timestamps[-2] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 1) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp) # Second one, there should be one timestamp smaller than it timestamp = self.start_timestamps[1] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 1) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp) # Half of the timestamps should be smaller index = (len(self.start_timestamps) - 1) // 2 timestamp = self.start_timestamps[index] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), index) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp) # Half of the timestamps should be greater index = (len(self.start_timestamps) - 1) // 2 timestamp = self.start_timestamps[-index] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), (index - 1)) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp) # Both, lt and gt filters, should return exactly two results timestamp_gt = self.start_timestamps[10] timestamp_lt = self.start_timestamps[13] - response = self.app.get('/v1/executions?timestamp_gt=%s×tamp_lt=%s' % - (isoformat(timestamp_gt), isoformat(timestamp_lt))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s×tamp_lt=%s" + % (isoformat(timestamp_gt), isoformat(timestamp_lt)) + ) self.assertEqual(len(response.json), 2) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp_gt) - self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) > timestamp_gt) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp_lt) - self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) < timestamp_lt) + self.assertTrue( + isotime.parse(response.json[0]["start_timestamp"]) > timestamp_gt + ) + self.assertTrue( + isotime.parse(response.json[1]["start_timestamp"]) > timestamp_gt + ) + self.assertTrue( + isotime.parse(response.json[0]["start_timestamp"]) < timestamp_lt + ) + self.assertTrue( + isotime.parse(response.json[1]["start_timestamp"]) < timestamp_lt + ) def test_filters_view(self): - response = self.app.get('/v1/executions/views/filters') + response = self.app.get("/v1/executions/views/filters") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) - self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['default'])) - for key, value in six.iteritems(history_views.ARTIFACTS['filters']['default']): + self.assertEqual( + len(response.json), len(history_views.ARTIFACTS["filters"]["default"]) + ) + for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["default"]): filter_values = response.json[key] # Verify empty (None / null) filters are excluded @@ -399,9 +461,13 @@ def test_filters_view(self): self.assertEqual(set(filter_values), set(value)) def test_filters_view_specific_types(self): - response = self.app.get('/v1/executions/views/filters?types=action,user,nonexistent') + response = self.app.get( + "/v1/executions/views/filters?types=action,user,nonexistent" + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) - self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['specific'])) - for key, value in six.iteritems(history_views.ARTIFACTS['filters']['specific']): + self.assertEqual( + len(response.json), len(history_views.ARTIFACTS["filters"]["specific"]) + ) + for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["specific"]): self.assertEqual(set(response.json[key]), set(value)) diff --git a/st2api/tests/unit/controllers/v1/test_inquiries.py b/st2api/tests/unit/controllers/v1/test_inquiries.py index 469b1c8816..173acbe405 100644 --- a/st2api/tests/unit/controllers/v1/test_inquiries.py +++ b/st2api/tests/unit/controllers/v1/test_inquiries.py @@ -36,58 +36,50 @@ ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'testpack', - 'runner_type': 'local-shell-cmd', + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "testpack", + "runner_type": "local-shell-cmd", } LIVE_ACTION_1 = { - 'action': 'testpack.st2.dummy.action1', - 'parameters': { - 'cmd': 'uname -a' - } + "action": "testpack.st2.dummy.action1", + "parameters": {"cmd": "uname -a"}, } INQUIRY_ACTION = { - 'name': 'st2.dummy.ask', - 'description': 'test description', - 'enabled': True, - 'pack': 'testpack', - 'runner_type': 'inquirer', + "name": "st2.dummy.ask", + "description": "test description", + "enabled": True, + "pack": "testpack", + "runner_type": "inquirer", } INQUIRY_1 = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'pending', - 'parameters': {}, - 'context': { - 'parent': { - 'user': 'testu', - 'execution_id': '59b845e132ed350d396a798f', - 'pack': 'examples' + "action": "testpack.st2.dummy.ask", + "status": "pending", + "parameters": {}, + "context": { + "parent": { + "user": "testu", + "execution_id": "59b845e132ed350d396a798f", + "pack": "examples", }, - 'trace_context': {'trace_tag': 'balleilaka'} - } + "trace_context": {"trace_tag": "balleilaka"}, + }, } INQUIRY_2 = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'pending', - 'parameters': { - 'route': 'superlative', - 'users': ['foo', 'bar'] - } + "action": "testpack.st2.dummy.ask", + "status": "pending", + "parameters": {"route": "superlative", "users": ["foo", "bar"]}, } INQUIRY_TIMEOUT = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'timeout', - 'parameters': { - 'route': 'superlative', - 'users': ['foo', 'bar'] - } + "action": "testpack.st2.dummy.ask", + "status": "timeout", + "parameters": {"route": "superlative", "users": ["foo", "bar"]}, } SCHEMA_DEFAULT = { @@ -97,7 +89,7 @@ "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, } @@ -109,18 +101,18 @@ "name": { "type": "string", "description": "What is your name?", - "required": True + "required": True, }, "pin": { "type": "integer", "description": "What is your PIN?", - "required": True + "required": True, }, "paradox": { "type": "boolean", "description": "This statement is False.", - "required": True - } + "required": True, + }, }, } @@ -132,7 +124,7 @@ "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } RESULT_2 = { @@ -140,7 +132,7 @@ "roles": [], "users": ["foo", "bar"], "route": "superlative", - "ttl": 1440 + "ttl": 1440, } RESULT_MULTIPLE = { @@ -148,58 +140,51 @@ "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } -RESPONSE_MULTIPLE = { - "name": "matt", - "pin": 1234, - "paradox": True -} +RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True} ROOT_LIVEACTION_DB = lv_db_models.LiveActionDB( - id=uuid.uuid4().hex, - status=action_constants.LIVEACTION_STATUS_PAUSED + id=uuid.uuid4().hex, status=action_constants.LIVEACTION_STATUS_PAUSED ) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class InquiryControllerTestCase(BaseInquiryControllerTestCase, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/inquiries' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class InquiryControllerTestCase( + BaseInquiryControllerTestCase, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/inquiries" controller_cls = InquiriesController - include_attribute_field_name = 'ttl' - exclude_attribute_field_name = 'ttl' + include_attribute_field_name = "ttl" + exclude_attribute_field_name = "ttl" @mock.patch.object( - action_validator, - 'validate_action', - mock.MagicMock(return_value=True)) + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUp(cls): super(BaseInquiryControllerTestCase, cls).setUpClass() cls.inquiry1 = copy.deepcopy(INQUIRY_ACTION) - post_resp = cls.app.post_json('/v1/actions', cls.inquiry1) - cls.inquiry1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.inquiry1) + cls.inquiry1["id"] = post_resp.json["id"] cls.action1 = copy.deepcopy(ACTION_1) - post_resp = cls.app.post_json('/v1/actions', cls.action1) - cls.action1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action1) + cls.action1["id"] = post_resp.json["id"] def test_get_all(self): - """Test retrieval of a list of Inquiries - """ + """Test retrieval of a list of Inquiries""" inquiry_count = 5 for i in range(inquiry_count): self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) get_all_resp = self._do_get_all() inquiries = get_all_resp.json - self.assertEqual(get_all_resp.headers['X-Total-Count'], str(len(inquiries))) + self.assertEqual(get_all_resp.headers["X-Total-Count"], str(len(inquiries))) self.assertIsInstance(inquiries, list) self.assertEqual(len(inquiries), inquiry_count) def test_get_all_empty(self): - """Test retrieval of a list of Inquiries when there are none - """ + """Test retrieval of a list of Inquiries when there are none""" inquiry_count = 0 get_all_resp = self._do_get_all() inquiries = get_all_resp.json @@ -207,8 +192,7 @@ def test_get_all_empty(self): self.assertEqual(len(inquiries), inquiry_count) def test_get_all_decrease_after_respond(self): - """Test that the inquiry list decreases when we respond to one of them - """ + """Test that the inquiry list decreases when we respond to one of them""" # Create inquiries inquiry_count = 5 @@ -221,7 +205,7 @@ def test_get_all_decrease_after_respond(self): # Respond to one of them response = {"continue": True} - self._do_respond(inquiries[0].get('id'), response) + self._do_respond(inquiries[0].get("id"), response) # Ensure the list is one smaller get_all_resp = self._do_get_all() @@ -230,8 +214,7 @@ def test_get_all_decrease_after_respond(self): self.assertEqual(len(inquiries), inquiry_count - 1) def test_get_all_limit(self): - """Test that the limit parameter works correctly - """ + """Test that the limit parameter works correctly""" # Create inquiries inquiry_count = 5 @@ -241,12 +224,11 @@ def test_get_all_limit(self): get_all_resp = self._do_get_all(limit=limit) inquiries = get_all_resp.json self.assertIsInstance(inquiries, list) - self.assertEqual(inquiry_count, int(get_all_resp.headers['X-Total-Count'])) + self.assertEqual(inquiry_count, int(get_all_resp.headers["X-Total-Count"])) self.assertEqual(len(inquiries), limit) def test_get_one(self): - """Test retrieval of a single Inquiry - """ + """Test retrieval of a single Inquiry""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) get_resp = self._do_get_one(inquiry_id) @@ -254,24 +236,21 @@ def test_get_one(self): self.assertEqual(self._get_inquiry_id(get_resp), inquiry_id) def test_get_one_failed(self): - """Test failed retrieval of an Inquiry - """ - inquiry_id = 'asdfeoijasdf' + """Test failed retrieval of an Inquiry""" + inquiry_id = "asdfeoijasdf" get_resp = self._do_get_one(inquiry_id, expect_errors=True) self.assertEqual(get_resp.status_int, http_client.NOT_FOUND) - self.assertIn('resource could not be found', get_resp.json['faultstring']) + self.assertIn("resource could not be found", get_resp.json["faultstring"]) def test_get_one_not_an_inquiry(self): - """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails - """ - test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body) - get_resp = self._do_get_one(test_exec.get('id'), expect_errors=True) + """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails""" + test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body) + get_resp = self._do_get_one(test_exec.get("id"), expect_errors=True) self.assertEqual(get_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('is not an inquiry', get_resp.json['faultstring']) + self.assertIn("is not an inquiry", get_resp.json["faultstring"]) def test_get_one_nondefault_params(self): - """Ensure an Inquiry with custom parameters contains those in result - """ + """Ensure an Inquiry with custom parameters contains those in result""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_2) inquiry_id = self._get_inquiry_id(post_resp) get_resp = self._do_get_one(inquiry_id) @@ -282,14 +261,15 @@ def test_get_one_nondefault_params(self): self.assertEqual(get_resp.json.get(param), RESULT_2.get(param)) @mock.patch.object( - action_service, 'get_root_liveaction', - mock.MagicMock(return_value=ROOT_LIVEACTION_DB)) + action_service, + "get_root_liveaction", + mock.MagicMock(return_value=ROOT_LIVEACTION_DB), + ) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond(self): - """Test that a correct response is successful - """ + """Test that a correct response is successful""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -300,21 +280,22 @@ def test_respond(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # This Inquiry is in a workflow, so has a parent. Assert that the resume # was requested for this parent. action_service.request_resume.assert_called_once() @mock.patch.object( - action_service, 'get_root_liveaction', - mock.MagicMock(return_value=ROOT_LIVEACTION_DB)) + action_service, + "get_root_liveaction", + mock.MagicMock(return_value=ROOT_LIVEACTION_DB), + ) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond_multiple(self): - """Test that a more complicated response is successful - """ + """Test that a more complicated response is successful""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_MULTIPLE) inquiry_id = self._get_inquiry_id(post_resp) @@ -324,38 +305,35 @@ def test_respond_multiple(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # This Inquiry is in a workflow, so has a parent. Assert that the resume # was requested for this parent. action_service.request_resume.assert_called_once() def test_respond_fail(self): - """Test that an incorrect response is unsuccessful - """ + """Test that an incorrect response is unsuccessful""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) response = {"continue": 123} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('did not pass schema validation', put_resp.json['faultstring']) + self.assertIn("did not pass schema validation", put_resp.json["faultstring"]) def test_respond_not_an_inquiry(self): - """Test that attempts to respond to an execution ID that isn't an Inquiry fails - """ - test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body) + """Test that attempts to respond to an execution ID that isn't an Inquiry fails""" + test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body) response = {"continue": 123} - put_resp = self._do_respond(test_exec.get('id'), response, expect_errors=True) + put_resp = self._do_respond(test_exec.get("id"), response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('is not an inquiry', put_resp.json['faultstring']) + self.assertIn("is not an inquiry", put_resp.json["faultstring"]) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond_no_parent(self): - """Test that a resume was not requested for an Inquiry without a parent - """ + """Test that a resume was not requested for an Inquiry without a parent""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -365,8 +343,7 @@ def test_respond_no_parent(self): action_service.request_resume.assert_not_called() def test_respond_duplicate_rejected(self): - """Test that responding to an already-responded Inquiry fails - """ + """Test that responding to an already-responded Inquiry fails""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -377,28 +354,30 @@ def test_respond_duplicate_rejected(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # A second, equivalent response attempt should not succeed, since the Inquiry # has already been successfully responded to put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('has already been responded to', put_resp.json['faultstring']) + self.assertIn("has already been responded to", put_resp.json["faultstring"]) def test_respond_timeout_rejected(self): - """Test that responding to a timed-out Inquiry fails - """ + """Test that responding to a timed-out Inquiry fails""" - post_resp = self._do_create_inquiry(INQUIRY_TIMEOUT, RESULT_DEFAULT, status='timeout') + post_resp = self._do_create_inquiry( + INQUIRY_TIMEOUT, RESULT_DEFAULT, status="timeout" + ) inquiry_id = self._get_inquiry_id(post_resp) response = {"continue": True} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('timed out and cannot be responded to', put_resp.json['faultstring']) + self.assertIn( + "timed out and cannot be responded to", put_resp.json["faultstring"] + ) def test_respond_restrict_users(self): - """Test that Inquiries can reject responses from users not in a list - """ + """Test that Inquiries can reject responses from users not in a list""" # Default user for tests is "stanley", which is not in the 'users' list # Should be rejected @@ -407,7 +386,9 @@ def test_respond_restrict_users(self): response = {"continue": True} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.FORBIDDEN) - self.assertIn('does not have permission to respond', put_resp.json['faultstring']) + self.assertIn( + "does not have permission to respond", put_resp.json["faultstring"] + ) # Responding as a use in the list should be accepted old_user = cfg.CONF.system_user.user @@ -425,8 +406,8 @@ def test_get_all_invalid_exclude_and_include_parameter(self): pass def _insert_mock_models(self): - id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id'] - id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id'] + id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"] + id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"] return [id_1, id_2] diff --git a/st2api/tests/unit/controllers/v1/test_kvps.py b/st2api/tests/unit/controllers/v1/test_kvps.py index 06103134bd..61a903a3ad 100644 --- a/st2api/tests/unit/controllers/v1/test_kvps.py +++ b/st2api/tests/unit/controllers/v1/test_kvps.py @@ -21,83 +21,66 @@ from six.moves import http_client -__all__ = [ - 'KeyValuePairControllerTestCase' -] +__all__ = ["KeyValuePairControllerTestCase"] -KVP = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.0.1:5000/v3' -} +KVP = {"name": "keystone_endpoint", "value": "http://127.0.0.1:5000/v3"} -KVP_2 = { - 'name': 'keystone_version', - 'value': 'v3' -} +KVP_2 = {"name": "keystone_version", "value": "v3"} -KVP_2_USER = { - 'name': 'keystone_version', - 'value': 'user_v3', - 'scope': 'st2kv.user' -} +KVP_2_USER = {"name": "keystone_version", "value": "user_v3", "scope": "st2kv.user"} -KVP_2_USER_LEGACY = { - 'name': 'keystone_version', - 'value': 'user_v3', - 'scope': 'user' -} +KVP_2_USER_LEGACY = {"name": "keystone_version", "value": "user_v3", "scope": "user"} KVP_3_USER = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.1.1:5000/v3', - 'scope': 'st2kv.user' + "name": "keystone_endpoint", + "value": "http://127.0.1.1:5000/v3", + "scope": "st2kv.user", } KVP_4_USER = { - 'name': 'customer_ssn', - 'value': '123-456-7890', - 'secret': True, - 'scope': 'st2kv.user' + "name": "customer_ssn", + "value": "123-456-7890", + "secret": True, + "scope": "st2kv.user", } KVP_WITH_TTL = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.0.1:5000/v3', - 'ttl': 10 + "name": "keystone_endpoint", + "value": "http://127.0.0.1:5000/v3", + "ttl": 10, } -SECRET_KVP = { - 'name': 'secret_key1', - 'value': 'secret_value1', - 'secret': True -} +SECRET_KVP = {"name": "secret_key1", "value": "secret_value1", "secret": True} # value = S3cret!Value # encrypted with st2tests/conf/st2_kvstore_tests.crypto.key.json ENCRYPTED_KVP = { - 'name': 'secret_key1', - 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E' - 'B30170DACF79498F30520236A629912C3584847098D'), - 'encrypted': True + "name": "secret_key1", + "value": ( + "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E" + "B30170DACF79498F30520236A629912C3584847098D" + ), + "encrypted": True, } ENCRYPTED_KVP_SECRET_FALSE = { - 'name': 'secret_key2', - 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E' - 'B30170DACF79498F30520236A629912C3584847098D'), - 'secret': True, - 'encrypted': True + "name": "secret_key2", + "value": ( + "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E" + "B30170DACF79498F30520236A629912C3584847098D" + ), + "secret": True, + "encrypted": True, } class KeyValuePairControllerTestCase(FunctionalTest): - def test_get_all(self): - resp = self.app.get('/v1/keys') + resp = self.app.get("/v1/keys") self.assertEqual(resp.status_int, 200) def test_get_one(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) kvp_id = self.__get_kvp_id(put_resp) get_resp = self.__do_get_one(kvp_id) self.assertEqual(get_resp.status_int, 200) @@ -107,484 +90,534 @@ def test_get_one(self): def test_get_all_all_scope(self): # Test which cases various scenarios which ensure non-admin users can't read / view keys # from other users - user_db_1 = UserDB(name='user1') - user_db_2 = UserDB(name='user2') - user_db_3 = UserDB(name='user3') + user_db_1 = UserDB(name="user1") + user_db_2 = UserDB(name="user2") + user_db_3 = UserDB(name="user3") # Insert some mock data # System scoped keys - put_resp = self.__do_put('system1', {'name': 'system1', 'value': 'val1', - 'scope': 'st2kv.system'}) + put_resp = self.__do_put( + "system1", {"name": "system1", "value": "val1", "scope": "st2kv.system"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'system1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') + self.assertEqual(put_resp.json["name"], "system1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") - put_resp = self.__do_put('system2', {'name': 'system2', 'value': 'val2', - 'scope': 'st2kv.system'}) + put_resp = self.__do_put( + "system2", {"name": "system2", "value": "val2", "scope": "st2kv.system"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'system2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') + self.assertEqual(put_resp.json["name"], "system2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") # user1 scoped keys self.use_user(user_db_1) - put_resp = self.__do_put('user1', {'name': 'user1', 'value': 'user1', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user1", {"name": "user1", "value": "user1", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user1') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user1') + self.assertEqual(put_resp.json["name"], "user1") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user1") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user1', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user1", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user1') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user1") # user2 scoped keys self.use_user(user_db_2) - put_resp = self.__do_put('user2', {'name': 'user2', 'value': 'user2', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user2", {"name": "user2", "value": "user2", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user2') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user2') + self.assertEqual(put_resp.json["name"], "user2") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user2") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user2', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user2", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user2') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user2") # user3 scoped keys self.use_user(user_db_3) - put_resp = self.__do_put('user3', {'name': 'user3', 'value': 'user3', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user3", {"name": "user3", "value": "user3", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user3') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user3') + self.assertEqual(put_resp.json["name"], "user3") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user3") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user3', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user3", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user3') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user3") # 1. "all" scope as user1 - should only be able to view system + current user items self.use_user(user_db_1) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user1') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user1') + self.assertEqual(resp.json[2]["name"], "user1") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user1") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user1') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user1") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user2:') + resp = self.app.get("/v1/keys?scope=all&prefix=user2:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user1') + self.assertEqual(resp.json[0]["name"], "user1") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user1") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user1') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user1") # 2. "all" scope user user2 - should only be able to view system + current user items self.use_user(user_db_2) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user2') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user2') + self.assertEqual(resp.json[2]["name"], "user2") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user2") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user2') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user2") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user1:') + resp = self.app.get("/v1/keys?scope=all&prefix=user1:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user2') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user2') + self.assertEqual(resp.json[0]["name"], "user2") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user2") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user2') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user2") # Verify non-admon user can't retrieve key for an arbitrary users - resp = self.app.get('/v1/keys?scope=user&user=user1', expect_errors=True) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + resp = self.app.get("/v1/keys?scope=user&user=user1", expect_errors=True) + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) # 3. "all" scope user user3 - should only be able to view system + current user items self.use_user(user_db_3) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user3') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user3') + self.assertEqual(resp.json[2]["name"], "user3") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user3") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user3') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user3") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user1:') + resp = self.app.get("/v1/keys?scope=all&prefix=user1:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user3') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user3') + self.assertEqual(resp.json[0]["name"], "user3") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user3") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user3') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user3") # Clean up - self.__do_delete('system1') - self.__do_delete('system2') + self.__do_delete("system1") + self.__do_delete("system2") self.use_user(user_db_1) - self.__do_delete('user1?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user1?scope=user") + self.__do_delete("userkey?scope=user") self.use_user(user_db_2) - self.__do_delete('user2?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user2?scope=user") + self.__do_delete("userkey?scope=user") self.use_user(user_db_3) - self.__do_delete('user3?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user3?scope=user") + self.__do_delete("userkey?scope=user") def test_get_all_user_query_param_can_only_be_used_with_rbac(self): - resp = self.app.get('/v1/keys?user=foousera', expect_errors=True) + resp = self.app.get("/v1/keys?user=foousera", expect_errors=True) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) def test_get_one_user_query_param_can_only_be_used_with_rbac(self): - resp = self.app.get('/v1/keys/keystone_endpoint?user=foousera', expect_errors=True) + resp = self.app.get( + "/v1/keys/keystone_endpoint?user=foousera", expect_errors=True + ) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) def test_get_all_prefix_filtering(self): - put_resp1 = self.__do_put(KVP['name'], KVP) - put_resp2 = self.__do_put(KVP_2['name'], KVP_2) + put_resp1 = self.__do_put(KVP["name"], KVP) + put_resp2 = self.__do_put(KVP_2["name"], KVP_2) self.assertEqual(put_resp1.status_int, 200) self.assertEqual(put_resp2.status_int, 200) # No keys with that prefix - resp = self.app.get('/v1/keys?prefix=something') + resp = self.app.get("/v1/keys?prefix=something") self.assertEqual(resp.json, []) # Two keys with the provided prefix - resp = self.app.get('/v1/keys?prefix=keystone') + resp = self.app.get("/v1/keys?prefix=keystone") self.assertEqual(len(resp.json), 2) # One key with the provided prefix - resp = self.app.get('/v1/keys?prefix=keystone_endpoint') + resp = self.app.get("/v1/keys?prefix=keystone_endpoint") self.assertEqual(len(resp.json), 1) self.__do_delete(self.__get_kvp_id(put_resp1)) self.__do_delete(self.__get_kvp_id(put_resp2)) def test_get_one_fail(self): - resp = self.app.get('/v1/keys/1', expect_errors=True) + resp = self.app.get("/v1/keys/1", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_put(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) update_input = put_resp.json - update_input['value'] = 'http://127.0.0.1:35357/v3' + update_input["value"] = "http://127.0.0.1:35357/v3" put_resp = self.__do_put(self.__get_kvp_id(put_resp), update_input) self.assertEqual(put_resp.status_int, 200) self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_with_scope(self): - self.app.put_json('/v1/keys/%s' % 'keystone_endpoint', KVP, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - - get_resp_1 = self.app.get('/v1/keys/keystone_endpoint') + self.app.put_json("/v1/keys/%s" % "keystone_endpoint", KVP, expect_errors=False) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + + get_resp_1 = self.app.get("/v1/keys/keystone_endpoint") self.assertTrue(get_resp_1.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_1), 'keystone_endpoint') - get_resp_2 = self.app.get('/v1/keys/keystone_version?scope=st2kv.system') + self.assertEqual(self.__get_kvp_id(get_resp_1), "keystone_endpoint") + get_resp_2 = self.app.get("/v1/keys/keystone_version?scope=st2kv.system") self.assertTrue(get_resp_2.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_2), 'keystone_version') - get_resp_3 = self.app.get('/v1/keys/keystone_version') + self.assertEqual(self.__get_kvp_id(get_resp_2), "keystone_version") + get_resp_3 = self.app.get("/v1/keys/keystone_version") self.assertTrue(get_resp_3.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_3), 'keystone_version') - self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') + self.assertEqual(self.__get_kvp_id(get_resp_3), "keystone_version") + self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") def test_put_user_scope_and_system_scope_dont_overlap(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) - get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.system') - self.assertEqual(get_resp.json['value'], KVP_2['value']) - - get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.user') - self.assertEqual(get_resp.json['value'], KVP_2_USER['value']) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) + get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.system") + self.assertEqual(get_resp.json["value"], KVP_2["value"]) + + get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.user") + self.assertEqual(get_resp.json["value"], KVP_2_USER["value"]) + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") def test_put_invalid_scope(self): - put_resp = self.app.put_json('/v1/keys/keystone_version?scope=st2', KVP_2, - expect_errors=True) + put_resp = self.app.put_json( + "/v1/keys/keystone_version?scope=st2", KVP_2, expect_errors=True + ) self.assertTrue(put_resp.status_int, 400) def test_get_all_with_scope(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) # Note that the following two calls overwrite st2sytem and st2kv.user scoped variables with # same name. - self.app.put_json('/v1/keys/%s?scope=system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=user' % 'keystone_version', KVP_2_USER_LEGACY, - expect_errors=False) - - get_resp_all = self.app.get('/v1/keys?scope=all') + self.app.put_json( + "/v1/keys/%s?scope=system" % "keystone_version", KVP_2, expect_errors=False + ) + self.app.put_json( + "/v1/keys/%s?scope=user" % "keystone_version", + KVP_2_USER_LEGACY, + expect_errors=False, + ) + + get_resp_all = self.app.get("/v1/keys?scope=all") self.assertTrue(len(get_resp_all.json), 2) - get_resp_sys = self.app.get('/v1/keys?scope=st2kv.system') + get_resp_sys = self.app.get("/v1/keys?scope=st2kv.system") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=system') + get_resp_sys = self.app.get("/v1/keys?scope=system") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=st2kv.user') + get_resp_sys = self.app.get("/v1/keys?scope=st2kv.user") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=user') + get_resp_sys = self.app.get("/v1/keys?scope=user") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"]) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") def test_get_all_with_scope_and_prefix_filtering(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_endpoint', KVP_3_USER, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'customer_ssn', KVP_4_USER, - expect_errors=False) - get_prefix = self.app.get('/v1/keys?scope=st2kv.user&prefix=keystone') + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_endpoint", + KVP_3_USER, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "customer_ssn", + KVP_4_USER, + expect_errors=False, + ) + get_prefix = self.app.get("/v1/keys?scope=st2kv.user&prefix=keystone") self.assertEqual(len(get_prefix.json), 2) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') - self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.user') - self.app.delete('/v1/keys/customer_ssn?scope=st2kv.user') + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") + self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.user") + self.app.delete("/v1/keys/customer_ssn?scope=st2kv.user") def test_put_with_ttl(self): - put_resp = self.__do_put('key_with_ttl', KVP_WITH_TTL) + put_resp = self.__do_put("key_with_ttl", KVP_WITH_TTL) self.assertEqual(put_resp.status_int, 200) - get_resp = self.app.get('/v1/keys') - self.assertTrue(get_resp.json[0]['expire_timestamp']) + get_resp = self.app.get("/v1/keys") + self.assertTrue(get_resp.json[0]["expire_timestamp"]) self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_secret(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) get_resp = self.__do_get_one(kvp_id) - self.assertTrue(get_resp.json['encrypted']) - crypto_val = get_resp.json['value'] - self.assertNotEqual(SECRET_KVP['value'], crypto_val) + self.assertTrue(get_resp.json["encrypted"]) + crypto_val = get_resp.json["value"] + self.assertNotEqual(SECRET_KVP["value"], crypto_val) self.__do_delete(self.__get_kvp_id(put_resp)) def test_get_one_secret_no_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) - get_resp = self.app.get('/v1/keys/secret_key1') + get_resp = self.app.get("/v1/keys/secret_key1") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_kvp_id(get_resp), kvp_id) - self.assertTrue(get_resp.json['secret']) - self.assertTrue(get_resp.json['encrypted']) + self.assertTrue(get_resp.json["secret"]) + self.assertTrue(get_resp.json["encrypted"]) self.__do_delete(kvp_id) def test_get_one_secret_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) - get_resp = self.app.get('/v1/keys/secret_key1?decrypt=true') + get_resp = self.app.get("/v1/keys/secret_key1?decrypt=true") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_kvp_id(get_resp), kvp_id) - self.assertTrue(get_resp.json['secret']) - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], SECRET_KVP['value']) + self.assertTrue(get_resp.json["secret"]) + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], SECRET_KVP["value"]) self.__do_delete(kvp_id) def test_get_all_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id_1 = self.__get_kvp_id(put_resp) - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) kvp_id_2 = self.__get_kvp_id(put_resp) - kvps = {'key1': KVP, 'secret_key1': SECRET_KVP} - stored_kvps = self.app.get('/v1/keys?decrypt=true').json + kvps = {"key1": KVP, "secret_key1": SECRET_KVP} + stored_kvps = self.app.get("/v1/keys?decrypt=true").json self.assertTrue(len(stored_kvps), 2) for stored_kvp in stored_kvps: - self.assertFalse(stored_kvp['encrypted']) - exp_kvp = kvps.get(stored_kvp['name']) + self.assertFalse(stored_kvp["encrypted"]) + exp_kvp = kvps.get(stored_kvp["name"]) self.assertIsNotNone(exp_kvp) - self.assertEqual(exp_kvp['value'], stored_kvp['value']) + self.assertEqual(exp_kvp["value"], stored_kvp["value"]) self.__do_delete(kvp_id_1) self.__do_delete(kvp_id_2) def test_put_encrypted_value(self): # 1. encrypted=True, secret=True - put_resp = self.__do_put('secret_key1', ENCRYPTED_KVP) + put_resp = self.__do_put("secret_key1", ENCRYPTED_KVP) kvp_id = self.__get_kvp_id(put_resp) # Verify there is no secrets leakage self.assertEqual(put_resp.status_code, 200) - self.assertEqual(put_resp.json['name'], 'secret_key1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) - self.assertTrue(put_resp.json['value'] != 'S3cret!Value') - self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2) - - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertEqual(put_resp.json['name'], 'secret_key1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) + self.assertEqual(put_resp.json["name"], "secret_key1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) + self.assertTrue(put_resp.json["value"] != "S3cret!Value") + self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2) + + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertEqual(put_resp.json["name"], "secret_key1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) # Verify data integrity post decryption - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], 'S3cret!Value') + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], "S3cret!Value") self.__do_delete(self.__get_kvp_id(put_resp)) # 2. encrypted=True, secret=False # encrypted should always imply secret=True - put_resp = self.__do_put('secret_key2', ENCRYPTED_KVP_SECRET_FALSE) + put_resp = self.__do_put("secret_key2", ENCRYPTED_KVP_SECRET_FALSE) kvp_id = self.__get_kvp_id(put_resp) # Verify there is no secrets leakage self.assertEqual(put_resp.status_code, 200) - self.assertEqual(put_resp.json['name'], 'secret_key2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) - self.assertTrue(put_resp.json['value'] != 'S3cret!Value') - self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2) - - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertEqual(put_resp.json['name'], 'secret_key2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) + self.assertEqual(put_resp.json["name"], "secret_key2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) + self.assertTrue(put_resp.json["value"] != "S3cret!Value") + self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2) + + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertEqual(put_resp.json["name"], "secret_key2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) # Verify data integrity post decryption - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], 'S3cret!Value') + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], "S3cret!Value") self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_encrypted_value_integrity_check_failed(self): data = copy.deepcopy(ENCRYPTED_KVP) - data['value'] = 'corrupted' - put_resp = self.__do_put('secret_key1', data, expect_errors=True) + data["value"] = "corrupted" + put_resp = self.__do_put("secret_key1", data, expect_errors=True) self.assertEqual(put_resp.status_code, 400) - expected_error = ('Failed to verify the integrity of the provided value for key ' - '"secret_key1".') - self.assertIn(expected_error, put_resp.json['faultstring']) + expected_error = ( + "Failed to verify the integrity of the provided value for key " + '"secret_key1".' + ) + self.assertIn(expected_error, put_resp.json["faultstring"]) data = copy.deepcopy(ENCRYPTED_KVP) - data['value'] = str(data['value'][:-2]) - put_resp = self.__do_put('secret_key1', data, expect_errors=True) + data["value"] = str(data["value"][:-2]) + put_resp = self.__do_put("secret_key1", data, expect_errors=True) self.assertEqual(put_resp.status_code, 400) - expected_error = ('Failed to verify the integrity of the provided value for key ' - '"secret_key1".') - self.assertIn(expected_error, put_resp.json['faultstring']) + expected_error = ( + "Failed to verify the integrity of the provided value for key " + '"secret_key1".' + ) + self.assertIn(expected_error, put_resp.json["faultstring"]) def test_put_delete(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) self.assertEqual(put_resp.status_int, 200) self.__do_delete(self.__get_kvp_id(put_resp)) def test_delete(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) del_resp = self.__do_delete(self.__get_kvp_id(put_resp)) self.assertEqual(del_resp.status_int, 204) def test_delete_fail(self): - resp = self.__do_delete('inexistentkey', expect_errors=True) + resp = self.__do_delete("inexistentkey", expect_errors=True) self.assertEqual(resp.status_int, 404) @staticmethod def __get_kvp_id(resp): - return resp.json['name'] + return resp.json["name"] def __do_get_one(self, kvp_id, expect_errors=False): - return self.app.get('/v1/keys/%s' % kvp_id, expect_errors=expect_errors) + return self.app.get("/v1/keys/%s" % kvp_id, expect_errors=expect_errors) def __do_put(self, kvp_id, kvp, expect_errors=False): - return self.app.put_json('/v1/keys/%s' % kvp_id, kvp, expect_errors=expect_errors) + return self.app.put_json( + "/v1/keys/%s" % kvp_id, kvp, expect_errors=expect_errors + ) def __do_delete(self, kvp_id, expect_errors=False): - return self.app.delete('/v1/keys/%s' % kvp_id, expect_errors=expect_errors) + return self.app.delete("/v1/keys/%s" % kvp_id, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py index bff5935e38..a38c278f07 100644 --- a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py +++ b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py @@ -19,12 +19,10 @@ from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PackConfigSchemasControllerTestCase' -] +__all__ = ["PackConfigSchemasControllerTestCase"] PACKS_PATH = get_fixtures_packs_base_path() -CONFIG_SCHEMA_COUNT = len(glob.glob('%s/*/config.schema.yaml' % (PACKS_PATH))) +CONFIG_SCHEMA_COUNT = len(glob.glob("%s/*/config.schema.yaml" % (PACKS_PATH))) assert CONFIG_SCHEMA_COUNT > 1 @@ -32,29 +30,34 @@ class PackConfigSchemasControllerTestCase(FunctionalTest): register_packs = True def test_get_all(self): - resp = self.app.get('/v1/config_schemas') + resp = self.app.get("/v1/config_schemas") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), CONFIG_SCHEMA_COUNT, - '/v1/config_schemas did not return all schemas.') + self.assertEqual( + len(resp.json), + CONFIG_SCHEMA_COUNT, + "/v1/config_schemas did not return all schemas.", + ) def test_get_one_success(self): - resp = self.app.get('/v1/config_schemas/dummy_pack_1') + resp = self.app.get("/v1/config_schemas/dummy_pack_1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['pack'], 'dummy_pack_1') - self.assertIn('api_key', resp.json['attributes']) + self.assertEqual(resp.json["pack"], "dummy_pack_1") + self.assertIn("api_key", resp.json["attributes"]) def test_get_one_doesnt_exist(self): # Pack exists, schema doesnt - resp = self.app.get('/v1/config_schemas/dummy_pack_2', - expect_errors=True) + resp = self.app.get("/v1/config_schemas/dummy_pack_2", expect_errors=True) self.assertEqual(resp.status_int, 404) - self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref ", resp.json["faultstring"] + ) # Pack doesn't exist - ref_or_id = 'pack_doesnt_exist' - resp = self.app.get('/v1/config_schemas/%s' % ref_or_id, - expect_errors=True) + ref_or_id = "pack_doesnt_exist" + resp = self.app.get("/v1/config_schemas/%s" % ref_or_id, expect_errors=True) self.assertEqual(resp.status_int, 404) # Changed from: 'Unable to find the PackDB instance' - self.assertTrue('Resource with a ref or id "%s" not found' % ref_or_id in - resp.json['faultstring']) + self.assertTrue( + 'Resource with a ref or id "%s" not found' % ref_or_id + in resp.json["faultstring"] + ) diff --git a/st2api/tests/unit/controllers/v1/test_pack_configs.py b/st2api/tests/unit/controllers/v1/test_pack_configs.py index 6e789c413a..5a87719eaa 100644 --- a/st2api/tests/unit/controllers/v1/test_pack_configs.py +++ b/st2api/tests/unit/controllers/v1/test_pack_configs.py @@ -21,12 +21,10 @@ from st2api.controllers.v1.pack_configs import PackConfigsController from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PackConfigsControllerTestCase' -] +__all__ = ["PackConfigsControllerTestCase"] PACKS_PATH = get_fixtures_packs_base_path() -CONFIGS_COUNT = len(glob.glob('%s/configs/*.yaml' % (PACKS_PATH))) +CONFIGS_COUNT = len(glob.glob("%s/configs/*.yaml" % (PACKS_PATH))) assert CONFIGS_COUNT > 1 @@ -35,60 +33,80 @@ class PackConfigsControllerTestCase(FunctionalTest): register_pack_configs = True def test_get_all(self): - resp = self.app.get('/v1/configs') + resp = self.app.get("/v1/configs") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), CONFIGS_COUNT, '/v1/configs did not return all configs.') + self.assertEqual( + len(resp.json), CONFIGS_COUNT, "/v1/configs did not return all configs." + ) def test_get_one_success(self): - resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True}, - expect_errors=True) + resp = self.app.get( + "/v1/configs/dummy_pack_1", + params={"show_secrets": True}, + expect_errors=True, + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['pack'], 'dummy_pack_1') - self.assertEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}') - self.assertEqual(resp.json['values']['region'], 'us-west-1') + self.assertEqual(resp.json["pack"], "dummy_pack_1") + self.assertEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}") + self.assertEqual(resp.json["values"]["region"], "us-west-1") def test_get_one_mask_secret(self): - resp = self.app.get('/v1/configs/dummy_pack_1') + resp = self.app.get("/v1/configs/dummy_pack_1") self.assertEqual(resp.status_int, 200) - self.assertNotEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}') + self.assertNotEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}") def test_get_one_pack_config_doesnt_exist(self): # Pack exists, config doesnt - resp = self.app.get('/v1/configs/dummy_pack_2', - expect_errors=True) + resp = self.app.get("/v1/configs/dummy_pack_2", expect_errors=True) self.assertEqual(resp.status_int, 404) - self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref ", resp.json["faultstring"] + ) # Pack doesn't exist - resp = self.app.get('/v1/configs/pack_doesnt_exist', - expect_errors=True) + resp = self.app.get("/v1/configs/pack_doesnt_exist", expect_errors=True) self.assertEqual(resp.status_int, 404) # Changed from : 'Unable to find the PackDB instance.' - self.assertIn('Unable to identify resource with pack_ref', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref", resp.json["faultstring"] + ) - @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock()) + @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock()) def test_put_pack_config(self): - get_resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True}, - expect_errors=True) - config = copy.copy(get_resp.json['values']) - config['region'] = 'us-west-2' + get_resp = self.app.get( + "/v1/configs/dummy_pack_1", + params={"show_secrets": True}, + expect_errors=True, + ) + config = copy.copy(get_resp.json["values"]) + config["region"] = "us-west-2" - put_resp = self.app.put_json('/v1/configs/dummy_pack_1', config) + put_resp = self.app.put_json("/v1/configs/dummy_pack_1", config) self.assertEqual(put_resp.status_int, 200) - put_resp_undo = self.app.put_json('/v1/configs/dummy_pack_1?show_secrets=true', - get_resp.json['values'], expect_errors=True) + put_resp_undo = self.app.put_json( + "/v1/configs/dummy_pack_1?show_secrets=true", + get_resp.json["values"], + expect_errors=True, + ) self.assertEqual(put_resp.status_int, 200) self.assertEqual(get_resp.json, put_resp_undo.json) - @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock()) + @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock()) def test_put_invalid_pack_config(self): - get_resp = self.app.get('/v1/configs/dummy_pack_11', params={'show_secrets': True}, - expect_errors=True) - config = copy.copy(get_resp.json['values']) - put_resp = self.app.put_json('/v1/configs/dummy_pack_11', config, expect_errors=True) + get_resp = self.app.get( + "/v1/configs/dummy_pack_11", + params={"show_secrets": True}, + expect_errors=True, + ) + config = copy.copy(get_resp.json["values"]) + put_resp = self.app.put_json( + "/v1/configs/dummy_pack_11", config, expect_errors=True + ) self.assertEqual(put_resp.status_int, 400) - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - self.assertIn(expected_msg, put_resp.json['faultstring']) + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + self.assertIn(expected_msg, put_resp.json["faultstring"]) diff --git a/st2api/tests/unit/controllers/v1/test_packs.py b/st2api/tests/unit/controllers/v1/test_packs.py index 9406a50af0..07cacd0be8 100644 --- a/st2api/tests/unit/controllers/v1/test_packs.py +++ b/st2api/tests/unit/controllers/v1/test_packs.py @@ -33,9 +33,7 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'PacksControllerTestCase' -] +__all__ = ["PacksControllerTestCase"] PACK_INDEX = { "test": { @@ -45,7 +43,7 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", }, "test2": { "version": "0.5.0", @@ -54,13 +52,13 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" - } + "description": "another st2 pack to test package management pipeline", + }, } PACK_INDEXES = { - 'http://main.example.com': PACK_INDEX, - 'http://fallback.example.com': { + "http://main.example.com": PACK_INDEX, + "http://fallback.example.com": { "test": { "version": "0.1.0", "name": "test", @@ -68,10 +66,10 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", } }, - 'http://override.example.com': { + "http://override.example.com": { "test2": { "version": "1.0.0", "name": "test2", @@ -79,10 +77,12 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", } }, - 'http://broken.example.com': requests.exceptions.RequestException('index is broken') + "http://broken.example.com": requests.exceptions.RequestException( + "index is broken" + ), } @@ -93,10 +93,7 @@ def mock_index_get(url, *args, **kwargs): raise index status = 200 - content = { - 'metadata': {}, - 'packs': index - } + content = {"metadata": {}, "packs": index} # Return mock response object @@ -104,311 +101,371 @@ def mock_index_get(url, *args, **kwargs): mock_resp.raise_for_status = mock.Mock() mock_resp.status_code = status mock_resp.content = content - mock_resp.json = mock.Mock( - return_value=content - ) + mock_resp.json = mock.Mock(return_value=content) return mock_resp -class PacksControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/packs' +class PacksControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/packs" controller_cls = PacksController - include_attribute_field_name = 'version' - exclude_attribute_field_name = 'author' + include_attribute_field_name = "version" + exclude_attribute_field_name = "author" @classmethod def setUpClass(cls): super(PacksControllerTestCase, cls).setUpClass() - cls.pack_db_1 = PackDB(name='pack1', description='foo', version='0.1.0', author='foo', - email='test@example.com', ref='pack1') - cls.pack_db_2 = PackDB(name='pack2', description='foo', version='0.1.0', author='foo', - email='test@example.com', ref='pack2') - cls.pack_db_3 = PackDB(name='pack3-name', ref='pack3-ref', description='foo', - version='0.1.0', author='foo', - email='test@example.com') + cls.pack_db_1 = PackDB( + name="pack1", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ref="pack1", + ) + cls.pack_db_2 = PackDB( + name="pack2", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ref="pack2", + ) + cls.pack_db_3 = PackDB( + name="pack3-name", + ref="pack3-ref", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ) Pack.add_or_update(cls.pack_db_1) Pack.add_or_update(cls.pack_db_2) Pack.add_or_update(cls.pack_db_3) def test_get_all(self): - resp = self.app.get('/v1/packs') + resp = self.app.get("/v1/packs") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/actionalias did not return all packs.') + self.assertEqual(len(resp.json), 3, "/v1/actionalias did not return all packs.") def test_get_one(self): # Get by id - resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.id)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['name'], self.pack_db_1.name) + self.assertEqual(resp.json["name"], self.pack_db_1.name) # Get by name - resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.ref)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.ref)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['ref'], self.pack_db_1.ref) - self.assertEqual(resp.json['name'], self.pack_db_1.name) + self.assertEqual(resp.json["ref"], self.pack_db_1.ref) + self.assertEqual(resp.json["name"], self.pack_db_1.name) # Get by ref (ref != name) - resp = self.app.get('/v1/packs/%s' % (self.pack_db_3.ref)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_3.ref)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['ref'], self.pack_db_3.ref) + self.assertEqual(resp.json["ref"], self.pack_db_3.ref) def test_get_one_doesnt_exist(self): - resp = self.app.get('/v1/packs/doesntexistfoo', expect_errors=True) + resp = self.app.get("/v1/packs/doesntexistfoo", expect_errors=True) self.assertEqual(resp.status_int, 404) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some']} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"]} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install_with_force_parameter(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some'], 'force': True} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"], "force": True} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install_with_skip_dependencies_parameter(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some'], 'skip_dependencies': True} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"], "skip_dependencies": True} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_uninstall(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some']} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"]} - resp = self.app.post_json('/v1/packs/uninstall', payload) + resp = self.app.post_json("/v1/packs/uninstall", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(pack_service, 'fetch_pack_index', - mock.MagicMock(return_value=(PACK_INDEX, {}))) + @mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) + ) def test_search_with_query(self): test_scenarios = [ { - 'input': {'query': 'test'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']] + "input": {"query": "test"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]], }, { - 'input': {'query': 'stanley'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "stanley"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'special'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "special"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'TEST'}, # Search should be case insensitive by default - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']] + "input": { + "query": "TEST" + }, # Search should be case insensitive by default + "expected_code": 200, + "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]], }, { - 'input': {'query': 'SPECIAL'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "SPECIAL"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'sPeCiAL'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "sPeCiAL"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'st2-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "st2-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, { - 'input': {'query': 'ST2-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "ST2-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, { - 'input': {'query': '-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, - { - 'input': {'query': 'core'}, - 'expected_code': 200, - 'expected_result': [] - } + {"input": {"query": "core"}, "expected_code": 200, "expected_result": []}, ] for scenario in test_scenarios: - resp = self.app.post_json('/v1/packs/index/search', scenario['input']) - self.assertEqual(resp.status_int, scenario['expected_code']) - self.assertEqual(resp.json, scenario['expected_result']) - - @mock.patch.object(pack_service, 'get_pack_from_index', - mock.MagicMock(return_value=PACK_INDEX['test'])) + resp = self.app.post_json("/v1/packs/index/search", scenario["input"]) + self.assertEqual(resp.status_int, scenario["expected_code"]) + self.assertEqual(resp.json, scenario["expected_result"]) + + @mock.patch.object( + pack_service, + "get_pack_from_index", + mock.MagicMock(return_value=PACK_INDEX["test"]), + ) def test_search_with_pack_has_result(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'st2-dev'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "st2-dev"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test']) + self.assertEqual(resp.json, PACK_INDEX["test"]) - @mock.patch.object(pack_service, 'get_pack_from_index', - mock.MagicMock(return_value=None)) + @mock.patch.object( + pack_service, "get_pack_from_index", mock.MagicMock(return_value=None) + ) def test_search_with_pack_no_result(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'not-found'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "not-found"}) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, []) - @mock.patch.object(pack_service, 'fetch_pack_index', - mock.MagicMock(return_value=(PACK_INDEX, {}))) + @mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) + ) def test_show(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "test"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test']) + self.assertEqual(resp.json, PACK_INDEX["test"]) - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test2'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "test2"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test2']) + self.assertEqual(resp.json, PACK_INDEX["test2"]) - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock(return_value=["http://main.example.com"]), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_health(self): - resp = self.app.get('/v1/packs/index/health') + resp = self.app.get("/v1/packs/index/health") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'packs': { - 'count': 2 + self.assertEqual( + resp.json, + { + "packs": {"count": 2}, + "indexes": { + "count": 1, + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + } + ], + "valid": 1, + "errors": {}, + "invalid": 0, + }, }, - 'indexes': { - 'count': 1, - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'valid': 1, - 'errors': {}, - 'invalid': 0 - } - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com', - 'http://broken.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://main.example.com", "http://broken.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_health_broken(self): - resp = self.app.get('/v1/packs/index/health') + resp = self.app.get("/v1/packs/index/health") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'packs': { - 'count': 2 - }, - 'indexes': { - 'count': 2, - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }, { - 'url': 'http://broken.example.com', - 'message': "RequestException('index is broken',)", - 'packs': 0, - 'error': 'unresponsive' - }], - 'valid': 1, - 'errors': { - 'unresponsive': 1 + self.assertEqual( + resp.json, + { + "packs": {"count": 2}, + "indexes": { + "count": 2, + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + { + "url": "http://broken.example.com", + "message": "RequestException('index is broken',)", + "packs": 0, + "error": "unresponsive", + }, + ], + "valid": 1, + "errors": {"unresponsive": 1}, + "invalid": 1, }, - 'invalid': 1 - } - }) + }, + ) - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock(return_value=["http://main.example.com"]), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'index': PACK_INDEX - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://fallback.example.com', - 'http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + } + ], + "index": PACK_INDEX, + }, + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://fallback.example.com", "http://main.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_fallback(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://fallback.example.com', - 'message': 'Success.', - 'packs': 1, - 'error': None - }, { - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'index': PACK_INDEX - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com', - 'http://override.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://fallback.example.com", + "message": "Success.", + "packs": 1, + "error": None, + }, + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + ], + "index": PACK_INDEX, + }, + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://main.example.com", "http://override.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_override(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }, { - 'url': 'http://override.example.com', - 'message': 'Success.', - 'packs': 1, - 'error': None - }], - 'index': { - 'test': PACK_INDEX['test'], - 'test2': PACK_INDEXES['http://override.example.com']['test2'] - } - }) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + { + "url": "http://override.example.com", + "message": "Success.", + "packs": 1, + "error": None, + }, + ], + "index": { + "test": PACK_INDEX["test"], + "test2": PACK_INDEXES["http://override.example.com"]["test2"], + }, + }, + ) def test_packs_register_endpoint_resource_register_order(self): # Verify that resources are registered in the same order as they are inside @@ -416,17 +473,17 @@ def test_packs_register_endpoint_resource_register_order(self): # Note: Sadly there is no easier / better way to test this resource_types = list(ENTITIES.keys()) expected_order = [ - 'trigger', - 'sensor', - 'action', - 'rule', - 'alias', - 'policy', - 'config' + "trigger", + "sensor", + "action", + "rule", + "alias", + "policy", + "config", ] self.assertEqual(resource_types, expected_order) - @mock.patch.object(ContentPackLoader, 'get_packs') + @mock.patch.object(ContentPackLoader, "get_packs") def test_packs_register_endpoint(self, mock_get_packs): # Register resources from all packs - make sure the count values are correctly added # together @@ -434,12 +491,12 @@ def test_packs_register_endpoint(self, mock_get_packs): # Note: We only register a couple of packs and not all on disk to speed # things up. Registering all the packs takes a long time. fixtures_base_path = get_fixtures_base_path() - packs_base_path = os.path.join(fixtures_base_path, 'packs') + packs_base_path = os.path.join(fixtures_base_path, "packs") pack_names = [ - 'dummy_pack_1', - 'dummy_pack_2', - 'dummy_pack_3', - 'dummy_pack_10', + "dummy_pack_1", + "dummy_pack_2", + "dummy_pack_3", + "dummy_pack_10", ] mock_return_value = {} for pack_name in pack_names: @@ -447,160 +504,180 @@ def test_packs_register_endpoint(self, mock_get_packs): mock_get_packs.return_value = mock_return_value - resp = self.app.post_json('/v1/packs/register', {'fail_on_failure': False}) + resp = self.app.post_json("/v1/packs/register", {"fail_on_failure": False}) self.assertEqual(resp.status_int, 200) - self.assertIn('runners', resp.json) - self.assertIn('actions', resp.json) - self.assertIn('triggers', resp.json) - self.assertIn('sensors', resp.json) - self.assertIn('rules', resp.json) - self.assertIn('rule_types', resp.json) - self.assertIn('aliases', resp.json) - self.assertIn('policy_types', resp.json) - self.assertIn('policies', resp.json) - self.assertIn('configs', resp.json) - - self.assertTrue(resp.json['actions'] >= 3) - self.assertTrue(resp.json['configs'] >= 1) + self.assertIn("runners", resp.json) + self.assertIn("actions", resp.json) + self.assertIn("triggers", resp.json) + self.assertIn("sensors", resp.json) + self.assertIn("rules", resp.json) + self.assertIn("rule_types", resp.json) + self.assertIn("aliases", resp.json) + self.assertIn("policy_types", resp.json) + self.assertIn("policies", resp.json) + self.assertIn("configs", resp.json) + + self.assertTrue(resp.json["actions"] >= 3) + self.assertTrue(resp.json["configs"] >= 1) # Register resources from a specific pack - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['actions'] >= 1) - self.assertTrue(resp.json['sensors'] >= 1) - self.assertTrue(resp.json['configs'] >= 1) + self.assertTrue(resp.json["actions"] >= 1) + self.assertTrue(resp.json["sensors"] >= 1) + self.assertTrue(resp.json["configs"] >= 1) # Verify metadata_file attribute is set - action_dbs = Action.query(pack='dummy_pack_1') - self.assertEqual(action_dbs[0].metadata_file, 'actions/my_action.yaml') + action_dbs = Action.query(pack="dummy_pack_1") + self.assertEqual(action_dbs[0].metadata_file, "actions/my_action.yaml") # Register 'all' resource types should try include any possible content for the pack - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['all']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["all"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertIn('runners', resp.json) - self.assertIn('actions', resp.json) - self.assertIn('triggers', resp.json) - self.assertIn('sensors', resp.json) - self.assertIn('rules', resp.json) - self.assertIn('rule_types', resp.json) - self.assertIn('aliases', resp.json) - self.assertIn('policy_types', resp.json) - self.assertIn('policies', resp.json) - self.assertIn('configs', resp.json) + self.assertIn("runners", resp.json) + self.assertIn("actions", resp.json) + self.assertIn("triggers", resp.json) + self.assertIn("sensors", resp.json) + self.assertIn("rules", resp.json) + self.assertIn("rule_types", resp.json) + self.assertIn("aliases", resp.json) + self.assertIn("policy_types", resp.json) + self.assertIn("policies", resp.json) + self.assertIn("configs", resp.json) # Registering single resource type should also cause dependent resources # to be registered # * actions -> runners # * rules -> rule types # * policies -> policy types - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['actions']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["actions"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['runners'] >= 1) - self.assertTrue(resp.json['actions'] >= 1) + self.assertTrue(resp.json["runners"] >= 1) + self.assertTrue(resp.json["actions"] >= 1) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['rules']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["rules"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['rule_types'] >= 1) - self.assertTrue(resp.json['rules'] >= 1) + self.assertTrue(resp.json["rule_types"] >= 1) + self.assertTrue(resp.json["rules"] >= 1) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_2'], - 'fail_on_failure': False, - 'types': ['policies']}) + resp = self.app.post_json( + "/v1/packs/register", + { + "packs": ["dummy_pack_2"], + "fail_on_failure": False, + "types": ["policies"], + }, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['policy_types'] >= 1) - self.assertTrue(resp.json['policies'] >= 0) + self.assertTrue(resp.json["policy_types"] >= 1) + self.assertTrue(resp.json["policies"] >= 0) # Register specific type for all packs - resp = self.app.post_json('/v1/packs/register', {'types': ['sensor'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"types": ["sensor"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, {'sensors': 3}) + self.assertEqual(resp.json, {"sensors": 3}) # Verify that plural name form also works - resp = self.app.post_json('/v1/packs/register', {'types': ['sensors'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"types": ["sensors"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) # Register specific type for a single packs - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1'], 'types': ['action']}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["action"]} + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Verify that plural name form also works - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1'], 'types': ['actions']}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["actions"]} + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Register single resource from a single pack specified multiple times - verify that # resources from the same pack are only registered once - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1', 'dummy_pack_1', 'dummy_pack_1'], - 'types': ['actions'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", + { + "packs": ["dummy_pack_1", "dummy_pack_1", "dummy_pack_1"], + "types": ["actions"], + "fail_on_failure": False, + }, + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Register resources from a single (non-existent pack) - resp = self.app.post_json('/v1/packs/register', {'packs': ['doesntexist']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["doesntexist"]}, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertIn('Pack "doesntexist" not found on disk:', resp.json['faultstring']) + self.assertIn('Pack "doesntexist" not found on disk:', resp.json["faultstring"]) # Fail on failure is enabled by default - resp = self.app.post_json('/v1/packs/register', expect_errors=True) + resp = self.app.post_json("/v1/packs/register", expect_errors=True) expected_msg = 'Failed to register pack "dummy_pack_10":' self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) # Fail on failure (broken pack metadata) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"]}, expect_errors=True + ) expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist' self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) # Fail on failure (broken action metadata) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_15']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_15"]}, expect_errors=True + ) - expected_msg = 'Failed to register action' + expected_msg = "Failed to register action" self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) - expected_msg = '\'stringa\' is not valid under any of the given schemas' + expected_msg = "'stringa' is not valid under any of the given schemas" self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) def test_get_all_invalid_exclude_and_include_parameter(self): pass def _insert_mock_models(self): - return [self.pack_db_1['id'], self.pack_db_2['id'], self.pack_db_3['id']] + return [self.pack_db_1["id"], self.pack_db_2["id"], self.pack_db_3["id"]] def _do_delete(self, object_ids): pass diff --git a/st2api/tests/unit/controllers/v1/test_packs_views.py b/st2api/tests/unit/controllers/v1/test_packs_views.py index 5535a6e22b..a1b96a4aea 100644 --- a/st2api/tests/unit/controllers/v1/test_packs_views.py +++ b/st2api/tests/unit/controllers/v1/test_packs_views.py @@ -21,7 +21,7 @@ from st2tests.api import FunctionalTest -@mock.patch('st2common.bootstrap.base.REGISTERED_PACKS_CACHE', {}) +@mock.patch("st2common.bootstrap.base.REGISTERED_PACKS_CACHE", {}) class PacksViewsControllerTestCase(FunctionalTest): @classmethod def setUpClass(cls): @@ -31,32 +31,34 @@ def setUpClass(cls): actions_registrar.register_actions(use_pack_cache=False) def test_get_pack_files_success(self): - resp = self.app.get('/v1/packs/views/files/dummy_pack_1') + resp = self.app.get("/v1/packs/views/files/dummy_pack_1") self.assertEqual(resp.status_int, http_client.OK) self.assertTrue(len(resp.json) > 1) - item = [_item for _item in resp.json if _item['file_path'] == 'pack.yaml'][0] - self.assertEqual(item['file_path'], 'pack.yaml') - item = [_item for _item in resp.json if _item['file_path'] == 'actions/my_action.py'][0] - self.assertEqual(item['file_path'], 'actions/my_action.py') + item = [_item for _item in resp.json if _item["file_path"] == "pack.yaml"][0] + self.assertEqual(item["file_path"], "pack.yaml") + item = [ + _item for _item in resp.json if _item["file_path"] == "actions/my_action.py" + ][0] + self.assertEqual(item["file_path"], "actions/my_action.py") def test_get_pack_files_pack_doesnt_exist(self): - resp = self.app.get('/v1/packs/views/files/doesntexist', expect_errors=True) + resp = self.app.get("/v1/packs/views/files/doesntexist", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_get_pack_files_binary_files_are_excluded(self): binary_files = [ - 'icon.png', - 'etc/permissions.png', - 'etc/travisci.png', - 'etc/generate_new_token.png' + "icon.png", + "etc/permissions.png", + "etc/travisci.png", + "etc/generate_new_token.png", ] - pack_db = Pack.get_by_ref('dummy_pack_1') + pack_db = Pack.get_by_ref("dummy_pack_1") all_files_count = len(pack_db.files) non_binary_files_count = all_files_count - len(binary_files) - resp = self.app.get('/v1/packs/views/files/dummy_pack_1') + resp = self.app.get("/v1/packs/views/files/dummy_pack_1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), non_binary_files_count) @@ -65,63 +67,75 @@ def test_get_pack_files_binary_files_are_excluded(self): # But not in files controller response for file_path in binary_files: - item = [item for item in resp.json if item['file_path'] == file_path] + item = [item for item in resp.json if item["file_path"] == file_path] self.assertFalse(item) def test_get_pack_file_success(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) def test_get_pack_file_pack_doesnt_exist(self): - resp = self.app.get('/v1/packs/views/files/doesntexist/pack.yaml', expect_errors=True) + resp = self.app.get( + "/v1/packs/views/files/doesntexist/pack.yaml", expect_errors=True + ) self.assertEqual(resp.status_int, http_client.NOT_FOUND) - @mock.patch('st2api.controllers.v1.pack_views.MAX_FILE_SIZE', 1) + @mock.patch("st2api.controllers.v1.pack_views.MAX_FILE_SIZE", 1) def test_pack_file_file_larger_then_maximum_size(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', expect_errors=True) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", expect_errors=True + ) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertIn('File pack.yaml exceeds maximum allowed file size', resp) + self.assertIn("File pack.yaml exceeds maximum allowed file size", resp) def test_headers_get_pack_file(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) - self.assertIsNotNone(resp.headers['ETag']) - self.assertIsNotNone(resp.headers['Last-Modified']) + self.assertIn(b"name : dummy_pack_1", resp.body) + self.assertIsNotNone(resp.headers["ETag"]) + self.assertIsNotNone(resp.headers["Last-Modified"]) def test_no_change_get_pack_file(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) # Confirm NOT_MODIFIED - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-None-Match': resp.headers['ETag']}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-None-Match": resp.headers["ETag"]}, + ) self.assertEqual(resp.status_code, http_client.NOT_MODIFIED) - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-Modified-Since': resp.headers['Last-Modified']}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-Modified-Since": resp.headers["Last-Modified"]}, + ) self.assertEqual(resp.status_code, http_client.NOT_MODIFIED) # Confirm value is returned if header do not match - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-None-Match': 'ETAG'}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-None-Match": "ETAG"}, + ) self.assertEqual(resp.status_code, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-Modified-Since': 'Last-Modified'}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-Modified-Since": "Last-Modified"}, + ) self.assertEqual(resp.status_code, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) def test_get_pack_files_and_pack_file_ref_doesnt_equal_pack_name(self): # Ref is not equal to the name, controller should still work - resp = self.app.get('/v1/packs/views/files/dummy_pack_16') + resp = self.app.get("/v1/packs/views/files/dummy_pack_16") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['file_path'], 'pack.yaml') + self.assertEqual(resp.json[0]["file_path"], "pack.yaml") - resp = self.app.get('/v1/packs/views/file/dummy_pack_16/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_16/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'ref: dummy_pack_16', resp.body) + self.assertIn(b"ref: dummy_pack_16", resp.body) diff --git a/st2api/tests/unit/controllers/v1/test_policies.py b/st2api/tests/unit/controllers/v1/test_policies.py index a26c3dea24..3127b3aeb7 100644 --- a/st2api/tests/unit/controllers/v1/test_policies.py +++ b/st2api/tests/unit/controllers/v1/test_policies.py @@ -27,36 +27,28 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'PolicyTypeControllerTestCase', - 'PolicyControllerTestCase' -] +__all__ = ["PolicyTypeControllerTestCase", "PolicyControllerTestCase"] TEST_FIXTURES = { - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) -class PolicyTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/policytypes' +class PolicyTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/policytypes" controller_cls = PolicyTypeController - include_attribute_field_name = 'module' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "module" + exclude_attribute_field_name = "parameters" - base_url = '/v1/policytypes' + base_url = "/v1/policytypes" @classmethod def setUpClass(cls): @@ -64,7 +56,7 @@ def setUpClass(cls): cls.policy_type_dbs = [] - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) policy_type_db = PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) cls.policy_type_dbs.append(policy_type_db) @@ -80,23 +72,25 @@ def test_policy_type_filter(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_all(filter='resource_type=%s&name=%s' % - (selected['resource_type'], selected['name'])) + resp = self.__do_get_all( + filter="resource_type=%s&name=%s" + % (selected["resource_type"], selected["name"]) + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='name=%s' % selected['name']) + resp = self.__do_get_all(filter="name=%s" % selected["name"]) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='resource_type=%s' % selected['resource_type']) + resp = self.__do_get_all(filter="resource_type=%s" % selected["resource_type"]) self.assertEqual(resp.status_int, 200) self.assertGreater(len(resp.json), 1) def test_policy_type_filter_empty(self): - resp = self.__do_get_all(filter='resource_type=yo&name=whatever') + resp = self.__do_get_all(filter="resource_type=yo&name=whatever") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) @@ -106,16 +100,16 @@ def test_policy_type_get_one(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_one(selected['id']) + resp = self.__do_get_one(selected["id"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) - resp = self.__do_get_one(selected['ref']) + resp = self.__do_get_one(selected["ref"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) def test_policy_type_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, 404) def _insert_mock_models(self): @@ -130,36 +124,37 @@ def _delete_mock_models(self, object_ids): @staticmethod def __get_obj_id(resp, idx=-1): - return resp.json['id'] if idx < 0 else resp.json[idx]['id'] + return resp.json["id"] if idx < 0 else resp.json[idx]["id"] def __do_get_all(self, filter=None): - url = '%s?%s' % (self.base_url, filter) if filter else self.base_url + url = "%s?%s" % (self.base_url, filter) if filter else self.base_url return self.app.get(url, expect_errors=True) def __do_get_one(self, id): - return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True) -class PolicyControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/policies' +class PolicyControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/policies" controller_cls = PolicyController - include_attribute_field_name = 'policy_type' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "policy_type" + exclude_attribute_field_name = "parameters" - base_url = '/v1/policies' + base_url = "/v1/policies" @classmethod def setUpClass(cls): super(PolicyControllerTestCase, cls).setUpClass() - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) cls.policy_dbs = [] - for _, fixture in six.iteritems(FIXTURES['policies']): + for _, fixture in six.iteritems(FIXTURES["policies"]): instance = PolicyAPI(**fixture) policy_db = Policy.add_or_update(PolicyAPI.to_model(instance)) cls.policy_dbs.append(policy_db) @@ -175,22 +170,24 @@ def test_filter(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_all(filter='pack=%s&name=%s' % (selected['pack'], selected['name'])) + resp = self.__do_get_all( + filter="pack=%s&name=%s" % (selected["pack"], selected["name"]) + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='name=%s' % selected['name']) + resp = self.__do_get_all(filter="name=%s" % selected["name"]) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='pack=%s' % selected['pack']) + resp = self.__do_get_all(filter="pack=%s" % selected["pack"]) self.assertEqual(resp.status_int, 200) self.assertGreater(len(resp.json), 1) def test_filter_empty(self): - resp = self.__do_get_all(filter='pack=yo&name=whatever') + resp = self.__do_get_all(filter="pack=yo&name=whatever") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) @@ -200,16 +197,16 @@ def test_get_one(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_one(selected['id']) + resp = self.__do_get_one(selected["id"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) - resp = self.__do_get_one(selected['ref']) + resp = self.__do_get_one(selected["ref"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) def test_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, 404) def test_crud(self): @@ -221,10 +218,10 @@ def test_crud(self): self.assertEqual(get_resp.status_int, http_client.OK) updated_input = get_resp.json - updated_input['enabled'] = not updated_input['enabled'] + updated_input["enabled"] = not updated_input["enabled"] put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['enabled'], updated_input['enabled']) + self.assertEqual(put_resp.json["enabled"], updated_input["enabled"]) del_resp = self.__do_delete(self.__get_obj_id(post_resp)) self.assertEqual(del_resp.status_int, http_client.NO_CONTENT) @@ -243,41 +240,45 @@ def test_post_duplicate(self): def test_put_not_found(self): updated_input = self.__create_instance() - put_resp = self.__do_put('12345', updated_input) + put_resp = self.__do_put("12345", updated_input) self.assertEqual(put_resp.status_int, http_client.NOT_FOUND) def test_put_sys_pack(self): instance = self.__create_instance() - instance['pack'] = 'core' + instance["pack"] = "core" post_resp = self.__do_post(instance) self.assertEqual(post_resp.status_int, http_client.CREATED) updated_input = post_resp.json - updated_input['enabled'] = not updated_input['enabled'] + updated_input["enabled"] = not updated_input["enabled"] put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(put_resp.json['faultstring'], - "Resources belonging to system level packs can't be manipulated") + self.assertEqual( + put_resp.json["faultstring"], + "Resources belonging to system level packs can't be manipulated", + ) # Clean up manually since API won't delete object in sys pack. Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp))) def test_delete_not_found(self): - del_resp = self.__do_delete('12345') + del_resp = self.__do_delete("12345") self.assertEqual(del_resp.status_int, http_client.NOT_FOUND) def test_delete_sys_pack(self): instance = self.__create_instance() - instance['pack'] = 'core' + instance["pack"] = "core" post_resp = self.__do_post(instance) self.assertEqual(post_resp.status_int, http_client.CREATED) del_resp = self.__do_delete(self.__get_obj_id(post_resp)) self.assertEqual(del_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(del_resp.json['faultstring'], - "Resources belonging to system level packs can't be manipulated") + self.assertEqual( + del_resp.json["faultstring"], + "Resources belonging to system level packs can't be manipulated", + ) # Clean up manually since API won't delete object in sys pack. Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp))) @@ -295,34 +296,34 @@ def _delete_mock_models(self, object_ids): @staticmethod def __create_instance(): return { - 'name': 'myaction.mypolicy', - 'pack': 'mypack', - 'resource_ref': 'mypack.myaction', - 'policy_type': 'action.mock_policy_error', - 'parameters': { - 'k1': 'v1' - } + "name": "myaction.mypolicy", + "pack": "mypack", + "resource_ref": "mypack.myaction", + "policy_type": "action.mock_policy_error", + "parameters": {"k1": "v1"}, } @staticmethod def __get_obj_id(resp, idx=-1): - return resp.json['id'] if idx < 0 else resp.json[idx]['id'] + return resp.json["id"] if idx < 0 else resp.json[idx]["id"] def __do_get_all(self, filter=None): - url = '%s?%s' % (self.base_url, filter) if filter else self.base_url + url = "%s?%s" % (self.base_url, filter) if filter else self.base_url return self.app.get(url, expect_errors=True) def __do_get_one(self, id): - return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_post(self, instance): return self.app.post_json(self.base_url, instance, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_put(self, id, instance): - return self.app.put_json('%s/%s' % (self.base_url, id), instance, expect_errors=True) + return self.app.put_json( + "%s/%s" % (self.base_url, id), instance, expect_errors=True + ) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_delete(self, id): - return self.app.delete('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.delete("%s/%s" % (self.base_url, id), expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py index 84c2a66b4a..0a7a104d35 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py +++ b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py @@ -21,87 +21,109 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'RuleEnforcementViewsControllerTestCase' -] +__all__ = ["RuleEnforcementViewsControllerTestCase"] http_client = six.moves.http_client TEST_FIXTURES = { - 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml'], - 'executions': ['execution1.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml'] + "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"], + "executions": ["execution1.yaml"], + "triggerinstances": ["trigger_instance_1.yaml"], } -FIXTURES_PACK = 'rule_enforcements' +FIXTURES_PACK = "rule_enforcements" -class RuleEnforcementViewsControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/ruleenforcements/views' +class RuleEnforcementViewsControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/ruleenforcements/views" controller_cls = RuleEnforcementViewController - include_attribute_field_name = 'enforced_at' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "enforced_at" + exclude_attribute_field_name = "status" fixtures_loader = FixturesLoader() @classmethod def setUpClass(cls): super(RuleEnforcementViewsControllerTestCase, cls).setUpClass() - cls.models = RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES, - use_object_ids=True) - cls.ENFORCEMENT_1 = cls.models['enforcements']['enforcement1.yaml'] + cls.models = ( + RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, + fixtures_dict=TEST_FIXTURES, + use_object_ids=True, + ) + ) + cls.ENFORCEMENT_1 = cls.models["enforcements"]["enforcement1.yaml"] def test_get_all(self): - resp = self.app.get('/v1/ruleenforcements/views') + resp = self.app.get("/v1/ruleenforcements/views") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) # Verify it includes corresponding execution and trigger instance object - self.assertEqual(resp.json[0]['trigger_instance']['id'], '565e15ce32ed350857dfa623') - self.assertEqual(resp.json[0]['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'}) - - self.assertEqual(resp.json[0]['execution']['action']['ref'], 'core.local') - self.assertEqual(resp.json[0]['execution']['action']['parameters'], - {'sudo': {'immutable': True}}) - self.assertEqual(resp.json[0]['execution']['runner']['name'], 'action-chain') - self.assertEqual(resp.json[0]['execution']['runner']['runner_parameters'], - {'foo': {'type': 'string'}}) - self.assertEqual(resp.json[0]['execution']['parameters'], {'cmd': 'echo bar'}) - self.assertEqual(resp.json[0]['execution']['status'], 'scheduled') - - self.assertEqual(resp.json[1]['trigger_instance'], {}) - self.assertEqual(resp.json[1]['execution'], {}) - - self.assertEqual(resp.json[2]['trigger_instance'], {}) - self.assertEqual(resp.json[2]['execution'], {}) + self.assertEqual( + resp.json[0]["trigger_instance"]["id"], "565e15ce32ed350857dfa623" + ) + self.assertEqual( + resp.json[0]["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"} + ) + + self.assertEqual(resp.json[0]["execution"]["action"]["ref"], "core.local") + self.assertEqual( + resp.json[0]["execution"]["action"]["parameters"], + {"sudo": {"immutable": True}}, + ) + self.assertEqual(resp.json[0]["execution"]["runner"]["name"], "action-chain") + self.assertEqual( + resp.json[0]["execution"]["runner"]["runner_parameters"], + {"foo": {"type": "string"}}, + ) + self.assertEqual(resp.json[0]["execution"]["parameters"], {"cmd": "echo bar"}) + self.assertEqual(resp.json[0]["execution"]["status"], "scheduled") + + self.assertEqual(resp.json[1]["trigger_instance"], {}) + self.assertEqual(resp.json[1]["execution"], {}) + + self.assertEqual(resp.json[2]["trigger_instance"], {}) + self.assertEqual(resp.json[2]["execution"], {}) def test_filter_by_rule_ref(self): - resp = self.app.get('/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule') + resp = self.app.get("/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['rule']['ref'], 'wolfpack.golden_rule') + self.assertEqual(resp.json[0]["rule"]["ref"], "wolfpack.golden_rule") def test_get_one_success(self): - resp = self.app.get('/v1/ruleenforcements/views/%s' % (str(self.ENFORCEMENT_1.id))) - self.assertEqual(resp.json['id'], str(self.ENFORCEMENT_1.id)) - - self.assertEqual(resp.json['trigger_instance']['id'], '565e15ce32ed350857dfa623') - self.assertEqual(resp.json['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'}) - - self.assertEqual(resp.json['execution']['action']['ref'], 'core.local') - self.assertEqual(resp.json['execution']['action']['parameters'], - {'sudo': {'immutable': True}}) - self.assertEqual(resp.json['execution']['runner']['name'], 'action-chain') - self.assertEqual(resp.json['execution']['runner']['runner_parameters'], - {'foo': {'type': 'string'}}) - self.assertEqual(resp.json['execution']['parameters'], {'cmd': 'echo bar'}) - self.assertEqual(resp.json['execution']['status'], 'scheduled') + resp = self.app.get( + "/v1/ruleenforcements/views/%s" % (str(self.ENFORCEMENT_1.id)) + ) + self.assertEqual(resp.json["id"], str(self.ENFORCEMENT_1.id)) + + self.assertEqual( + resp.json["trigger_instance"]["id"], "565e15ce32ed350857dfa623" + ) + self.assertEqual( + resp.json["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"} + ) + + self.assertEqual(resp.json["execution"]["action"]["ref"], "core.local") + self.assertEqual( + resp.json["execution"]["action"]["parameters"], + {"sudo": {"immutable": True}}, + ) + self.assertEqual(resp.json["execution"]["runner"]["name"], "action-chain") + self.assertEqual( + resp.json["execution"]["runner"]["runner_parameters"], + {"foo": {"type": "string"}}, + ) + self.assertEqual(resp.json["execution"]["parameters"], {"cmd": "echo bar"}) + self.assertEqual(resp.json["execution"]["status"], "scheduled") def _insert_mock_models(self): - enfrocement_ids = [enforcement['id'] for enforcement in - self.models['enforcements'].values()] + enfrocement_ids = [ + enforcement["id"] for enforcement in self.models["enforcements"].values() + ] return enfrocement_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py index 172b186098..f2de1e2b2a 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py +++ b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py @@ -24,92 +24,106 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml'] + "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"] } -FIXTURES_PACK = 'rule_enforcements' +FIXTURES_PACK = "rule_enforcements" -class RuleEnforcementControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/ruleenforcements' +class RuleEnforcementControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/ruleenforcements" controller_cls = RuleEnforcementController - include_attribute_field_name = 'enforced_at' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "enforced_at" + exclude_attribute_field_name = "status" fixtures_loader = FixturesLoader() @classmethod def setUpClass(cls): super(RuleEnforcementControllerTestCase, cls).setUpClass() - cls.models = RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RuleEnforcementControllerTestCase.ENFORCEMENT_1 = \ - cls.models['enforcements']['enforcement1.yaml'] + cls.models = ( + RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + ) + RuleEnforcementControllerTestCase.ENFORCEMENT_1 = cls.models["enforcements"][ + "enforcement1.yaml" + ] def test_get_all(self): - resp = self.app.get('/v1/ruleenforcements') + resp = self.app.get("/v1/ruleenforcements") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) def test_get_all_minus_one(self): - resp = self.app.get('/v1/ruleenforcements/?limit=-1') + resp = self.app.get("/v1/ruleenforcements/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) def test_get_all_limit(self): - resp = self.app.get('/v1/ruleenforcements/?limit=1') + resp = self.app.get("/v1/ruleenforcements/?limit=1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/ruleenforcements?limit=-22', expect_errors=True) + resp = self.app.get("/v1/ruleenforcements?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_one_by_id(self): e_id = str(RuleEnforcementControllerTestCase.ENFORCEMENT_1.id) - resp = self.app.get('/v1/ruleenforcements/%s' % e_id) + resp = self.app.get("/v1/ruleenforcements/%s" % e_id) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(resp.json['id'], e_id) + self.assertEqual(resp.json["id"], e_id) def test_get_one_fail(self): - resp = self.app.get('/v1/ruleenforcements/1', expect_errors=True) + resp = self.app.get("/v1/ruleenforcements/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_filter_by_rule_ref(self): - resp = self.app.get('/v1/ruleenforcements?rule_ref=wolfpack.golden_rule') + resp = self.app.get("/v1/ruleenforcements?rule_ref=wolfpack.golden_rule") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_rule_id(self): - resp = self.app.get('/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331') + resp = self.app.get("/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) def test_filter_by_execution_id(self): - resp = self.app.get('/v1/ruleenforcements?execution=565e15cd32ed350857dfa620') + resp = self.app.get("/v1/ruleenforcements?execution=565e15cd32ed350857dfa620") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_trigger_instance_id(self): - resp = self.app.get('/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623') + resp = self.app.get( + "/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_enforced_at(self): - resp = self.app.get('/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z') + resp = self.app.get( + "/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) - resp = self.app.get('/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z') + resp = self.app.get( + "/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def _insert_mock_models(self): - enfrocement_ids = [enforcement['id'] for enforcement in - self.models['enforcements'].values()] + enfrocement_ids = [ + enforcement["id"] for enforcement in self.models["enforcements"].values() + ] return enfrocement_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_rule_views.py b/st2api/tests/unit/controllers/v1/test_rule_views.py index f8a25e5d3d..95839c3110 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_views.py +++ b/st2api/tests/unit/controllers/v1/test_rule_views.py @@ -25,25 +25,24 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml', 'action2.yaml'], - 'triggers': ['trigger1.yaml'], - 'triggertypes': ['triggertype1.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml", "action2.yaml"], + "triggers": ["trigger1.yaml"], + "triggertypes": ["triggertype1.yaml"], } -TEST_FIXTURES_RULES = { - 'rules': ['rule1.yaml', 'rule4.yaml', 'rule5.yaml'] -} +TEST_FIXTURES_RULES = {"rules": ["rule1.yaml", "rule4.yaml", "rule5.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -class RuleViewControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/rules/views' +class RuleViewControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/rules/views" controller_cls = RuleViewController - include_attribute_field_name = 'criteria' - exclude_attribute_field_name = 'enabled' + include_attribute_field_name = "criteria" + exclude_attribute_field_name = "enabled" fixtures_loader = FixturesLoader() @@ -51,17 +50,21 @@ class RuleViewControllerTestCase(FunctionalTest, def setUpClass(cls): super(RuleViewControllerTestCase, cls).setUpClass() models = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RuleViewControllerTestCase.ACTION_1 = models['actions']['action1.yaml'] - RuleViewControllerTestCase.TRIGGER_TYPE_1 = models['triggertypes']['triggertype1.yaml'] - - file_name = 'rule1.yaml' + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RuleViewControllerTestCase.ACTION_1 = models["actions"]["action1.yaml"] + RuleViewControllerTestCase.TRIGGER_TYPE_1 = models["triggertypes"][ + "triggertype1.yaml" + ] + + file_name = "rule1.yaml" cls.rules = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES)['rules'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES + )["rules"] RuleViewControllerTestCase.RULE_1 = cls.rules[file_name] def test_get_all(self): - resp = self.app.get('/v1/rules/views') + resp = self.app.get("/v1/rules/views") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) @@ -70,25 +73,29 @@ def test_get_one_by_id(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - self.assertEqual(get_resp.json['action']['description'], - RuleViewControllerTestCase.ACTION_1.description) - self.assertEqual(get_resp.json['trigger']['description'], - RuleViewControllerTestCase.TRIGGER_TYPE_1.description) + self.assertEqual( + get_resp.json["action"]["description"], + RuleViewControllerTestCase.ACTION_1.description, + ) + self.assertEqual( + get_resp.json["trigger"]["description"], + RuleViewControllerTestCase.TRIGGER_TYPE_1.description, + ) def test_get_one_by_ref(self): rule_name = RuleViewControllerTestCase.RULE_1.name rule_pack = RuleViewControllerTestCase.RULE_1.pack ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) def test_get_one_fail(self): - resp = self.app.get('/v1/rules/1', expect_errors=True) + resp = self.app.get("/v1/rules/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def _insert_mock_models(self): - rule_ids = [rule['id'] for rule in self.rules.values()] + rule_ids = [rule["id"] for rule in self.rules.values()] return rule_ids def _delete_mock_models(self, object_ids): @@ -96,7 +103,7 @@ def _delete_mock_models(self, object_ids): @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/views/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/views/%s" % rule_id, expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_rules.py b/st2api/tests/unit/controllers/v1/test_rules.py index f52b4294ca..daf6845bcb 100644 --- a/st2api/tests/unit/controllers/v1/test_rules.py +++ b/st2api/tests/unit/controllers/v1/test_rules.py @@ -34,21 +34,23 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml'], - 'triggers': ['trigger1.yaml'], - 'triggertypes': ['triggertype1.yaml', 'triggertype_with_parameters_2.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "triggers": ["trigger1.yaml"], + "triggertypes": ["triggertype1.yaml", "triggertype_with_parameters_2.yaml"], } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class RulesControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/rules' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class RulesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/rules" controller_cls = RuleController - include_attribute_field_name = 'criteria' - exclude_attribute_field_name = 'enabled' + include_attribute_field_name = "criteria" + exclude_attribute_field_name = "enabled" VALIDATE_TRIGGER_PAYLOAD = None @@ -64,71 +66,96 @@ def setUpClass(cls): cls.VALIDATE_TRIGGER_PAYLOAD = cfg.CONF.system.validate_trigger_parameters models = RulesControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RulesControllerTestCase.RUNNER_TYPE = models['runners']['testrunner1.yaml'] - RulesControllerTestCase.ACTION = models['actions']['action1.yaml'] - RulesControllerTestCase.TRIGGER = models['triggers']['trigger1.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RulesControllerTestCase.RUNNER_TYPE = models["runners"]["testrunner1.yaml"] + RulesControllerTestCase.ACTION = models["actions"]["action1.yaml"] + RulesControllerTestCase.TRIGGER = models["triggers"]["trigger1.yaml"] # Don't load rule into DB as that is what is being tested. - file_name = 'rule1.yaml' - RulesControllerTestCase.RULE_1 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters.yaml' - RulesControllerTestCase.RULE_2 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_no_enabled_attribute.yaml' - RulesControllerTestCase.RULE_3 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'backstop_rule.yaml' - RulesControllerTestCase.RULE_4 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'date_timer_rule_invalid_parameters.yaml' - RulesControllerTestCase.RULE_5 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_1.yaml' - RulesControllerTestCase.RULE_6 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_2.yaml' - RulesControllerTestCase.RULE_7 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_3.yaml' - RulesControllerTestCase.RULE_8 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_invalid_trigger_parameter_type.yaml' - RulesControllerTestCase.RULE_9 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_trigger_with_no_parameters.yaml' - RulesControllerTestCase.RULE_10 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_invalid_trigger_parameter_type_default_cfg.yaml' - RulesControllerTestCase.RULE_11 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule space.yaml' - RulesControllerTestCase.RULE_SPACE = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] + file_name = "rule1.yaml" + RulesControllerTestCase.RULE_1 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters.yaml" + RulesControllerTestCase.RULE_2 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_no_enabled_attribute.yaml" + RulesControllerTestCase.RULE_3 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "backstop_rule.yaml" + RulesControllerTestCase.RULE_4 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "date_timer_rule_invalid_parameters.yaml" + RulesControllerTestCase.RULE_5 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_1.yaml" + RulesControllerTestCase.RULE_6 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_2.yaml" + RulesControllerTestCase.RULE_7 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_3.yaml" + RulesControllerTestCase.RULE_8 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_invalid_trigger_parameter_type.yaml" + RulesControllerTestCase.RULE_9 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_trigger_with_no_parameters.yaml" + RulesControllerTestCase.RULE_10 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_invalid_trigger_parameter_type_default_cfg.yaml" + RulesControllerTestCase.RULE_11 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule space.yaml" + RulesControllerTestCase.RULE_SPACE = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) @classmethod def tearDownClass(cls): @@ -136,18 +163,19 @@ def tearDownClass(cls): cfg.CONF.system.validate_trigger_payload = cls.VALIDATE_TRIGGER_PAYLOAD RulesControllerTestCase.fixtures_loader.delete_fixtures_from_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) super(RulesControllerTestCase, cls).setUpClass() def test_get_all_and_minus_one(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) - resp = self.app.get('/v1/rules') + resp = self.app.get("/v1/rules") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) - resp = self.app.get('/v1/rules/?limit=-1') + resp = self.app.get("/v1/rules/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) @@ -158,10 +186,12 @@ def test_get_all_limit_negative_number(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) - resp = self.app.get('/v1/rules?limit=-22', expect_errors=True) + resp = self.app.get("/v1/rules?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) self.__do_delete(self.__get_rule_id(post_resp_rule_3)) @@ -171,18 +201,18 @@ def test_get_all_enabled(self): post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) # enabled=True - resp = self.app.get('/v1/rules?enabled=True') + resp = self.app.get("/v1/rules?enabled=True") self.assertEqual(resp.status_int, http_client.OK) rule = resp.json[0] - self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule['id']) - self.assertEqual(rule['enabled'], True) + self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule["id"]) + self.assertEqual(rule["enabled"], True) # enabled=False - resp = self.app.get('/v1/rules?enabled=False') + resp = self.app.get("/v1/rules?enabled=False") self.assertEqual(resp.status_int, http_client.OK) rule = resp.json[0] - self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule['id']) - self.assertEqual(rule['enabled'], False) + self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule["id"]) + self.assertEqual(rule["enabled"], False) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) self.__do_delete(self.__get_rule_id(post_resp_rule_3)) @@ -191,37 +221,45 @@ def test_get_all_action_parameters_secrets_masking(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], + MASKED_ATTRIBUTE_VALUE, + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret') + resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], "secret" + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) def test_get_all_parameters_mask_with_exclude_parameters(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) - resp = self.app.get('/v1/rules?exclude_attributes=action') - self.assertEqual('action' in resp.json[0], False) + resp = self.app.get("/v1/rules?exclude_attributes=action") + self.assertEqual("action" in resp.json[0], False) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) def test_get_all_parameters_mask_with_include_parameters(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules?include_attributes=action') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules?include_attributes=action") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], + MASKED_ATTRIBUTE_VALUE, + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret') + resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], "secret" + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) @@ -229,13 +267,16 @@ def test_get_one_action_parameters_secrets_masking(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules/%s' % (post_resp_rule_1.json['id'])) - self.assertEqual(resp.json['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules/%s" % (post_resp_rule_1.json["id"])) + self.assertEqual( + resp.json["action"]["parameters"]["action_secret"], MASKED_ATTRIBUTE_VALUE + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules/%s?show_secrets=true' % (post_resp_rule_1.json['id'])) - self.assertEqual(resp.json['action']['parameters']['action_secret'], 'secret') + resp = self.app.get( + "/v1/rules/%s?show_secrets=true" % (post_resp_rule_1.json["id"]) + ) + self.assertEqual(resp.json["action"]["parameters"]["action_secret"], "secret") self.__do_delete(self.__get_rule_id(post_resp_rule_1)) @@ -249,27 +290,27 @@ def test_get_one_by_id(self): def test_get_one_by_ref(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) - rule_name = post_resp.json['name'] - rule_pack = post_resp.json['pack'] + rule_name = post_resp.json["name"] + rule_pack = post_resp.json["pack"] ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) - rule_id = post_resp.json['id'] + rule_id = post_resp.json["id"] get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) self.__do_delete(rule_id) post_resp = self.__do_post(RulesControllerTestCase.RULE_SPACE) - rule_name = post_resp.json['name'] - rule_pack = post_resp.json['pack'] + rule_name = post_resp.json["name"] + rule_pack = post_resp.json["pack"] ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) - rule_id = post_resp.json['id'] + rule_id = post_resp.json["id"] get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) self.__do_delete(rule_id) def test_get_one_fail(self): - resp = self.app.get('/v1/rules/1', expect_errors=True) + resp = self.app.get("/v1/rules/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -283,38 +324,44 @@ def test_post_duplicate(self): self.assertEqual(post_resp.status_int, http_client.CREATED) post_resp_2 = self.__do_post(RulesControllerTestCase.RULE_1) self.assertEqual(post_resp_2.status_int, http_client.CONFLICT) - self.assertEqual(post_resp_2.json['conflict-id'], org_id) + self.assertEqual(post_resp_2.json["conflict-id"], org_id) self.__do_delete(org_id) def test_post_invalid_rule_data(self): - post_resp = self.__do_post({'name': 'rule'}) + post_resp = self.__do_post({"name": "rule"}) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) expected_msg = "'trigger' is a required property" - self.assertEqual(post_resp.json['faultstring'], expected_msg) + self.assertEqual(post_resp.json["faultstring"], expected_msg) def test_post_trigger_parameter_schema_validation_fails(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_2) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) if six.PY3: - expected_msg = b'Additional properties are not allowed (\'minutex\' was unexpected)' + expected_msg = ( + b"Additional properties are not allowed ('minutex' was unexpected)" + ) else: - expected_msg = b'Additional properties are not allowed (u\'minutex\' was unexpected)' + expected_msg = ( + b"Additional properties are not allowed (u'minutex' was unexpected)" + ) self.assertIn(expected_msg, post_resp.body) - def test_post_trigger_parameter_schema_validation_fails_missing_required_param(self): + def test_post_trigger_parameter_schema_validation_fails_missing_required_param( + self, + ): post_resp = self.__do_post(RulesControllerTestCase.RULE_5) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = b'\'date\' is a required property' + expected_msg = b"'date' is a required property" self.assertIn(expected_msg, post_resp.body) def test_post_invalid_crontimer_trigger_parameters(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_6) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = b'1000 is greater than the maximum of 6' + expected_msg = b"1000 is greater than the maximum of 6" self.assertIn(expected_msg, post_resp.body) post_resp = self.__do_post(RulesControllerTestCase.RULE_7) @@ -329,7 +376,9 @@ def test_post_invalid_crontimer_trigger_parameters(self): expected_msg = b'Invalid weekday name \\"a\\"' self.assertIn(expected_msg, post_resp.body) - def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled(self): + def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled( + self, + ): # Invalid custom trigger parameter (invalid type) and non-system trigger parameter # validation is enabled - trigger creation should fail cfg.CONF.system.validate_trigger_parameters = True @@ -338,16 +387,22 @@ def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled( self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) if six.PY3: - expected_msg_1 = "Failed validating 'type' in schema['properties']['param1']:" - expected_msg_2 = '12345 is not of type \'string\'' + expected_msg_1 = ( + "Failed validating 'type' in schema['properties']['param1']:" + ) + expected_msg_2 = "12345 is not of type 'string'" else: - expected_msg_1 = "Failed validating u'type' in schema[u'properties'][u'param1']:" - expected_msg_2 = '12345 is not of type u\'string\'' + expected_msg_1 = ( + "Failed validating u'type' in schema[u'properties'][u'param1']:" + ) + expected_msg_2 = "12345 is not of type u'string'" - self.assertIn(expected_msg_1, post_resp.json['faultstring']) - self.assertIn(expected_msg_2, post_resp.json['faultstring']) + self.assertIn(expected_msg_1, post_resp.json["faultstring"]) + self.assertIn(expected_msg_2, post_resp.json["faultstring"]) - def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled(self): + def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled( + self, + ): # Invalid custom trigger parameter (invalid type) and non-system trigger parameter # validation is disabled - trigger creation should succeed cfg.CONF.system.validate_trigger_parameters = False @@ -368,33 +423,33 @@ def test_post_invalid_custom_trigger_parameter_trigger_no_parameters_schema(self def test_post_no_enabled_attribute_disabled_by_default(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_3) self.assertEqual(post_resp.status_int, http_client.CREATED) - self.assertFalse(post_resp.json['enabled']) + self.assertFalse(post_resp.json["enabled"]) self.__do_delete(self.__get_rule_id(post_resp)) def test_put(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) update_input = post_resp.json - update_input['enabled'] = not update_input['enabled'] + update_input["enabled"] = not update_input["enabled"] put_resp = self.__do_put(self.__get_rule_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_rule_id(put_resp)) def test_post_no_pack_info(self): rule = copy.deepcopy(RulesControllerTestCase.RULE_1) - del rule['pack'] + del rule["pack"] post_resp = self.__do_post(rule) - self.assertEqual(post_resp.json['pack'], DEFAULT_PACK_NAME) + self.assertEqual(post_resp.json["pack"], DEFAULT_PACK_NAME) self.assertEqual(post_resp.status_int, http_client.CREATED) self.__do_delete(self.__get_rule_id(post_resp)) def test_put_no_pack_info(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) test_rule = post_resp.json - if 'pack' in test_rule: - del test_rule['pack'] - self.assertNotIn('pack', test_rule) + if "pack" in test_rule: + del test_rule["pack"] + self.assertNotIn("pack", test_rule) put_resp = self.__do_put(self.__get_rule_id(post_resp), test_rule) - self.assertEqual(put_resp.json['pack'], DEFAULT_PACK_NAME) + self.assertEqual(put_resp.json["pack"], DEFAULT_PACK_NAME) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_rule_id(put_resp)) @@ -417,7 +472,7 @@ def test_rule_with_tags(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - self.assertEqual(get_resp.json['tags'], RulesControllerTestCase.RULE_1['tags']) + self.assertEqual(get_resp.json["tags"], RulesControllerTestCase.RULE_1["tags"]) self.__do_delete(rule_id) def test_rule_without_type(self): @@ -426,10 +481,13 @@ def test_rule_without_type(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - assigned_rule_type = get_resp.json['type'] - self.assertTrue(assigned_rule_type, 'rule_type should be assigned') - self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_STANDARD, - 'rule_type should be standard') + assigned_rule_type = get_resp.json["type"] + self.assertTrue(assigned_rule_type, "rule_type should be assigned") + self.assertEqual( + assigned_rule_type["ref"], + RULE_TYPE_STANDARD, + "rule_type should be standard", + ) self.__do_delete(rule_id) def test_rule_with_type(self): @@ -438,10 +496,13 @@ def test_rule_with_type(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - assigned_rule_type = get_resp.json['type'] - self.assertTrue(assigned_rule_type, 'rule_type should be assigned') - self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_BACKSTOP, - 'rule_type should be backstop') + assigned_rule_type = get_resp.json["type"] + self.assertTrue(assigned_rule_type, "rule_type should be assigned") + self.assertEqual( + assigned_rule_type["ref"], + RULE_TYPE_BACKSTOP, + "rule_type should be backstop", + ) self.__do_delete(rule_id) def test_update_rule_no_data(self): @@ -451,7 +512,7 @@ def test_update_rule_no_data(self): put_resp = self.__do_put(rule_1_id, {}) expected_msg = "'name' is a required property" self.assertEqual(put_resp.status_code, http_client.BAD_REQUEST) - self.assertEqual(put_resp.json['faultstring'], expected_msg) + self.assertEqual(put_resp.json["faultstring"], expected_msg) self.__do_delete(rule_1_id) @@ -460,16 +521,16 @@ def test_update_rule_missing_id_in_body(self): rule_1_id = self.__get_rule_id(post_resp) rule_without_id = copy.deepcopy(self.RULE_1) - rule_without_id.pop('id', None) + rule_without_id.pop("id", None) put_resp = self.__do_put(rule_1_id, rule_without_id) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['id'], rule_1_id) + self.assertEqual(put_resp.json["id"], rule_1_id) self.__do_delete(rule_1_id) def _insert_mock_models(self): rule = copy.deepcopy(RulesControllerTestCase.RULE_1) - rule['name'] += '-253' + rule["name"] += "-253" post_resp = self.__do_post(rule) rule_1_id = self.__get_rule_id(post_resp) return [rule_1_id] @@ -479,32 +540,32 @@ def _do_delete(self, rule_id): @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_post(self, rule): - return self.app.post_json('/v1/rules', rule, expect_errors=True) + return self.app.post_json("/v1/rules", rule, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_put(self, rule_id, rule): - return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True) + return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_delete(self, rule_id): - return self.app.delete('/v1/rules/%s' % rule_id) + return self.app.delete("/v1/rules/%s" % rule_id) TEST_FIXTURES_2 = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml'], - 'triggertypes': ['triggertype_with_parameter.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "triggertypes": ["triggertype_with_parameter.yaml"], } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RulesControllerTestCaseTriggerCreator(FunctionalTest): fixtures_loader = FixturesLoader() @@ -513,32 +574,33 @@ class RulesControllerTestCaseTriggerCreator(FunctionalTest): def setUpClass(cls): super(RulesControllerTestCaseTriggerCreator, cls).setUpClass() cls.models = cls.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2 + ) # Don't load rule into DB as that is what is being tested. - file_name = 'rule_trigger_params.yaml' + file_name = "rule_trigger_params.yaml" cls.RULE_1 = cls.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] def test_ref_count_trigger_increment(self): post_resp = self.__do_post(self.RULE_1) rule_1_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) # ref_count is not served over API. Likely a choice that will prove unwise. - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1") # different rule same params rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1") self.__do_delete(rule_1_id) self.__do_delete(rule_2_id) @@ -549,16 +611,16 @@ def test_ref_count_trigger_decrement(self): self.assertEqual(post_resp.status_int, http_client.CREATED) rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) # validate decrement self.__do_delete(rule_1_id) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1") self.__do_delete(rule_2_id) def test_trigger_cleanup(self): @@ -567,34 +629,34 @@ def test_trigger_cleanup(self): self.assertEqual(post_resp.status_int, http_client.CREATED) rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1") self.__do_delete(rule_1_id) self.__do_delete(rule_2_id) # validate cleanup - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 0, 'Exactly 1 should exist') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 0, "Exactly 1 should exist") @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True) def __do_post(self, rule): - return self.app.post_json('/v1/rules', rule, expect_errors=True) + return self.app.post_json("/v1/rules", rule, expect_errors=True) def __do_put(self, rule_id, rule): - return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True) + return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True) def __do_delete(self, rule_id): - return self.app.delete('/v1/rules/%s' % rule_id) + return self.app.delete("/v1/rules/%s" % rule_id) diff --git a/st2api/tests/unit/controllers/v1/test_ruletypes.py b/st2api/tests/unit/controllers/v1/test_ruletypes.py index 5cba961409..87b1c4c584 100644 --- a/st2api/tests/unit/controllers/v1/test_ruletypes.py +++ b/st2api/tests/unit/controllers/v1/test_ruletypes.py @@ -26,20 +26,26 @@ def setUpClass(cls): ruletypes_registrar.register_rule_types() def test_get_one(self): - list_resp = self.app.get('/v1/ruletypes') + list_resp = self.app.get("/v1/ruletypes") self.assertEqual(list_resp.status_int, 200) - self.assertTrue(len(list_resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.') - ruletype_id = list_resp.json[0]['id'] - get_resp = self.app.get('/v1/ruletypes/%s' % ruletype_id) - retrieved_id = get_resp.json['id'] + self.assertTrue( + len(list_resp.json) > 0, "/v1/ruletypes did not return correct ruletypes." + ) + ruletype_id = list_resp.json[0]["id"] + get_resp = self.app.get("/v1/ruletypes/%s" % ruletype_id) + retrieved_id = get_resp.json["id"] self.assertEqual(get_resp.status_int, 200) - self.assertEqual(retrieved_id, ruletype_id, '/v1/ruletypes returned incorrect ruletype.') + self.assertEqual( + retrieved_id, ruletype_id, "/v1/ruletypes returned incorrect ruletype." + ) def test_get_all(self): - resp = self.app.get('/v1/ruletypes') + resp = self.app.get("/v1/ruletypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/ruletypes did not return correct ruletypes." + ) def test_get_one_fail_doesnt_exist(self): - resp = self.app.get('/v1/ruletypes/1', expect_errors=True) + resp = self.app.get("/v1/ruletypes/1", expect_errors=True) self.assertEqual(resp.status_int, 404) diff --git a/st2api/tests/unit/controllers/v1/test_runnertypes.py b/st2api/tests/unit/controllers/v1/test_runnertypes.py index edaacdf6dd..34c243c545 100644 --- a/st2api/tests/unit/controllers/v1/test_runnertypes.py +++ b/st2api/tests/unit/controllers/v1/test_runnertypes.py @@ -18,67 +18,76 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'RunnerTypesControllerTestCase' -] +__all__ = ["RunnerTypesControllerTestCase"] -class RunnerTypesControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/runnertypes' +class RunnerTypesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/runnertypes" controller_cls = RunnerTypesController - include_attribute_field_name = 'runner_package' - exclude_attribute_field_name = 'runner_module' - test_exact_object_count = False # runners are registered dynamically in base test class + include_attribute_field_name = "runner_package" + exclude_attribute_field_name = "runner_module" + test_exact_object_count = ( + False # runners are registered dynamically in base test class + ) def test_get_one(self): - resp = self.app.get('/v1/runnertypes') + resp = self.app.get("/v1/runnertypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes." + ) runnertype_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json[0]) - resp = self.app.get('/v1/runnertypes/%s' % runnertype_id) + resp = self.app.get("/v1/runnertypes/%s" % runnertype_id) retrieved_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json) self.assertEqual(resp.status_int, 200) - self.assertEqual(retrieved_id, runnertype_id, - '/v1/runnertypes returned incorrect runnertype.') + self.assertEqual( + retrieved_id, + runnertype_id, + "/v1/runnertypes returned incorrect runnertype.", + ) def test_get_all(self): - resp = self.app.get('/v1/runnertypes') + resp = self.app.get("/v1/runnertypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes." + ) def test_get_one_fail_doesnt_exist(self): - resp = self.app.get('/v1/runnertypes/1', expect_errors=True) + resp = self.app.get("/v1/runnertypes/1", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_put_disable_runner(self): - runnertype_id = 'action-chain' - resp = self.app.get('/v1/runnertypes/%s' % runnertype_id) - self.assertTrue(resp.json['enabled']) + runnertype_id = "action-chain" + resp = self.app.get("/v1/runnertypes/%s" % runnertype_id) + self.assertTrue(resp.json["enabled"]) # Disable the runner update_input = resp.json - update_input['enabled'] = False - update_input['name'] = 'foobar' + update_input["enabled"] = False + update_input["name"] = "foobar" put_resp = self.__do_put(runnertype_id, update_input) - self.assertFalse(put_resp.json['enabled']) + self.assertFalse(put_resp.json["enabled"]) # Verify that the name hasn't been updated - we only allow updating # enabled attribute on the runner - self.assertEqual(put_resp.json['name'], 'action-chain') + self.assertEqual(put_resp.json["name"], "action-chain") # Enable the runner update_input = resp.json - update_input['enabled'] = True + update_input["enabled"] = True put_resp = self.__do_put(runnertype_id, update_input) - self.assertTrue(put_resp.json['enabled']) + self.assertTrue(put_resp.json["enabled"]) def __do_put(self, runner_type_id, runner_type): - return self.app.put_json('/v1/runnertypes/%s' % runner_type_id, runner_type, - expect_errors=True) + return self.app.put_json( + "/v1/runnertypes/%s" % runner_type_id, runner_type, expect_errors=True + ) @staticmethod def __get_runnertype_id(resp_json): - return resp_json['id'] + return resp_json["id"] diff --git a/st2api/tests/unit/controllers/v1/test_sensortypes.py b/st2api/tests/unit/controllers/v1/test_sensortypes.py index 8e66cdfb40..c59a1c28e2 100644 --- a/st2api/tests/unit/controllers/v1/test_sensortypes.py +++ b/st2api/tests/unit/controllers/v1/test_sensortypes.py @@ -25,17 +25,16 @@ http_client = six.moves.http_client -__all__ = [ - 'SensorTypeControllerTestCase' -] +__all__ = ["SensorTypeControllerTestCase"] -class SensorTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/sensortypes' +class SensorTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/sensortypes" controller_cls = SensorTypeController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'artifact_uri' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "artifact_uri" test_exact_object_count = False @classmethod @@ -46,106 +45,108 @@ def setUpClass(cls): sensors_registrar.register_sensors(use_pack_cache=False) def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/sensortypes') + resp = self.app.get("/v1/sensortypes") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) - self.assertEqual(resp.json[0]['name'], 'SampleSensor') + self.assertEqual(resp.json[0]["name"], "SampleSensor") - resp = self.app.get('/v1/sensortypes/?limit=-1') + resp = self.app.get("/v1/sensortypes/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) - self.assertEqual(resp.json[0]['name'], 'SampleSensor') + self.assertEqual(resp.json[0]["name"], "SampleSensor") def test_get_all_negative_limit(self): - resp = self.app.get('/v1/sensortypes/?limit=-22', expect_errors=True) + resp = self.app.get("/v1/sensortypes/?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_filters(self): - resp = self.app.get('/v1/sensortypes') + resp = self.app.get("/v1/sensortypes") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) # ?name filter - resp = self.app.get('/v1/sensortypes?name=foobar') + resp = self.app.get("/v1/sensortypes?name=foobar") self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/sensortypes?name=SampleSensor2') + resp = self.app.get("/v1/sensortypes?name=SampleSensor2") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'SampleSensor2') - self.assertEqual(resp.json[0]['ref'], 'dummy_pack_1.SampleSensor2') + self.assertEqual(resp.json[0]["name"], "SampleSensor2") + self.assertEqual(resp.json[0]["ref"], "dummy_pack_1.SampleSensor2") - resp = self.app.get('/v1/sensortypes?name=SampleSensor3') + resp = self.app.get("/v1/sensortypes?name=SampleSensor3") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'SampleSensor3') + self.assertEqual(resp.json[0]["name"], "SampleSensor3") # ?pack filter - resp = self.app.get('/v1/sensortypes?pack=foobar') + resp = self.app.get("/v1/sensortypes?pack=foobar") self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/sensortypes?pack=dummy_pack_1') + resp = self.app.get("/v1/sensortypes?pack=dummy_pack_1") self.assertEqual(len(resp.json), 3) # ?enabled filter - resp = self.app.get('/v1/sensortypes?enabled=False') + resp = self.app.get("/v1/sensortypes?enabled=False") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['enabled'], False) + self.assertEqual(resp.json[0]["enabled"], False) - resp = self.app.get('/v1/sensortypes?enabled=True') + resp = self.app.get("/v1/sensortypes?enabled=True") self.assertEqual(len(resp.json), 2) - self.assertEqual(resp.json[0]['enabled'], True) - self.assertEqual(resp.json[1]['enabled'], True) + self.assertEqual(resp.json[0]["enabled"], True) + self.assertEqual(resp.json[1]["enabled"], True) # ?trigger filter - resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event3') + resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event3") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event3']) + self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event3"]) - resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event') + resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event") self.assertEqual(len(resp.json), 2) - self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event']) - self.assertEqual(resp.json[1]['trigger_types'], ['dummy_pack_1.event']) + self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event"]) + self.assertEqual(resp.json[1]["trigger_types"], ["dummy_pack_1.event"]) def test_get_one_success(self): - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(resp.json['name'], 'SampleSensor') - self.assertEqual(resp.json['ref'], 'dummy_pack_1.SampleSensor') + self.assertEqual(resp.json["name"], "SampleSensor") + self.assertEqual(resp.json["ref"], "dummy_pack_1.SampleSensor") def test_get_one_doesnt_exist(self): - resp = self.app.get('/v1/sensortypes/1', expect_errors=True) + resp = self.app.get("/v1/sensortypes/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_disable_and_enable_sensor(self): # Verify initial state - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertTrue(resp.json['enabled']) + self.assertTrue(resp.json["enabled"]) sensor_data = resp.json # Disable sensor data = copy.deepcopy(sensor_data) - data['enabled'] = False - put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data) + data["enabled"] = False + put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['ref'], 'dummy_pack_1.SampleSensor') - self.assertFalse(put_resp.json['enabled']) + self.assertEqual(put_resp.json["ref"], "dummy_pack_1.SampleSensor") + self.assertFalse(put_resp.json["enabled"]) # Verify sensor has been disabled - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertFalse(resp.json['enabled']) + self.assertFalse(resp.json["enabled"]) # Enable sensor data = copy.deepcopy(sensor_data) - data['enabled'] = True - put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data) + data["enabled"] = True + put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertTrue(put_resp.json['enabled']) + self.assertTrue(put_resp.json["enabled"]) # Verify sensor has been enabled - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertTrue(resp.json['enabled']) + self.assertTrue(resp.json["enabled"]) diff --git a/st2api/tests/unit/controllers/v1/test_service_registry.py b/st2api/tests/unit/controllers/v1/test_service_registry.py index efeb7d432a..d195c2361e 100644 --- a/st2api/tests/unit/controllers/v1/test_service_registry.py +++ b/st2api/tests/unit/controllers/v1/test_service_registry.py @@ -22,9 +22,7 @@ from st2tests.api import FunctionalTest -__all__ = [ - 'ServiceyRegistryControllerTestCase' -] +__all__ = ["ServiceyRegistryControllerTestCase"] class ServiceyRegistryControllerTestCase(FunctionalTest): @@ -41,10 +39,11 @@ def setUpClass(cls): # NOTE: We mock call common_setup to emulate service being registered in the service # registry during bootstrap phase - register_service_in_service_registry(service='mock_service', - capabilities={'key1': 'value1', - 'name': 'mock_service'}, - start_heart=True) + register_service_in_service_registry( + service="mock_service", + capabilities={"key1": "value1", "name": "mock_service"}, + start_heart=True, + ) @classmethod def tearDownClass(cls): @@ -53,33 +52,40 @@ def tearDownClass(cls): coordination.coordinator_teardown(cls.coordinator) def test_get_groups(self): - list_resp = self.app.get('/v1/service_registry/groups') + list_resp = self.app.get("/v1/service_registry/groups") self.assertEqual(list_resp.status_int, 200) - self.assertEqual(list_resp.json, {'groups': ['mock_service']}) + self.assertEqual(list_resp.json, {"groups": ["mock_service"]}) def test_get_group_members(self): proc_info = system_info.get_process_info() member_id = get_member_id() # 1. Group doesn't exist - resp = self.app.get('/v1/service_registry/groups/doesnt-exist/members', expect_errors=True) + resp = self.app.get( + "/v1/service_registry/groups/doesnt-exist/members", expect_errors=True + ) self.assertEqual(resp.status_int, 404) - self.assertEqual(resp.json['faultstring'], 'Group with ID "doesnt-exist" not found.') + self.assertEqual( + resp.json["faultstring"], 'Group with ID "doesnt-exist" not found.' + ) # 2. Group exists and has a single member - resp = self.app.get('/v1/service_registry/groups/mock_service/members') + resp = self.app.get("/v1/service_registry/groups/mock_service/members") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'members': [ - { - 'group_id': 'mock_service', - 'member_id': member_id.decode('utf-8'), - 'capabilities': { - 'key1': 'value1', - 'name': 'mock_service', - 'hostname': proc_info['hostname'], - 'pid': proc_info['pid'] + self.assertEqual( + resp.json, + { + "members": [ + { + "group_id": "mock_service", + "member_id": member_id.decode("utf-8"), + "capabilities": { + "key1": "value1", + "name": "mock_service", + "hostname": proc_info["hostname"], + "pid": proc_info["pid"], + }, } - } - ] - }) + ] + }, + ) diff --git a/st2api/tests/unit/controllers/v1/test_timers.py b/st2api/tests/unit/controllers/v1/test_timers.py index 492c231b10..cb57844539 100644 --- a/st2api/tests/unit/controllers/v1/test_timers.py +++ b/st2api/tests/unit/controllers/v1/test_timers.py @@ -17,20 +17,29 @@ import st2common.services.triggers as trigger_service -with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()): +with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()): from st2api.controllers.v1.timers import TimersHolder from st2common.models.system.common import ResourceReference from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader -from st2common.constants.triggers import INTERVAL_TIMER_TRIGGER_REF, DATE_TIMER_TRIGGER_REF +from st2common.constants.triggers import ( + INTERVAL_TIMER_TRIGGER_REF, + DATE_TIMER_TRIGGER_REF, +) from st2common.constants.triggers import CRON_TIMER_TRIGGER_REF from st2tests.api import FunctionalTest -PACK = 'timers' +PACK = "timers" FIXTURES = { - 'triggers': ['cron1.yaml', 'date1.yaml', 'interval1.yaml', 'interval2.yaml', 'interval3.yaml'] + "triggers": [ + "cron1.yaml", + "date1.yaml", + "interval1.yaml", + "interval2.yaml", + "interval3.yaml", + ] } @@ -43,23 +52,28 @@ def setUpClass(cls): loader = FixturesLoader() TestTimersHolder.MODELS = loader.load_fixtures( - fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers'] + fixtures_pack=PACK, fixtures_dict=FIXTURES + )["triggers"] loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=FIXTURES) def test_add_trigger(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder._timers), 5) def test_remove_trigger(self): holder = TimersHolder() - model = TestTimersHolder.MODELS.get('cron1.yaml', None) + model = TestTimersHolder.MODELS.get("cron1.yaml", None) self.assertIsNotNone(model) - ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name']) + ref = ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ) holder.add_trigger(ref, model) self.assertEqual(len(holder._timers), 1) holder.remove_trigger(ref, model) @@ -69,8 +83,10 @@ def test_get_all(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder.get_all()), 5) @@ -78,8 +94,10 @@ def test_get_all_filters_filter_by_type(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder.get_all(timer_type=INTERVAL_TIMER_TRIGGER_REF)), 3) self.assertEqual(len(holder.get_all(timer_type=DATE_TIMER_TRIGGER_REF)), 1) @@ -95,20 +113,23 @@ def setUpClass(cls): loader = FixturesLoader() TestTimersController.MODELS = loader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers'] + fixtures_pack=PACK, fixtures_dict=FIXTURES + )["triggers"] def test_timerscontroller_get_one_with_id(self): - model = TestTimersController.MODELS['interval1.yaml'] + model = TestTimersController.MODELS["interval1.yaml"] get_resp = self._do_get_one(model.id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['parameters'], model['parameters']) + self.assertEqual(get_resp.json["parameters"], model["parameters"]) def test_timerscontroller_get_one_with_ref(self): - model = TestTimersController.MODELS['interval1.yaml'] - ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name']) + model = TestTimersController.MODELS["interval1.yaml"] + ref = ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ) get_resp = self._do_get_one(ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['parameters'], model['parameters']) + self.assertEqual(get_resp.json["parameters"], model["parameters"]) def _do_get_one(self, timer_id, expect_errors=False): - return self.app.get('/v1/timers/%s' % timer_id, expect_errors=expect_errors) + return self.app.get("/v1/timers/%s" % timer_id, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_traces.py b/st2api/tests/unit/controllers/v1/test_traces.py index 0ce16a2a29..79bbdad6ae 100644 --- a/st2api/tests/unit/controllers/v1/test_traces.py +++ b/st2api/tests/unit/controllers/v1/test_traces.py @@ -19,23 +19,24 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -FIXTURES_PACK = 'traces' +FIXTURES_PACK = "traces" TEST_MODELS = { - 'traces': [ - 'trace_empty.yaml', - 'trace_one_each.yaml', - 'trace_multiple_components.yaml' + "traces": [ + "trace_empty.yaml", + "trace_one_each.yaml", + "trace_multiple_components.yaml", ] } -class TracesControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/traces' +class TracesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/traces" controller_cls = TracesController - include_attribute_field_name = 'trace_tag' - exclude_attribute_field_name = 'start_timestamp' + include_attribute_field_name = "trace_tag" + exclude_attribute_field_name = "start_timestamp" models = None trace1 = None @@ -45,112 +46,145 @@ class TracesControllerTestCase(FunctionalTest, @classmethod def setUpClass(cls): super(TracesControllerTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.trace1 = cls.models['traces']['trace_empty.yaml'] - cls.trace2 = cls.models['traces']['trace_one_each.yaml'] - cls.trace3 = cls.models['traces']['trace_multiple_components.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.trace1 = cls.models["traces"]["trace_empty.yaml"] + cls.trace2 = cls.models["traces"]["trace_one_each.yaml"] + cls.trace3 = cls.models["traces"]["trace_multiple_components.yaml"] def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/traces') + resp = self.app.get("/v1/traces") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") # Note: traces are returned sorted by start_timestamp in descending order by default - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') - - resp = self.app.get('/v1/traces/?limit=-1') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) + + resp = self.app.get("/v1/traces/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") # Note: traces are returned sorted by start_timestamp in descending order by default - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) def test_get_all_ascending_and_descending(self): - resp = self.app.get('/v1/traces?sort_asc=True') + resp = self.app.get("/v1/traces?sort_asc=True") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag], + "Incorrect traces retrieved.", + ) - resp = self.app.get('/v1/traces?sort_desc=True') + resp = self.app.get("/v1/traces?sort_desc=True") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) def test_get_all_limit(self): - resp = self.app.get('/v1/traces?limit=1') + resp = self.app.get("/v1/traces?limit=1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 1, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag], 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, [self.trace3.trace_tag], "Incorrect traces retrieved." + ) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/traces?limit=-22', expect_errors=True) + resp = self.app.get("/v1/traces?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_by_id(self): - resp = self.app.get('/v1/traces/%s' % self.trace1.id) + resp = self.app.get("/v1/traces/%s" % self.trace1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.trace1.id), - 'Incorrect trace retrieved.') + self.assertEqual( + resp.json["id"], str(self.trace1.id), "Incorrect trace retrieved." + ) def test_query_by_trace_tag(self): - resp = self.app.get('/v1/traces?trace_tag=test-trace-1') + resp = self.app.get("/v1/traces?trace_tag=test-trace-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces?trace_tag=x did not return correct trace.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?trace_tag=x did not return correct trace." + ) - self.assertEqual(resp.json[0]['trace_tag'], self.trace1['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace1["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_action_execution(self): - execution_id = self.trace3['action_executions'][0].object_id - resp = self.app.get('/v1/traces?execution=%s' % execution_id) + execution_id = self.trace3["action_executions"][0].object_id + resp = self.app.get("/v1/traces?execution=%s" % execution_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, - '/v1/traces?execution=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?execution=x did not return correct trace." + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_rule(self): - rule_id = self.trace3['rules'][0].object_id - resp = self.app.get('/v1/traces?rule=%s' % rule_id) + rule_id = self.trace3["rules"][0].object_id + resp = self.app.get("/v1/traces?rule=%s" % rule_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces?rule=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?rule=x did not return correct trace." + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_trigger_instance(self): - trigger_instance_id = self.trace3['trigger_instances'][0].object_id - resp = self.app.get('/v1/traces?trigger_instance=%s' % trigger_instance_id) + trigger_instance_id = self.trace3["trigger_instances"][0].object_id + resp = self.app.get("/v1/traces?trigger_instance=%s" % trigger_instance_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, - '/v1/traces?trigger_instance=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), + 1, + "/v1/traces?trigger_instance=x did not return correct trace.", + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def _insert_mock_models(self): - trace_ids = [trace['id'] for trace in self.models['traces'].values()] + trace_ids = [trace["id"] for trace in self.models["traces"].values()] return trace_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_triggerinstances.py b/st2api/tests/unit/controllers/v1/test_triggerinstances.py index 0d81de723d..2a4149707c 100644 --- a/st2api/tests/unit/controllers/v1/test_triggerinstances.py +++ b/st2api/tests/unit/controllers/v1/test_triggerinstances.py @@ -31,13 +31,14 @@ http_client = six.moves.http_client -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class TriggerInstanceTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/triggerinstances' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class TriggerInstanceTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/triggerinstances" controller_cls = TriggerInstanceController - include_attribute_field_name = 'trigger' - exclude_attribute_field_name = 'payload' + include_attribute_field_name = "trigger" + exclude_attribute_field_name = "payload" @classmethod def setUpClass(cls): @@ -47,74 +48,84 @@ def setUpClass(cls): cls._setupTriggerInstance() def test_get_all(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), self.triggerinstance_count, 'Get all failure.') + self.assertEqual(len(resp.json), self.triggerinstance_count, "Get all failure.") def test_get_all_limit(self): limit = 1 - resp = self.app.get('/v1/triggerinstances?limit=%d' % limit) + resp = self.app.get("/v1/triggerinstances?limit=%d" % limit) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), limit, 'Get all failure. Length doesn\'t match limit.') + self.assertEqual( + len(resp.json), limit, "Get all failure. Length doesn't match limit." + ) def test_get_all_limit_negative_number(self): limit = -22 - resp = self.app.get('/v1/triggerinstances?limit=%d' % limit, expect_errors=True) + resp = self.app.get("/v1/triggerinstances?limit=%d" % limit, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_filter_by_trigger(self): - trigger = 'dummy_pack_1.st2.test.trigger0' - resp = self.app.get('/v1/triggerinstances?trigger=%s' % trigger) + trigger = "dummy_pack_1.st2.test.trigger0" + resp = self.app.get("/v1/triggerinstances?trigger=%s" % trigger) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), 1, 'Get all failure. Must get only one such instance.') + self.assertEqual( + len(resp.json), 1, "Get all failure. Must get only one such instance." + ) def test_get_all_filter_by_timestamp(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - timestamp_largest = resp.json[0]['occurrence_time'] - timestamp_middle = resp.json[1]['occurrence_time'] + timestamp_largest = resp.json[0]["occurrence_time"] + timestamp_middle = resp.json[1]["occurrence_time"] dt = isotime.parse(timestamp_largest) dt = dt + datetime.timedelta(seconds=1) timestamp_largest = isotime.format(dt, offset=False) - resp = self.app.get('/v1/triggerinstances?timestamp_gt=%s' % timestamp_largest) + resp = self.app.get("/v1/triggerinstances?timestamp_gt=%s" % timestamp_largest) # Since we sort trigger instances by time (latest first), the previous # get should return no trigger instances. self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/triggerinstances?timestamp_lt=%s' % (timestamp_middle)) + resp = self.app.get("/v1/triggerinstances?timestamp_lt=%s" % (timestamp_middle)) self.assertEqual(len(resp.json), 1) def test_get_all_trigger_type_ref_filtering(self): # 1. Invalid / inexistent trigger type ref - resp = self.app.get('/v1/triggerinstances?trigger_type=foo.bar.invalid') + resp = self.app.get("/v1/triggerinstances?trigger_type=foo.bar.invalid") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) # 2. Valid trigger type ref with corresponding trigger instances - resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0') + resp = self.app.get( + "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) # 3. Valid trigger type ref with no corresponding trigger instances - resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3') + resp = self.app.get( + "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) def test_reemit_trigger_instance(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - instance_id = resp.json[0]['id'] - resp = self.app.post('/v1/triggerinstances/%s/re_emit' % instance_id) + instance_id = resp.json[0]["id"] + resp = self.app.post("/v1/triggerinstances/%s/re_emit" % instance_id) self.assertEqual(resp.status_int, http_client.OK) - resent_message = resp.json['message'] - resent_payload = resp.json['payload'] + resent_message = resp.json["message"] + resent_payload = resp.json["payload"] self.assertIn(instance_id, resent_message) - self.assertIn('__context', resent_payload) - self.assertEqual(resent_payload['__context']['original_id'], instance_id) + self.assertIn("__context", resent_payload) + self.assertEqual(resent_payload["__context"]["original_id"], instance_id) def test_get_one(self): triggerinstance_id = str(self.triggerinstance_1.id) @@ -133,79 +144,78 @@ def test_get_one(self): self.assertEqual(self._get_id(resp), triggerinstance_id) def test_get_one_fail(self): - resp = self._do_get_one('1') + resp = self._do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) @classmethod def _setupTriggerTypes(cls): TRIGGERTYPE_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGERTYPE_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGERTYPE_2 = { - 'name': 'st2.test.triggertype2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype2", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } TRIGGERTYPE_3 = { - 'name': 'st2.test.triggertype3', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype3", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_3, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_3, expect_errors=False) @classmethod def _setupTriggers(cls): TRIGGER_0 = { - 'name': 'st2.test.trigger0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype0', - 'parameters': {} + "name": "st2.test.trigger0", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype0", + "parameters": {}, } TRIGGER_1 = { - 'name': 'st2.test.trigger1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype1', - 'parameters': {} + "name": "st2.test.trigger1", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype1", + "parameters": {}, } TRIGGER_2 = { - 'name': 'st2.test.trigger2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype2', - 'parameters': { - 'param1': { - 'foo': 'bar' - } - } + "name": "st2.test.trigger2", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype2", + "parameters": {"param1": {"foo": "bar"}}, } - cls.app.post_json('/v1/triggers', TRIGGER_0, expect_errors=False) - cls.app.post_json('/v1/triggers', TRIGGER_1, expect_errors=False) - cls.app.post_json('/v1/triggers', TRIGGER_2, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_0, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_1, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_2, expect_errors=False) def _insert_mock_models(self): - return [self.triggerinstance_1['id'], self.triggerinstance_2['id'], - self.triggerinstance_3['id']] + return [ + self.triggerinstance_1["id"], + self.triggerinstance_2["id"], + self.triggerinstance_3["id"], + ] def _delete_mock_models(self, object_ids): return None @@ -214,17 +224,20 @@ def _delete_mock_models(self, object_ids): def _setupTriggerInstance(cls): cls.triggerinstance_count = 0 cls.triggerinstance_1 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger0', - payload={'tp1': 1, 'tp2': 2, 'tp3': 3}, - seconds=1) + trigger_ref="dummy_pack_1.st2.test.trigger0", + payload={"tp1": 1, "tp2": 2, "tp3": 3}, + seconds=1, + ) cls.triggerinstance_2 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger1', - payload={'tp1': 'a', 'tp2': 'b', 'tp3': 'c'}, - seconds=2) + trigger_ref="dummy_pack_1.st2.test.trigger1", + payload={"tp1": "a", "tp2": "b", "tp3": "c"}, + seconds=2, + ) cls.triggerinstance_3 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger2', - payload={'tp1': None, 'tp2': None, 'tp3': None}, - seconds=3) + trigger_ref="dummy_pack_1.st2.test.trigger2", + payload={"tp1": None, "tp2": None, "tp3": None}, + seconds=3, + ) @classmethod def _create_trigger_instance(cls, trigger_ref, payload, seconds): @@ -244,7 +257,9 @@ def _create_trigger_instance(cls, trigger_ref, payload, seconds): @staticmethod def _get_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_one(self, triggerinstance_id): - return self.app.get('/v1/triggerinstances/%s' % triggerinstance_id, expect_errors=True) + return self.app.get( + "/v1/triggerinstances/%s" % triggerinstance_id, expect_errors=True + ) diff --git a/st2api/tests/unit/controllers/v1/test_triggers.py b/st2api/tests/unit/controllers/v1/test_triggers.py index d3526e624a..5067c7674f 100644 --- a/st2api/tests/unit/controllers/v1/test_triggers.py +++ b/st2api/tests/unit/controllers/v1/test_triggers.py @@ -22,57 +22,52 @@ http_client = six.moves.http_client TRIGGER_0 = { - 'name': 'st2.test.trigger0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype0', - 'parameters': {} + "name": "st2.test.trigger0", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype0", + "parameters": {}, } TRIGGER_1 = { - 'name': 'st2.test.trigger1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype1', - 'parameters': {} + "name": "st2.test.trigger1", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype1", + "parameters": {}, } TRIGGER_2 = { - 'name': 'st2.test.trigger2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype2', - 'parameters': { - 'param1': { - 'foo': 'bar' - } - } + "name": "st2.test.trigger2", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype2", + "parameters": {"param1": {"foo": "bar"}}, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class TestTriggerController(FunctionalTest): - @classmethod def setUpClass(cls): super(TestTriggerController, cls).setUpClass() cls._setupTriggerTypes() def test_get_all(self): - resp = self.app.get('/v1/triggers') + resp = self.app.get("/v1/triggers") self.assertEqual(resp.status_int, http_client.OK) # TriggerType without parameters will register a trigger # with same name. - self.assertEqual(len(resp.json), 2, 'Get all failure. %s' % resp.json) + self.assertEqual(len(resp.json), 2, "Get all failure. %s" % resp.json) post_resp = self._do_post(TRIGGER_0) trigger_id_0 = self._get_trigger_id(post_resp) post_resp = self._do_post(TRIGGER_1) trigger_id_1 = self._get_trigger_id(post_resp) - resp = self.app.get('/v1/triggers') + resp = self.app.get("/v1/triggers") self.assertEqual(resp.status_int, http_client.OK) # TriggerType without parameters will register a trigger # with same name. So here we see 4 instead of 2. - self.assertEqual(len(resp.json), 4, 'Get all failure.') + self.assertEqual(len(resp.json), 4, "Get all failure.") self._do_delete(trigger_id_0) self._do_delete(trigger_id_1) @@ -85,7 +80,7 @@ def test_get_one(self): self._do_delete(trigger_id) def test_get_one_fail(self): - resp = self._do_get_one('1') + resp = self._do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -106,13 +101,15 @@ def test_post_duplicate(self): # id is same in both cases. post_resp_2 = self._do_post(TRIGGER_1) self.assertEqual(post_resp_2.status_int, http_client.CREATED) - self.assertEqual(self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2)) + self.assertEqual( + self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2) + ) self._do_delete(self._get_trigger_id(post_resp)) def test_put(self): post_resp = self._do_post(TRIGGER_1) update_input = post_resp.json - update_input['description'] = 'updated description.' + update_input["description"] = "updated description." put_resp = self._do_put(self._get_trigger_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self._do_delete(self._get_trigger_id(put_resp)) @@ -133,41 +130,43 @@ def test_delete(self): @classmethod def _setupTriggerTypes(cls): TRIGGERTYPE_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGERTYPE_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGERTYPE_2 = { - 'name': 'st2.test.triggertype2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype2", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False) @staticmethod def _get_trigger_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_one(self, trigger_id): - return self.app.get('/v1/triggers/%s' % trigger_id, expect_errors=True) + return self.app.get("/v1/triggers/%s" % trigger_id, expect_errors=True) def _do_post(self, trigger): - return self.app.post_json('/v1/triggers', trigger, expect_errors=True) + return self.app.post_json("/v1/triggers", trigger, expect_errors=True) def _do_put(self, trigger_id, trigger): - return self.app.put_json('/v1/triggers/%s' % trigger_id, trigger, expect_errors=True) + return self.app.put_json( + "/v1/triggers/%s" % trigger_id, trigger, expect_errors=True + ) def _do_delete(self, trigger_id): - return self.app.delete('/v1/triggers/%s' % trigger_id) + return self.app.delete("/v1/triggers/%s" % trigger_id) diff --git a/st2api/tests/unit/controllers/v1/test_triggertypes.py b/st2api/tests/unit/controllers/v1/test_triggertypes.py index c7848f5c2d..414fc34360 100644 --- a/st2api/tests/unit/controllers/v1/test_triggertypes.py +++ b/st2api/tests/unit/controllers/v1/test_triggertypes.py @@ -23,33 +23,34 @@ http_client = six.moves.http_client TRIGGER_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGER_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_2', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_2", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGER_2 = { - 'name': 'st2.test.triggertype3', - 'pack': 'dummy_pack_3', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype3", + "pack": "dummy_pack_3", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } -class TriggerTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/triggertypes' +class TriggerTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/triggertypes" controller_cls = TriggerTypeController - include_attribute_field_name = 'payload_schema' - exclude_attribute_field_name = 'parameters_schema' + include_attribute_field_name = "payload_schema" + exclude_attribute_field_name = "parameters_schema" @classmethod def setUpClass(cls): @@ -71,19 +72,19 @@ def test_get_all(self): trigger_id_0 = self.__get_trigger_id(post_resp) post_resp = self.__do_post(TRIGGER_1) trigger_id_1 = self.__get_trigger_id(post_resp) - resp = self.app.get('/v1/triggertypes') + resp = self.app.get("/v1/triggertypes") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), 2, 'Get all failure.') + self.assertEqual(len(resp.json), 2, "Get all failure.") # ?pack query filter - resp = self.app.get('/v1/triggertypes?pack=doesnt-exist-invalid') + resp = self.app.get("/v1/triggertypes?pack=doesnt-exist-invalid") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/triggertypes?pack=%s' % (TRIGGER_0['pack'])) + resp = self.app.get("/v1/triggertypes?pack=%s" % (TRIGGER_0["pack"])) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['pack'], TRIGGER_0['pack']) + self.assertEqual(resp.json[0]["pack"], TRIGGER_0["pack"]) self.__do_delete(trigger_id_0) self.__do_delete(trigger_id_1) @@ -97,7 +98,7 @@ def test_get_one(self): self.__do_delete(trigger_id) def test_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -116,13 +117,13 @@ def test_post_duplicate(self): self.assertEqual(post_resp.status_int, http_client.CREATED) post_resp_2 = self.__do_post(TRIGGER_1) self.assertEqual(post_resp_2.status_int, http_client.CONFLICT) - self.assertEqual(post_resp_2.json['conflict-id'], org_id) + self.assertEqual(post_resp_2.json["conflict-id"], org_id) self.__do_delete(org_id) def test_put(self): post_resp = self.__do_post(TRIGGER_1) update_input = post_resp.json - update_input['description'] = 'updated description.' + update_input["description"] = "updated description." put_resp = self.__do_put(self.__get_trigger_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_trigger_id(put_resp)) @@ -151,16 +152,18 @@ def _do_delete(self, trigger_id): @staticmethod def __get_trigger_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, trigger_id): - return self.app.get('/v1/triggertypes/%s' % trigger_id, expect_errors=True) + return self.app.get("/v1/triggertypes/%s" % trigger_id, expect_errors=True) def __do_post(self, trigger): - return self.app.post_json('/v1/triggertypes', trigger, expect_errors=True) + return self.app.post_json("/v1/triggertypes", trigger, expect_errors=True) def __do_put(self, trigger_id, trigger): - return self.app.put_json('/v1/triggertypes/%s' % trigger_id, trigger, expect_errors=True) + return self.app.put_json( + "/v1/triggertypes/%s" % trigger_id, trigger, expect_errors=True + ) def __do_delete(self, trigger_id): - return self.app.delete('/v1/triggertypes/%s' % trigger_id) + return self.app.delete("/v1/triggertypes/%s" % trigger_id) diff --git a/st2api/tests/unit/controllers/v1/test_webhooks.py b/st2api/tests/unit/controllers/v1/test_webhooks.py index 487830a092..e8fedc673c 100644 --- a/st2api/tests/unit/controllers/v1/test_webhooks.py +++ b/st2api/tests/unit/controllers/v1/test_webhooks.py @@ -21,7 +21,7 @@ import st2common.services.triggers as trigger_service -with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()): +with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()): from st2api.controllers.v1.webhooks import WebhooksController, HooksHolder from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES @@ -34,28 +34,20 @@ http_client = six.moves.http_client -WEBHOOK_1 = { - 'action': 'closed', - 'pull_request': { - 'merged': True - } -} +WEBHOOK_1 = {"action": "closed", "pull_request": {"merged": True}} ST2_WEBHOOK = { - 'trigger': 'git.pr-merged', - 'payload': { - 'value_str': 'string!', - 'value_int': 12345 - } + "trigger": "git.pr-merged", + "payload": {"value_str": "string!", "value_int": 12345}, } WEBHOOK_DATA = { - 'value_str': 'test string 1', - 'value_int': 987654, + "value_str": "test string 1", + "value_int": 987654, } # 1. Trigger which references a system webhook trigger type -DUMMY_TRIGGER_DB = TriggerDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_DB = TriggerDB(name="pr-merged", pack="git") DUMMY_TRIGGER_DB.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0] @@ -63,34 +55,24 @@ DUMMY_TRIGGER_DICT = vars(DUMMY_TRIGGER_API) # 2. Custom TriggerType object -DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name="pr-merged", pack="git") DUMMY_TRIGGER_TYPE_DB.payload_schema = { - 'type': 'object', - 'properties': { - 'body': { - 'properties': { - 'value_str': { - 'type': 'string', - 'required': True - }, - 'value_int': { - 'type': 'integer', - 'required': True - } + "type": "object", + "properties": { + "body": { + "properties": { + "value_str": {"type": "string", "required": True}, + "value_int": {"type": "integer", "required": True}, } } - } + }, } # 2. Custom TriggerType object -DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name="pr-merged", pack="git") DUMMY_TRIGGER_TYPE_DB_2.payload_schema = { - 'type': 'object', - 'properties': { - 'body': { - 'type': 'array' - } - } + "type": "object", + "properties": {"body": {"type": "array"}}, } @@ -100,190 +82,244 @@ def setUp(self): cfg.CONF.system.validate_trigger_payload = True - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_all', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, "get_all", mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]) + ) def test_get_all(self): - get_resp = self.app.get('/v1/webhooks', expect_errors=False) + get_resp = self.app.get("/v1/webhooks", expect_errors=False) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(get_resp.json, [DUMMY_TRIGGER_DICT]) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post(self, dispatch_mock): - post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False) + post_resp = self.__do_post("git", WEBHOOK_1, expect_errors=False) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_with_trace(self, dispatch_mock): - post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False, - headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post( + "git", WEBHOOK_1, expect_errors=False, headers={"St2-Trace-Tag": "tag1"} + ) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) def test_post_hook_not_registered(self): - post_resp = self.__do_post('foo', WEBHOOK_1, expect_errors=True) + post_resp = self.__do_post("foo", WEBHOOK_1, expect_errors=True) self.assertEqual(post_resp.status_int, http_client.NOT_FOUND) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_success(self, dispatch_mock): - post_resp = self.__do_post('st2', ST2_WEBHOOK) + post_resp = self.__do_post("st2", ST2_WEBHOOK) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag) + self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag) - post_resp = self.__do_post('st2/', ST2_WEBHOOK) + post_resp = self.__do_post("st2/", ST2_WEBHOOK) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_failure_payload_validation_failed(self, dispatch_mock): - data = { - 'trigger': 'git.pr-merged', - 'payload': 'invalid' - } - post_resp = self.__do_post('st2', data, expect_errors=True) + data = {"trigger": "git.pr-merged", "payload": "invalid"} + post_resp = self.__do_post("st2", data, expect_errors=True) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = 'Trigger payload validation failed' - self.assertIn(expected_msg, post_resp.json['faultstring']) + expected_msg = "Trigger payload validation failed" + self.assertIn(expected_msg, post_resp.json["faultstring"]) expected_msg = "'invalid' is not of type 'object'" - self.assertIn(expected_msg, post_resp.json['faultstring']) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn(expected_msg, post_resp.json["faultstring"]) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_with_trace(self, dispatch_mock): - post_resp = self.__do_post('st2', ST2_WEBHOOK, headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post( + "st2", ST2_WEBHOOK, headers={"St2-Trace-Tag": "tag1"} + ) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) def test_st2_webhook_body_missing_trigger(self): - post_resp = self.__do_post('st2', {'payload': {}}, expect_errors=True) - self.assertIn('Trigger not specified.', post_resp) + post_resp = self.__do_post("st2", {"payload": {}}, expect_errors=True) + self.assertIn("Trigger not specified.", post_resp) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_json_request_body(self, dispatch_mock): # 1. Send JSON using application/json content type data = WEBHOOK_1 - post_resp = self.__do_post('git', data, - headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post("git", data, headers={"St2-Trace-Tag": "tag1"}) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/json') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/json", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") # 2. Send JSON using application/json + charset content type data = WEBHOOK_1 - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json; charset=utf-8'} - post_resp = self.__do_post('git', data, - headers=headers) + headers = { + "St2-Trace-Tag": "tag1", + "Content-Type": "application/json; charset=utf-8", + } + post_resp = self.__do_post("git", data, headers=headers) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/json; charset=utf-8') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/json; charset=utf-8", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") # 3. JSON content type, invalid JSON body - data = 'invalid' - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json'} - post_resp = self.app.post('/v1/webhooks/git', data, headers=headers, - expect_errors=True) + data = "invalid" + headers = {"St2-Trace-Tag": "tag1", "Content-Type": "application/json"} + post_resp = self.app.post( + "/v1/webhooks/git", data, headers=headers, expect_errors=True + ) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('Failed to parse request body', post_resp) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn("Failed to parse request body", post_resp) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_form_encoded_request_body(self, dispatch_mock): # Send request body as form urlencoded data if six.PY3: - data = {b'form': [b'test']} + data = {b"form": [b"test"]} else: - data = {'form': ['test']} + data = {"form": ["test"]} headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'St2-Trace-Tag': 'tag1' + "Content-Type": "application/x-www-form-urlencoded", + "St2-Trace-Tag": "tag1", } - self.app.post('/v1/webhooks/git', data, headers=headers) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/x-www-form-urlencoded') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.app.post("/v1/webhooks/git", data, headers=headers) + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/x-www-form-urlencoded", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") def test_unsupported_content_type(self): # Invalid / unsupported content type - should throw data = WEBHOOK_1 - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'foo/invalid'} - post_resp = self.app.post('/v1/webhooks/git', json.dumps(data), headers=headers, - expect_errors=True) + headers = {"St2-Trace-Tag": "tag1", "Content-Type": "foo/invalid"} + post_resp = self.app.post( + "/v1/webhooks/git", json.dumps(data), headers=headers, expect_errors=True + ) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('Failed to parse request body', post_resp) - self.assertIn('Unsupported Content-Type', post_resp) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB_2)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn("Failed to parse request body", post_resp) + self.assertIn("Unsupported Content-Type", post_resp) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB_2), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_custom_webhook_array_input_type(self, _): - post_resp = self.__do_post('sample', [{'foo': 'bar'}]) + post_resp = self.__do_post("sample", [{"foo": "bar"}]) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(post_resp.json, [{'foo': 'bar'}]) + self.assertEqual(post_resp.json, [{"foo": "bar"}]) def test_st2_webhook_array_webhook_array_input_type_not_valid(self): - post_resp = self.__do_post('st2', [{'foo': 'bar'}], expect_errors=True) + post_resp = self.__do_post("st2", [{"foo": "bar"}], expect_errors=True) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(post_resp.json['faultstring'], - 'Webhook body needs to be an object, got: array') + self.assertEqual( + post_resp.json["faultstring"], + "Webhook body needs to be an object, got: array", + ) def test_leading_trailing_slashes(self): # Ideally the test should setup fixtures in DB. However, the triggerwatcher @@ -296,52 +332,65 @@ def test_leading_trailing_slashes(self): # require hacking into the test app and force dependency on pecan internals. # TLDR; sorry for the ghetto test. Not sure how else to test this as a unit test. def get_webhook_trigger(name, url): - trigger = TriggerDB(name=name, pack='test') + trigger = TriggerDB(name=name, pack="test") trigger.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0] - trigger.parameters = {'url': url} + trigger.parameters = {"url": url} return trigger test_triggers = [ - get_webhook_trigger('no_slash', 'no_slash'), - get_webhook_trigger('with_leading_slash', '/with_leading_slash'), - get_webhook_trigger('with_trailing_slash', '/with_trailing_slash/'), - get_webhook_trigger('with_leading_trailing_slash', '/with_leading_trailing_slash/'), - get_webhook_trigger('with_mixed_slash', '/with/mixed/slash/') + get_webhook_trigger("no_slash", "no_slash"), + get_webhook_trigger("with_leading_slash", "/with_leading_slash"), + get_webhook_trigger("with_trailing_slash", "/with_trailing_slash/"), + get_webhook_trigger( + "with_leading_trailing_slash", "/with_leading_trailing_slash/" + ), + get_webhook_trigger("with_mixed_slash", "/with/mixed/slash/"), ] controller = WebhooksController() for trigger in test_triggers: controller.add_trigger(trigger) - self.assertTrue(controller._is_valid_hook('no_slash')) - self.assertFalse(controller._is_valid_hook('/no_slash')) - self.assertTrue(controller._is_valid_hook('with_leading_slash')) - self.assertTrue(controller._is_valid_hook('with_trailing_slash')) - self.assertTrue(controller._is_valid_hook('with_leading_trailing_slash')) - self.assertTrue(controller._is_valid_hook('with/mixed/slash')) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertTrue(controller._is_valid_hook("no_slash")) + self.assertFalse(controller._is_valid_hook("/no_slash")) + self.assertTrue(controller._is_valid_hook("with_leading_slash")) + self.assertTrue(controller._is_valid_hook("with_trailing_slash")) + self.assertTrue(controller._is_valid_hook("with_leading_trailing_slash")) + self.assertTrue(controller._is_valid_hook("with/mixed/slash")) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_authentication_headers_should_be_removed(self, dispatch_mock): headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'St2-Api-Key': 'foobar', - 'X-Auth-Token': 'deadbeaf', - 'Cookie': 'foo=bar' + "Content-Type": "application/x-www-form-urlencoded", + "St2-Api-Key": "foobar", + "X-Auth-Token": "deadbeaf", + "Cookie": "foo=bar", } - self.app.post('/v1/webhooks/git', WEBHOOK_1, headers=headers) - self.assertNotIn('St2-Api-Key', dispatch_mock.call_args[1]['payload']['headers']) - self.assertNotIn('X-Auth-Token', dispatch_mock.call_args[1]['payload']['headers']) - self.assertNotIn('Cookie', dispatch_mock.call_args[1]['payload']['headers']) + self.app.post("/v1/webhooks/git", WEBHOOK_1, headers=headers) + self.assertNotIn( + "St2-Api-Key", dispatch_mock.call_args[1]["payload"]["headers"] + ) + self.assertNotIn( + "X-Auth-Token", dispatch_mock.call_args[1]["payload"]["headers"] + ) + self.assertNotIn("Cookie", dispatch_mock.call_args[1]["payload"]["headers"]) def __do_post(self, hook, webhook, expect_errors=False, headers=None): - return self.app.post_json('/v1/webhooks/' + hook, - params=webhook, - expect_errors=expect_errors, - headers=headers) + return self.app.post_json( + "/v1/webhooks/" + hook, + params=webhook, + expect_errors=expect_errors, + headers=headers, + ) diff --git a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py index 3b45421d79..91e251fe9d 100644 --- a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py +++ b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py @@ -22,13 +22,17 @@ from st2tests.api import FunctionalTest -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK -PACKS = [TEST_PACK_PATH, st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'] +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) +PACKS = [ + TEST_PACK_PATH, + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", +] class WorkflowInspectionControllerTest(FunctionalTest, st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowInspectionControllerTest, cls).setUpClass() @@ -39,8 +43,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -48,14 +51,14 @@ def setUpClass(cls): def _do_post(self, wf_def, expect_errors=False): return self.app.post( - '/v1/workflows/inspect', + "/v1/workflows/inspect", wf_def, expect_errors=expect_errors, - content_type='text/plain' + content_type="text/plain", ) def test_inspection(self): - wf_file = 'sequential.yaml' + wf_file = "sequential.yaml" wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) @@ -65,48 +68,48 @@ def test_inspection(self): self.assertListEqual(response.json, expected_errors) def test_inspection_return_errors(self): - wf_file = 'fail-inspection.yaml' + wf_file = "fail-inspection.yaml" wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] response = self._do_post(wf_def, expect_errors=False) diff --git a/st2api/tests/unit/test_validation_utils.py b/st2api/tests/unit/test_validation_utils.py index eaf1cd75a5..bad17b22a5 100644 --- a/st2api/tests/unit/test_validation_utils.py +++ b/st2api/tests/unit/test_validation_utils.py @@ -19,9 +19,7 @@ from st2api.validation import validate_rbac_is_correctly_configured from st2tests import config as tests_config -__all__ = [ - 'ValidationUtilsTestCase' -] +__all__ = ["ValidationUtilsTestCase"] class ValidationUtilsTestCase(unittest2.TestCase): @@ -34,26 +32,34 @@ def test_validate_rbac_is_correctly_configured_succcess(self): self.assertTrue(result) def test_validate_rbac_is_correctly_configured_auth_not_enabled(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='auth', name='enable', override=False) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="auth", name="enable", override=False) - expected_msg = ('Authentication is not enabled. RBAC only works when authentication is ' - 'enabled. You can either enable authentication or disable RBAC.') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_rbac_is_correctly_configured) + expected_msg = ( + "Authentication is not enabled. RBAC only works when authentication is " + "enabled. You can either enable authentication or disable RBAC." + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_rbac_is_correctly_configured + ) def test_validate_rbac_is_correctly_configured_non_default_backend_set(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='backend', override='invalid') - cfg.CONF.set_override(group='auth', name='enable', override=True) - - expected_msg = ('You have enabled RBAC, but RBAC backend is not set to "default".') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_rbac_is_correctly_configured) - - def test_validate_rbac_is_correctly_configured_default_backend_available_success(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='backend', override='default') - cfg.CONF.set_override(group='auth', name='enable', override=True) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="backend", override="invalid") + cfg.CONF.set_override(group="auth", name="enable", override=True) + + expected_msg = ( + 'You have enabled RBAC, but RBAC backend is not set to "default".' + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_rbac_is_correctly_configured + ) + + def test_validate_rbac_is_correctly_configured_default_backend_available_success( + self, + ): + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="backend", override="default") + cfg.CONF.set_override(group="auth", name="enable", override=True) result = validate_rbac_is_correctly_configured() self.assertTrue(result) diff --git a/st2auth/dist_utils.py b/st2auth/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2auth/dist_utils.py +++ b/st2auth/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2auth/setup.py b/st2auth/setup.py index f77ee72f03..c6e266472b 100644 --- a/st2auth/setup.py +++ b/st2auth/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2auth import __version__ -ST2_COMPONENT = 'st2auth' +ST2_COMPONENT = "st2auth" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,23 +33,21 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2auth' - ], + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2auth"], entry_points={ - 'st2auth.sso.backends': [ - 'noop = st2auth.sso.noop:NoOpSingleSignOnBackend' - ] - } + "st2auth.sso.backends": ["noop = st2auth.sso.noop:NoOpSingleSignOnBackend"] + }, ) diff --git a/st2auth/st2auth/__init__.py b/st2auth/st2auth/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2auth/st2auth/__init__.py +++ b/st2auth/st2auth/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2auth/st2auth/app.py b/st2auth/st2auth/app.py index 3398104c6c..b9b8f7d595 100644 --- a/st2auth/st2auth/app.py +++ b/st2auth/st2auth/app.py @@ -36,34 +36,38 @@ def setup_app(config=None): config = config or {} - LOG.info('Creating st2auth: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2auth: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # NOTE: We only want to perform this logic in the WSGI worker st2auth_config.register_opts() capabilities = { - 'name': 'auth', - 'listen_host': cfg.CONF.auth.host, - 'listen_port': cfg.CONF.auth.port, - 'listen_ssl': cfg.CONF.auth.use_ssl, - 'type': 'active' + "name": "auth", + "listen_host": cfg.CONF.auth.host, + "listen_port": cfg.CONF.auth.port, + "listen_ssl": cfg.CONF.auth.use_ssl, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='auth', config=st2auth_config, setup_db=True, - register_mq_exchanges=False, - register_signal_handlers=True, - register_internal_trigger_types=False, - run_migrations=False, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="auth", + config=st2auth_config, + setup_db=True, + register_mq_exchanges=False, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) # pysaml2 uses subprocess communicate which calls communicate_with_poll - if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == 'saml2': + if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == "saml2": use_select_poll_workaround(nose_only=False) # Additional pre-run time checks @@ -71,10 +75,8 @@ def setup_app(config=None): router = Router(debug=cfg.CONF.auth.debug, is_gunicorn=is_gunicorn) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') - transforms = { - '^/auth/v1/': ['/', '/v1/'] - } + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") + transforms = {"^/auth/v1/": ["/", "/v1/"]} router.add_spec(spec, transforms=transforms) app = router.as_wsgi @@ -83,8 +85,8 @@ def setup_app(config=None): app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='auth') + app = ResponseInstrumentationMiddleware(app, router, service_name="auth") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='auth') + app = RequestInstrumentationMiddleware(app, router, service_name="auth") return app diff --git a/st2auth/st2auth/backends/__init__.py b/st2auth/st2auth/backends/__init__.py index 64d3275af5..a626f0d082 100644 --- a/st2auth/st2auth/backends/__init__.py +++ b/st2auth/st2auth/backends/__init__.py @@ -22,14 +22,11 @@ from st2common import log as logging from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance' -] +__all__ = ["get_available_backends", "get_backend_instance"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2auth.backends.backend' +BACKENDS_NAMESPACE = "st2auth.backends.backend" def get_available_backends(): @@ -43,8 +40,10 @@ def get_backend_instance(name): try: kwargs = json.loads(backend_kwargs) except ValueError as e: - raise ValueError('Failed to JSON parse backend settings for backend "%s": %s' % - (name, six.text_type(e))) + raise ValueError( + 'Failed to JSON parse backend settings for backend "%s": %s' + % (name, six.text_type(e)) + ) else: kwargs = {} @@ -55,9 +54,11 @@ def get_backend_instance(name): except Exception as e: tb_msg = traceback.format_exc() class_name = cls.__name__ - msg = ('Failed to instantiate auth backend "%s" (class %s) with backend settings ' - '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to instantiate auth backend "%s" (class %s) with backend settings ' + '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) diff --git a/st2auth/st2auth/backends/base.py b/st2auth/st2auth/backends/base.py index 0246729c1a..4d32e51860 100644 --- a/st2auth/st2auth/backends/base.py +++ b/st2auth/st2auth/backends/base.py @@ -19,9 +19,7 @@ from st2auth.backends.constants import AuthBackendCapability -__all__ = [ - 'BaseAuthenticationBackend' -] +__all__ = ["BaseAuthenticationBackend"] @six.add_metaclass(abc.ABCMeta) @@ -31,9 +29,7 @@ class BaseAuthenticationBackend(object): """ # Capabilities offered by the auth backend - CAPABILITIES = ( - AuthBackendCapability.CAN_AUTHENTICATE_USER - ) + CAPABILITIES = AuthBackendCapability.CAN_AUTHENTICATE_USER @abc.abstractmethod def authenticate(self, username, password): @@ -47,7 +43,7 @@ def get_user(self, username): :rtype: ``dict`` """ - raise NotImplementedError('get_user() not implemented for this backend') + raise NotImplementedError("get_user() not implemented for this backend") def get_user_groups(self, username): """ @@ -57,4 +53,4 @@ def get_user_groups(self, username): :rtype: ``list`` of ``str`` """ - raise NotImplementedError('get_groups() not implemented for this backend') + raise NotImplementedError("get_groups() not implemented for this backend") diff --git a/st2auth/st2auth/backends/constants.py b/st2auth/st2auth/backends/constants.py index 6cb990c64d..b50625e745 100644 --- a/st2auth/st2auth/backends/constants.py +++ b/st2auth/st2auth/backends/constants.py @@ -19,17 +19,15 @@ from st2common.util.enum import Enum -__all__ = [ - 'AuthBackendCapability' -] +__all__ = ["AuthBackendCapability"] class AuthBackendCapability(Enum): # This auth backend can authenticate a user. - CAN_AUTHENTICATE_USER = 'can_authenticate_user' + CAN_AUTHENTICATE_USER = "can_authenticate_user" # Auth backend can provide additional information about a particular user. - HAS_USER_INFORMATION = 'has_user_info' + HAS_USER_INFORMATION = "has_user_info" # Auth backend can provide a group membership information for a particular user. - HAS_GROUP_INFORMATION = 'has_groups_info' + HAS_GROUP_INFORMATION = "has_groups_info" diff --git a/st2auth/st2auth/cmd/api.py b/st2auth/st2auth/cmd/api.py index d1fd7605bd..4c52f2649c 100644 --- a/st2auth/st2auth/cmd/api.py +++ b/st2auth/st2auth/cmd/api.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import eventlet @@ -27,15 +28,14 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown from st2auth import config + config.register_opts() from st2auth import app from st2auth.validation import validate_auth_backend_is_correctly_configured -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -43,15 +43,23 @@ def _setup(): capabilities = { - 'name': 'auth', - 'listen_host': cfg.CONF.auth.host, - 'listen_port': cfg.CONF.auth.port, - 'listen_ssl': cfg.CONF.auth.use_ssl, - 'type': 'active' + "name": "auth", + "listen_host": cfg.CONF.auth.host, + "listen_port": cfg.CONF.auth.port, + "listen_ssl": cfg.CONF.auth.use_ssl, + "type": "active", } - common_setup(service='auth', config=config, setup_db=True, register_mq_exchanges=False, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=False, service_registry=True, capabilities=capabilities) + common_setup( + service="auth", + config=config, + setup_db=True, + register_mq_exchanges=False, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + ) # Additional pre-run time checks validate_auth_backend_is_correctly_configured() @@ -74,14 +82,18 @@ def _run_server(): socket = eventlet.listen((host, port)) if use_ssl: - socket = eventlet.wrap_ssl(socket, - certfile=cert_file_path, - keyfile=key_file_path, - server_side=True) + socket = eventlet.wrap_ssl( + socket, certfile=cert_file_path, keyfile=key_file_path, server_side=True + ) LOG.info('ST2 Auth API running in "%s" auth mode', cfg.CONF.auth.mode) - LOG.info('(PID=%s) ST2 Auth API is serving on %s://%s:%s.', os.getpid(), - 'https' if use_ssl else 'http', host, port) + LOG.info( + "(PID=%s) ST2 Auth API is serving on %s://%s:%s.", + os.getpid(), + "https" if use_ssl else "http", + host, + port, + ) wsgi.server(socket, app.setup_app(), log=LOG, log_output=False) return 0 @@ -98,7 +110,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) ST2 Auth API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 Auth API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2auth/st2auth/config.py b/st2auth/st2auth/config.py index 00cfa2aca7..dee0d2d064 100644 --- a/st2auth/st2auth/config.py +++ b/st2auth/st2auth/config.py @@ -28,8 +28,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -50,47 +53,61 @@ def _register_app_opts(): auth_opts = [ cfg.StrOpt( - 'host', default='127.0.0.1', - help='Host on which the service should listen on.'), + "host", + default="127.0.0.1", + help="Host on which the service should listen on.", + ), cfg.IntOpt( - 'port', default=9100, - help='Port on which the service should listen on.'), - cfg.BoolOpt( - 'use_ssl', default=False, - help='Specify to enable SSL / TLS mode'), + "port", default=9100, help="Port on which the service should listen on." + ), + cfg.BoolOpt("use_ssl", default=False, help="Specify to enable SSL / TLS mode"), cfg.StrOpt( - 'cert', default='/etc/apache2/ssl/mycert.crt', - help='Path to the SSL certificate file. Only used when "use_ssl" is specified.'), + "cert", + default="/etc/apache2/ssl/mycert.crt", + help='Path to the SSL certificate file. Only used when "use_ssl" is specified.', + ), cfg.StrOpt( - 'key', default='/etc/apache2/ssl/mycert.key', - help='Path to the SSL private key file. Only used when "use_ssl" is specified.'), + "key", + default="/etc/apache2/ssl/mycert.key", + help='Path to the SSL private key file. Only used when "use_ssl" is specified.', + ), cfg.StrOpt( - 'logging', default='/etc/st2/logging.auth.conf', - help='Path to the logging config.'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "logging", + default="/etc/st2/logging.auth.conf", + help="Path to the logging config.", + ), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), cfg.StrOpt( - 'mode', default=DEFAULT_MODE, - help='Authentication mode (%s)' % (','.join(VALID_MODES))), + "mode", + default=DEFAULT_MODE, + help="Authentication mode (%s)" % (",".join(VALID_MODES)), + ), cfg.StrOpt( - 'backend', default=DEFAULT_BACKEND, - help='Authentication backend to use in a standalone mode. Available ' - 'backends: %s.' % (', '.join(available_backends))), + "backend", + default=DEFAULT_BACKEND, + help="Authentication backend to use in a standalone mode. Available " + "backends: %s." % (", ".join(available_backends)), + ), cfg.StrOpt( - 'backend_kwargs', default=None, - help='JSON serialized arguments which are passed to the authentication ' - 'backend in a standalone mode.'), + "backend_kwargs", + default=None, + help="JSON serialized arguments which are passed to the authentication " + "backend in a standalone mode.", + ), cfg.BoolOpt( - 'sso', default=False, - help='Enable Single Sign On for GUI if true.'), + "sso", default=False, help="Enable Single Sign On for GUI if true." + ), cfg.StrOpt( - 'sso_backend', default=DEFAULT_SSO_BACKEND, - help='Single Sign On backend to use when SSO is enabled. Available ' - 'backends: noop, saml2.'), + "sso_backend", + default=DEFAULT_SSO_BACKEND, + help="Single Sign On backend to use when SSO is enabled. Available " + "backends: noop, saml2.", + ), cfg.StrOpt( - 'sso_backend_kwargs', default=None, - help='JSON serialized arguments which are passed to the SSO backend.') + "sso_backend_kwargs", + default=None, + help="JSON serialized arguments which are passed to the SSO backend.", + ), ] - cfg.CONF.register_cli_opts(auth_opts, group='auth') + cfg.CONF.register_cli_opts(auth_opts, group="auth") diff --git a/st2auth/st2auth/controllers/v1/auth.py b/st2auth/st2auth/controllers/v1/auth.py index f0042632e9..c77546141f 100644 --- a/st2auth/st2auth/controllers/v1/auth.py +++ b/st2auth/st2auth/controllers/v1/auth.py @@ -29,8 +29,8 @@ HANDLER_MAPPINGS = { - 'proxy': handlers.ProxyAuthHandler, - 'standalone': handlers.StandaloneAuthHandler + "proxy": handlers.ProxyAuthHandler, + "standalone": handlers.StandaloneAuthHandler, } LOG = logging.getLogger(__name__) @@ -38,17 +38,17 @@ class TokenValidationController(object): def post(self, request): - token = getattr(request, 'token', None) + token = getattr(request, "token", None) if not token: - raise exc.HTTPBadRequest('Token is not provided.') + raise exc.HTTPBadRequest("Token is not provided.") try: - return {'valid': auth_utils.validate_token(token) is not None} + return {"valid": auth_utils.validate_token(token) is not None} except (TokenNotFoundError, TokenExpiredError): - return {'valid': False} + return {"valid": False} except Exception: - msg = 'Unexpected error occurred while verifying token.' + msg = "Unexpected error occurred while verifying token." LOG.exception(msg) raise exc.HTTPInternalServerError(msg) @@ -60,30 +60,32 @@ def __init__(self): try: self.handler = HANDLER_MAPPINGS[cfg.CONF.auth.mode]() except KeyError: - raise ParamException("%s is not a valid auth mode" % - cfg.CONF.auth.mode) + raise ParamException("%s is not a valid auth mode" % cfg.CONF.auth.mode) def post(self, request, **kwargs): headers = {} - if 'x-forwarded-for' in kwargs: - headers['x-forwarded-for'] = kwargs.pop('x-forwarded-for') + if "x-forwarded-for" in kwargs: + headers["x-forwarded-for"] = kwargs.pop("x-forwarded-for") - authorization = kwargs.pop('authorization', None) + authorization = kwargs.pop("authorization", None) if authorization: - authorization = tuple(authorization.split(' ')) - - token = self.handler.handle_auth(request=request, headers=headers, - remote_addr=kwargs.pop('remote_addr', None), - remote_user=kwargs.pop('remote_user', None), - authorization=authorization, - **kwargs) + authorization = tuple(authorization.split(" ")) + + token = self.handler.handle_auth( + request=request, + headers=headers, + remote_addr=kwargs.pop("remote_addr", None), + remote_user=kwargs.pop("remote_user", None), + authorization=authorization, + **kwargs, + ) return process_successful_response(token=token) def process_successful_response(token): resp = Response(json=token, status=http_client.CREATED) # NOTE: gunicon fails and throws an error if header value is not a string (e.g. if it's None) - resp.headers['X-API-URL'] = api_utils.get_base_public_api_url() + resp.headers["X-API-URL"] = api_utils.get_base_public_api_url() return resp diff --git a/st2auth/st2auth/controllers/v1/sso.py b/st2auth/st2auth/controllers/v1/sso.py index f25effe681..ef1096462c 100644 --- a/st2auth/st2auth/controllers/v1/sso.py +++ b/st2auth/st2auth/controllers/v1/sso.py @@ -32,7 +32,6 @@ class IdentityProviderCallbackController(object): - def __init__(self): self.st2_auth_handler = handlers.ProxyAuthHandler() @@ -40,16 +39,21 @@ def post(self, response, **kwargs): try: verified_user = SSO_BACKEND.verify_response(response) - st2_auth_token_create_request = {'user': verified_user['username'], 'ttl': None} + st2_auth_token_create_request = { + "user": verified_user["username"], + "ttl": None, + } st2_auth_token = self.st2_auth_handler.handle_auth( request=st2_auth_token_create_request, - remote_addr=verified_user['referer'], - remote_user=verified_user['username'], - headers={} + remote_addr=verified_user["referer"], + remote_user=verified_user["username"], + headers={}, ) - return process_successful_authn_response(verified_user['referer'], st2_auth_token) + return process_successful_authn_response( + verified_user["referer"], st2_auth_token + ) except NotImplementedError as e: return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e) except auth_exc.SSOVerificationError as e: @@ -59,7 +63,6 @@ def post(self, response, **kwargs): class SingleSignOnRequestController(object): - def get(self, referer): try: response = router.Response(status=http_client.TEMPORARY_REDIRECT) @@ -76,15 +79,15 @@ class SingleSignOnController(object): callback = IdentityProviderCallbackController() def _get_sso_enabled_config(self): - return {'enabled': cfg.CONF.auth.sso} + return {"enabled": cfg.CONF.auth.sso} def get(self): try: result = self._get_sso_enabled_config() return process_successful_response(http_client.OK, result) except Exception: - LOG.exception('Error encountered while getting SSO configuration.') - result = {'enabled': False} + LOG.exception("Error encountered while getting SSO configuration.") + result = {"enabled": False} return process_successful_response(http_client.OK, result) @@ -107,23 +110,23 @@ def get(self): def process_successful_authn_response(referer, token): token_json = { - 'id': str(token.id), - 'user': token.user, - 'token': token.token, - 'expiry': str(token.expiry), - 'service': False, - 'metadata': {} + "id": str(token.id), + "user": token.user, + "token": token.token, + "expiry": str(token.expiry), + "service": False, + "metadata": {}, } body = CALLBACK_SUCCESS_RESPONSE_BODY % referer resp = router.Response(body=body) - resp.headers['Content-Type'] = 'text/html' + resp.headers["Content-Type"] = "text/html" resp.set_cookie( - 'st2-auth-token', + "st2-auth-token", value=urllib.parse.quote(json.dumps(token_json)), expires=datetime.timedelta(seconds=60), - overwrite=True + overwrite=True, ) return resp @@ -135,7 +138,7 @@ def process_successful_response(status_code, json_body): def process_failure_response(status_code, exception): LOG.error(str(exception)) - json_body = {'faultstring': str(exception)} + json_body = {"faultstring": str(exception)} return router.Response(status_code=status_code, json_body=json_body) diff --git a/st2auth/st2auth/handlers.py b/st2auth/st2auth/handlers.py index 59d74085cf..f6540bcda7 100644 --- a/st2auth/st2auth/handlers.py +++ b/st2auth/st2auth/handlers.py @@ -35,13 +35,22 @@ LOG = logging.getLogger(__name__) -def abort_request(status_code=http_client.UNAUTHORIZED, message='Invalid or missing credentials'): +def abort_request( + status_code=http_client.UNAUTHORIZED, message="Invalid or missing credentials" +): return abort(status_code, message) class AuthHandlerBase(object): - def handle_auth(self, request, headers=None, remote_addr=None, - remote_user=None, authorization=None, **kwargs): + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): raise NotImplementedError() def _create_token_for_user(self, username, ttl=None): @@ -49,80 +58,90 @@ def _create_token_for_user(self, username, ttl=None): return TokenAPI.from_model(tokendb) def _get_username_for_request(self, username, request): - impersonate_user = getattr(request, 'user', None) + impersonate_user = getattr(request, "user", None) if impersonate_user is not None: # check this is a service account try: if not User.get_by_name(username).is_service: - message = "Current user is not a service and cannot " \ - "request impersonated tokens" - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Current user is not a service and cannot " + "request impersonated tokens" + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return username = impersonate_user except (UserNotFoundError, StackStormDBObjectNotFoundError): - message = "Could not locate user %s" % \ - (impersonate_user) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "Could not locate user %s" % (impersonate_user) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return else: - impersonate_user = getattr(request, 'impersonate_user', None) - nickname_origin = getattr(request, 'nickname_origin', None) + impersonate_user = getattr(request, "impersonate_user", None) + nickname_origin = getattr(request, "nickname_origin", None) if impersonate_user is not None: try: # check this is a service account if not User.get_by_name(username).is_service: raise NotServiceUserError() - username = User.get_by_nickname(impersonate_user, - nickname_origin).name + username = User.get_by_nickname( + impersonate_user, nickname_origin + ).name except NotServiceUserError: - message = "Current user is not a service and cannot " \ - "request impersonated tokens" - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Current user is not a service and cannot " + "request impersonated tokens" + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except (UserNotFoundError, StackStormDBObjectNotFoundError): - message = "Could not locate user %s@%s" % \ - (impersonate_user, nickname_origin) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "Could not locate user %s@%s" % ( + impersonate_user, + nickname_origin, + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except NoNicknameOriginProvidedError: - message = "Nickname origin is not provided for nickname '%s'" % \ - impersonate_user - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Nickname origin is not provided for nickname '%s'" + % impersonate_user + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except AmbiguousUserError: - message = "%s@%s matched more than one username" % \ - (impersonate_user, nickname_origin) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "%s@%s matched more than one username" % ( + impersonate_user, + nickname_origin, + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return return username class ProxyAuthHandler(AuthHandlerBase): - def handle_auth(self, request, headers=None, remote_addr=None, - remote_user=None, authorization=None, **kwargs): - remote_addr = headers.get('x-forwarded-for', - remote_addr) - extra = {'remote_addr': remote_addr} + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): + remote_addr = headers.get("x-forwarded-for", remote_addr) + extra = {"remote_addr": remote_addr} if remote_user: - ttl = getattr(request, 'ttl', None) + ttl = getattr(request, "ttl", None) username = self._get_username_for_request(remote_user, request) try: - token = self._create_token_for_user(username=username, - ttl=ttl) + token = self._create_token_for_user(username=username, ttl=ttl) except TTLTooLargeException as e: - abort_request(status_code=http_client.BAD_REQUEST, - message=six.text_type(e)) + abort_request( + status_code=http_client.BAD_REQUEST, message=six.text_type(e) + ) return token - LOG.audit('Access denied to anonymous user.', extra=extra) + LOG.audit("Access denied to anonymous user.", extra=extra) abort_request() @@ -131,77 +150,91 @@ def __init__(self, *args, **kwargs): self._auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend) super(StandaloneAuthHandler, self).__init__(*args, **kwargs) - def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None, - authorization=None, **kwargs): + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): auth_backend = self._auth_backend.__class__.__name__ - extra = {'auth_backend': auth_backend, 'remote_addr': remote_addr} + extra = {"auth_backend": auth_backend, "remote_addr": remote_addr} if not authorization: - LOG.audit('Authorization header not provided', extra=extra) + LOG.audit("Authorization header not provided", extra=extra) abort_request() return auth_type, auth_value = authorization - if auth_type.lower() not in ['basic']: - extra['auth_type'] = auth_type - LOG.audit('Unsupported authorization type: %s' % (auth_type), extra=extra) + if auth_type.lower() not in ["basic"]: + extra["auth_type"] = auth_type + LOG.audit("Unsupported authorization type: %s" % (auth_type), extra=extra) abort_request() return try: auth_value = base64.b64decode(auth_value) except Exception: - LOG.audit('Invalid authorization header', extra=extra) + LOG.audit("Invalid authorization header", extra=extra) abort_request() return - split = auth_value.split(b':', 1) + split = auth_value.split(b":", 1) if len(split) != 2: - LOG.audit('Invalid authorization header', extra=extra) + LOG.audit("Invalid authorization header", extra=extra) abort_request() return username, password = split if six.PY3 and isinstance(username, six.binary_type): - username = username.decode('utf-8') + username = username.decode("utf-8") if six.PY3 and isinstance(password, six.binary_type): - password = password.decode('utf-8') + password = password.decode("utf-8") result = self._auth_backend.authenticate(username=username, password=password) if result is True: - ttl = getattr(request, 'ttl', None) + ttl = getattr(request, "ttl", None) username = self._get_username_for_request(username, request) try: token = self._create_token_for_user(username=username, ttl=ttl) except TTLTooLargeException as e: - abort_request(status_code=http_client.BAD_REQUEST, - message=six.text_type(e)) + abort_request( + status_code=http_client.BAD_REQUEST, message=six.text_type(e) + ) return # If remote group sync is enabled, sync the remote groups with local StackStorm roles - if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != 'noop': - LOG.debug('Retrieving auth backend groups for user "%s"' % (username), - extra=extra) + if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != "noop": + LOG.debug( + 'Retrieving auth backend groups for user "%s"' % (username), + extra=extra, + ) try: user_groups = self._auth_backend.get_user_groups(username=username) except (NotImplementedError, AttributeError): - LOG.debug('Configured auth backend doesn\'t expose user group membership ' - 'information, skipping sync...') + LOG.debug( + "Configured auth backend doesn't expose user group membership " + "information, skipping sync..." + ) return token if not user_groups: # No groups, return early return token - extra['username'] = username - extra['user_groups'] = user_groups + extra["username"] = username + extra["user_groups"] = user_groups - LOG.debug('Found "%s" groups for user "%s"' % (len(user_groups), username), - extra=extra) + LOG.debug( + 'Found "%s" groups for user "%s"' % (len(user_groups), username), + extra=extra, + ) user_db = UserDB(name=username) @@ -212,14 +245,19 @@ def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None, syncer.sync(user_db=user_db, groups=user_groups) except Exception: # Note: Failed sync is not fatal - LOG.exception('Failed to synchronize remote groups for user "%s"' % (username), - extra=extra) + LOG.exception( + 'Failed to synchronize remote groups for user "%s"' + % (username), + extra=extra, + ) else: - LOG.debug('Successfully synchronized groups for user "%s"' % (username), - extra=extra) + LOG.debug( + 'Successfully synchronized groups for user "%s"' % (username), + extra=extra, + ) return token return token - LOG.audit('Invalid credentials provided', extra=extra) + LOG.audit("Invalid credentials provided", extra=extra) abort_request() diff --git a/st2auth/st2auth/sso/__init__.py b/st2auth/st2auth/sso/__init__.py index 5839059ed9..b6d0df930a 100644 --- a/st2auth/st2auth/sso/__init__.py +++ b/st2auth/st2auth/sso/__init__.py @@ -25,15 +25,11 @@ from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance', - 'get_sso_backend' -] +__all__ = ["get_available_backends", "get_backend_instance", "get_sso_backend"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2auth.sso.backends' +BACKENDS_NAMESPACE = "st2auth.sso.backends" def get_available_backends(): @@ -41,7 +37,9 @@ def get_available_backends(): def get_backend_instance(name): - sso_backend_cls = driver_loader.get_backend_driver(namespace=BACKENDS_NAMESPACE, name=name) + sso_backend_cls = driver_loader.get_backend_driver( + namespace=BACKENDS_NAMESPACE, name=name + ) kwargs = {} sso_backend_kwargs = cfg.CONF.auth.sso_backend_kwargs @@ -51,8 +49,8 @@ def get_backend_instance(name): kwargs = json.loads(sso_backend_kwargs) except ValueError as e: raise ValueError( - 'Failed to JSON parse backend settings for backend "%s": %s' % - (name, six.text_type(e)) + 'Failed to JSON parse backend settings for backend "%s": %s' + % (name, six.text_type(e)) ) try: @@ -60,9 +58,11 @@ def get_backend_instance(name): except Exception as e: tb_msg = traceback.format_exc() class_name = sso_backend_cls.__name__ - msg = ('Failed to instantiate SSO backend "%s" (class %s) with backend settings ' - '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to instantiate SSO backend "%s" (class %s) with backend settings ' + '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) diff --git a/st2auth/st2auth/sso/base.py b/st2auth/st2auth/sso/base.py index c96782aba2..5e11199818 100644 --- a/st2auth/st2auth/sso/base.py +++ b/st2auth/st2auth/sso/base.py @@ -16,9 +16,7 @@ import six -__all__ = [ - 'BaseSingleSignOnBackend' -] +__all__ = ["BaseSingleSignOnBackend"] @six.add_metaclass(abc.ABCMeta) @@ -32,5 +30,7 @@ def get_request_redirect_url(self, referer): raise NotImplementedError(msg) def verify_response(self, response): - msg = 'The function "verify_response" is not implemented in the base SSO backend.' + msg = ( + 'The function "verify_response" is not implemented in the base SSO backend.' + ) raise NotImplementedError(msg) diff --git a/st2auth/st2auth/sso/noop.py b/st2auth/st2auth/sso/noop.py index 6cacb5e7e9..6699e084f3 100644 --- a/st2auth/st2auth/sso/noop.py +++ b/st2auth/st2auth/sso/noop.py @@ -17,13 +17,11 @@ from st2auth.sso.base import BaseSingleSignOnBackend -__all__ = [ - 'NoOpSingleSignOnBackend' -] +__all__ = ["NoOpSingleSignOnBackend"] NOT_IMPLEMENTED_MESSAGE = ( 'The default "noop" SSO backend is not a proper implementation. ' - 'Please refer to the enterprise version for configuring SSO.' + "Please refer to the enterprise version for configuring SSO." ) diff --git a/st2auth/st2auth/validation.py b/st2auth/st2auth/validation.py index 924ad390f2..ccea906062 100644 --- a/st2auth/st2auth/validation.py +++ b/st2auth/st2auth/validation.py @@ -19,26 +19,28 @@ from st2auth.backends import get_backend_instance as get_auth_backend_instance from st2auth.backends.constants import AuthBackendCapability -__all__ = [ - 'validate_auth_backend_is_correctly_configured' -] +__all__ = ["validate_auth_backend_is_correctly_configured"] def validate_auth_backend_is_correctly_configured(): # 1. Verify correct mode is specified if cfg.CONF.auth.mode not in VALID_MODES: - msg = ('Invalid auth mode "%s" specified in the config. Valid modes are: %s' % - (cfg.CONF.auth.mode, ', '.join(VALID_MODES))) + msg = 'Invalid auth mode "%s" specified in the config. Valid modes are: %s' % ( + cfg.CONF.auth.mode, + ", ".join(VALID_MODES), + ) raise ValueError(msg) # 2. Verify that auth backend used by the user exposes group information if cfg.CONF.rbac.enable and cfg.CONF.rbac.sync_remote_groups: auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend) - capabilies = getattr(auth_backend, 'CAPABILITIES', ()) + capabilies = getattr(auth_backend, "CAPABILITIES", ()) if AuthBackendCapability.HAS_GROUP_INFORMATION not in capabilies: - msg = ('Configured auth backend doesn\'t expose user group information. Disable ' - 'remote group synchronization or use a different backend which exposes ' - 'user group membership information.') + msg = ( + "Configured auth backend doesn't expose user group information. Disable " + "remote group synchronization or use a different backend which exposes " + "user group membership information." + ) raise ValueError(msg) return True diff --git a/st2auth/st2auth/wsgi.py b/st2auth/st2auth/wsgi.py index 2fb9bee07a..16a44e64f3 100644 --- a/st2auth/st2auth/wsgi.py +++ b/st2auth/st2auth/wsgi.py @@ -16,6 +16,7 @@ import os from st2common.util.monkey_patch import monkey_patch + # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things # including shutdown @@ -28,8 +29,11 @@ from st2auth import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2auth/tests/base.py b/st2auth/tests/base.py index e3bc2e1a05..dc63c1094e 100644 --- a/st2auth/tests/base.py +++ b/st2auth/tests/base.py @@ -20,7 +20,6 @@ class FunctionalTest(DbTestCase): - @classmethod def setUpClass(cls, **kwargs): super(FunctionalTest, cls).setUpClass() diff --git a/st2auth/tests/unit/controllers/v1/test_sso.py b/st2auth/tests/unit/controllers/v1/test_sso.py index 81d9dcea1d..2b6edb1f83 100644 --- a/st2auth/tests/unit/controllers/v1/test_sso.py +++ b/st2auth/tests/unit/controllers/v1/test_sso.py @@ -13,6 +13,7 @@ # limitations under the License. import st2tests.config as tests_config + tests_config.parse_args() import json @@ -28,110 +29,125 @@ from tests.base import FunctionalTest -SSO_V1_PATH = '/v1/sso' -SSO_REQUEST_V1_PATH = SSO_V1_PATH + '/request' -SSO_CALLBACK_V1_PATH = SSO_V1_PATH + '/callback' -MOCK_REFERER = 'https://127.0.0.1' -MOCK_USER = 'stanley' +SSO_V1_PATH = "/v1/sso" +SSO_REQUEST_V1_PATH = SSO_V1_PATH + "/request" +SSO_CALLBACK_V1_PATH = SSO_V1_PATH + "/callback" +MOCK_REFERER = "https://127.0.0.1" +MOCK_USER = "stanley" class TestSingleSignOnController(FunctionalTest): - def test_sso_enabled(self): - cfg.CONF.set_override(group='auth', name='sso', override=True) + cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': True}) + self.assertDictEqual(response.json, {"enabled": True}) def test_sso_disabled(self): - cfg.CONF.set_override(group='auth', name='sso', override=False) + cfg.CONF.set_override(group="auth", name="sso", override=False) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': False}) + self.assertDictEqual(response.json, {"enabled": False}) @mock.patch.object( sso_api_controller.SingleSignOnController, - '_get_sso_enabled_config', - mock.MagicMock(side_effect=KeyError('foobar'))) + "_get_sso_enabled_config", + mock.MagicMock(side_effect=KeyError("foobar")), + ) def test_unknown_exception(self): - cfg.CONF.set_override(group='auth', name='sso', override=True) + cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': False}) - self.assertTrue(sso_api_controller.SingleSignOnController._get_sso_enabled_config.called) + self.assertDictEqual(response.json, {"enabled": False}) + self.assertTrue( + sso_api_controller.SingleSignOnController._get_sso_enabled_config.called + ) class TestSingleSignOnRequestController(FunctionalTest): - @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'get_request_redirect_url', - mock.MagicMock(side_effect=Exception('fooobar'))) + "get_request_redirect_url", + mock.MagicMock(side_effect=Exception("fooobar")), + ) def test_default_backend_unknown_exception(self): - expected_error = {'faultstring': 'Internal Server Error'} + expected_error = {"faultstring": "Internal Server Error"} response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) def test_default_backend_not_implemented(self): - expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE} + expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'get_request_redirect_url', - mock.MagicMock(return_value='https://127.0.0.1')) + "get_request_redirect_url", + mock.MagicMock(return_value="https://127.0.0.1"), + ) def test_idp_redirect(self): response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.TEMPORARY_REDIRECT) - self.assertEqual(response.location, 'https://127.0.0.1') + self.assertEqual(response.location, "https://127.0.0.1") class TestIdentityProviderCallbackController(FunctionalTest): - @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(side_effect=Exception('fooobar'))) + "verify_response", + mock.MagicMock(side_effect=Exception("fooobar")), + ) def test_default_backend_unknown_exception(self): - expected_error = {'faultstring': 'Internal Server Error'} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": "Internal Server Error"} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) def test_default_backend_not_implemented(self): - expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(return_value={'referer': MOCK_REFERER, 'username': MOCK_USER})) + "verify_response", + mock.MagicMock(return_value={"referer": MOCK_REFERER, "username": MOCK_USER}), + ) def test_idp_callback(self): expected_body = sso_api_controller.CALLBACK_SUCCESS_RESPONSE_BODY % MOCK_REFERER - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=False) + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=False + ) self.assertTrue(response.status_code, http_client.OK) - self.assertEqual(expected_body, response.body.decode('utf-8')) + self.assertEqual(expected_body, response.body.decode("utf-8")) - set_cookies_list = [h for h in response.headerlist if h[0] == 'Set-Cookie'] + set_cookies_list = [h for h in response.headerlist if h[0] == "Set-Cookie"] self.assertEqual(len(set_cookies_list), 1) - self.assertIn('st2-auth-token', set_cookies_list[0][1]) + self.assertIn("st2-auth-token", set_cookies_list[0][1]) - cookie = urllib.parse.unquote(set_cookies_list[0][1]).split('=') - st2_auth_token = json.loads(cookie[1].split(';')[0]) - self.assertIn('token', st2_auth_token) - self.assertEqual(st2_auth_token['user'], MOCK_USER) + cookie = urllib.parse.unquote(set_cookies_list[0][1]).split("=") + st2_auth_token = json.loads(cookie[1].split(";")[0]) + self.assertIn("token", st2_auth_token) + self.assertEqual(st2_auth_token["user"], MOCK_USER) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(side_effect=auth_exc.SSOVerificationError('Verification Failed'))) + "verify_response", + mock.MagicMock( + side_effect=auth_exc.SSOVerificationError("Verification Failed") + ), + ) def test_idp_callback_verification_failed(self): - expected_error = {'faultstring': 'Verification Failed'} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": "Verification Failed"} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.UNAUTHORIZED) self.assertDictEqual(response.json, expected_error) diff --git a/st2auth/tests/unit/controllers/v1/test_token.py b/st2auth/tests/unit/controllers/v1/test_token.py index ab5f12342b..cd90a6cef1 100644 --- a/st2auth/tests/unit/controllers/v1/test_token.py +++ b/st2auth/tests/unit/controllers/v1/test_token.py @@ -29,25 +29,25 @@ from st2common.persistence.auth import User, Token, ApiKey -USERNAME = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) -TOKEN_DEFAULT_PATH = '/tokens' -TOKEN_V1_PATH = '/v1/tokens' -TOKEN_VERIFY_PATH = '/v1/tokens/validate' +USERNAME = "".join(random.choice(string.ascii_lowercase) for i in range(10)) +TOKEN_DEFAULT_PATH = "/tokens" +TOKEN_V1_PATH = "/v1/tokens" +TOKEN_VERIFY_PATH = "/v1/tokens/validate" class TestTokenController(FunctionalTest): - @classmethod def setUpClass(cls, **kwargs): - kwargs['extra_environ'] = { - 'REMOTE_USER': USERNAME - } + kwargs["extra_environ"] = {"REMOTE_USER": USERNAME} super(TestTokenController, cls).setUpClass(**kwargs) def test_token_model(self): dt = date_utils.get_datetime_utc_now() - tk1 = TokenAPI(user='stanley', token=uuid.uuid4().hex, - expiry=isotime.format(dt, offset=False)) + tk1 = TokenAPI( + user="stanley", + token=uuid.uuid4().hex, + expiry=isotime.format(dt, offset=False), + ) tkdb1 = TokenAPI.to_model(tk1) self.assertIsNotNone(tkdb1) self.assertIsInstance(tkdb1, TokenDB) @@ -64,7 +64,7 @@ def test_token_model(self): def test_token_model_null_token(self): dt = date_utils.get_datetime_utc_now() - tk = TokenAPI(user='stanley', token=None, expiry=isotime.format(dt)) + tk = TokenAPI(user="stanley", token=None, expiry=isotime.format(dt)) self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def test_token_model_null_user(self): @@ -73,191 +73,215 @@ def test_token_model_null_user(self): self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def test_token_model_null_expiry(self): - tk = TokenAPI(user='stanley', token=uuid.uuid4().hex, expiry=None) + tk = TokenAPI(user="stanley", token=uuid.uuid4().hex, expiry=None) self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def _test_token_post(self, path=TOKEN_V1_PATH): ttl = cfg.CONF.auth.token_ttl timestamp = date_utils.get_datetime_utc_now() response = self.app.post_json(path, {}, expect_errors=False) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertEqual(response.status_int, 201) - self.assertIsNotNone(response.json['token']) - self.assertEqual(response.json['user'], USERNAME) - actual_expiry = isotime.parse(response.json['expiry']) + self.assertIsNotNone(response.json["token"]) + self.assertEqual(response.json["user"], USERNAME) + actual_expiry = isotime.parse(response.json["expiry"]) self.assertLess(timestamp, actual_expiry) self.assertLess(actual_expiry, expected_expiry) return response def test_token_post_unauthorized(self): - response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={ - 'REMOTE_USER': '' - }) + response = self.app.post_json( + TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={"REMOTE_USER": ""} + ) self.assertEqual(response.status_int, 401) + @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=Exception())) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(side_effect=Exception())) - @mock.patch.object( - User, 'add_or_update', - mock.Mock(return_value=UserDB(name=USERNAME))) + User, "add_or_update", mock.Mock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_new_user(self): self._test_token_post() @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_existing_user(self): self._test_token_post() @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_success_x_api_url_header_value(self): # auth.api_url option is explicitly set - cfg.CONF.set_override('api_url', override='https://example.com', group='auth') + cfg.CONF.set_override("api_url", override="https://example.com", group="auth") resp = self._test_token_post() - self.assertEqual(resp.headers['X-API-URL'], 'https://example.com') + self.assertEqual(resp.headers["X-API-URL"], "https://example.com") # auth.api_url option is not set, url is inferred from listen host and port - cfg.CONF.set_override('api_url', override=None, group='auth') + cfg.CONF.set_override("api_url", override=None, group="auth") resp = self._test_token_post() - self.assertEqual(resp.headers['X-API-URL'], 'http://127.0.0.1:9101') + self.assertEqual(resp.headers["X-API-URL"], "http://127.0.0.1:9101") @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_default_url_path(self): self._test_token_post(path=TOKEN_DEFAULT_PATH) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_ttl(self): timestamp = date_utils.add_utc_tz(date_utils.get_datetime_utc_now()) - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 60}, expect_errors=False) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=60) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 60}, expect_errors=False) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=60 + ) self.assertEqual(response.status_int, 201) - actual_expiry = isotime.parse(response.json['expiry']) + actual_expiry = isotime.parse(response.json["expiry"]) self.assertLess(timestamp, actual_expiry) self.assertLess(actual_expiry, expected_expiry) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_no_data_in_body_text_plain_context_type_used(self): - response = self.app.post(TOKEN_V1_PATH, expect_errors=False, content_type='text/plain') + response = self.app.post( + TOKEN_V1_PATH, expect_errors=False, content_type="text/plain" + ) self.assertEqual(response.status_int, 201) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_ttl_over_policy(self): ttl = cfg.CONF.auth.token_ttl - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': ttl + 60}, expect_errors=True) - self.assertEqual(response.status_int, 400) - message = 'TTL specified %s is greater than max allowed %s.' % ( - ttl + 60, ttl + response = self.app.post_json( + TOKEN_V1_PATH, {"ttl": ttl + 60}, expect_errors=True ) - self.assertEqual(response.json['faultstring'], message) + self.assertEqual(response.status_int, 400) + message = "TTL specified %s is greater than max allowed %s." % (ttl + 60, ttl) + self.assertEqual(response.json["faultstring"], message) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_bad_ttl(self): - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': -1}, expect_errors=True) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": -1}, expect_errors=True) self.assertEqual(response.status_int, 400) - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 0}, expect_errors=True) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 0}, expect_errors=True) self.assertEqual(response.status_int, 400) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because an API key or token is not provided in header. - data = {'token': str(response.json['token'])} + data = {"token": str(response.json["token"])} response = self.app.post_json(TOKEN_VERIFY_PATH, data, expect_errors=True) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized_bad_api_key(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because the API key is bad. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"St2-Api-Key": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized_bad_token(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because the token is bad. - headers = {'X-Auth-Token': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"X-Auth-Token": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) @mock.patch.object( - ApiKey, 'get', - mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar'))) + ApiKey, + "get", + mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")), + ) def test_token_get_auth_with_api_key(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. Use an API key to authenticate with the st2 auth get token endpoint. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"St2-Api-Key": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 200) - self.assertTrue(response.json['valid']) + self.assertTrue(response.json["valid"]) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_auth_with_token(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=False) # Verify the token. Use a token to authenticate with the st2 auth get token endpoint. - headers = {'X-Auth-Token': str(response.json['token'])} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"X-Auth-Token": str(response.json["token"])} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 200) - self.assertTrue(response.json['valid']) + self.assertTrue(response.json["valid"]) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) @mock.patch.object( - ApiKey, 'get', - mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar'))) + ApiKey, + "get", + mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")), + ) @mock.patch.object( - Token, 'get', + Token, + "get", mock.MagicMock( return_value=TokenDB( - user=USERNAME, token='12345', - expiry=date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=1)))) + user=USERNAME, + token="12345", + expiry=date_utils.get_datetime_utc_now() + - datetime.timedelta(minutes=1), + ) + ), + ) def test_token_get_unauthorized_bad_ttl(self): # Verify the token. 400 is expected because the token has expired. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': '12345'} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False) + headers = {"St2-Api-Key": "foobar"} + data = {"token": "12345"} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False + ) self.assertEqual(response.status_int, 200) - self.assertFalse(response.json['valid']) + self.assertFalse(response.json["valid"]) diff --git a/st2auth/tests/unit/test_auth_backends.py b/st2auth/tests/unit/test_auth_backends.py index 96856e8a3e..b367e328da 100644 --- a/st2auth/tests/unit/test_auth_backends.py +++ b/st2auth/tests/unit/test_auth_backends.py @@ -25,4 +25,4 @@ class AuthenticationBackendsTestCase(unittest2.TestCase): def test_flat_file_backend_is_available_by_default(self): available_backends = get_available_backends() - self.assertIn('flat_file', available_backends) + self.assertIn("flat_file", available_backends) diff --git a/st2auth/tests/unit/test_handlers.py b/st2auth/tests/unit/test_handlers.py index a3627019d8..cf00e642a6 100644 --- a/st2auth/tests/unit/test_handlers.py +++ b/st2auth/tests/unit/test_handlers.py @@ -30,25 +30,23 @@ from st2tests.mocks.auth import MockRequest from st2tests.mocks.auth import get_mock_backend -__all__ = [ - 'AuthHandlerTestCase' -] +__all__ = ["AuthHandlerTestCase"] -@mock.patch('st2auth.handlers.get_auth_backend_instance', get_mock_backend) +@mock.patch("st2auth.handlers.get_auth_backend_instance", get_mock_backend) class AuthHandlerTestCase(CleanDbTestCase): def setUp(self): super(AuthHandlerTestCase, self).setUp() - cfg.CONF.auth.backend = 'mock' + cfg.CONF.auth.backend = "mock" def test_proxy_handler(self): h = handlers.ProxyAuthHandler() request = {} token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user='test_proxy_handler') - self.assertEqual(token.user, 'test_proxy_handler') + request, headers={}, remote_addr=None, remote_user="test_proxy_handler" + ) + self.assertEqual(token.user, "test_proxy_handler") def test_standalone_bad_auth_type(self): h = handlers.StandaloneAuthHandler() @@ -56,8 +54,12 @@ def test_standalone_bad_auth_type(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('complex', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("complex", DUMMY_CREDS), + ) def test_standalone_no_auth(self): h = handlers.StandaloneAuthHandler() @@ -65,8 +67,12 @@ def test_standalone_no_auth(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=None) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=None, + ) def test_standalone_bad_auth_value(self): h = handlers.StandaloneAuthHandler() @@ -74,109 +80,159 @@ def test_standalone_bad_auth_value(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', 'gobblegobble')) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", "gobblegobble"), + ) def test_standalone_handler(self): h = handlers.StandaloneAuthHandler() request = {} token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'auser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "auser") def test_standalone_handler_ttl(self): h = handlers.StandaloneAuthHandler() token1 = h.handle_auth( - MockRequest(23), headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + MockRequest(23), + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) token2 = h.handle_auth( - MockRequest(2300), headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token1.user, 'auser') + MockRequest(2300), + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token1.user, "auser") self.assertNotEqual(token1.expiry, token2.expiry) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser'))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name="auser")) + ) def test_standalone_for_user_not_service(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser', is_service=True))) + User, + "get_by_name", + mock.MagicMock(return_value=UserDB(name="auser", is_service=True)), + ) def test_standalone_for_user_service(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'anotheruser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "anotheruser") def test_standalone_for_user_not_found(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) def test_standalone_impersonate_user_not_found(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = 'anotheruser' + request.impersonate_user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser', is_service=True))) + User, + "get_by_name", + mock.MagicMock(return_value=UserDB(name="auser", is_service=True)), + ) @mock.patch.object( - User, 'get_by_nickname', - mock.MagicMock(return_value=UserDB(name='anotheruser', is_service=True))) + User, + "get_by_nickname", + mock.MagicMock(return_value=UserDB(name="anotheruser", is_service=True)), + ) def test_standalone_impersonate_user_with_nick_origin(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = 'anotheruser' - request.nickname_origin = 'slack' + request.impersonate_user = "anotheruser" + request.nickname_origin = "slack" token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'anotheruser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "anotheruser") def test_standalone_impersonate_user_no_origin(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = '@anotheruser' + request.impersonate_user = "@anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) def test_password_contains_colon(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - authorization = ('Basic', base64.b64encode(b'username:password:password')) + authorization = ("Basic", base64.b64encode(b"username:password:password")) token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=authorization) - self.assertEqual(token.user, 'username') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=authorization, + ) + self.assertEqual(token.user, "username") diff --git a/st2auth/tests/unit/test_validation_utils.py b/st2auth/tests/unit/test_validation_utils.py index 21ab5e26b5..213e106625 100644 --- a/st2auth/tests/unit/test_validation_utils.py +++ b/st2auth/tests/unit/test_validation_utils.py @@ -19,9 +19,7 @@ from st2auth.validation import validate_auth_backend_is_correctly_configured from st2tests import config as tests_config -__all__ = [ - 'ValidationUtilsTestCase' -] +__all__ = ["ValidationUtilsTestCase"] class ValidationUtilsTestCase(unittest2.TestCase): @@ -34,22 +32,31 @@ def test_validate_auth_backend_is_correctly_configured_success(self): self.assertTrue(result) def test_validate_auth_backend_is_correctly_configured_invalid_backend(self): - cfg.CONF.set_override(group='auth', name='mode', override='invalid') - expected_msg = ('Invalid auth mode "invalid" specified in the config. ' - 'Valid modes are: proxy, standalone') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_auth_backend_is_correctly_configured) - - def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups(self): + cfg.CONF.set_override(group="auth", name="mode", override="invalid") + expected_msg = ( + 'Invalid auth mode "invalid" specified in the config. ' + "Valid modes are: proxy, standalone" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_auth_backend_is_correctly_configured + ) + + def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups( + self, + ): # Flat file backend doesn't expose user group membership information aha provide # "has group info" capability - cfg.CONF.set_override(group='auth', name='backend', override='flat_file') - cfg.CONF.set_override(group='auth', name='backend_kwargs', - override='{"file_path": "dummy"}') - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='sync_remote_groups', override=True) - - expected_msg = ('Configured auth backend doesn\'t expose user group information. Disable ' - 'remote group synchronization or') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_auth_backend_is_correctly_configured) + cfg.CONF.set_override(group="auth", name="backend", override="flat_file") + cfg.CONF.set_override( + group="auth", name="backend_kwargs", override='{"file_path": "dummy"}' + ) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="sync_remote_groups", override=True) + + expected_msg = ( + "Configured auth backend doesn't expose user group information. Disable " + "remote group synchronization or" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_auth_backend_is_correctly_configured + ) diff --git a/st2client/dist_utils.py b/st2client/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2client/dist_utils.py +++ b/st2client/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2client/setup.py b/st2client/setup.py index 916b282301..b318aed359 100644 --- a/st2client/setup.py +++ b/st2client/setup.py @@ -26,10 +26,10 @@ check_pip_version() -ST2_COMPONENT = 'st2client' +ST2_COMPONENT = "st2client" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -README_FILE = os.path.join(BASE_DIR, 'README.rst') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +README_FILE = os.path.join(BASE_DIR, "README.rst") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() @@ -40,43 +40,41 @@ setup( name=ST2_COMPONENT, version=__version__, - description=('Python client library and CLI for the StackStorm (st2) event-driven ' - 'automation platform.'), + description=( + "Python client library and CLI for the StackStorm (st2) event-driven " + "automation platform." + ), long_description=readme, - author='StackStorm', - author_email='info@stackstorm.com', - url='https://stackstorm.com/', + author="StackStorm", + author_email="info@stackstorm.com", + url="https://stackstorm.com/", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Information Technology', - 'Intended Audience :: Developers', - 'Intended Audience :: System Administrators', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7' + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Information Technology", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", ], install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - entry_points={ - 'console_scripts': [ - 'st2 = st2client.shell:main' - ] - }, + packages=find_packages(exclude=["setuptools", "tests"]), + entry_points={"console_scripts": ["st2 = st2client.shell:main"]}, project_urls={ - 'Pack Exchange': 'https://exchange.stackstorm.org', - 'Repository': 'https://github.com/StackStorm/st2', - 'Documentation': 'https://docs.stackstorm.com', - 'Community': 'https://stackstorm.com/community-signup', - 'Questions': 'https://forum.stackstorm.com/', - 'Donate': 'https://funding.communitybridge.org/projects/stackstorm', - 'News/Blog': 'https://stackstorm.com/blog', - 'Security': 'https://docs.stackstorm.com/latest/security.html', - 'Bug Reports': 'https://github.com/StackStorm/st2/issues', - } + "Pack Exchange": "https://exchange.stackstorm.org", + "Repository": "https://github.com/StackStorm/st2", + "Documentation": "https://docs.stackstorm.com", + "Community": "https://stackstorm.com/community-signup", + "Questions": "https://forum.stackstorm.com/", + "Donate": "https://funding.communitybridge.org/projects/stackstorm", + "News/Blog": "https://stackstorm.com/blog", + "Security": "https://docs.stackstorm.com/latest/security.html", + "Bug Reports": "https://github.com/StackStorm/st2/issues", + }, ) diff --git a/st2client/st2client/__init__.py b/st2client/st2client/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2client/st2client/__init__.py +++ b/st2client/st2client/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2client/st2client/base.py b/st2client/st2client/base.py index e435540726..54b7a91b14 100644 --- a/st2client/st2client/base.py +++ b/st2client/st2client/base.py @@ -38,9 +38,7 @@ from st2client.utils.date import parse as parse_isotime from st2client.utils.misc import merge_dicts -__all__ = [ - 'BaseCLIApp' -] +__all__ = ["BaseCLIApp"] # Fix for "os.getlogin()) OSError: [Errno 2] No such file or directory" os.getlogin = lambda: pwd.getpwuid(os.getuid())[0] @@ -51,14 +49,14 @@ TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS = 15 CONFIG_OPTION_TO_CLIENT_KWARGS_MAP = { - 'base_url': ['general', 'base_url'], - 'auth_url': ['auth', 'url'], - 'stream_url': ['stream', 'url'], - 'api_url': ['api', 'url'], - 'api_version': ['general', 'api_version'], - 'api_key': ['credentials', 'api_key'], - 'cacert': ['general', 'cacert'], - 'debug': ['cli', 'debug'] + "base_url": ["general", "base_url"], + "auth_url": ["auth", "url"], + "stream_url": ["stream", "url"], + "api_url": ["api", "url"], + "api_version": ["general", "api_version"], + "api_key": ["credentials", "api_key"], + "cacert": ["general", "cacert"], + "debug": ["cli", "debug"], } @@ -74,7 +72,7 @@ class BaseCLIApp(object): SKIP_AUTH_CLASSES = [] def get_client(self, args, debug=False): - ST2_CLI_SKIP_CONFIG = os.environ.get('ST2_CLI_SKIP_CONFIG', 0) + ST2_CLI_SKIP_CONFIG = os.environ.get("ST2_CLI_SKIP_CONFIG", 0) ST2_CLI_SKIP_CONFIG = int(ST2_CLI_SKIP_CONFIG) skip_config = args.skip_config @@ -82,12 +80,19 @@ def get_client(self, args, debug=False): # Note: Options provided as the CLI argument have the highest precedence # Precedence order: cli arguments > environment variables > rc file variables - cli_options = ['base_url', 'auth_url', 'api_url', 'stream_url', 'api_version', 'cacert'] + cli_options = [ + "base_url", + "auth_url", + "api_url", + "stream_url", + "api_version", + "cacert", + ] cli_options = {opt: getattr(args, opt, None) for opt in cli_options} if cli_options.get("cacert", None) is not None: - if cli_options["cacert"].lower() in ['true', '1', 't', 'y', 'yes']: + if cli_options["cacert"].lower() in ["true", "1", "t", "y", "yes"]: cli_options["cacert"] = True - elif cli_options["cacert"].lower() in ['false', '0', 'f', 'no']: + elif cli_options["cacert"].lower() in ["false", "0", "f", "no"]: cli_options["cacert"] = False config_file_options = self._get_config_file_options(args=args) @@ -98,20 +103,22 @@ def get_client(self, args, debug=False): kwargs = merge_dicts(kwargs, config_file_options) kwargs = merge_dicts(kwargs, cli_options) - kwargs['debug'] = debug + kwargs["debug"] = debug client = Client(**kwargs) if skip_config: # Config parsing is skipped - self.LOG.info('Skipping parsing CLI config') + self.LOG.info("Skipping parsing CLI config") return client # Ok to use config at this point rc_config = get_config() # Silence SSL warnings - silence_ssl_warnings = rc_config.get('general', {}).get('silence_ssl_warnings', False) + silence_ssl_warnings = rc_config.get("general", {}).get( + "silence_ssl_warnings", False + ) if silence_ssl_warnings: # pylint: disable=no-member requests.packages.urllib3.disable_warnings(InsecureRequestWarning) @@ -127,34 +134,45 @@ def get_client(self, args, debug=False): # We also skip automatic authentication if token is provided via the environment variable # or as a command line argument - env_var_token = os.environ.get('ST2_AUTH_TOKEN', None) - cli_argument_token = getattr(args, 'token', None) - env_var_api_key = os.environ.get('ST2_API_KEY', None) - cli_argument_api_key = getattr(args, 'api_key', None) - if env_var_token or cli_argument_token or env_var_api_key or cli_argument_api_key: + env_var_token = os.environ.get("ST2_AUTH_TOKEN", None) + cli_argument_token = getattr(args, "token", None) + env_var_api_key = os.environ.get("ST2_API_KEY", None) + cli_argument_api_key = getattr(args, "api_key", None) + if ( + env_var_token + or cli_argument_token + or env_var_api_key + or cli_argument_api_key + ): return client # If credentials are provided in the CLI config use them and try to authenticate - credentials = rc_config.get('credentials', {}) - username = credentials.get('username', None) - password = credentials.get('password', None) - cache_token = rc_config.get('cli', {}).get('cache_token', False) + credentials = rc_config.get("credentials", {}) + username = credentials.get("username", None) + password = credentials.get("password", None) + cache_token = rc_config.get("cli", {}).get("cache_token", False) if username: # Credentials are provided, try to authenticate agaist the API try: - token = self._get_auth_token(client=client, username=username, password=password, - cache_token=cache_token) + token = self._get_auth_token( + client=client, + username=username, + password=password, + cache_token=cache_token, + ) except requests.exceptions.ConnectionError as e: - self.LOG.warn('Auth API server is not available, skipping authentication.') + self.LOG.warn( + "Auth API server is not available, skipping authentication." + ) self.LOG.exception(e) return client except Exception as e: - print('Failed to authenticate with credentials provided in the config.') + print("Failed to authenticate with credentials provided in the config.") raise e client.token = token # TODO: Hack, refactor when splitting out the client - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token return client @@ -166,9 +184,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False): :rtype: ``dict`` """ rc_options = self._parse_config_file( - args=args, validate_config_permissions=validate_config_permissions) + args=args, validate_config_permissions=validate_config_permissions + ) result = {} - for kwarg_name, (section, option) in six.iteritems(CONFIG_OPTION_TO_CLIENT_KWARGS_MAP): + for kwarg_name, (section, option) in six.iteritems( + CONFIG_OPTION_TO_CLIENT_KWARGS_MAP + ): result[kwarg_name] = rc_options.get(section, {}).get(option, None) return result @@ -176,10 +197,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False): def _parse_config_file(self, args, validate_config_permissions=False): config_file_path = self._get_config_file_path(args=args) - parser = CLIConfigParser(config_file_path=config_file_path, - validate_config_exists=False, - validate_config_permissions=validate_config_permissions, - log=self.LOG) + parser = CLIConfigParser( + config_file_path=config_file_path, + validate_config_exists=False, + validate_config_permissions=validate_config_permissions, + log=self.LOG, + ) result = parser.parse() return result @@ -189,7 +212,7 @@ def _get_config_file_path(self, args): :rtype: ``str`` """ - path = os.environ.get('ST2_CONFIG_FILE', ST2_CONFIG_PATH) + path = os.environ.get("ST2_CONFIG_FILE", ST2_CONFIG_PATH) if args.config_file: path = args.config_file @@ -212,15 +235,16 @@ def _get_auth_token(self, client, username, password, cache_token): :rtype: ``str`` """ if cache_token: - token = self._get_cached_auth_token(client=client, username=username, - password=password) + token = self._get_cached_auth_token( + client=client, username=username, password=password + ) else: token = None if not token: # Token is either expired or not available - token_obj = self._authenticate_and_retrieve_auth_token(client=client, - username=username, - password=password) + token_obj = self._authenticate_and_retrieve_auth_token( + client=client, username=username, password=password + ) self._cache_auth_token(token_obj=token_obj) token = token_obj.token @@ -243,10 +267,12 @@ def _get_cached_auth_token(self, client, username, password): if not os.access(ST2_CONFIG_DIRECTORY, os.R_OK): # We don't have read access to the file with a cached token - message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' - 'access to the parent directory). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' + "access to the parent directory). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -255,9 +281,11 @@ def _get_cached_auth_token(self, client, username, password): if not os.access(cached_token_path, os.R_OK): # We don't have read access to the file with a cached token - message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' - 'access to this file). Subsequent requests won\'t use a cached token ' - 'meaning they may be slower.' % (cached_token_path, os.getlogin())) + message = ( + 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' + "access to this file). Subsequent requests won't use a cached token " + "meaning they may be slower." % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -267,9 +295,11 @@ def _get_cached_auth_token(self, client, username, password): if others_st_mode >= 2: # Every user has access to this file which is dangerous - message = ('Permissions (%s) for cached token file "%s" are too permissive. Please ' - 'restrict the permissions and make sure only your own user can read ' - 'from or write to the file.' % (file_st_mode, cached_token_path)) + message = ( + 'Permissions (%s) for cached token file "%s" are too permissive. Please ' + "restrict the permissions and make sure only your own user can read " + "from or write to the file." % (file_st_mode, cached_token_path) + ) self.LOG.warn(message) with open(cached_token_path) as fp: @@ -278,16 +308,20 @@ def _get_cached_auth_token(self, client, username, password): try: data = json.loads(data) - token = data['token'] - expire_timestamp = data['expire_timestamp'] + token = data["token"] + expire_timestamp = data["expire_timestamp"] except Exception as e: - msg = ('File "%s" with cached token is corrupted or invalid (%s). Please delete ' - ' this file' % (cached_token_path, six.text_type(e))) + msg = ( + 'File "%s" with cached token is corrupted or invalid (%s). Please delete ' + " this file" % (cached_token_path, six.text_type(e)) + ) raise ValueError(msg) now = int(time.time()) if (expire_timestamp - TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS) < now: - self.LOG.debug('Cached token from file "%s" has expired' % (cached_token_path)) + self.LOG.debug( + 'Cached token from file "%s" has expired' % (cached_token_path) + ) # Token has expired return None @@ -312,19 +346,25 @@ def _cache_auth_token(self, token_obj): if not os.access(ST2_CONFIG_DIRECTORY, os.W_OK): # We don't have write access to the file with a cached token - message = ('Unable to write token to "%s" (user %s doesn\'t have write ' - 'access to the parent directory). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to write token to "%s" (user %s doesn\'t have write ' + "access to the parent directory). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None - if os.path.isfile(cached_token_path) and not os.access(cached_token_path, os.W_OK): + if os.path.isfile(cached_token_path) and not os.access( + cached_token_path, os.W_OK + ): # We don't have write access to the file with a cached token - message = ('Unable to write token to "%s" (user %s doesn\'t have write ' - 'access to this file). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to write token to "%s" (user %s doesn\'t have write ' + "access to this file). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -333,8 +373,8 @@ def _cache_auth_token(self, token_obj): expire_timestamp = calendar.timegm(expire_timestamp.timetuple()) data = {} - data['token'] = token - data['expire_timestamp'] = expire_timestamp + data["token"] = token + data["expire_timestamp"] = expire_timestamp data = json.dumps(data) # Note: We explictly use fdopen instead of open + chmod to avoid a security issue. @@ -342,7 +382,7 @@ def _cache_auth_token(self, token_obj): # open and chmod) when file can potentially be read by other users if the default # permissions used during create allow that. fd = os.open(cached_token_path, os.O_WRONLY | os.O_CREAT, 0o660) - with os.fdopen(fd, 'w') as fp: + with os.fdopen(fd, "w") as fp: fp.write(data) os.chmod(cached_token_path, 0o660) @@ -350,8 +390,12 @@ def _cache_auth_token(self, token_obj): return True def _authenticate_and_retrieve_auth_token(self, client, username, password): - manager = models.ResourceManager(models.Token, client.endpoints['auth'], - cacert=client.cacert, debug=client.debug) + manager = models.ResourceManager( + models.Token, + client.endpoints["auth"], + cacert=client.cacert, + debug=client.debug, + ) instance = models.Token() instance = manager.create(instance, auth=(username, password)) return instance @@ -360,7 +404,7 @@ def _get_cached_token_path_for_user(self, username): """ Retrieve cached token path for the provided username. """ - file_name = 'token-%s' % (username) + file_name = "token-%s" % (username) result = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, file_name)) return result @@ -368,10 +412,10 @@ def _print_config(self, args): config = self._parse_config_file(args=args, validate_config_permissions=False) for section, options in six.iteritems(config): - print('[%s]' % (section)) + print("[%s]" % (section)) for name, value in six.iteritems(options): - print('%s = %s' % (name, value)) + print("%s = %s" % (name, value)) def _print_debug_info(self, args): # Print client settings @@ -388,19 +432,19 @@ def _print_client_settings(self, args): config_file_path = self._get_config_file_path(args=args) - print('CLI settings:') - print('----------------') - print('Config file path: %s' % (config_file_path)) - print('Client settings:') - print('----------------') - print('ST2_BASE_URL: %s' % (client.endpoints['base'])) - print('ST2_AUTH_URL: %s' % (client.endpoints['auth'])) - print('ST2_API_URL: %s' % (client.endpoints['api'])) - print('ST2_STREAM_URL: %s' % (client.endpoints['stream'])) - print('ST2_AUTH_TOKEN: %s' % (os.environ.get('ST2_AUTH_TOKEN'))) - print('') - print('Proxy settings:') - print('---------------') - print('HTTP_PROXY: %s' % (os.environ.get('HTTP_PROXY', ''))) - print('HTTPS_PROXY: %s' % (os.environ.get('HTTPS_PROXY', ''))) - print('') + print("CLI settings:") + print("----------------") + print("Config file path: %s" % (config_file_path)) + print("Client settings:") + print("----------------") + print("ST2_BASE_URL: %s" % (client.endpoints["base"])) + print("ST2_AUTH_URL: %s" % (client.endpoints["auth"])) + print("ST2_API_URL: %s" % (client.endpoints["api"])) + print("ST2_STREAM_URL: %s" % (client.endpoints["stream"])) + print("ST2_AUTH_TOKEN: %s" % (os.environ.get("ST2_AUTH_TOKEN"))) + print("") + print("Proxy settings:") + print("---------------") + print("HTTP_PROXY: %s" % (os.environ.get("HTTP_PROXY", ""))) + print("HTTPS_PROXY: %s" % (os.environ.get("HTTPS_PROXY", ""))) + print("") diff --git a/st2client/st2client/client.py b/st2client/st2client/client.py index 6bda37942b..9772c825b7 100644 --- a/st2client/st2client/client.py +++ b/st2client/st2client/client.py @@ -47,144 +47,224 @@ DEFAULT_AUTH_PORT = 9100 DEFAULT_STREAM_PORT = 9102 -DEFAULT_BASE_URL = 'http://127.0.0.1' -DEFAULT_API_VERSION = 'v1' +DEFAULT_BASE_URL = "http://127.0.0.1" +DEFAULT_API_VERSION = "v1" class Client(object): - def __init__(self, base_url=None, auth_url=None, api_url=None, stream_url=None, - api_version=None, cacert=None, debug=False, token=None, api_key=None): + def __init__( + self, + base_url=None, + auth_url=None, + api_url=None, + stream_url=None, + api_version=None, + cacert=None, + debug=False, + token=None, + api_key=None, + ): # Get CLI options. If not given, then try to get it from the environment. self.endpoints = dict() # Populate the endpoints if base_url: - self.endpoints['base'] = base_url + self.endpoints["base"] = base_url else: - self.endpoints['base'] = os.environ.get('ST2_BASE_URL', DEFAULT_BASE_URL) + self.endpoints["base"] = os.environ.get("ST2_BASE_URL", DEFAULT_BASE_URL) - api_version = api_version or os.environ.get('ST2_API_VERSION', DEFAULT_API_VERSION) + api_version = api_version or os.environ.get( + "ST2_API_VERSION", DEFAULT_API_VERSION + ) - self.endpoints['exp'] = '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, 'exp') + self.endpoints["exp"] = "%s:%s/%s" % ( + self.endpoints["base"], + DEFAULT_API_PORT, + "exp", + ) if api_url: - self.endpoints['api'] = api_url + self.endpoints["api"] = api_url else: - self.endpoints['api'] = os.environ.get( - 'ST2_API_URL', '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, api_version)) + self.endpoints["api"] = os.environ.get( + "ST2_API_URL", + "%s:%s/%s" % (self.endpoints["base"], DEFAULT_API_PORT, api_version), + ) if auth_url: - self.endpoints['auth'] = auth_url + self.endpoints["auth"] = auth_url else: - self.endpoints['auth'] = os.environ.get( - 'ST2_AUTH_URL', '%s:%s' % (self.endpoints['base'], DEFAULT_AUTH_PORT)) + self.endpoints["auth"] = os.environ.get( + "ST2_AUTH_URL", "%s:%s" % (self.endpoints["base"], DEFAULT_AUTH_PORT) + ) if stream_url: - self.endpoints['stream'] = stream_url + self.endpoints["stream"] = stream_url else: - self.endpoints['stream'] = os.environ.get( - 'ST2_STREAM_URL', - '%s:%s/%s' % ( - self.endpoints['base'], - DEFAULT_STREAM_PORT, - api_version - ) + self.endpoints["stream"] = os.environ.get( + "ST2_STREAM_URL", + "%s:%s/%s" % (self.endpoints["base"], DEFAULT_STREAM_PORT, api_version), ) if cacert is not None: self.cacert = cacert else: - self.cacert = os.environ.get('ST2_CACERT', None) + self.cacert = os.environ.get("ST2_CACERT", None) # Note: boolean is also a valid value for "cacert" is_cacert_string = isinstance(self.cacert, six.string_types) - if (self.cacert and is_cacert_string and not os.path.isfile(self.cacert)): + if self.cacert and is_cacert_string and not os.path.isfile(self.cacert): raise ValueError('CA cert file "%s" does not exist.' % (self.cacert)) self.debug = debug # Note: This is a nasty hack for now, but we need to get rid of the decrator abuse if token: - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token self.token = token if api_key: - os.environ['ST2_API_KEY'] = api_key + os.environ["ST2_API_KEY"] = api_key self.api_key = api_key # Instantiate resource managers and assign appropriate API endpoint. self.managers = dict() - self.managers['Token'] = ResourceManager( - models.Token, self.endpoints['auth'], cacert=self.cacert, debug=self.debug) - self.managers['RunnerType'] = ResourceManager( - models.RunnerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Action'] = ActionResourceManager( - models.Action, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ActionAlias'] = ActionAliasResourceManager( - models.ActionAlias, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ActionAliasExecution'] = ActionAliasExecutionManager( - models.ActionAliasExecution, self.endpoints['api'], - cacert=self.cacert, debug=self.debug) - self.managers['ApiKey'] = ResourceManager( - models.ApiKey, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Config'] = ConfigManager( - models.Config, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ConfigSchema'] = ResourceManager( - models.ConfigSchema, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Execution'] = ExecutionResourceManager( - models.Execution, self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["Token"] = ResourceManager( + models.Token, self.endpoints["auth"], cacert=self.cacert, debug=self.debug + ) + self.managers["RunnerType"] = ResourceManager( + models.RunnerType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Action"] = ActionResourceManager( + models.Action, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["ActionAlias"] = ActionAliasResourceManager( + models.ActionAlias, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["ActionAliasExecution"] = ActionAliasExecutionManager( + models.ActionAliasExecution, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["ApiKey"] = ResourceManager( + models.ApiKey, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Config"] = ConfigManager( + models.Config, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["ConfigSchema"] = ResourceManager( + models.ConfigSchema, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Execution"] = ExecutionResourceManager( + models.Execution, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for # backward compatibility reasons until v3.2.0 - self.managers['LiveAction'] = self.managers['Execution'] - self.managers['Inquiry'] = InquiryResourceManager( - models.Inquiry, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Pack'] = PackResourceManager( - models.Pack, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Policy'] = ResourceManager( - models.Policy, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['PolicyType'] = ResourceManager( - models.PolicyType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Rule'] = ResourceManager( - models.Rule, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Sensor'] = ResourceManager( - models.Sensor, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['TriggerType'] = ResourceManager( - models.TriggerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Trigger'] = ResourceManager( - models.Trigger, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['TriggerInstance'] = TriggerInstanceResourceManager( - models.TriggerInstance, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['KeyValuePair'] = ResourceManager( - models.KeyValuePair, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Webhook'] = WebhookManager( - models.Webhook, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Timer'] = ResourceManager( - models.Timer, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Trace'] = ResourceManager( - models.Trace, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['RuleEnforcement'] = ResourceManager( - models.RuleEnforcement, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Stream'] = StreamManager( - self.endpoints['stream'], cacert=self.cacert, debug=self.debug) - self.managers['Workflow'] = WorkflowManager( - self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["LiveAction"] = self.managers["Execution"] + self.managers["Inquiry"] = InquiryResourceManager( + models.Inquiry, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Pack"] = PackResourceManager( + models.Pack, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Policy"] = ResourceManager( + models.Policy, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["PolicyType"] = ResourceManager( + models.PolicyType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Rule"] = ResourceManager( + models.Rule, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Sensor"] = ResourceManager( + models.Sensor, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["TriggerType"] = ResourceManager( + models.TriggerType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Trigger"] = ResourceManager( + models.Trigger, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["TriggerInstance"] = TriggerInstanceResourceManager( + models.TriggerInstance, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["KeyValuePair"] = ResourceManager( + models.KeyValuePair, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Webhook"] = WebhookManager( + models.Webhook, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Timer"] = ResourceManager( + models.Timer, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Trace"] = ResourceManager( + models.Trace, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["RuleEnforcement"] = ResourceManager( + models.RuleEnforcement, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Stream"] = StreamManager( + self.endpoints["stream"], cacert=self.cacert, debug=self.debug + ) + self.managers["Workflow"] = WorkflowManager( + self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) # Service Registry - self.managers['ServiceRegistryGroups'] = ServiceRegistryGroupsManager( - models.ServiceRegistryGroup, self.endpoints['api'], cacert=self.cacert, - debug=self.debug) - - self.managers['ServiceRegistryMembers'] = ServiceRegistryMembersManager( - models.ServiceRegistryMember, self.endpoints['api'], cacert=self.cacert, - debug=self.debug) + self.managers["ServiceRegistryGroups"] = ServiceRegistryGroupsManager( + models.ServiceRegistryGroup, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + + self.managers["ServiceRegistryMembers"] = ServiceRegistryMembersManager( + models.ServiceRegistryMember, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) # RBAC - self.managers['Role'] = ResourceManager( - models.Role, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['UserRoleAssignment'] = ResourceManager( - models.UserRoleAssignment, self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["Role"] = ResourceManager( + models.Role, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["UserRoleAssignment"] = ResourceManager( + models.UserRoleAssignment, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) @add_auth_token_to_kwargs_from_env def get_user_info(self, **kwargs): @@ -193,9 +273,10 @@ def get_user_info(self, **kwargs): :rtype: ``dict`` """ - url = '/user' - client = httpclient.HTTPClient(root=self.endpoints['api'], cacert=self.cacert, - debug=self.debug) + url = "/user" + client = httpclient.HTTPClient( + root=self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) response = client.get(url=url, **kwargs) if response.status_code != 200: @@ -205,80 +286,85 @@ def get_user_info(self, **kwargs): @property def actions(self): - return self.managers['Action'] + return self.managers["Action"] @property def apikeys(self): - return self.managers['ApiKey'] + return self.managers["ApiKey"] @property def keys(self): - return self.managers['KeyValuePair'] + return self.managers["KeyValuePair"] @property def executions(self): - return self.managers['Execution'] + return self.managers["Execution"] # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for # backward compatibility reasons until v3.2.0 @property def liveactions(self): - warnings.warn(('st2client.liveactions has been renamed to st2client.executions, please ' - 'update your code'), DeprecationWarning) + warnings.warn( + ( + "st2client.liveactions has been renamed to st2client.executions, please " + "update your code" + ), + DeprecationWarning, + ) return self.executions @property def inquiries(self): - return self.managers['Inquiry'] + return self.managers["Inquiry"] @property def packs(self): - return self.managers['Pack'] + return self.managers["Pack"] @property def policies(self): - return self.managers['Policy'] + return self.managers["Policy"] @property def policytypes(self): - return self.managers['PolicyType'] + return self.managers["PolicyType"] @property def rules(self): - return self.managers['Rule'] + return self.managers["Rule"] @property def runners(self): - return self.managers['RunnerType'] + return self.managers["RunnerType"] @property def sensors(self): - return self.managers['Sensor'] + return self.managers["Sensor"] @property def tokens(self): - return self.managers['Token'] + return self.managers["Token"] @property def triggertypes(self): - return self.managers['TriggerType'] + return self.managers["TriggerType"] @property def triggerinstances(self): - return self.managers['TriggerInstance'] + return self.managers["TriggerInstance"] @property def trace(self): - return self.managers['Trace'] + return self.managers["Trace"] @property def ruleenforcements(self): - return self.managers['RuleEnforcement'] + return self.managers["RuleEnforcement"] @property def webhooks(self): - return self.managers['Webhook'] + return self.managers["Webhook"] @property def workflows(self): - return self.managers['Workflow'] + return self.managers["Workflow"] diff --git a/st2client/st2client/commands/__init__.py b/st2client/st2client/commands/__init__.py index a9b9cee86b..995d3fd9d3 100644 --- a/st2client/st2client/commands/__init__.py +++ b/st2client/st2client/commands/__init__.py @@ -35,9 +35,9 @@ def __init__(self, name, description, app, subparsers, parent_parser=None): self.description = description self.app = app self.parent_parser = parent_parser - self.parser = subparsers.add_parser(self.name, - description=self.description, - help=self.description) + self.parser = subparsers.add_parser( + self.name, description=self.description, help=self.description + ) self.commands = dict() @@ -45,16 +45,19 @@ def __init__(self, name, description, app, subparsers, parent_parser=None): class Command(object): """Represents a commandlet in the command tree.""" - def __init__(self, name, description, app, subparsers, - parent_parser=None, add_help=True): + def __init__( + self, name, description, app, subparsers, parent_parser=None, add_help=True + ): self.name = name self.description = description self.app = app self.parent_parser = parent_parser - self.parser = subparsers.add_parser(self.name, - description=self.description, - help=self.description, - add_help=add_help) + self.parser = subparsers.add_parser( + self.name, + description=self.description, + help=self.description, + add_help=add_help, + ) self.parser.set_defaults(func=self.run_and_print) @abc.abstractmethod @@ -74,8 +77,8 @@ def run_and_print(self, args, **kwargs): raise NotImplementedError def format_output(self, subject, formatter, *args, **kwargs): - json = kwargs.get('json', False) - yaml = kwargs.get('yaml', False) + json = kwargs.get("json", False) + yaml = kwargs.get("yaml", False) if json: func = doc.JsonFormatter.format @@ -90,4 +93,4 @@ def print_output(self, subject, formatter, *args, **kwargs): output = self.format_output(subject, formatter, *args, **kwargs) print(output) else: - print('No matching items found') + print("No matching items found") diff --git a/st2client/st2client/commands/action.py b/st2client/st2client/commands/action.py index 7a41d9e2eb..dcf76c3a7d 100644 --- a/st2client/st2client/commands/action.py +++ b/st2client/st2client/commands/action.py @@ -44,60 +44,54 @@ LOG = logging.getLogger(__name__) -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' -LIVEACTION_STATUS_PAUSING = 'pausing' -LIVEACTION_STATUS_PAUSED = 'paused' -LIVEACTION_STATUS_RESUMING = 'resuming' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" +LIVEACTION_STATUS_PAUSING = "pausing" +LIVEACTION_STATUS_PAUSED = "paused" +LIVEACTION_STATUS_RESUMING = "resuming" LIVEACTION_COMPLETED_STATES = [ LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] # Who parameters should be masked when displaying action execution output -PARAMETERS_TO_MASK = [ - 'password', - 'private_key' -] +PARAMETERS_TO_MASK = ["password", "private_key"] # A list of environment variables which are never inherited when using run # --inherit-env flag ENV_VARS_BLACKLIST = [ - 'pwd', - 'mail', - 'username', - 'user', - 'path', - 'home', - 'ps1', - 'shell', - 'pythonpath', - 'ssh_tty', - 'ssh_connection', - 'lang', - 'ls_colors', - 'logname', - 'oldpwd', - 'term', - 'xdg_session_id' + "pwd", + "mail", + "username", + "user", + "path", + "home", + "ps1", + "shell", + "pythonpath", + "ssh_tty", + "ssh_connection", + "lang", + "ls_colors", + "logname", + "oldpwd", + "term", + "xdg_session_id", ] -WORKFLOW_RUNNER_TYPES = [ - 'action-chain', - 'orquesta' -] +WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"] def format_parameters(value): @@ -108,15 +102,15 @@ def format_parameters(value): for param_name, _ in value.items(): if param_name in PARAMETERS_TO_MASK: - value[param_name] = '********' + value[param_name] = "********" return value # String for indenting etc. -WF_PREFIX = '+ ' -NON_WF_PREFIX = ' ' -INDENT_CHAR = ' ' +WF_PREFIX = "+ " +NON_WF_PREFIX = " " +INDENT_CHAR = " " def format_wf_instances(instances): @@ -127,7 +121,7 @@ def format_wf_instances(instances): # only add extr chars if there are workflows. has_wf = False for instance in instances: - if not getattr(instance, 'children', None): + if not getattr(instance, "children", None): continue else: has_wf = True @@ -136,7 +130,7 @@ def format_wf_instances(instances): return instances # Prepend wf and non_wf prefixes. for instance in instances: - if getattr(instance, 'children', None): + if getattr(instance, "children", None): instance.id = WF_PREFIX + instance.id else: instance.id = NON_WF_PREFIX + instance.id @@ -158,59 +152,75 @@ def format_execution_status(instance): executions which are in running state and execution total run time for all the executions which have finished. """ - status = getattr(instance, 'status', None) - start_timestamp = getattr(instance, 'start_timestamp', None) - end_timestamp = getattr(instance, 'end_timestamp', None) + status = getattr(instance, "status", None) + start_timestamp = getattr(instance, "start_timestamp", None) + end_timestamp = getattr(instance, "end_timestamp", None) if status == LIVEACTION_STATUS_RUNNING and start_timestamp: start_timestamp = instance.start_timestamp start_timestamp = parse_isotime(start_timestamp) start_timestamp = calendar.timegm(start_timestamp.timetuple()) now = int(time.time()) - elapsed_seconds = (now - start_timestamp) - instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds) + elapsed_seconds = now - start_timestamp + instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds) elif status in LIVEACTION_COMPLETED_STATES and start_timestamp and end_timestamp: start_timestamp = parse_isotime(start_timestamp) start_timestamp = calendar.timegm(start_timestamp.timetuple()) end_timestamp = parse_isotime(end_timestamp) end_timestamp = calendar.timegm(end_timestamp.timetuple()) - elapsed_seconds = (end_timestamp - start_timestamp) - instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds) + elapsed_seconds = end_timestamp - start_timestamp + instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds) return instance class ActionBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ActionBranch, self).__init__( - models.Action, description, app, subparsers, + models.Action, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': ActionListCommand, - 'get': ActionGetCommand, - 'update': ActionUpdateCommand, - 'delete': ActionDeleteCommand - }) + "list": ActionListCommand, + "get": ActionGetCommand, + "update": ActionUpdateCommand, + "delete": ActionDeleteCommand, + }, + ) # Registers extended commands - self.commands['enable'] = ActionEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = ActionDisableCommand(self.resource, self.app, self.subparsers) - self.commands['execute'] = ActionRunCommand( - self.resource, self.app, self.subparsers, - add_help=False) + self.commands["enable"] = ActionEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = ActionDisableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["execute"] = ActionRunCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class ActionListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description'] + display_attributes = ["ref", "pack", "description"] class ActionGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -218,17 +228,33 @@ class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand): class ActionEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionDeleteCommand(resource.ContentPackResourceDeleteCommand): @@ -239,15 +265,32 @@ class ActionRunCommandMixin(object): """ Mixin class which contains utility functions related to action execution. """ - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] - attribute_display_order = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] + attribute_display_order = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone, - 'end_timestamp': format_isodate_for_user_timezone, - 'parameters': format_parameters, - 'status': format_status + "start_timestamp": format_isodate_for_user_timezone, + "end_timestamp": format_isodate_for_user_timezone, + "parameters": format_parameters, + "status": format_status, } poll_interval = 2 # how often to poll for execution completion when using sync mode @@ -262,14 +305,19 @@ def run_and_print(self, args, **kwargs): execution = self.run(args, **kwargs) if args.action_async: - self.print_output('To get the results, execute:\n st2 execution get %s' % - (execution.id), six.text_type) - self.print_output('\nTo view output in real-time, execute:\n st2 execution ' - 'tail %s' % (execution.id), six.text_type) + self.print_output( + "To get the results, execute:\n st2 execution get %s" % (execution.id), + six.text_type, + ) + self.print_output( + "\nTo view output in real-time, execute:\n st2 execution " + "tail %s" % (execution.id), + six.text_type, + ) else: self._print_execution_details(execution=execution, args=args, **kwargs) - if execution.status == 'failed': + if execution.status == "failed": # Exit with non zero if the action has failed sys.exit(1) @@ -278,52 +326,99 @@ def _add_common_options(self): # Display options task_list_arg_grp = root_arg_grp.add_argument_group() - task_list_arg_grp.add_argument('--with-schema', - default=False, action='store_true', - help=('Show schema_ouput suggestion with action.')) - - task_list_arg_grp.add_argument('--raw', action='store_true', - help='Raw output, don\'t show sub-tasks for workflows.') - task_list_arg_grp.add_argument('--show-tasks', action='store_true', - help='Whether to show sub-tasks of an execution.') - task_list_arg_grp.add_argument('--depth', type=int, default=-1, - help='Depth to which to show sub-tasks. \ - By default all are shown.') - task_list_arg_grp.add_argument('-w', '--width', nargs='+', type=int, default=None, - help='Set the width of columns in output.') + task_list_arg_grp.add_argument( + "--with-schema", + default=False, + action="store_true", + help=("Show schema_ouput suggestion with action."), + ) + + task_list_arg_grp.add_argument( + "--raw", + action="store_true", + help="Raw output, don't show sub-tasks for workflows.", + ) + task_list_arg_grp.add_argument( + "--show-tasks", + action="store_true", + help="Whether to show sub-tasks of an execution.", + ) + task_list_arg_grp.add_argument( + "--depth", + type=int, + default=-1, + help="Depth to which to show sub-tasks. \ + By default all are shown.", + ) + task_list_arg_grp.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help="Set the width of columns in output.", + ) execution_details_arg_grp = root_arg_grp.add_mutually_exclusive_group() detail_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group() - detail_arg_grp.add_argument('--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) - detail_arg_grp.add_argument('-d', '--detail', action='store_true', - help='Display full detail of the execution in table format.') + detail_arg_grp.add_argument( + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) + detail_arg_grp.add_argument( + "-d", + "--detail", + action="store_true", + help="Display full detail of the execution in table format.", + ) result_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group() - result_arg_grp.add_argument('-k', '--key', - help=('If result is type of JSON, then print specific ' - 'key-value pair; dot notation for nested JSON is ' - 'supported.')) - result_arg_grp.add_argument('--delay', type=int, default=None, - help=('How long (in milliseconds) to delay the ' - 'execution before scheduling.')) + result_arg_grp.add_argument( + "-k", + "--key", + help=( + "If result is type of JSON, then print specific " + "key-value pair; dot notation for nested JSON is " + "supported." + ), + ) + result_arg_grp.add_argument( + "--delay", + type=int, + default=None, + help=( + "How long (in milliseconds) to delay the " + "execution before scheduling." + ), + ) # Other options - detail_arg_grp.add_argument('--tail', action='store_true', - help='Automatically start tailing new execution.') + detail_arg_grp.add_argument( + "--tail", + action="store_true", + help="Automatically start tailing new execution.", + ) # Flag to opt-in to functionality introduced in PR #3670. More robust parsing # of complex datatypes is planned for 2.6, so this flag will be deprecated soon - detail_arg_grp.add_argument('--auto-dict', action='store_true', dest='auto_dict', - default=False, help='Automatically convert list items to ' - 'dictionaries when colons are detected. ' - '(NOTE - this parameter and its functionality will be ' - 'deprecated in the next release in favor of a more ' - 'robust conversion method)') + detail_arg_grp.add_argument( + "--auto-dict", + action="store_true", + dest="auto_dict", + default=False, + help="Automatically convert list items to " + "dictionaries when colons are detected. " + "(NOTE - this parameter and its functionality will be " + "deprecated in the next release in favor of a more " + "robust conversion method)", + ) return root_arg_grp @@ -334,20 +429,24 @@ def _print_execution_details(self, execution, args, **kwargs): This method takes into account if an executed action was workflow or not and formats the output accordingly. """ - runner_type = execution.action.get('runner_type', 'unknown') + runner_type = execution.action.get("runner_type", "unknown") is_workflow_action = runner_type in WORKFLOW_RUNNER_TYPES - show_tasks = getattr(args, 'show_tasks', False) - raw = getattr(args, 'raw', False) - detail = getattr(args, 'detail', False) - key = getattr(args, 'key', None) - attr = getattr(args, 'attr', []) + show_tasks = getattr(args, "show_tasks", False) + raw = getattr(args, "raw", False) + detail = getattr(args, "detail", False) + key = getattr(args, "key", None) + attr = getattr(args, "attr", []) if show_tasks and not is_workflow_action: - raise ValueError('--show-tasks option can only be used with workflow actions') + raise ValueError( + "--show-tasks option can only be used with workflow actions" + ) if not raw and not detail and (show_tasks or is_workflow_action): - self._run_and_print_child_task_list(execution=execution, args=args, **kwargs) + self._run_and_print_child_task_list( + execution=execution, args=args, **kwargs + ) else: instance = execution @@ -357,47 +456,61 @@ def _print_execution_details(self, execution, args, **kwargs): formatter = execution_formatter.ExecutionResult if detail: - options = {'attributes': copy.copy(self.display_attributes)} + options = {"attributes": copy.copy(self.display_attributes)} elif key: - options = {'attributes': ['result.%s' % (key)], 'key': key} + options = {"attributes": ["result.%s" % (key)], "key": key} else: - options = {'attributes': attr} - - options['json'] = args.json - options['yaml'] = args.yaml - options['with_schema'] = args.with_schema - options['attribute_transform_functions'] = self.attribute_transform_functions + options = {"attributes": attr} + + options["json"] = args.json + options["yaml"] = args.yaml + options["with_schema"] = args.with_schema + options[ + "attribute_transform_functions" + ] = self.attribute_transform_functions self.print_output(instance, formatter, **options) def _run_and_print_child_task_list(self, execution, args, **kwargs): - action_exec_mgr = self.app.client.managers['Execution'] + action_exec_mgr = self.app.client.managers["Execution"] instance = execution - options = {'attributes': ['id', 'action.ref', 'parameters', 'status', 'start_timestamp', - 'end_timestamp']} - options['json'] = args.json - options['attribute_transform_functions'] = self.attribute_transform_functions + options = { + "attributes": [ + "id", + "action.ref", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + ] + } + options["json"] = args.json + options["attribute_transform_functions"] = self.attribute_transform_functions formatter = execution_formatter.ExecutionResult - kwargs['depth'] = args.depth - child_instances = action_exec_mgr.get_property(execution.id, 'children', **kwargs) + kwargs["depth"] = args.depth + child_instances = action_exec_mgr.get_property( + execution.id, "children", **kwargs + ) child_instances = self._format_child_instances(child_instances, execution.id) child_instances = format_execution_statuses(child_instances) if not child_instances: # No child error, there might be a global error, include result in the output - options['attributes'].append('result') + options["attributes"].append("result") - status_index = options['attributes'].index('status') + status_index = options["attributes"].index("status") - if hasattr(instance, 'result') and isinstance(instance.result, dict): - tasks = instance.result.get('tasks', []) + if hasattr(instance, "result") and isinstance(instance.result, dict): + tasks = instance.result.get("tasks", []) else: tasks = [] # On failure we also want to include error message and traceback at the top level - if instance.status == 'failed': - top_level_error, top_level_traceback = self._get_top_level_error(live_action=instance) + if instance.status == "failed": + top_level_error, top_level_traceback = self._get_top_level_error( + live_action=instance + ) if len(tasks) >= 1: task_error, task_traceback = self._get_task_error(task=tasks[-1]) @@ -408,18 +521,18 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs): # Top-level error instance.error = top_level_error instance.traceback = top_level_traceback - instance.result = 'See error and traceback.' - options['attributes'].insert(status_index + 1, 'error') - options['attributes'].insert(status_index + 2, 'traceback') + instance.result = "See error and traceback." + options["attributes"].insert(status_index + 1, "error") + options["attributes"].insert(status_index + 2, "traceback") elif task_error: # Task error instance.error = task_error instance.traceback = task_traceback - instance.result = 'See error and traceback.' - instance.failed_on = tasks[-1].get('name', 'unknown') - options['attributes'].insert(status_index + 1, 'error') - options['attributes'].insert(status_index + 2, 'traceback') - options['attributes'].insert(status_index + 3, 'failed_on') + instance.result = "See error and traceback." + instance.failed_on = tasks[-1].get("name", "unknown") + options["attributes"].insert(status_index + 1, "error") + options["attributes"].insert(status_index + 2, "traceback") + options["attributes"].insert(status_index + 3, "failed_on") # Include result on the top-level object so user doesn't need to issue another command to # see the result @@ -427,57 +540,63 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs): task_result = self._get_task_result(task=tasks[-1]) if task_result: - instance.result_task = tasks[-1].get('name', 'unknown') - options['attributes'].insert(status_index + 1, 'result_task') - options['attributes'].insert(status_index + 2, 'result') + instance.result_task = tasks[-1].get("name", "unknown") + options["attributes"].insert(status_index + 1, "result_task") + options["attributes"].insert(status_index + 2, "result") instance.result = task_result # Otherwise include the result of the workflow execution. else: - if 'result' not in options['attributes']: - options['attributes'].append('result') + if "result" not in options["attributes"]: + options["attributes"].append("result") # print root task self.print_output(instance, formatter, **options) # print child tasks if child_instances: - self.print_output(child_instances, table.MultiColumnTable, - attributes=['id', 'status', 'task', 'action', 'start_timestamp'], - widths=args.width, json=args.json, - yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + child_instances, + table.MultiColumnTable, + attributes=["id", "status", "task", "action", "start_timestamp"], + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) def _get_execution_result(self, execution, action_exec_mgr, args, **kwargs): pending_statuses = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING + LIVEACTION_STATUS_CANCELING, ] if args.tail: # Start tailing new execution print('Tailing execution "%s"' % (str(execution.id))) - execution_manager = self.app.client.managers['Execution'] - stream_manager = self.app.client.managers['Stream'] - ActionExecutionTailCommand.tail_execution(execution=execution, - execution_manager=execution_manager, - stream_manager=stream_manager, - **kwargs) + execution_manager = self.app.client.managers["Execution"] + stream_manager = self.app.client.managers["Stream"] + ActionExecutionTailCommand.tail_execution( + execution=execution, + execution_manager=execution_manager, + stream_manager=stream_manager, + **kwargs, + ) execution = action_exec_mgr.get_by_id(execution.id, **kwargs) - print('') + print("") return execution if not args.action_async: while execution.status in pending_statuses: time.sleep(self.poll_interval) if not args.json and not args.yaml: - sys.stdout.write('.') + sys.stdout.write(".") sys.stdout.flush() execution = action_exec_mgr.get_by_id(execution.id, **kwargs) - sys.stdout.write('\n') + sys.stdout.write("\n") if execution.status == LIVEACTION_STATUS_CANCELED: return execution @@ -491,8 +610,8 @@ def _get_top_level_error(self, live_action): :return: (error, traceback) """ if isinstance(live_action.result, dict): - error = live_action.result.get('error', None) - traceback = live_action.result.get('traceback', None) + error = live_action.result.get("error", None) + traceback = live_action.result.get("traceback", None) else: error = "See result" traceback = "See result" @@ -508,12 +627,12 @@ def _get_task_error(self, task): if not task: return None, None - result = task['result'] + result = task["result"] if isinstance(result, dict): - stderr = result.get('stderr', None) - error = result.get('error', None) - traceback = result.get('traceback', None) + stderr = result.get("stderr", None) + error = result.get("error", None) + traceback = result.get("traceback", None) error = error if error else stderr else: stderr = None @@ -526,7 +645,7 @@ def _get_task_result(self, task): if not task: return None - return task['result'] + return task["result"] def _get_action_parameters_from_args(self, action, runner, args): """ @@ -553,22 +672,22 @@ def read_file(file_path): if not os.path.isfile(file_path): raise ValueError('"%s" is not a file' % (file_path)) - with open(file_path, 'rb') as fp: + with open(file_path, "rb") as fp: content = fp.read() return content.decode("utf-8") def transform_object(value): # Also support simple key1=val1,key2=val2 syntax - if value.startswith('{'): + if value.startswith("{"): # Assume it's JSON result = value = json.loads(value) else: - pairs = value.split(',') + pairs = value.split(",") result = {} for pair in pairs: - split = pair.split('=', 1) + split = pair.split("=", 1) if len(split) != 2: continue @@ -605,18 +724,22 @@ def transform_array(value, action_params=None, auto_dict=False): try: result = json.loads(value) except ValueError: - result = [v.strip() for v in value.split(',')] + result = [v.strip() for v in value.split(",")] # When each values in this array represent dict type, this converts # the 'result' to the dict type value. - if all([isinstance(x, str) and ':' in x for x in result]) and auto_dict: + if all([isinstance(x, str) and ":" in x for x in result]) and auto_dict: result_dict = {} - for (k, v) in [x.split(':') for x in result]: + for (k, v) in [x.split(":") for x in result]: # To parse values using the 'transformer' according to the type which is # specified in the action metadata, calling 'normalize' method recursively. - if 'properties' in action_params and k in action_params['properties']: - result_dict[k] = normalize(k, v, action_params['properties'], - auto_dict=auto_dict) + if ( + "properties" in action_params + and k in action_params["properties"] + ): + result_dict[k] = normalize( + k, v, action_params["properties"], auto_dict=auto_dict + ) else: result_dict[k] = v return [result_dict] @@ -624,12 +747,12 @@ def transform_array(value, action_params=None, auto_dict=False): return result transformer = { - 'array': transform_array, - 'boolean': (lambda x: ast.literal_eval(x.capitalize())), - 'integer': int, - 'number': float, - 'object': transform_object, - 'string': str + "array": transform_array, + "boolean": (lambda x: ast.literal_eval(x.capitalize())), + "integer": int, + "number": float, + "object": transform_object, + "string": str, } def get_param_type(key, action_params=None): @@ -642,13 +765,13 @@ def get_param_type(key, action_params=None): param = action_params[key] if param: - return param['type'] + return param["type"] return None def normalize(name, value, action_params=None, auto_dict=False): - """ The desired type is contained in the action meta-data, so we can look that up - and call the desired "caster" function listed in the "transformer" dict + """The desired type is contained in the action meta-data, so we can look that up + and call the desired "caster" function listed in the "transformer" dict """ action_params = action_params or action.parameters @@ -663,8 +786,10 @@ def normalize(name, value, action_params=None, auto_dict=False): # (items: type: int for example) and this information is available here so we could # also leverage that to cast each array item to the correct type. param_type = get_param_type(name, action_params) - if param_type == 'array' and name in action_params: - return transformer[param_type](value, action_params[name], auto_dict=auto_dict) + if param_type == "array" and name in action_params: + return transformer[param_type]( + value, action_params[name], auto_dict=auto_dict + ) elif param_type: return transformer[param_type](value) @@ -677,11 +802,11 @@ def normalize(name, value, action_params=None, auto_dict=False): for idx in range(len(args.parameters)): arg = args.parameters[idx] - if '=' in arg: - k, v = arg.split('=', 1) + if "=" in arg: + k, v = arg.split("=", 1) # Attribute for files are prefixed with "@" - if k.startswith('@'): + if k.startswith("@"): k = k[1:] is_file = True else: @@ -695,15 +820,15 @@ def normalize(name, value, action_params=None, auto_dict=False): file_name = os.path.basename(file_path) content = read_file(file_path=file_path) - if action_ref_or_id == 'core.http': + if action_ref_or_id == "core.http": # Special case for http runner - result['_file_name'] = file_name - result['file_content'] = content + result["_file_name"] = file_name + result["file_content"] = content else: result[k] = content else: # This permits multiple declarations of argument only in the array type. - if get_param_type(k) == 'array' and k in result: + if get_param_type(k) == "array" and k in result: result[k] += normalize(k, v, auto_dict=args.auto_dict) else: result[k] = normalize(k, v, auto_dict=args.auto_dict) @@ -711,42 +836,44 @@ def normalize(name, value, action_params=None, auto_dict=False): except Exception as e: # TODO: Move transformers in a separate module and handle # exceptions there - if 'malformed string' in six.text_type(e): - message = ('Invalid value for boolean parameter. ' - 'Valid values are: true, false') + if "malformed string" in six.text_type(e): + message = ( + "Invalid value for boolean parameter. " + "Valid values are: true, false" + ) raise ValueError(message) else: raise e else: - result['cmd'] = ' '.join(args.parameters[idx:]) + result["cmd"] = " ".join(args.parameters[idx:]) break # Special case for http runner - if 'file_content' in result: - if 'method' not in result: + if "file_content" in result: + if "method" not in result: # Default to POST if a method is not provided - result['method'] = 'POST' + result["method"] = "POST" - if 'file_name' not in result: + if "file_name" not in result: # File name not provided, use default file name - result['file_name'] = result['_file_name'] + result["file_name"] = result["_file_name"] - del result['_file_name'] + del result["_file_name"] if args.inherit_env: - result['env'] = self._get_inherited_env_vars() + result["env"] = self._get_inherited_env_vars() return result @add_auth_token_to_kwargs_from_cli def _print_help(self, args, **kwargs): # Print appropriate help message if the help option is given. - action_mgr = self.app.client.managers['Action'] - action_exec_mgr = self.app.client.managers['Execution'] + action_mgr = self.app.client.managers["Action"] + action_exec_mgr = self.app.client.managers["Execution"] if args.help: - action_ref_or_id = getattr(args, 'ref_or_id', None) - action_exec_id = getattr(args, 'id', None) + action_ref_or_id = getattr(args, "ref_or_id", None) + action_exec_id = getattr(args, "id", None) if action_exec_id and not action_ref_or_id: action_exec = action_exec_mgr.get_by_id(action_exec_id, **kwargs) @@ -756,34 +883,47 @@ def _print_help(self, args, **kwargs): try: action = action_mgr.get_by_ref_or_id(args.ref_or_id, **kwargs) if not action: - raise resource.ResourceNotFoundError('Action %s not found' % args.ref_or_id) - runner_mgr = self.app.client.managers['RunnerType'] + raise resource.ResourceNotFoundError( + "Action %s not found" % args.ref_or_id + ) + runner_mgr = self.app.client.managers["RunnerType"] runner = runner_mgr.get_by_name(action.runner_type, **kwargs) - parameters, required, optional, _ = self._get_params_types(runner, - action) - print('') + parameters, required, optional, _ = self._get_params_types( + runner, action + ) + print("") print(textwrap.fill(action.description)) - print('') + print("") if required: - required = self._sort_parameters(parameters=parameters, - names=required) - - print('Required Parameters:') - [self._print_param(name, parameters.get(name)) - for name in required] + required = self._sort_parameters( + parameters=parameters, names=required + ) + + print("Required Parameters:") + [ + self._print_param(name, parameters.get(name)) + for name in required + ] if optional: - optional = self._sort_parameters(parameters=parameters, - names=optional) - - print('Optional Parameters:') - [self._print_param(name, parameters.get(name)) - for name in optional] + optional = self._sort_parameters( + parameters=parameters, names=optional + ) + + print("Optional Parameters:") + [ + self._print_param(name, parameters.get(name)) + for name in optional + ] except resource.ResourceNotFoundError: - print(('Action "%s" is not found. ' % args.ref_or_id) + - 'Use "st2 action list" to see the list of available actions.') + print( + ('Action "%s" is not found. ' % args.ref_or_id) + + 'Use "st2 action list" to see the list of available actions.' + ) except Exception as e: - print('ERROR: Unable to print help for action "%s". %s' % - (args.ref_or_id, e)) + print( + 'ERROR: Unable to print help for action "%s". %s' + % (args.ref_or_id, e) + ) else: self.parser.print_help() return True @@ -795,20 +935,20 @@ def _print_param(name, schema): raise ValueError('Missing schema for parameter "%s"' % (name)) wrapper = textwrap.TextWrapper(width=78) - wrapper.initial_indent = ' ' * 4 + wrapper.initial_indent = " " * 4 wrapper.subsequent_indent = wrapper.initial_indent print(wrapper.fill(name)) - wrapper.initial_indent = ' ' * 8 + wrapper.initial_indent = " " * 8 wrapper.subsequent_indent = wrapper.initial_indent - if 'description' in schema and schema['description']: - print(wrapper.fill(schema['description'])) - if 'type' in schema and schema['type']: - print(wrapper.fill('Type: %s' % schema['type'])) - if 'enum' in schema and schema['enum']: - print(wrapper.fill('Enum: %s' % ', '.join(schema['enum']))) - if 'default' in schema and schema['default'] is not None: - print(wrapper.fill('Default: %s' % schema['default'])) - print('') + if "description" in schema and schema["description"]: + print(wrapper.fill(schema["description"])) + if "type" in schema and schema["type"]: + print(wrapper.fill("Type: %s" % schema["type"])) + if "enum" in schema and schema["enum"]: + print(wrapper.fill("Enum: %s" % ", ".join(schema["enum"]))) + if "default" in schema and schema["default"] is not None: + print(wrapper.fill("Default: %s" % schema["default"])) + print("") @staticmethod def _get_params_types(runner, action): @@ -816,19 +956,18 @@ def _get_params_types(runner, action): action_params = action.parameters parameters = copy.copy(runner_params) parameters.update(copy.copy(action_params)) - required = set([k for k, v in six.iteritems(parameters) if v.get('required')]) + required = set([k for k, v in six.iteritems(parameters) if v.get("required")]) def is_immutable(runner_param_meta, action_param_meta): # If runner sets a param as immutable, action cannot override that. - if runner_param_meta.get('immutable', False): + if runner_param_meta.get("immutable", False): return True else: - return action_param_meta.get('immutable', False) + return action_param_meta.get("immutable", False) immutable = set() for param in parameters.keys(): - if is_immutable(runner_params.get(param, {}), - action_params.get(param, {})): + if is_immutable(runner_params.get(param, {}), action_params.get(param, {})): immutable.add(param) required = required - immutable @@ -837,12 +976,12 @@ def is_immutable(runner_param_meta, action_param_meta): return parameters, required, optional, immutable def _format_child_instances(self, children, parent_id): - ''' + """ The goal of this method is to add an indent at every level. This way the WF is represented as a tree structure while in a list. For the right visuals representation the list must be a DF traversal else the idents will end up looking strange. - ''' + """ # apply basic WF formating first. children = format_wf_instances(children) # setup a depth lookup table @@ -856,7 +995,9 @@ def _format_child_instances(self, children, parent_id): parent = None for instance in children: if WF_PREFIX in instance.id: - instance_id = instance.id[instance.id.index(WF_PREFIX) + len(WF_PREFIX):] + instance_id = instance.id[ + instance.id.index(WF_PREFIX) + len(WF_PREFIX) : + ] else: instance_id = instance.id if instance_id == child.parent: @@ -871,26 +1012,28 @@ def _format_child_instances(self, children, parent_id): return result def _format_for_common_representation(self, task): - ''' + """ Formats a task for common representation for action-chain. - ''' + """ # This really needs to be better handled on the back-end but that would be a bigger # change so handling in cli. - context = getattr(task, 'context', None) - if context and 'chain' in context: - task_name_key = 'context.chain.name' - elif context and 'orquesta' in context: - task_name_key = 'context.orquesta.task_name' + context = getattr(task, "context", None) + if context and "chain" in context: + task_name_key = "context.chain.name" + elif context and "orquesta" in context: + task_name_key = "context.orquesta.task_name" # Use Execution as the object so that the formatter lookup does not change. # AKA HACK! - return models.action.Execution(**{ - 'id': task.id, - 'status': task.status, - 'task': jsutil.get_value(vars(task), task_name_key), - 'action': task.action.get('ref', None), - 'start_timestamp': task.start_timestamp, - 'end_timestamp': getattr(task, 'end_timestamp', None) - }) + return models.action.Execution( + **{ + "id": task.id, + "status": task.status, + "task": jsutil.get_value(vars(task), task_name_key), + "action": task.action.get("ref", None), + "start_timestamp": task.start_timestamp, + "end_timestamp": getattr(task, "end_timestamp", None), + } + ) def _sort_parameters(self, parameters, names): """ @@ -899,10 +1042,12 @@ def _sort_parameters(self, parameters, names): :type parameters: ``list`` :type names: ``list`` or ``set`` """ - sorted_parameters = sorted(names, key=lambda name: - self._get_parameter_sort_value( - parameters=parameters, - name=name)) + sorted_parameters = sorted( + names, + key=lambda name: self._get_parameter_sort_value( + parameters=parameters, name=name + ), + ) return sorted_parameters @@ -919,7 +1064,7 @@ def _get_parameter_sort_value(self, parameters, name): if not parameter: return None - sort_value = parameter.get('position', name) + sort_value = parameter.get("position", name) return sort_value def _get_inherited_env_vars(self): @@ -938,44 +1083,76 @@ class ActionRunCommand(ActionRunCommandMixin, resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(ActionRunCommand, self).__init__( - resource, kwargs.pop('name', 'execute'), - 'Invoke an action manually.', - *args, **kwargs) - - self.parser.add_argument('ref_or_id', nargs='?', - metavar='ref-or-id', - help='Action reference (pack.action_name) ' + - 'or ID of the action.') - self.parser.add_argument('parameters', nargs='*', - help='List of keyword args, positional args, ' - 'and optional args for the action.') - - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "execute"), + "Invoke an action manually.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "ref_or_id", + nargs="?", + metavar="ref-or-id", + help="Action reference (pack.action_name) " + "or ID of the action.", + ) + self.parser.add_argument( + "parameters", + nargs="*", + help="List of keyword args, positional args, " + "and optional args for the action.", + ) + + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) self._add_common_options() - if self.name in ['run', 'execute']: - self.parser.add_argument('--trace-tag', '--trace_tag', - help='A trace tag string to track execution later.', - dest='trace_tag', required=False) - self.parser.add_argument('--trace-id', - help='Existing trace id for this execution.', - dest='trace_id', required=False) - self.parser.add_argument('-a', '--async', - action='store_true', dest='action_async', - help='Do not wait for action to finish.') - self.parser.add_argument('-e', '--inherit-env', - action='store_true', dest='inherit_env', - help='Pass all the environment variables ' - 'which are accessible to the CLI as "env" ' - 'parameter to the action. Note: Only works ' - 'with python, local and remote runners.') - self.parser.add_argument('-u', '--user', type=str, default=None, - help='User under which to run the action (admins only).') - - if self.name == 'run': + if self.name in ["run", "execute"]: + self.parser.add_argument( + "--trace-tag", + "--trace_tag", + help="A trace tag string to track execution later.", + dest="trace_tag", + required=False, + ) + self.parser.add_argument( + "--trace-id", + help="Existing trace id for this execution.", + dest="trace_id", + required=False, + ) + self.parser.add_argument( + "-a", + "--async", + action="store_true", + dest="action_async", + help="Do not wait for action to finish.", + ) + self.parser.add_argument( + "-e", + "--inherit-env", + action="store_true", + dest="inherit_env", + help="Pass all the environment variables " + 'which are accessible to the CLI as "env" ' + "parameter to the action. Note: Only works " + "with python, local and remote runners.", + ) + self.parser.add_argument( + "-u", + "--user", + type=str, + default=None, + help="User under which to run the action (admins only).", + ) + + if self.name == "run": self.parser.set_defaults(action_async=False) else: self.parser.set_defaults(action_async=True) @@ -983,22 +1160,27 @@ def __init__(self, resource, *args, **kwargs): @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if not args.ref_or_id: - self.parser.error('Missing action reference or id') + self.parser.error("Missing action reference or id") action = self.get_resource(args.ref_or_id, **kwargs) if not action: - raise resource.ResourceNotFoundError('Action "%s" cannot be found.' - % (args.ref_or_id)) + raise resource.ResourceNotFoundError( + 'Action "%s" cannot be found.' % (args.ref_or_id) + ) - runner_mgr = self.app.client.managers['RunnerType'] + runner_mgr = self.app.client.managers["RunnerType"] runner = runner_mgr.get_by_name(action.runner_type, **kwargs) if not runner: - raise resource.ResourceNotFoundError('Runner type "%s" for action "%s" cannot be \ - found.' % (action.runner_type, action.name)) + raise resource.ResourceNotFoundError( + 'Runner type "%s" for action "%s" cannot be \ + found.' + % (action.runner_type, action.name) + ) - action_ref = '.'.join([action.pack, action.name]) - action_parameters = self._get_action_parameters_from_args(action=action, runner=runner, - args=args) + action_ref = ".".join([action.pack, action.name]) + action_parameters = self._get_action_parameters_from_args( + action=action, runner=runner, args=args + ) execution = models.Execution() execution.action = action_ref @@ -1009,56 +1191,79 @@ def run(self, args, **kwargs): execution.delay = args.delay if not args.trace_id and args.trace_tag: - execution.context = {'trace_context': {'trace_tag': args.trace_tag}} + execution.context = {"trace_context": {"trace_tag": args.trace_tag}} if args.trace_id: - execution.context = {'trace_context': {'id_': args.trace_id}} + execution.context = {"trace_context": {"id_": args.trace_id}} - action_exec_mgr = self.app.client.managers['Execution'] + action_exec_mgr = self.app.client.managers["Execution"] execution = action_exec_mgr.create(execution, **kwargs) - execution = self._get_execution_result(execution=execution, - action_exec_mgr=action_exec_mgr, - args=args, **kwargs) + execution = self._get_execution_result( + execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs + ) return execution class ActionExecutionBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ActionExecutionBranch, self).__init__( - models.Execution, description, app, subparsers, - parent_parser=parent_parser, read_only=True, - commands={'list': ActionExecutionListCommand, - 'get': ActionExecutionGetCommand}) + models.Execution, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, + commands={ + "list": ActionExecutionListCommand, + "get": ActionExecutionGetCommand, + }, + ) # Register extended commands - self.commands['re-run'] = ActionExecutionReRunCommand( - self.resource, self.app, self.subparsers, add_help=False) - self.commands['cancel'] = ActionExecutionCancelCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['pause'] = ActionExecutionPauseCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['resume'] = ActionExecutionResumeCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['tail'] = ActionExecutionTailCommand(self.resource, self.app, - self.subparsers, - add_help=True) - - -POSSIBLE_ACTION_STATUS_VALUES = ('succeeded', 'running', 'scheduled', 'paused', 'failed', - 'canceling', 'canceled') + self.commands["re-run"] = ActionExecutionReRunCommand( + self.resource, self.app, self.subparsers, add_help=False + ) + self.commands["cancel"] = ActionExecutionCancelCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["pause"] = ActionExecutionPauseCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["resume"] = ActionExecutionResumeCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["tail"] = ActionExecutionTailCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + + +POSSIBLE_ACTION_STATUS_VALUES = ( + "succeeded", + "running", + "scheduled", + "paused", + "failed", + "canceling", + "canceled", +) class ActionExecutionListCommand(ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'status', 'start_timestamp', - 'end_timestamp'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "status", + "start_timestamp", + "end_timestamp", + ] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone, - 'end_timestamp': format_isodate_for_user_timezone, - 'parameters': format_parameters, - 'status': format_status + "start_timestamp": format_isodate_for_user_timezone, + "end_timestamp": format_isodate_for_user_timezone, + "parameters": format_parameters, + "status": format_status, } def __init__(self, resource, *args, **kwargs): @@ -1066,83 +1271,133 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(ActionExecutionListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('-s', '--sort', type=str, dest='sort_order', - default='descending', - help=('Sort %s by start timestamp, ' - 'asc|ascending (earliest first) ' - 'or desc|descending (latest first)' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "-s", + "--sort", + type=str, + dest="sort_order", + default="descending", + help=( + "Sort %s by start timestamp, " + "asc|ascending (earliest first) " + "or desc|descending (latest first)" % self.resource_name + ), + ) # Filter options - self.group.add_argument('--action', help='Action reference to filter the list.') - self.group.add_argument('--status', help=('Only return executions with the provided \ - status. Possible values are \'%s\', \'%s\', \ - \'%s\', \'%s\', \'%s\', \'%s\' or \'%s\'' - '.' % POSSIBLE_ACTION_STATUS_VALUES)) - self.group.add_argument('--user', - help='Only return executions created by the provided user.') - self.group.add_argument('--trigger_instance', - help='Trigger instance id to filter the list.') - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return executions with timestamp ' - 'greater than the one provided. ' - 'Use time in the format "2000-01-01T12:00:00.000Z".')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return executions with timestamp ' - 'lower than the one provided. ' - 'Use time in the format "2000-01-01T12:00:00.000Z".')) - self.parser.add_argument('-l', '--showall', action='store_true', - help='') + self.group.add_argument("--action", help="Action reference to filter the list.") + self.group.add_argument( + "--status", + help=( + "Only return executions with the provided \ + status. Possible values are '%s', '%s', \ + '%s', '%s', '%s', '%s' or '%s'" + "." % POSSIBLE_ACTION_STATUS_VALUES + ), + ) + self.group.add_argument( + "--user", help="Only return executions created by the provided user." + ) + self.group.add_argument( + "--trigger_instance", help="Trigger instance id to filter the list." + ) + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return executions with timestamp " + "greater than the one provided. " + 'Use time in the format "2000-01-01T12:00:00.000Z".' + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return executions with timestamp " + "lower than the one provided. " + 'Use time in the format "2000-01-01T12:00:00.000Z".' + ), + ) + self.parser.add_argument("-l", "--showall", action="store_true", help="") # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.action: - kwargs['action'] = args.action + kwargs["action"] = args.action if args.status: - kwargs['status'] = args.status + kwargs["status"] = args.status if args.user: - kwargs['user'] = args.user + kwargs["user"] = args.user if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if not args.showall: # null is the magic string that translates to does not exist. - kwargs['parent'] = 'null' + kwargs["parent"] = "null" if args.timestamp_gt: - kwargs['timestamp_gt'] = args.timestamp_gt + kwargs["timestamp_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['timestamp_lt'] = args.timestamp_lt + kwargs["timestamp_lt"] = args.timestamp_lt if args.sort_order: - if args.sort_order in ['asc', 'ascending']: - kwargs['sort_asc'] = True - elif args.sort_order in ['desc', 'descending']: - kwargs['sort_desc'] = True + if args.sort_order in ["asc", "ascending"]: + kwargs["sort_asc"] = True + elif args.sort_order in ["desc", "descending"]: + kwargs["sort_desc"] = True # We only retrieve attributes which are needed to speed things up include_attributes = self._get_include_attributes(args=args) if include_attributes: - kwargs['include_attributes'] = ','.join(include_attributes) + kwargs["include_attributes"] = ",".join(include_attributes) return self.manager.query_with_count(limit=args.last, **kwargs) @@ -1152,49 +1407,73 @@ def run_and_print(self, args, **kwargs): instances = format_wf_instances(result) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, - yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: # Include elapsed time for running executions instances = format_execution_statuses(instances) - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class ActionExecutionGetCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] - include_attributes = ['action.ref', 'action.runner_type', 'start_timestamp', - 'end_timestamp'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] + include_attributes = [ + "action.ref", + "action.runner_type", + "start_timestamp", + "end_timestamp", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionGetCommand, self).__init__( - resource, 'get', - 'Get individual %s.' % resource.get_display_name().lower(), - *args, **kwargs) + resource, + "get", + "Get individual %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) - self.parser.add_argument('id', - help=('ID of the %s.' % - resource.get_display_name().lower())) + self.parser.add_argument( + "id", help=("ID of the %s." % resource.get_display_name().lower()) + ) self._add_common_options() @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # We only retrieve attributes which are needed to speed things up - include_attributes = self._get_include_attributes(args=args, - extra_attributes=self.include_attributes) + include_attributes = self._get_include_attributes( + args=args, extra_attributes=self.include_attributes + ) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} execution = self.get_resource_by_id(id=args.id, **kwargs) return execution @@ -1209,22 +1488,25 @@ def run_and_print(self, args, **kwargs): execution = format_execution_status(execution) except resource.ResourceNotFoundError: self.print_not_found(args.id) - raise ResourceNotFoundError('Execution with id %s not found.' % (args.id)) + raise ResourceNotFoundError("Execution with id %s not found." % (args.id)) return self._print_execution_details(execution=execution, args=args, **kwargs) class ActionExecutionCancelCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ActionExecutionCancelCommand, self).__init__( - resource, 'cancel', 'Cancel %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help=('IDs of the %ss to cancel.' % - resource.get_display_name().lower())) + resource, + "cancel", + "Cancel %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", + nargs="+", + help=("IDs of the %ss to cancel." % resource.get_display_name().lower()), + ) def run(self, args, **kwargs): responses = [] @@ -1242,16 +1524,23 @@ def run_and_print(self, args, **kwargs): self._print_result(execution_id=execution_id, response=response) def _print_result(self, execution_id, response): - if response and 'faultstring' in response: - message = response.get('faultstring', 'Cancellation requested for %s with id %s.' % - (self.resource.get_display_name().lower(), execution_id)) + if response and "faultstring" in response: + message = response.get( + "faultstring", + "Cancellation requested for %s with id %s." + % (self.resource.get_display_name().lower(), execution_id), + ) elif response: - message = '%s with id %s canceled.' % (self.resource.get_display_name().lower(), - execution_id) + message = "%s with id %s canceled." % ( + self.resource.get_display_name().lower(), + execution_id, + ) else: - message = 'Cannot cancel %s with id %s.' % (self.resource.get_display_name().lower(), - execution_id) + message = "Cannot cancel %s with id %s." % ( + self.resource.get_display_name().lower(), + execution_id, + ) print(message) @@ -1259,35 +1548,58 @@ class ActionExecutionReRunCommand(ActionRunCommandMixin, resource.ResourceComman def __init__(self, resource, *args, **kwargs): super(ActionExecutionReRunCommand, self).__init__( - resource, kwargs.pop('name', 're-run'), - 'Re-run a particular action.', - *args, **kwargs) - - self.parser.add_argument('id', nargs='?', - metavar='id', - help='ID of action execution to re-run ') - self.parser.add_argument('parameters', nargs='*', - help='List of keyword args, positional args, ' - 'and optional args for the action.') - self.parser.add_argument('--tasks', nargs='*', - help='Name of the workflow tasks to re-run.') - self.parser.add_argument('--no-reset', dest='no_reset', nargs='*', - help='Name of the with-items tasks to not reset. This only ' - 'applies to Orquesta workflows. By default, all iterations ' - 'for with-items tasks is rerun. If no reset, only failed ' - ' iterations are rerun.') - self.parser.add_argument('-a', '--async', - action='store_true', dest='action_async', - help='Do not wait for action to finish.') - self.parser.add_argument('-e', '--inherit-env', - action='store_true', dest='inherit_env', - help='Pass all the environment variables ' - 'which are accessible to the CLI as "env" ' - 'parameter to the action. Note: Only works ' - 'with python, local and remote runners.') - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "re-run"), + "Re-run a particular action.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "id", nargs="?", metavar="id", help="ID of action execution to re-run " + ) + self.parser.add_argument( + "parameters", + nargs="*", + help="List of keyword args, positional args, " + "and optional args for the action.", + ) + self.parser.add_argument( + "--tasks", nargs="*", help="Name of the workflow tasks to re-run." + ) + self.parser.add_argument( + "--no-reset", + dest="no_reset", + nargs="*", + help="Name of the with-items tasks to not reset. This only " + "applies to Orquesta workflows. By default, all iterations " + "for with-items tasks is rerun. If no reset, only failed " + " iterations are rerun.", + ) + self.parser.add_argument( + "-a", + "--async", + action="store_true", + dest="action_async", + help="Do not wait for action to finish.", + ) + self.parser.add_argument( + "-e", + "--inherit-env", + action="store_true", + dest="inherit_env", + help="Pass all the environment variables " + 'which are accessible to the CLI as "env" ' + "parameter to the action. Note: Only works " + "with python, local and remote runners.", + ) + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) self._add_common_options() @add_auth_token_to_kwargs_from_cli @@ -1295,47 +1607,63 @@ def run(self, args, **kwargs): existing_execution = self.manager.get_by_id(args.id, **kwargs) if not existing_execution: - raise resource.ResourceNotFoundError('Action execution with id "%s" cannot be found.' % - (args.id)) + raise resource.ResourceNotFoundError( + 'Action execution with id "%s" cannot be found.' % (args.id) + ) - action_mgr = self.app.client.managers['Action'] - runner_mgr = self.app.client.managers['RunnerType'] - action_exec_mgr = self.app.client.managers['Execution'] + action_mgr = self.app.client.managers["Action"] + runner_mgr = self.app.client.managers["RunnerType"] + action_exec_mgr = self.app.client.managers["Execution"] - action_ref = existing_execution.action['ref'] + action_ref = existing_execution.action["ref"] action = action_mgr.get_by_ref_or_id(action_ref) runner = runner_mgr.get_by_name(action.runner_type) - action_parameters = self._get_action_parameters_from_args(action=action, runner=runner, - args=args) + action_parameters = self._get_action_parameters_from_args( + action=action, runner=runner, args=args + ) - execution = action_exec_mgr.re_run(execution_id=args.id, - parameters=action_parameters, - tasks=args.tasks, - no_reset=args.no_reset, - delay=args.delay if args.delay else 0, - **kwargs) + execution = action_exec_mgr.re_run( + execution_id=args.id, + parameters=action_parameters, + tasks=args.tasks, + no_reset=args.no_reset, + delay=args.delay if args.delay else 0, + **kwargs, + ) - execution = self._get_execution_result(execution=execution, - action_exec_mgr=action_exec_mgr, - args=args, **kwargs) + execution = self._get_execution_result( + execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs + ) return execution class ActionExecutionPauseCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionPauseCommand, self).__init__( - resource, 'pause', 'Pause %s (workflow executions only).' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help='ID of action execution to pause.') + resource, + "pause", + "Pause %s (workflow executions only)." + % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", nargs="+", help="ID of action execution to pause." + ) self._add_common_options() @@ -1348,7 +1676,9 @@ def run(self, args, **kwargs): responses.append([execution_id, response]) except resource.ResourceNotFoundError: self.print_not_found(args.ids) - raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id)) + raise ResourceNotFoundError( + "Execution with id %s not found." % (execution_id) + ) return responses @@ -1367,18 +1697,30 @@ def _print_result(self, args, execution_id, execution, **kwargs): class ActionExecutionResumeCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionResumeCommand, self).__init__( - resource, 'resume', 'Resume %s (workflow executions only).' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help='ID of action execution to resume.') + resource, + "resume", + "Resume %s (workflow executions only)." + % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", nargs="+", help="ID of action execution to resume." + ) self._add_common_options() @@ -1391,7 +1733,9 @@ def run(self, args, **kwargs): responses.append([execution_id, response]) except resource.ResourceNotFoundError: self.print_not_found(execution_id) - raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id)) + raise ResourceNotFoundError( + "Execution with id %s not found." % (execution_id) + ) return responses @@ -1412,22 +1756,33 @@ def _print_result(self, args, execution, **kwargs): class ActionExecutionTailCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(ActionExecutionTailCommand, self).__init__( - resource, kwargs.pop('name', 'tail'), - 'Tail output of a particular execution.', - *args, **kwargs) - - self.parser.add_argument('id', nargs='?', - metavar='id', - default='last', - help='ID of action execution to tail.') - self.parser.add_argument('--type', dest='output_type', action='store', - help=('Type of output to tail for. If not provided, ' - 'defaults to all.')) - self.parser.add_argument('--include-metadata', dest='include_metadata', - action='store_true', - default=False, - help=('Include metadata (timestamp, output type) with the ' - 'output.')) + resource, + kwargs.pop("name", "tail"), + "Tail output of a particular execution.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "id", + nargs="?", + metavar="id", + default="last", + help="ID of action execution to tail.", + ) + self.parser.add_argument( + "--type", + dest="output_type", + action="store", + help=("Type of output to tail for. If not provided, " "defaults to all."), + ) + self.parser.add_argument( + "--include-metadata", + dest="include_metadata", + action="store_true", + default=False, + help=("Include metadata (timestamp, output type) with the " "output."), + ) def run(self, args, **kwargs): pass @@ -1435,45 +1790,55 @@ def run(self, args, **kwargs): @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): execution_id = args.id - output_type = getattr(args, 'output_type', None) + output_type = getattr(args, "output_type", None) include_metadata = args.include_metadata # Special case for id "last" - if execution_id == 'last': + if execution_id == "last": executions = self.manager.query(limit=1) if executions: execution = executions[0] execution_id = execution.id else: - print('No executions found in db.') + print("No executions found in db.") return else: execution = self.manager.get_by_id(execution_id, **kwargs) if not execution: - raise ResourceNotFoundError('Execution with id %s not found.' % (args.id)) + raise ResourceNotFoundError("Execution with id %s not found." % (args.id)) execution_manager = self.manager - stream_manager = self.app.client.managers['Stream'] - ActionExecutionTailCommand.tail_execution(execution=execution, - execution_manager=execution_manager, - stream_manager=stream_manager, - output_type=output_type, - include_metadata=include_metadata, - **kwargs) + stream_manager = self.app.client.managers["Stream"] + ActionExecutionTailCommand.tail_execution( + execution=execution, + execution_manager=execution_manager, + stream_manager=stream_manager, + output_type=output_type, + include_metadata=include_metadata, + **kwargs, + ) @classmethod - def tail_execution(cls, execution_manager, stream_manager, execution, output_type=None, - include_metadata=False, **kwargs): + def tail_execution( + cls, + execution_manager, + stream_manager, + execution, + output_type=None, + include_metadata=False, + **kwargs, + ): execution_id = str(execution.id) # Indicates if the execution we are tailing is a child execution in a workflow context = cls.get_normalized_context_execution_task_event(execution.__dict__) - has_parent_attribute = bool(getattr(execution, 'parent', None)) - has_parent_execution_id = bool(context['parent_execution_id']) + has_parent_attribute = bool(getattr(execution, "parent", None)) + has_parent_execution_id = bool(context["parent_execution_id"]) - is_tailing_execution_child_execution = bool(has_parent_attribute or - has_parent_execution_id) + is_tailing_execution_child_execution = bool( + has_parent_attribute or has_parent_execution_id + ) # Note: For non-workflow actions child_execution_id always matches parent_execution_id so # we don't need to do any other checks to determine if executions represents a workflow @@ -1484,10 +1849,14 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ # NOTE: This doesn't recurse down into child executions if user is tailing a workflow # execution if execution.status in LIVEACTION_COMPLETED_STATES: - output = execution_manager.get_output(execution_id=execution_id, - output_type=output_type) + output = execution_manager.get_output( + execution_id=execution_id, output_type=output_type + ) print(output) - print('Execution %s has completed (status=%s).' % (execution_id, execution.status)) + print( + "Execution %s has completed (status=%s)." + % (execution_id, execution.status) + ) return # We keep track of all the workflow executions which could contain children. @@ -1497,29 +1866,27 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ # Retrieve parent execution object so we can keep track of any existing children # executions (only applies to already running executions). - filters = { - 'params': { - 'include_attributes': 'id,children' - } - } + filters = {"params": {"include_attributes": "id,children"}} execution = execution_manager.get_by_id(id=execution_id, **filters) - children_execution_ids = getattr(execution, 'children', []) + children_execution_ids = getattr(execution, "children", []) workflow_execution_ids.update(children_execution_ids) - events = ['st2.execution__update', 'st2.execution.output__create'] - for event in stream_manager.listen(events, - end_execution_id=execution_id, - end_event="st2.execution__update", - **kwargs): - status = event.get('status', None) + events = ["st2.execution__update", "st2.execution.output__create"] + for event in stream_manager.listen( + events, + end_execution_id=execution_id, + end_event="st2.execution__update", + **kwargs, + ): + status = event.get("status", None) is_execution_event = status is not None if is_execution_event: context = cls.get_normalized_context_execution_task_event(event) - task_execution_id = context['execution_id'] - task_name = context['task_name'] - task_parent_execution_id = context['parent_execution_id'] + task_execution_id = context["execution_id"] + task_name = context["task_name"] + task_parent_execution_id = context["parent_execution_id"] # An execution is considered a child execution if it has parent execution id is_child_execution = bool(task_parent_execution_id) @@ -1536,14 +1903,18 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ if is_child_execution: if status == LIVEACTION_STATUS_RUNNING: - print('Child execution (task=%s) %s has started.' % (task_name, - task_execution_id)) - print('') + print( + "Child execution (task=%s) %s has started." + % (task_name, task_execution_id) + ) + print("") continue elif status in LIVEACTION_COMPLETED_STATES: - print('') - print('Child execution (task=%s) %s has finished (status=%s).' % - (task_name, task_execution_id, status)) + print("") + print( + "Child execution (task=%s) %s has finished (status=%s)." + % (task_name, task_execution_id, status) + ) if is_tailing_execution_child_execution: # User is tailing a child execution inside a workflow, stop the command. @@ -1556,56 +1927,69 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ else: # NOTE: In some situations execution update event with "running" status is # dispatched twice so we ignore any duplicated events - if status == LIVEACTION_STATUS_RUNNING and not event.get('children', []): - print('Execution %s has started.' % (execution_id)) - print('') + if status == LIVEACTION_STATUS_RUNNING and not event.get( + "children", [] + ): + print("Execution %s has started." % (execution_id)) + print("") continue elif status in LIVEACTION_COMPLETED_STATES: # Bail out once parent execution has finished - print('') - print('Execution %s has completed (status=%s).' % (execution_id, status)) + print("") + print( + "Execution %s has completed (status=%s)." + % (execution_id, status) + ) break else: # We don't care about other execution events continue # Ignore events for executions which don't belong to the one we are tailing - event_execution_id = event['execution_id'] + event_execution_id = event["execution_id"] if event_execution_id not in workflow_execution_ids: continue # Filter on output_type if provided - event_output_type = event.get('output_type', None) - if output_type != 'all' and output_type and (event_output_type != output_type): + event_output_type = event.get("output_type", None) + if ( + output_type != "all" + and output_type + and (event_output_type != output_type) + ): continue if include_metadata: - sys.stdout.write('[%s][%s] %s' % (event['timestamp'], event['output_type'], - event['data'])) + sys.stdout.write( + "[%s][%s] %s" + % (event["timestamp"], event["output_type"], event["data"]) + ) else: - sys.stdout.write(event['data']) + sys.stdout.write(event["data"]) @classmethod def get_normalized_context_execution_task_event(cls, event): """ Return a dictionary with normalized context attributes for execution event or object. """ - context = event.get('context', {}) - - result = { - 'parent_execution_id': None, - 'execution_id': None, - 'task_name': None - } - - if 'orquesta' in context: - result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None) - result['execution_id'] = event['id'] - result['task_name'] = context.get('orquesta', {}).get('task_name', 'unknown') + context = event.get("context", {}) + + result = {"parent_execution_id": None, "execution_id": None, "task_name": None} + + if "orquesta" in context: + result["parent_execution_id"] = context.get("parent", {}).get( + "execution_id", None + ) + result["execution_id"] = event["id"] + result["task_name"] = context.get("orquesta", {}).get( + "task_name", "unknown" + ) else: # Action chain workflow - result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None) - result['execution_id'] = event['id'] - result['task_name'] = context.get('chain', {}).get('name', 'unknown') + result["parent_execution_id"] = context.get("parent", {}).get( + "execution_id", None + ) + result["execution_id"] = event["id"] + result["task_name"] = context.get("chain", {}).get("name", "unknown") return result diff --git a/st2client/st2client/commands/action_alias.py b/st2client/st2client/commands/action_alias.py index 32a65776cc..d6f5fbcfc1 100644 --- a/st2client/st2client/commands/action_alias.py +++ b/st2client/st2client/commands/action_alias.py @@ -22,63 +22,87 @@ from st2client.formatters import table -__all__ = [ - 'ActionAliasBranch', - 'ActionAliasMatchCommand', - 'ActionAliasExecuteCommand' -] +__all__ = ["ActionAliasBranch", "ActionAliasMatchCommand", "ActionAliasExecuteCommand"] class ActionAliasBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ActionAliasBranch, self).__init__( - ActionAlias, description, app, subparsers, - parent_parser=parent_parser, read_only=False, - commands={ - 'list': ActionAliasListCommand, - 'get': ActionAliasGetCommand - }) - - self.commands['match'] = ActionAliasMatchCommand( - self.resource, self.app, self.subparsers, - add_help=True) - self.commands['execute'] = ActionAliasExecuteCommand( - self.resource, self.app, self.subparsers, - add_help=True) + ActionAlias, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=False, + commands={"list": ActionAliasListCommand, "get": ActionAliasGetCommand}, + ) + + self.commands["match"] = ActionAliasMatchCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["execute"] = ActionAliasExecuteCommand( + self.resource, self.app, self.subparsers, add_help=True + ) class ActionAliasListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] class ActionAliasGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'action_ref', 'formats'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "action_ref", + "formats", + ] class ActionAliasMatchCommand(resource.ResourceCommand): - display_attributes = ['name', 'description'] + display_attributes = ["name", "description"] def __init__(self, resource, *args, **kwargs): super(ActionAliasMatchCommand, self).__init__( - resource, 'match', - 'Get the %s that match the command text.' % - resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('match_text', - metavar='command', - help=('Get the %s that match the command text.' % - resource.get_display_name().lower())) - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + resource, + "match", + "Get the %s that match the command text." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "match_text", + metavar="command", + help=( + "Get the %s that match the command text." + % resource.get_display_name().lower() + ), + ) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -90,40 +114,62 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ActionAliasExecuteCommand(resource.ResourceCommand): - display_attributes = ['name'] + display_attributes = ["name"] def __init__(self, resource, *args, **kwargs): super(ActionAliasExecuteCommand, self).__init__( - resource, 'execute', - ('Execute the command text by finding a matching %s.' % - resource.get_display_name().lower()), *args, **kwargs) - - self.parser.add_argument('command_text', - metavar='command', - help=('Execute the command text by finding a matching %s.' % - resource.get_display_name().lower())) - self.parser.add_argument('-u', '--user', type=str, default=None, - help='User under which to run the action (admins only).') + resource, + "execute", + ( + "Execute the command text by finding a matching %s." + % resource.get_display_name().lower() + ), + *args, + **kwargs, + ) + + self.parser.add_argument( + "command_text", + metavar="command", + help=( + "Execute the command text by finding a matching %s." + % resource.get_display_name().lower() + ), + ) + self.parser.add_argument( + "-u", + "--user", + type=str, + default=None, + help="User under which to run the action (admins only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): payload = core.Resource() payload.command = args.command_text payload.user = args.user or "" - payload.source_channel = 'cli' + payload.source_channel = "cli" - alias_execution_mgr = self.app.client.managers['ActionAliasExecution'] + alias_execution_mgr = self.app.client.managers["ActionAliasExecution"] execution = alias_execution_mgr.match_and_execute(payload) return execution def run_and_print(self, args, **kwargs): execution = self.run(args, **kwargs) - print("Matching Action-alias: '%s'" % execution.actionalias['ref']) - print("To get the results, execute:\n st2 execution get %s" % - (execution.execution['id'])) + print("Matching Action-alias: '%s'" % execution.actionalias["ref"]) + print( + "To get the results, execute:\n st2 execution get %s" + % (execution.execution["id"]) + ) diff --git a/st2client/st2client/commands/auth.py b/st2client/st2client/commands/auth.py index 40066d1a5e..5b0507f324 100644 --- a/st2client/st2client/commands/auth.py +++ b/st2client/st2client/commands/auth.py @@ -39,36 +39,54 @@ class TokenCreateCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(TokenCreateCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Authenticate user and acquire access token.', - *args, **kwargs) - - self.parser.add_argument('username', - help='Name of the user to authenticate.') - - self.parser.add_argument('-p', '--password', dest='password', - help='Password for the user. If password is not provided, ' - 'it will be prompted for.') - self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None, - help='The life span of the token in seconds. ' - 'Max TTL configured by the admin supersedes this.') - self.parser.add_argument('-t', '--only-token', action='store_true', dest='only_token', - default=False, - help='On successful authentication, print only token to the ' - 'console.') + resource, + kwargs.pop("name", "create"), + "Authenticate user and acquire access token.", + *args, + **kwargs, + ) + + self.parser.add_argument("username", help="Name of the user to authenticate.") + + self.parser.add_argument( + "-p", + "--password", + dest="password", + help="Password for the user. If password is not provided, " + "it will be prompted for.", + ) + self.parser.add_argument( + "-l", + "--ttl", + type=int, + dest="ttl", + default=None, + help="The life span of the token in seconds. " + "Max TTL configured by the admin supersedes this.", + ) + self.parser.add_argument( + "-t", + "--only-token", + action="store_true", + dest="only_token", + default=False, + help="On successful authentication, print only token to the " "console.", + ) def run(self, args, **kwargs): if not args.password: args.password = getpass.getpass() instance = self.resource(ttl=args.ttl) if args.ttl else self.resource() - return self.manager.create(instance, auth=(args.username, args.password), **kwargs) + return self.manager.create( + instance, auth=(args.username, args.password), **kwargs + ) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) @@ -76,35 +94,57 @@ def run_and_print(self, args, **kwargs): if args.only_token: print(instance.token) else: - self.print_output(instance, table.PropertyValueTable, - attributes=self.display_attributes, json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=self.display_attributes, + json=args.json, + yaml=args.yaml, + ) class LoginCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(LoginCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Authenticate user, acquire access token, and update CLI config directory', - *args, **kwargs) - - self.parser.add_argument('username', - help='Name of the user to authenticate.') - - self.parser.add_argument('-p', '--password', dest='password', - help='Password for the user. If password is not provided, ' - 'it will be prompted for.') - self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None, - help='The life span of the token in seconds. ' - 'Max TTL configured by the admin supersedes this.') - self.parser.add_argument('-w', '--write-password', action='store_true', default=False, - dest='write_password', - help='Write the password in plain text to the config file ' - '(default is to omit it)') + resource, + kwargs.pop("name", "create"), + "Authenticate user, acquire access token, and update CLI config directory", + *args, + **kwargs, + ) + + self.parser.add_argument("username", help="Name of the user to authenticate.") + + self.parser.add_argument( + "-p", + "--password", + dest="password", + help="Password for the user. If password is not provided, " + "it will be prompted for.", + ) + self.parser.add_argument( + "-l", + "--ttl", + type=int, + dest="ttl", + default=None, + help="The life span of the token in seconds. " + "Max TTL configured by the admin supersedes this.", + ) + self.parser.add_argument( + "-w", + "--write-password", + action="store_true", + default=False, + dest="write_password", + help="Write the password in plain text to the config file " + "(default is to omit it)", + ) def run(self, args, **kwargs): @@ -122,7 +162,9 @@ def run(self, args, **kwargs): config_file = config_parser.ST2_CONFIG_PATH # Retrieve token - manager = self.manager.create(instance, auth=(args.username, args.password), **kwargs) + manager = self.manager.create( + instance, auth=(args.username, args.password), **kwargs + ) cli._cache_auth_token(token_obj=manager) # Update existing configuration with new credentials @@ -130,18 +172,18 @@ def run(self, args, **kwargs): config.read(config_file) # Modify config (and optionally populate with password) - if not config.has_section('credentials'): - config.add_section('credentials') + if not config.has_section("credentials"): + config.add_section("credentials") - config.set('credentials', 'username', args.username) + config.set("credentials", "username", args.username) if args.write_password: - config.set('credentials', 'password', args.password) + config.set("credentials", "password", args.password) else: # Remove any existing password from config - config.remove_option('credentials', 'password') + config.remove_option("credentials", "password") config_existed = os.path.exists(config_file) - with open(config_file, 'w') as cfg_file_out: + with open(config_file, "w") as cfg_file_out: config.write(cfg_file_out) # If we created the config file, correct the permissions if not config_existed: @@ -156,35 +198,44 @@ def run_and_print(self, args, **kwargs): if self.app.client.debug: raise - raise Exception('Failed to log in as %s: %s' % (args.username, six.text_type(e))) + raise Exception( + "Failed to log in as %s: %s" % (args.username, six.text_type(e)) + ) - print('Logged in as %s' % (args.username)) + print("Logged in as %s" % (args.username)) if not args.write_password: # Note: Client can't depend and import from common so we need to hard-code this # default value token_expire_hours = 24 - print('') - print('Note: You didn\'t use --write-password option so the password hasn\'t been ' - 'stored in the client config and you will need to login again in %s hours when ' - 'the auth token expires.' % (token_expire_hours)) - print('As an alternative, you can run st2 login command with the "--write-password" ' - 'flag, but keep it mind this will cause it to store the password in plain-text ' - 'in the client config file (~/.st2/config).') + print("") + print( + "Note: You didn't use --write-password option so the password hasn't been " + "stored in the client config and you will need to login again in %s hours when " + "the auth token expires." % (token_expire_hours) + ) + print( + 'As an alternative, you can run st2 login command with the "--write-password" ' + "flag, but keep it mind this will cause it to store the password in plain-text " + "in the client config file (~/.st2/config)." + ) class WhoamiCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(WhoamiCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Display the currently authenticated user', - *args, **kwargs) + resource, + kwargs.pop("name", "create"), + "Display the currently authenticated user", + *args, + **kwargs, + ) def run(self, args, **kwargs): user_info = self.app.client.get_user_info(**kwargs) @@ -194,119 +245,157 @@ def run_and_print(self, args, **kwargs): try: user_info = self.run(args, **kwargs) except Exception as e: - response = getattr(e, 'response', None) - status_code = getattr(response, 'status_code', None) - is_unathorized_error = (status_code == http_client.UNAUTHORIZED) + response = getattr(e, "response", None) + status_code = getattr(response, "status_code", None) + is_unathorized_error = status_code == http_client.UNAUTHORIZED if response and is_unathorized_error: - print('Not authenticated') + print("Not authenticated") else: - print('Unable to retrieve currently logged-in user') + print("Unable to retrieve currently logged-in user") if self.app.client.debug: raise return - print('Currently logged in as "%s".' % (user_info['username'])) - print('') - print('Authentication method: %s' % (user_info['authentication']['method'])) + print('Currently logged in as "%s".' % (user_info["username"])) + print("") + print("Authentication method: %s" % (user_info["authentication"]["method"])) - if user_info['authentication']['method'] == 'authentication token': - print('Authentication token expire time: %s' % - (user_info['authentication']['token_expire'])) + if user_info["authentication"]["method"] == "authentication token": + print( + "Authentication token expire time: %s" + % (user_info["authentication"]["token_expire"]) + ) - print('') - print('RBAC:') - print(' - Enabled: %s' % (user_info['rbac']['enabled'])) - print(' - Roles: %s' % (', '.join(user_info['rbac']['roles']))) + print("") + print("RBAC:") + print(" - Enabled: %s" % (user_info["rbac"]["enabled"])) + print(" - Roles: %s" % (", ".join(user_info["rbac"]["roles"]))) class ApiKeyBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ApiKeyBranch, self).__init__( - models.ApiKey, description, app, subparsers, + models.ApiKey, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': ApiKeyListCommand, - 'get': ApiKeyGetCommand, - 'create': ApiKeyCreateCommand, - 'update': NoopCommand, - 'delete': ApiKeyDeleteCommand - }) - - self.commands['enable'] = ApiKeyEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = ApiKeyDisableCommand(self.resource, self.app, self.subparsers) - self.commands['load'] = ApiKeyLoadCommand(self.resource, self.app, self.subparsers) + "list": ApiKeyListCommand, + "get": ApiKeyGetCommand, + "create": ApiKeyCreateCommand, + "update": NoopCommand, + "delete": ApiKeyDeleteCommand, + }, + ) + + self.commands["enable"] = ApiKeyEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = ApiKeyDisableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["load"] = ApiKeyLoadCommand( + self.resource, self.app, self.subparsers + ) class ApiKeyListCommand(resource.ResourceListCommand): - detail_display_attributes = ['all'] - display_attributes = ['id', 'user', 'metadata'] + detail_display_attributes = ["all"] + display_attributes = ["id", "user", "metadata"] def __init__(self, resource, *args, **kwargs): super(ApiKeyListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-u', '--user', type=str, - help='Only return ApiKeys belonging to the provided user') - self.parser.add_argument('-d', '--detail', action='store_true', - help='Full list of attributes.') - self.parser.add_argument('--show-secrets', action='store_true', - help='Full list of attributes.') + self.parser.add_argument( + "-u", + "--user", + type=str, + help="Only return ApiKeys belonging to the provided user", + ) + self.parser.add_argument( + "-d", "--detail", action="store_true", help="Full list of attributes." + ) + self.parser.add_argument( + "--show-secrets", action="store_true", help="Full list of attributes." + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): filters = {} - filters['user'] = args.user + filters["user"] = args.user filters.update(**kwargs) # show_secrets is not a filter but a query param. There is some special # handling for filters in the get method which reuqires this odd hack. if args.show_secrets: - params = filters.get('params', {}) - params['show_secrets'] = True - filters['params'] = params + params = filters.get("params", {}) + params["show_secrets"] = True + filters["params"] = params return self.manager.get_all(**filters) def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) attr = self.detail_display_attributes if args.detail else args.attr - self.print_output(instances, table.MultiColumnTable, - attributes=attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ApiKeyGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'user', 'metadata'] + display_attributes = ["all"] + attribute_display_order = ["id", "user", "metadata"] - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyCreateCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ApiKeyCreateCommand, self).__init__( - resource, 'create', 'Create a new %s.' % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('-u', '--user', type=str, - help='User for which to create API Keys.', - default='') - self.parser.add_argument('-m', '--metadata', type=json.loads, - help='Optional metadata to associate with the API Keys.', - default={}) - self.parser.add_argument('-k', '--only-key', action='store_true', dest='only_key', - default=False, - help='Only print API Key to the console on creation.') + resource, + "create", + "Create a new %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "-u", + "--user", + type=str, + help="User for which to create API Keys.", + default="", + ) + self.parser.add_argument( + "-m", + "--metadata", + type=json.loads, + help="Optional metadata to associate with the API Keys.", + default={}, + ) + self.parser.add_argument( + "-k", + "--only-key", + action="store_true", + dest="only_key", + default=False, + help="Only print API Key to the console on creation.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): data = {} if args.user: - data['user'] = args.user + data["user"] = args.user if args.metadata: - data['metadata'] = args.metadata + data["metadata"] = args.metadata instance = self.resource.deserialize(data) return self.manager.create(instance, **kwargs) @@ -314,39 +403,59 @@ def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') + raise Exception("Server did not create instance.") except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) if args.only_key: print(instance.key) else: - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ApiKeyLoadCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ApiKeyLoadCommand, self).__init__( - resource, 'load', 'Load %s from a file.' % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s(s) to load.' - % resource.get_display_name().lower()), - default='') - - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + resource, + "load", + "Load %s from a file." % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s(s) to load." + % resource.get_display_name().lower() + ), + default="", + ) + + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resources = resource.load_meta_file(args.file) if not resources: - print('No %s found in %s.' % (self.resource.get_display_name().lower(), args.file)) + print( + "No %s found in %s." + % (self.resource.get_display_name().lower(), args.file) + ) return None if not isinstance(resources, list): resources = [resources] @@ -354,14 +463,14 @@ def run(self, args, **kwargs): for res in resources: # pick only the meaningful properties. data = { - 'user': res['user'], # required - 'key_hash': res['key_hash'], # required - 'metadata': res.get('metadata', {}), - 'enabled': res.get('enabled', False) + "user": res["user"], # required + "key_hash": res["key_hash"], # required + "metadata": res.get("metadata", {}), + "enabled": res.get("enabled", False), } - if 'id' in res: - data['id'] = res['id'] + if "id" in res: + data["id"] = res["id"] instance = self.resource.deserialize(data) @@ -381,19 +490,23 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) if instances: - self.print_output(instances, table.MultiColumnTable, - attributes=ApiKeyListCommand.display_attributes, - widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=ApiKeyListCommand.display_attributes, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ApiKeyDeleteCommand(resource.ResourceDeleteCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyEnableCommand(resource.ResourceEnableCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyDisableCommand(resource.ResourceDisableCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK diff --git a/st2client/st2client/commands/inquiry.py b/st2client/st2client/commands/inquiry.py index 250c86b3d7..d9395a5c54 100644 --- a/st2client/st2client/commands/inquiry.py +++ b/st2client/st2client/commands/inquiry.py @@ -25,60 +25,81 @@ LOG = logging.getLogger(__name__) -DEFAULT_SCOPE = 'system' +DEFAULT_SCOPE = "system" class InquiryBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(InquiryBranch, self).__init__( - Inquiry, description, app, subparsers, - parent_parser=parent_parser, read_only=True, - commands={'list': InquiryListCommand, - 'get': InquiryGetCommand}) + Inquiry, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, + commands={"list": InquiryListCommand, "get": InquiryGetCommand}, + ) # Register extended commands - self.commands['respond'] = InquiryRespondCommand( - self.resource, self.app, self.subparsers) + self.commands["respond"] = InquiryRespondCommand( + self.resource, self.app, self.subparsers + ) class InquiryListCommand(resource.ResourceCommand): # Omitting "schema" and "response", as it doesn't really show up in a table well. # The user can drill into a specific Inquiry to get this - display_attributes = [ - 'id', - 'roles', - 'users', - 'route', - 'ttl' - ] + display_attributes = ["id", "roles", "users", "route", "ttl"] def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(InquiryListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -87,17 +108,21 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, - yaml=args.yaml) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class InquiryGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'id' - display_attributes = ['id', 'roles', 'users', 'route', 'ttl', 'schema'] + pk_argument_name = "id" + display_attributes = ["id", "roles", "users", "route", "ttl", "schema"] def __init__(self, kv_resource, *args, **kwargs): super(InquiryGetCommand, self).__init__(kv_resource, *args, **kwargs) @@ -109,22 +134,28 @@ def run(self, args, **kwargs): class InquiryRespondCommand(resource.ResourceCommand): - display_attributes = ['id', 'response'] + display_attributes = ["id", "response"] def __init__(self, resource, *args, **kwargs): super(InquiryRespondCommand, self).__init__( - resource, 'respond', - 'Respond to an %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "respond", + "Respond to an %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) - self.parser.add_argument('id', - metavar='id', - help='Inquiry ID') - self.parser.add_argument('-r', '--response', type=str, dest='response', - default=None, - help=('Entire response payload as JSON string ' - '(bypass interactive mode)')) + self.parser.add_argument("id", metavar="id", help="Inquiry ID") + self.parser.add_argument( + "-r", + "--response", + type=str, + dest="response", + default=None, + help=( + "Entire response payload as JSON string " "(bypass interactive mode)" + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -135,12 +166,13 @@ def run(self, args, **kwargs): instance.response = json.loads(args.response) else: response = InteractiveForm( - inquiry.schema.get('properties')).initiate_dialog() + inquiry.schema.get("properties") + ).initiate_dialog() instance.response = response - return self.manager.respond(inquiry_id=instance.id, - inquiry_response=instance.response, - **kwargs) + return self.manager.respond( + inquiry_id=instance.id, inquiry_response=instance.response, **kwargs + ) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) diff --git a/st2client/st2client/commands/keyvalue.py b/st2client/st2client/commands/keyvalue.py index 9c0d06f806..e87f6afa35 100644 --- a/st2client/st2client/commands/keyvalue.py +++ b/st2client/st2client/commands/keyvalue.py @@ -31,83 +31,125 @@ LOG = logging.getLogger(__name__) -DEFAULT_LIST_SCOPE = 'all' -DEFAULT_GET_SCOPE = 'system' -DEFAULT_CUD_SCOPE = 'system' +DEFAULT_LIST_SCOPE = "all" +DEFAULT_GET_SCOPE = "system" +DEFAULT_CUD_SCOPE = "system" class KeyValuePairBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(KeyValuePairBranch, self).__init__( - KeyValuePair, description, app, subparsers, + KeyValuePair, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': KeyValuePairListCommand, - 'get': KeyValuePairGetCommand, - 'delete': KeyValuePairDeleteCommand, - 'create': NoopCommand, - 'update': NoopCommand - }) + "list": KeyValuePairListCommand, + "get": KeyValuePairGetCommand, + "delete": KeyValuePairDeleteCommand, + "create": NoopCommand, + "update": NoopCommand, + }, + ) # Registers extended commands - self.commands['set'] = KeyValuePairSetCommand(self.resource, self.app, - self.subparsers) - self.commands['load'] = KeyValuePairLoadCommand( - self.resource, self.app, self.subparsers) - self.commands['delete_by_prefix'] = KeyValuePairDeleteByPrefixCommand( - self.resource, self.app, self.subparsers) + self.commands["set"] = KeyValuePairSetCommand( + self.resource, self.app, self.subparsers + ) + self.commands["load"] = KeyValuePairLoadCommand( + self.resource, self.app, self.subparsers + ) + self.commands["delete_by_prefix"] = KeyValuePairDeleteByPrefixCommand( + self.resource, self.app, self.subparsers + ) # Remove unsupported commands # TODO: Refactor parent class and make it nicer - del self.commands['create'] - del self.commands['update'] + del self.commands["create"] + del self.commands["update"] class KeyValuePairListCommand(resource.ResourceTableCommand): - display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'user', - 'expire_timestamp'] + display_attributes = [ + "name", + "value", + "secret", + "encrypted", + "scope", + "user", + "expire_timestamp", + ] attribute_transform_functions = { - 'expire_timestamp': format_isodate_for_user_timezone, + "expire_timestamp": format_isodate_for_user_timezone, } def __init__(self, resource, *args, **kwargs): self.default_limit = 50 - super(KeyValuePairListCommand, self).__init__(resource, 'list', - 'Get the list of the %s most recent %s.' % - (self.default_limit, - resource.get_plural_display_name().lower()), - *args, **kwargs) + super(KeyValuePairListCommand, self).__init__( + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() # Filter options - self.parser.add_argument('--prefix', help=('Only return values with names starting with ' - 'the provided prefix.')) - self.parser.add_argument('-d', '--decrypt', action='store_true', - help='Decrypt secrets and displays plain text.') - self.parser.add_argument('-s', '--scope', default=DEFAULT_LIST_SCOPE, dest='scope', - help='Scope item is under. Example: "user".') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "--prefix", + help=( + "Only return values with names starting with " "the provided prefix." + ), + ) + self.parser.add_argument( + "-d", + "--decrypt", + action="store_true", + help="Decrypt secrets and displays plain text.", + ) + self.parser.add_argument( + "-s", + "--scope", + default=DEFAULT_LIST_SCOPE, + dest="scope", + help='Scope item is under. Example: "user".', + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.prefix: - kwargs['prefix'] = args.prefix + kwargs["prefix"] = args.prefix - decrypt = getattr(args, 'decrypt', False) - kwargs['params'] = {'decrypt': str(decrypt).lower()} - scope = getattr(args, 'scope', DEFAULT_LIST_SCOPE) - kwargs['params']['scope'] = scope + decrypt = getattr(args, "decrypt", False) + kwargs["params"] = {"decrypt": str(decrypt).lower()} + scope = getattr(args, "scope", DEFAULT_LIST_SCOPE) + kwargs["params"]["scope"] = scope if args.user: - kwargs['params']['user'] = args.user - kwargs['params']['limit'] = args.last + kwargs["params"]["user"] = args.user + kwargs["params"]["limit"] = args.last return self.manager.query_with_count(**kwargs) @@ -115,73 +157,124 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class KeyValuePairGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'name' - display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'expire_timestamp'] + pk_argument_name = "name" + display_attributes = [ + "name", + "value", + "secret", + "encrypted", + "scope", + "expire_timestamp", + ] def __init__(self, kv_resource, *args, **kwargs): super(KeyValuePairGetCommand, self).__init__(kv_resource, *args, **kwargs) - self.parser.add_argument('-d', '--decrypt', action='store_true', - help='Decrypt secret if encrypted and show plain text.') - self.parser.add_argument('-s', '--scope', default=DEFAULT_GET_SCOPE, dest='scope', - help='Scope item is under. Example: "user".') + self.parser.add_argument( + "-d", + "--decrypt", + action="store_true", + help="Decrypt secret if encrypted and show plain text.", + ) + self.parser.add_argument( + "-s", + "--scope", + default=DEFAULT_GET_SCOPE, + dest="scope", + help='Scope item is under. Example: "user".', + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resource_name = getattr(args, self.pk_argument_name, None) - decrypt = getattr(args, 'decrypt', False) - scope = getattr(args, 'scope', DEFAULT_GET_SCOPE) - kwargs['params'] = {'decrypt': str(decrypt).lower()} - kwargs['params']['scope'] = scope + decrypt = getattr(args, "decrypt", False) + scope = getattr(args, "scope", DEFAULT_GET_SCOPE) + kwargs["params"] = {"decrypt": str(decrypt).lower()} + kwargs["params"]["scope"] = scope return self.get_resource_by_id(id=resource_name, **kwargs) class KeyValuePairSetCommand(resource.ResourceCommand): - display_attributes = ['name', 'value', 'scope', 'expire_timestamp'] + display_attributes = ["name", "value", "scope", "expire_timestamp"] def __init__(self, resource, *args, **kwargs): super(KeyValuePairSetCommand, self).__init__( - resource, 'set', - 'Set an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "set", + "Set an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) # --encrypt and --encrypted options are mutually exclusive. # --encrypt implies provided value is plain text and should be encrypted whereas # --encrypted implies value is already encrypted and should be treated as-is. encryption_group = self.parser.add_mutually_exclusive_group() - encryption_group.add_argument('-e', '--encrypt', dest='secret', - action='store_true', - help='Encrypt value before saving.') - encryption_group.add_argument('--encrypted', dest='encrypted', - action='store_true', - help=('Value provided is already encrypted with the ' - 'instance crypto key and should be stored as-is.')) - - self.parser.add_argument('name', - metavar='name', - help='Name of the key value pair.') - self.parser.add_argument('value', help='Value paired with the key.') - self.parser.add_argument('-l', '--ttl', dest='ttl', type=int, default=None, - help='TTL (in seconds) for this value.') - self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE, - help='Specify the scope under which you want ' + - 'to place the item.') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') + encryption_group.add_argument( + "-e", + "--encrypt", + dest="secret", + action="store_true", + help="Encrypt value before saving.", + ) + encryption_group.add_argument( + "--encrypted", + dest="encrypted", + action="store_true", + help=( + "Value provided is already encrypted with the " + "instance crypto key and should be stored as-is." + ), + ) + + self.parser.add_argument( + "name", metavar="name", help="Name of the key value pair." + ) + self.parser.add_argument("value", help="Value paired with the key.") + self.parser.add_argument( + "-l", + "--ttl", + dest="ttl", + type=int, + default=None, + help="TTL (in seconds) for this value.", + ) + self.parser.add_argument( + "-s", + "--scope", + dest="scope", + default=DEFAULT_CUD_SCOPE, + help="Specify the scope under which you want " + "to place the item.", + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -205,35 +298,49 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=self.display_attributes, json=args.json, - yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=self.display_attributes, + json=args.json, + yaml=args.yaml, + ) class KeyValuePairDeleteCommand(resource.ResourceDeleteCommand): - pk_argument_name = 'name' + pk_argument_name = "name" def __init__(self, resource, *args, **kwargs): super(KeyValuePairDeleteCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE, - help='Specify the scope under which you want ' + - 'to place the item.') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') + self.parser.add_argument( + "-s", + "--scope", + dest="scope", + default=DEFAULT_CUD_SCOPE, + help="Specify the scope under which you want " + "to place the item.", + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resource_id = getattr(args, self.pk_argument_name, None) - scope = getattr(args, 'scope', DEFAULT_CUD_SCOPE) - kwargs['params'] = {} - kwargs['params']['scope'] = scope - kwargs['params']['user'] = args.user + scope = getattr(args, "scope", DEFAULT_CUD_SCOPE) + kwargs["params"] = {} + kwargs["params"]["scope"] = scope + kwargs["params"]["user"] = args.user instance = self.get_resource(resource_id, **kwargs) if not instance: - raise resource.ResourceNotFoundError('KeyValuePair with id "%s" not found' - % resource_id) + raise resource.ResourceNotFoundError( + 'KeyValuePair with id "%s" not found' % resource_id + ) instance.id = resource_id # TODO: refactor and get rid of id self.manager.delete(instance, **kwargs) @@ -244,14 +351,23 @@ class KeyValuePairDeleteByPrefixCommand(resource.ResourceCommand): Commands which delete all the key value pairs which match the provided prefix. """ + def __init__(self, resource, *args, **kwargs): - super(KeyValuePairDeleteByPrefixCommand, self).__init__(resource, 'delete_by_prefix', - 'Delete KeyValue pairs which \ - match the provided prefix', - *args, **kwargs) + super(KeyValuePairDeleteByPrefixCommand, self).__init__( + resource, + "delete_by_prefix", + "Delete KeyValue pairs which \ + match the provided prefix", + *args, + **kwargs, + ) - self.parser.add_argument('-p', '--prefix', required=True, - help='Name prefix (e.g. twitter.TwitterSensor:)') + self.parser.add_argument( + "-p", + "--prefix", + required=True, + help="Name prefix (e.g. twitter.TwitterSensor:)", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -276,27 +392,39 @@ def run_and_print(self, args, **kwargs): deleted = self.run(args, **kwargs) key_ids = [key_pair.id for key_pair in deleted] - print('Deleted %s keys' % (len(deleted))) - print('Deleted key ids: %s' % (', '.join(key_ids))) + print("Deleted %s keys" % (len(deleted))) + print("Deleted key ids: %s" % (", ".join(key_ids))) class KeyValuePairLoadCommand(resource.ResourceCommand): - pk_argument_name = 'name' - display_attributes = ['name', 'value'] + pk_argument_name = "name" + display_attributes = ["name", "value"] def __init__(self, resource, *args, **kwargs): - help_text = ('Load a list of %s from file.' % - resource.get_plural_display_name().lower()) - super(KeyValuePairLoadCommand, self).__init__(resource, 'load', - help_text, *args, **kwargs) - - self.parser.add_argument('-c', '--convert', action='store_true', - help=('Convert non-string types (hash, array, boolean,' - ' int, float) to a JSON string before loading it' - ' into the datastore.')) + help_text = ( + "Load a list of %s from file." % resource.get_plural_display_name().lower() + ) + super(KeyValuePairLoadCommand, self).__init__( + resource, "load", help_text, *args, **kwargs + ) + + self.parser.add_argument( + "-c", + "--convert", + action="store_true", + help=( + "Convert non-string types (hash, array, boolean," + " int, float) to a JSON string before loading it" + " into the datastore." + ), + ) self.parser.add_argument( - 'file', help=('JSON/YAML file containing the %s(s) to load' - % resource.get_plural_display_name().lower())) + "file", + help=( + "JSON/YAML file containing the %s(s) to load" + % resource.get_plural_display_name().lower() + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -318,15 +446,15 @@ def run(self, args, **kwargs): for item in kvps: # parse required KeyValuePair properties - name = item['name'] - value = item['value'] + name = item["name"] + value = item["value"] # parse optional KeyValuePair properties - scope = item.get('scope', DEFAULT_CUD_SCOPE) - user = item.get('user', None) - encrypted = item.get('encrypted', False) - secret = item.get('secret', False) - ttl = item.get('ttl', None) + scope = item.get("scope", DEFAULT_CUD_SCOPE) + user = item.get("user", None) + encrypted = item.get("encrypted", False) + secret = item.get("secret", False) + ttl = item.get("ttl", None) # if the value is not a string, convert it to JSON # all keys in the datastore must strings @@ -334,10 +462,15 @@ def run(self, args, **kwargs): if args.convert: value = json.dumps(value) else: - raise ValueError(("Item '%s' has a value that is not a string." - " Either pass in the -c/--convert option to convert" - " non-string types to JSON strings automatically, or" - " convert the data to a string in the file") % name) + raise ValueError( + ( + "Item '%s' has a value that is not a string." + " Either pass in the -c/--convert option to convert" + " non-string types to JSON strings automatically, or" + " convert the data to a string in the file" + ) + % name + ) # create the KeyValuePair instance instance = KeyValuePair() @@ -368,7 +501,10 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=['name', 'value', 'secret', 'scope', 'user', 'ttl'], - json=args.json, - yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=["name", "value", "secret", "scope", "user", "ttl"], + json=args.json, + yaml=args.yaml, + ) diff --git a/st2client/st2client/commands/pack.py b/st2client/st2client/commands/pack.py index 827db663df..8d05fc88dc 100644 --- a/st2client/st2client/commands/pack.py +++ b/st2client/st2client/commands/pack.py @@ -34,43 +34,56 @@ from st2client.utils import interactive -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" LIVEACTION_COMPLETED_STATES = [ LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] class PackBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(PackBranch, self).__init__( - Pack, description, app, subparsers, + Pack, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': PackListCommand, - 'get': PackGetCommand - }) - - self.commands['show'] = PackShowCommand(self.resource, self.app, self.subparsers) - self.commands['search'] = PackSearchCommand(self.resource, self.app, self.subparsers) - self.commands['install'] = PackInstallCommand(self.resource, self.app, self.subparsers) - self.commands['remove'] = PackRemoveCommand(self.resource, self.app, self.subparsers) - self.commands['register'] = PackRegisterCommand(self.resource, self.app, self.subparsers) - self.commands['config'] = PackConfigCommand(self.resource, self.app, self.subparsers) + commands={"list": PackListCommand, "get": PackGetCommand}, + ) + + self.commands["show"] = PackShowCommand( + self.resource, self.app, self.subparsers + ) + self.commands["search"] = PackSearchCommand( + self.resource, self.app, self.subparsers + ) + self.commands["install"] = PackInstallCommand( + self.resource, self.app, self.subparsers + ) + self.commands["remove"] = PackRemoveCommand( + self.resource, self.app, self.subparsers + ) + self.commands["register"] = PackRegisterCommand( + self.resource, self.app, self.subparsers + ) + self.commands["config"] = PackConfigCommand( + self.resource, self.app, self.subparsers + ) class PackResourceCommand(resource.ResourceCommand): @@ -79,13 +92,18 @@ def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: raise resource.ResourceNotFoundError("No matching items found") - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except resource.ResourceNotFoundError: print("No matching items found") except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) @@ -93,48 +111,72 @@ class PackAsyncCommand(ActionRunCommandMixin, resource.ResourceCommand): def __init__(self, *args, **kwargs): super(PackAsyncCommand, self).__init__(*args, **kwargs) - self.parser.add_argument('-w', '--width', nargs='+', type=int, default=None, - help='Set the width of columns in output.') + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help="Set the width of columns in output.", + ) detail_arg_grp = self.parser.add_mutually_exclusive_group() - detail_arg_grp.add_argument('--attr', nargs='+', - default=['ref', 'name', 'description', 'version', 'author'], - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) - detail_arg_grp.add_argument('-d', '--detail', action='store_true', - help='Display full detail of the execution in table format.') + detail_arg_grp.add_argument( + "--attr", + nargs="+", + default=["ref", "name", "description", "version", "author"], + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) + detail_arg_grp.add_argument( + "-d", + "--detail", + action="store_true", + help="Display full detail of the execution in table format.", + ) @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') + raise Exception("Server did not create instance.") parent_id = instance.execution_id - stream_mgr = self.app.client.managers['Stream'] + stream_mgr = self.app.client.managers["Stream"] execution = None with term.TaskIndicator() as indicator: - events = ['st2.execution__create', 'st2.execution__update'] - for event in stream_mgr.listen(events, end_execution_id=parent_id, - end_event="st2.execution__update", **kwargs): + events = ["st2.execution__create", "st2.execution__update"] + for event in stream_mgr.listen( + events, + end_execution_id=parent_id, + end_event="st2.execution__update", + **kwargs, + ): execution = Execution(**event) - if execution.id == parent_id \ - and execution.status in LIVEACTION_COMPLETED_STATES: + if ( + execution.id == parent_id + and execution.status in LIVEACTION_COMPLETED_STATES + ): break # Suppress intermediate output in case output formatter is requested if args.json or args.yaml: continue - if getattr(execution, 'parent', None) == parent_id: + if getattr(execution, "parent", None) == parent_id: status = execution.status - name = execution.context['orquesta']['task_name'] \ - if 'orquesta' in execution.context else execution.context['chain']['name'] + name = ( + execution.context["orquesta"]["task_name"] + if "orquesta" in execution.context + else execution.context["chain"]["name"] + ) if status == LIVEACTION_STATUS_SCHEDULED: indicator.add_stage(status, name) @@ -148,31 +190,48 @@ def run_and_print(self, args, **kwargs): self._print_execution_details(execution=execution, args=args, **kwargs) sys.exit(1) - return self.app.client.managers['Execution'].get_by_id(parent_id, **kwargs) + return self.app.client.managers["Execution"].get_by_id(parent_id, **kwargs) class PackListCommand(resource.ResourceListCommand): - display_attributes = ['ref', 'name', 'description', 'version', 'author'] - attribute_display_order = ['ref', 'name', 'description', 'version', 'author'] + display_attributes = ["ref", "name", "description", "version", "author"] + attribute_display_order = ["ref", "name", "description", "version", "author"] class PackGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'ref' - display_attributes = ['name', 'version', 'author', 'email', 'keywords', 'description'] - attribute_display_order = ['name', 'version', 'author', 'email', 'keywords', 'description'] - help_string = 'Get information about an installed pack.' + pk_argument_name = "ref" + display_attributes = [ + "name", + "version", + "author", + "email", + "keywords", + "description", + ] + attribute_display_order = [ + "name", + "version", + "author", + "email", + "keywords", + "description", + ] + help_string = "Get information about an installed pack." class PackShowCommand(PackResourceCommand): def __init__(self, resource, *args, **kwargs): - help_string = ('Get information about an available %s from the index.' % - resource.get_display_name().lower()) - super(PackShowCommand, self).__init__(resource, 'show', help_string, - *args, **kwargs) - - self.parser.add_argument('pack', - help='Name of the %s to show.' % - resource.get_display_name().lower()) + help_string = ( + "Get information about an available %s from the index." + % resource.get_display_name().lower() + ) + super(PackShowCommand, self).__init__( + resource, "show", help_string, *args, **kwargs + ) + + self.parser.add_argument( + "pack", help="Name of the %s to show." % resource.get_display_name().lower() + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -181,27 +240,39 @@ def run(self, args, **kwargs): class PackInstallCommand(PackAsyncCommand): def __init__(self, resource, *args, **kwargs): - super(PackInstallCommand, self).__init__(resource, 'install', 'Install new %s.' - % resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='+', - metavar='pack', - help='Name of the %s in Exchange, or a git repo URL.' % - resource.get_plural_display_name().lower()) - self.parser.add_argument('--python3', - action='store_true', - default=False, - help='Use Python 3 binary for pack virtual environment.') - self.parser.add_argument('--force', - action='store_true', - default=False, - help='Force pack installation.') - self.parser.add_argument('--skip-dependencies', - action='store_true', - default=False, - help='Skip pack dependency installation.') + super(PackInstallCommand, self).__init__( + resource, + "install", + "Install new %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="+", + metavar="pack", + help="Name of the %s in Exchange, or a git repo URL." + % resource.get_plural_display_name().lower(), + ) + self.parser.add_argument( + "--python3", + action="store_true", + default=False, + help="Use Python 3 binary for pack virtual environment.", + ) + self.parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force pack installation.", + ) + self.parser.add_argument( + "--skip-dependencies", + action="store_true", + default=False, + help="Skip pack dependency installation.", + ) def run(self, args, **kwargs): is_structured_output = args.json or args.yaml @@ -212,30 +283,42 @@ def run(self, args, **kwargs): self._get_content_counts_for_pack(args, **kwargs) if args.python3: - warnings.warn('DEPRECATION WARNING: --python3 flag is ignored and will be removed ' - 'in v3.5.0 as StackStorm now runs with python3 only') - - return self.manager.install(args.packs, force=args.force, - skip_dependencies=args.skip_dependencies, **kwargs) + warnings.warn( + "DEPRECATION WARNING: --python3 flag is ignored and will be removed " + "in v3.5.0 as StackStorm now runs with python3 only" + ) + + return self.manager.install( + args.packs, + force=args.force, + skip_dependencies=args.skip_dependencies, + **kwargs, + ) def _get_content_counts_for_pack(self, args, **kwargs): # Global content list, excluding "tests" # Note: We skip this step for local packs - pack_content = {'actions': 0, 'rules': 0, 'sensors': 0, 'aliases': 0, 'triggers': 0} + pack_content = { + "actions": 0, + "rules": 0, + "sensors": 0, + "aliases": 0, + "triggers": 0, + } if len(args.packs) == 1: args.pack = args.packs[0] - if args.pack.startswith('file://'): + if args.pack.startswith("file://"): return pack_info = self.manager.search(args, ignore_errors=True, **kwargs) - content = getattr(pack_info, 'content', {}) + content = getattr(pack_info, "content", {}) if content: for entity in content.keys(): if entity in pack_content: - pack_content[entity] += content[entity]['count'] + pack_content[entity] += content[entity]["count"] self._print_pack_content(args.packs, pack_content) else: @@ -246,122 +329,165 @@ def _get_content_counts_for_pack(self, args, **kwargs): # args.pack required for search args.pack = pack - if args.pack.startswith('file://'): + if args.pack.startswith("file://"): return pack_info = self.manager.search(args, ignore_errors=True, **kwargs) - content = getattr(pack_info, 'content', {}) + content = getattr(pack_info, "content", {}) if content: for entity in content.keys(): if entity in pack_content: - pack_content[entity] += content[entity]['count'] + pack_content[entity] += content[entity]["count"] if content: self._print_pack_content(args.packs, pack_content) @staticmethod def _print_pack_content(pack_name, pack_content): - print('\nFor the "%s" %s, the following content will be registered:\n' - % (', '.join(pack_name), 'pack' if len(pack_name) == 1 else 'packs')) + print( + '\nFor the "%s" %s, the following content will be registered:\n' + % (", ".join(pack_name), "pack" if len(pack_name) == 1 else "packs") + ) for item, count in pack_content.items(): - print('%-10s| %s' % (item, count)) - print('\nInstallation may take a while for packs with many items.') + print("%-10s| %s" % (item, count)) + print("\nInstallation may take a while for packs with many items.") @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): instance = super(PackInstallCommand, self).run_and_print(args, **kwargs) # Hack to get a list of resolved references of installed packs - packs = instance.result['output']['packs_list'] + packs = instance.result["output"]["packs_list"] if len(packs) == 1: - pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs) - self.print_output(pack_instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id( + packs[0], **kwargs + ) + self.print_output( + pack_instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) else: - all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs) pack_instances = [] for pack in all_pack_instances: if pack.name in packs or pack.ref in packs: pack_instances.append(pack) - self.print_output(pack_instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + pack_instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) - warnings = instance.result['output']['warning_list'] + warnings = instance.result["output"]["warning_list"] for warning in warnings: print(warning) class PackRemoveCommand(PackAsyncCommand): def __init__(self, resource, *args, **kwargs): - super(PackRemoveCommand, self).__init__(resource, 'remove', 'Remove %s.' - % resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='+', - metavar='pack', - help='Name of the %s to remove.' % - resource.get_plural_display_name().lower()) + super(PackRemoveCommand, self).__init__( + resource, + "remove", + "Remove %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="+", + metavar="pack", + help="Name of the %s to remove." + % resource.get_plural_display_name().lower(), + ) def run(self, args, **kwargs): return self.manager.remove(args.packs, **kwargs) @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): - all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs) super(PackRemoveCommand, self).run_and_print(args, **kwargs) packs = args.packs if len(packs) == 1: - pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs) + pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id( + packs[0], **kwargs + ) if pack_instance: - raise OperationFailureException('Pack %s has not been removed properly' % packs[0]) - - removed_pack_instance = next((pack for pack in all_pack_instances - if pack.name == packs[0]), None) - - self.print_output(removed_pack_instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + raise OperationFailureException( + "Pack %s has not been removed properly" % packs[0] + ) + + removed_pack_instance = next( + (pack for pack in all_pack_instances if pack.name == packs[0]), None + ) + + self.print_output( + removed_pack_instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) else: - remaining_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + remaining_pack_instances = self.app.client.managers["Pack"].get_all( + **kwargs + ) pack_instances = [] for pack in all_pack_instances: if pack.name in packs or pack.ref in packs: pack_instances.append(pack) if pack in remaining_pack_instances: - raise OperationFailureException('Pack %s has not been removed properly' - % pack.name) + raise OperationFailureException( + "Pack %s has not been removed properly" % pack.name + ) - self.print_output(pack_instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + pack_instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class PackRegisterCommand(PackResourceCommand): def __init__(self, resource, *args, **kwargs): - super(PackRegisterCommand, self).__init__(resource, 'register', - 'Register %s(s): sync all file changes with DB.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='*', - metavar='pack', - help='Name of the %s(s) to register.' % - resource.get_display_name().lower()) - - self.parser.add_argument('--types', - nargs='+', - help='Types of content to register.') + super(PackRegisterCommand, self).__init__( + resource, + "register", + "Register %s(s): sync all file changes with DB." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="*", + metavar="pack", + help="Name of the %s(s) to register." % resource.get_display_name().lower(), + ) + + self.parser.add_argument( + "--types", nargs="+", help="Types of content to register." + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -369,18 +495,21 @@ def run(self, args, **kwargs): class PackSearchCommand(resource.ResourceTableCommand): - display_attributes = ['name', 'description', 'version', 'author'] - attribute_display_order = ['name', 'description', 'version', 'author'] + display_attributes = ["name", "description", "version", "author"] + attribute_display_order = ["name", "description", "version", "author"] def __init__(self, resource, *args, **kwargs): - super(PackSearchCommand, self).__init__(resource, 'search', - 'Search the index for a %s with any attribute \ - matching the query.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('query', - help='Search query.') + super(PackSearchCommand, self).__init__( + resource, + "search", + "Search the index for a %s with any attribute \ + matching the query." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument("query", help="Search query.") @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -389,31 +518,41 @@ def run(self, args, **kwargs): class PackConfigCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): - super(PackConfigCommand, self).__init__(resource, 'config', - 'Configure a %s based on config schema.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('name', - help='Name of the %s(s) to configure.' % - resource.get_display_name().lower()) + super(PackConfigCommand, self).__init__( + resource, + "config", + "Configure a %s based on config schema." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "name", + help="Name of the %s(s) to configure." + % resource.get_display_name().lower(), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - schema = self.app.client.managers['ConfigSchema'].get_by_ref_or_id(args.name, **kwargs) + schema = self.app.client.managers["ConfigSchema"].get_by_ref_or_id( + args.name, **kwargs + ) if not schema: - msg = '%s "%s" doesn\'t exist or doesn\'t have a config schema defined.' - raise resource.ResourceNotFoundError(msg % (self.resource.get_display_name(), - args.name)) + msg = "%s \"%s\" doesn't exist or doesn't have a config schema defined." + raise resource.ResourceNotFoundError( + msg % (self.resource.get_display_name(), args.name) + ) config = interactive.InteractiveForm(schema.attributes).initiate_dialog() - message = '---\nDo you want to preview the config in an editor before saving?' - description = 'Secrets will be shown in plain text.' - preview_dialog = interactive.Question(message, {'default': 'y', - 'description': description}) - if preview_dialog.read() == 'y': + message = "---\nDo you want to preview the config in an editor before saving?" + description = "Secrets will be shown in plain text." + preview_dialog = interactive.Question( + message, {"default": "y", "description": description} + ) + if preview_dialog.read() == "y": try: contents = yaml.safe_dump(config, indent=4, default_flow_style=False) modified = editor.edit(contents=contents) @@ -421,13 +560,13 @@ def run(self, args, **kwargs): except editor.EditorError as e: print(six.text_type(e)) - message = '---\nDo you want me to save it?' - save_dialog = interactive.Question(message, {'default': 'y'}) - if save_dialog.read() == 'n': - raise OperationFailureException('Interrupted') + message = "---\nDo you want me to save it?" + save_dialog = interactive.Question(message, {"default": "y"}) + if save_dialog.read() == "n": + raise OperationFailureException("Interrupted") config_item = Config(pack=args.name, values=config) - result = self.app.client.managers['Config'].update(config_item, **kwargs) + result = self.app.client.managers["Config"].update(config_item, **kwargs) return result @@ -436,14 +575,19 @@ def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: raise Exception("Configuration failed") - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except (KeyboardInterrupt, SystemExit): - raise OperationFailureException('Interrupted') + raise OperationFailureException("Interrupted") except Exception as e: if self.app.client.debug: raise message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) diff --git a/st2client/st2client/commands/policy.py b/st2client/st2client/commands/policy.py index de6c8ba997..cd891bc3a8 100644 --- a/st2client/st2client/commands/policy.py +++ b/st2client/st2client/commands/policy.py @@ -25,31 +25,36 @@ class PolicyTypeBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(PolicyTypeBranch, self).__init__( - models.PolicyType, description, app, subparsers, + models.PolicyType, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': PolicyTypeListCommand, - 'get': PolicyTypeGetCommand - }) + commands={"list": PolicyTypeListCommand, "get": PolicyTypeGetCommand}, + ) class PolicyTypeListCommand(resource.ResourceListCommand): - display_attributes = ['id', 'resource_type', 'name', 'description'] + display_attributes = ["id", "resource_type", "name", "description"] def __init__(self, resource, *args, **kwargs): super(PolicyTypeListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-r', '--resource-type', type=str, dest='resource_type', - help='Return policy types for the resource type.') + self.parser.add_argument( + "-r", + "--resource-type", + type=str, + dest="resource_type", + help="Return policy types for the resource type.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if args.resource_type: - filters = {'resource_type': args.resource_type} + filters = {"resource_type": args.resource_type} filters.update(**kwargs) instances = self.manager.query(**filters) return instances @@ -58,36 +63,49 @@ def run(self, args, **kwargs): class PolicyTypeGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def get_resource(self, ref_or_id, **kwargs): return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) class PolicyBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(PolicyBranch, self).__init__( - models.Policy, description, app, subparsers, + models.Policy, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': PolicyListCommand, - 'get': PolicyGetCommand, - 'update': PolicyUpdateCommand, - 'delete': PolicyDeleteCommand - }) + "list": PolicyListCommand, + "get": PolicyGetCommand, + "update": PolicyUpdateCommand, + "delete": PolicyDeleteCommand, + }, + ) class PolicyListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'resource_ref', 'policy_type', 'enabled'] + display_attributes = ["ref", "resource_ref", "policy_type", "enabled"] def __init__(self, resource, *args, **kwargs): super(PolicyListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-r', '--resource-ref', type=str, dest='resource_ref', - help='Return policies for the resource ref.') - self.parser.add_argument('-pt', '--policy-type', type=str, dest='policy_type', - help='Return policies of the policy type.') + self.parser.add_argument( + "-r", + "--resource-ref", + type=str, + dest="resource_ref", + help="Return policies for the resource ref.", + ) + self.parser.add_argument( + "-pt", + "--policy-type", + type=str, + dest="policy_type", + help="Return policies of the policy type.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -95,10 +113,10 @@ def run(self, args, **kwargs): filters = {} if args.resource_ref: - filters['resource_ref'] = args.resource_ref + filters["resource_ref"] = args.resource_ref if args.policy_type: - filters['policy_type'] = args.policy_type + filters["policy_type"] = args.policy_type filters.update(**kwargs) @@ -108,10 +126,18 @@ def run(self, args, **kwargs): class PolicyGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'resource_ref', 'policy_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "resource_ref", + "policy_type", + "parameters", + ] class PolicyUpdateCommand(resource.ContentPackResourceUpdateCommand): diff --git a/st2client/st2client/commands/rbac.py b/st2client/st2client/commands/rbac.py index 0d7ea7f400..9a9e8c274b 100644 --- a/st2client/st2client/commands/rbac.py +++ b/st2client/st2client/commands/rbac.py @@ -20,58 +20,77 @@ from st2client.models.rbac import Role from st2client.models.rbac import UserRoleAssignment -__all__ = [ - 'RoleBranch', - 'RoleAssignmentBranch' +__all__ = ["RoleBranch", "RoleAssignmentBranch"] + +ROLE_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "system", "permission_grants"] +ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = [ + "id", + "role", + "user", + "is_remote", + "description", ] -ROLE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'system', 'permission_grants'] -ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'role', 'user', 'is_remote', 'description'] - class RoleBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(RoleBranch, self).__init__( - Role, description, app, subparsers, + Role, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': RoleListCommand, - 'get': RoleGetCommand - }) + commands={"list": RoleListCommand, "get": RoleGetCommand}, + ) class RoleListCommand(resource.ResourceCommand): - display_attributes = ['id', 'name', 'system', 'description'] + display_attributes = ["id", "name", "system", "description"] attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER def __init__(self, resource, *args, **kwargs): super(RoleListCommand, self).__init__( - resource, 'list', 'Get the list of the %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of the %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) self.group = self.parser.add_mutually_exclusive_group() # Filter options - self.group.add_argument('-s', '--system', action='store_true', - help='Only display system roles.') + self.group.add_argument( + "-s", "--system", action="store_true", help="Only display system roles." + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.system: - kwargs['system'] = args.system + kwargs["system"] = args.system if args.system: result = self.manager.query(**kwargs) @@ -82,67 +101,93 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class RoleGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER - pk_argument_name = 'id' + pk_argument_name = "id" class RoleAssignmentBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(RoleAssignmentBranch, self).__init__( - UserRoleAssignment, description, app, subparsers, + UserRoleAssignment, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, commands={ - 'list': RoleAssignmentListCommand, - 'get': RoleAssignmentGetCommand - }) + "list": RoleAssignmentListCommand, + "get": RoleAssignmentGetCommand, + }, + ) class RoleAssignmentListCommand(resource.ResourceCommand): - display_attributes = ['id', 'role', 'user', 'is_remote', 'source', 'description'] + display_attributes = ["id", "role", "user", "is_remote", "source", "description"] attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER def __init__(self, resource, *args, **kwargs): super(RoleAssignmentListCommand, self).__init__( - resource, 'list', 'Get the list of the %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of the %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) # Filter options - self.parser.add_argument('-r', '--role', help='Role to filter on.') - self.parser.add_argument('-u', '--user', help='User to filter on.') - self.parser.add_argument('-s', '--source', help='Source to filter on.') - self.parser.add_argument('--remote', action='store_true', - help='Only display remote role assignments.') + self.parser.add_argument("-r", "--role", help="Role to filter on.") + self.parser.add_argument("-u", "--user", help="User to filter on.") + self.parser.add_argument("-s", "--source", help="Source to filter on.") + self.parser.add_argument( + "--remote", + action="store_true", + help="Only display remote role assignments.", + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.role: - kwargs['role'] = args.role + kwargs["role"] = args.role if args.user: - kwargs['user'] = args.user + kwargs["user"] = args.user if args.source: - kwargs['source'] = args.source + kwargs["source"] = args.source if args.remote: - kwargs['remote'] = args.remote + kwargs["remote"] = args.remote if args.role or args.user or args.remote or args.source: result = self.manager.query(**kwargs) @@ -153,12 +198,17 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class RoleAssignmentGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER - pk_argument_name = 'id' + pk_argument_name = "id" diff --git a/st2client/st2client/commands/resource.py b/st2client/st2client/commands/resource.py index 15ca68bb09..da0fbc85e3 100644 --- a/st2client/st2client/commands/resource.py +++ b/st2client/st2client/commands/resource.py @@ -32,8 +32,8 @@ from st2client.formatters import table from st2client.utils.types import OrderedSet -ALLOWED_EXTS = ['.json', '.yaml', '.yml'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".json", ".yaml", ".yml"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} LOG = logging.getLogger(__name__) @@ -41,11 +41,12 @@ def add_auth_token_to_kwargs_from_cli(func): @wraps(func) def decorate(*args, **kwargs): ns = args[1] - if getattr(ns, 'token', None): - kwargs['token'] = ns.token - if getattr(ns, 'api_key', None): - kwargs['api_key'] = ns.api_key + if getattr(ns, "token", None): + kwargs["token"] = ns.token + if getattr(ns, "api_key", None): + kwargs["api_key"] = ns.api_key return func(*args, **kwargs) + return decorate @@ -58,20 +59,34 @@ class ResourceNotFoundError(Exception): class ResourceBranch(commands.Branch): - - def __init__(self, resource, description, app, subparsers, - parent_parser=None, read_only=False, commands=None, - has_disable=False): + def __init__( + self, + resource, + description, + app, + subparsers, + parent_parser=None, + read_only=False, + commands=None, + has_disable=False, + ): self.resource = resource super(ResourceBranch, self).__init__( - self.resource.get_alias().lower(), description, - app, subparsers, parent_parser=parent_parser) + self.resource.get_alias().lower(), + description, + app, + subparsers, + parent_parser=parent_parser, + ) # Registers subcommands for managing the resource type. self.subparsers = self.parser.add_subparsers( - help=('List of commands for managing %s.' % - self.resource.get_plural_display_name().lower())) + help=( + "List of commands for managing %s." + % self.resource.get_plural_display_name().lower() + ) + ) # Resolves if commands need to be overridden. commands = commands or {} @@ -82,7 +97,7 @@ def __init__(self, resource, description, app, subparsers, "update": ResourceUpdateCommand, "delete": ResourceDeleteCommand, "enable": ResourceEnableCommand, - "disable": ResourceDisableCommand + "disable": ResourceDisableCommand, } for cmd, cmd_class in cmd_map.items(): if cmd not in commands: @@ -90,17 +105,17 @@ def __init__(self, resource, description, app, subparsers, # Instantiate commands. args = [self.resource, self.app, self.subparsers] - self.commands['list'] = commands['list'](*args) - self.commands['get'] = commands['get'](*args) + self.commands["list"] = commands["list"](*args) + self.commands["get"] = commands["get"](*args) if not read_only: - self.commands['create'] = commands['create'](*args) - self.commands['update'] = commands['update'](*args) - self.commands['delete'] = commands['delete'](*args) + self.commands["create"] = commands["create"](*args) + self.commands["update"] = commands["update"](*args) + self.commands["delete"] = commands["delete"](*args) if has_disable: - self.commands['enable'] = commands['enable'](*args) - self.commands['disable'] = commands['disable'](*args) + self.commands["enable"] = commands["enable"](*args) + self.commands["disable"] = commands["disable"](*args) @six.add_metaclass(abc.ABCMeta) @@ -109,29 +124,44 @@ class ResourceCommand(commands.Command): def __init__(self, resource, *args, **kwargs): - has_token_opt = kwargs.pop('has_token_opt', True) + has_token_opt = kwargs.pop("has_token_opt", True) super(ResourceCommand, self).__init__(*args, **kwargs) self.resource = resource if has_token_opt: - self.parser.add_argument('-t', '--token', dest='token', - help='Access token for user authentication. ' - 'Get ST2_AUTH_TOKEN from the environment ' - 'variables by default.') - self.parser.add_argument('--api-key', dest='api_key', - help='Api Key for user authentication. ' - 'Get ST2_API_KEY from the environment ' - 'variables by default.') + self.parser.add_argument( + "-t", + "--token", + dest="token", + help="Access token for user authentication. " + "Get ST2_AUTH_TOKEN from the environment " + "variables by default.", + ) + self.parser.add_argument( + "--api-key", + dest="api_key", + help="Api Key for user authentication. " + "Get ST2_API_KEY from the environment " + "variables by default.", + ) # Formatter flags - self.parser.add_argument('-j', '--json', - action='store_true', dest='json', - help='Print output in JSON format.') - self.parser.add_argument('-y', '--yaml', - action='store_true', dest='yaml', - help='Print output in YAML format.') + self.parser.add_argument( + "-j", + "--json", + action="store_true", + dest="json", + help="Print output in JSON format.", + ) + self.parser.add_argument( + "-y", + "--yaml", + action="store_true", + dest="yaml", + help="Print output in YAML format.", + ) @property def manager(self): @@ -140,18 +170,17 @@ def manager(self): @property def arg_name_for_resource_id(self): resource_name = self.resource.get_display_name().lower() - return '%s-id' % resource_name.replace(' ', '-') + return "%s-id" % resource_name.replace(" ", "-") def print_not_found(self, name): - print('%s "%s" is not found.\n' % - (self.resource.get_display_name(), name)) + print('%s "%s" is not found.\n' % (self.resource.get_display_name(), name)) def get_resource(self, name_or_id, **kwargs): pk_argument_name = self.pk_argument_name - if pk_argument_name == 'name_or_id': + if pk_argument_name == "name_or_id": instance = self.get_resource_by_name_or_id(name_or_id=name_or_id, **kwargs) - elif pk_argument_name == 'ref_or_id': + elif pk_argument_name == "ref_or_id": instance = self.get_resource_by_ref_or_id(ref_or_id=name_or_id, **kwargs) else: instance = self.get_resource_by_pk(pk=name_or_id, **kwargs) @@ -167,8 +196,8 @@ def get_resource_by_pk(self, pk, **kwargs): except Exception as e: traceback.print_exc() # Hack for "Unauthorized" exceptions, we do want to propagate those - response = getattr(e, 'response', None) - status_code = getattr(response, 'status_code', None) + response = getattr(e, "response", None) + status_code = getattr(response, "status_code", None) if status_code and status_code == http_client.UNAUTHORIZED: raise e @@ -180,7 +209,7 @@ def get_resource_by_id(self, id, **kwargs): instance = self.get_resource_by_pk(pk=id, **kwargs) if not instance: - message = ('Resource with id "%s" doesn\'t exist.' % (id)) + message = 'Resource with id "%s" doesn\'t exist.' % (id) raise ResourceNotFoundError(message) return instance @@ -197,8 +226,7 @@ def get_resource_by_name_or_id(self, name_or_id, **kwargs): instance = self.get_resource_by_pk(pk=name_or_id, **kwargs) if not instance: - message = ('Resource with id or name "%s" doesn\'t exist.' % - (name_or_id)) + message = 'Resource with id or name "%s" doesn\'t exist.' % (name_or_id) raise ResourceNotFoundError(message) return instance @@ -206,8 +234,7 @@ def get_resource_by_ref_or_id(self, ref_or_id, **kwargs): instance = self.manager.get_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) if not instance: - message = ('Resource with id or reference "%s" doesn\'t exist.' % - (ref_or_id)) + message = 'Resource with id or reference "%s" doesn\'t exist.' % (ref_or_id) raise ResourceNotFoundError(message) return instance @@ -220,18 +247,18 @@ def run_and_print(self, args, **kwargs): raise NotImplementedError def _get_metavar_for_argument(self, argument): - return argument.replace('_', '-') + return argument.replace("_", "-") def _get_help_for_argument(self, resource, argument): argument_display_name = argument.title() resource_display_name = resource.get_display_name().lower() - if 'ref' in argument: - result = ('Reference or ID of the %s.' % (resource_display_name)) - elif 'name_or_id' in argument: - result = ('Name or ID of the %s.' % (resource_display_name)) + if "ref" in argument: + result = "Reference or ID of the %s." % (resource_display_name) + elif "name_or_id" in argument: + result = "Name or ID of the %s." % (resource_display_name) else: - result = ('%s of the %s.' % (argument_display_name, resource_display_name)) + result = "%s of the %s." % (argument_display_name, resource_display_name) return result @@ -263,7 +290,7 @@ def _get_include_attributes(cls, args, extra_attributes=None): # into account # Special case for "all" - if 'all' in args.attr: + if "all" in args.attr: return None for attr in args.attr: @@ -272,7 +299,7 @@ def _get_include_attributes(cls, args, extra_attributes=None): if include_attributes: return include_attributes - display_attributes = getattr(cls, 'display_attributes', []) + display_attributes = getattr(cls, "display_attributes", []) if display_attributes: include_attributes += display_attributes @@ -283,97 +310,129 @@ def _get_include_attributes(cls, args, extra_attributes=None): class ResourceTableCommand(ResourceViewCommand): - display_attributes = ['id', 'name', 'description'] + display_attributes = ["id", "name", "description"] def __init__(self, resource, name, description, *args, **kwargs): - super(ResourceTableCommand, self).__init__(resource, name, description, - *args, **kwargs) - - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + super(ResourceTableCommand, self).__init__( + resource, name, description, *args, **kwargs + ) + + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.get_all(**kwargs) def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ResourceListCommand(ResourceTableCommand): def __init__(self, resource, *args, **kwargs): super(ResourceListCommand, self).__init__( - resource, 'list', 'Get the list of %s.' % resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) class ContentPackResourceListCommand(ResourceListCommand): """ Base command class for use with resources which belong to a content pack. """ + def __init__(self, resource, *args, **kwargs): - super(ContentPackResourceListCommand, self).__init__(resource, - *args, **kwargs) + super(ContentPackResourceListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-p', '--pack', type=str, - help=('Only return resources belonging to the' - ' provided pack')) + self.parser.add_argument( + "-p", + "--pack", + type=str, + help=("Only return resources belonging to the" " provided pack"), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - filters = {'pack': args.pack} + filters = {"pack": args.pack} filters.update(**kwargs) include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - filters['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + filters["params"] = {"include_attributes": include_attributes} return self.manager.get_all(**filters) class ResourceGetCommand(ResourceViewCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'name', 'description'] + display_attributes = ["all"] + attribute_display_order = ["id", "name", "description"] - pk_argument_name = 'name_or_id' # name of the attribute which stores resource PK + pk_argument_name = "name_or_id" # name of the attribute which stores resource PK help_string = None def __init__(self, resource, *args, **kwargs): super(ResourceGetCommand, self).__init__( - resource, 'get', - self.help_string or 'Get individual %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "get", + self.help_string + or "Get individual %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) - - self.parser.add_argument(argument, - metavar=metavar, - help=help) - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) + + self.parser.add_argument(argument, metavar=metavar, help=help) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -383,13 +442,18 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + self.print_output( + instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) except ResourceNotFoundError: resource_id = getattr(args, self.pk_argument_name, None) self.print_not_found(resource_id) - raise OperationFailureException('Resource %s not found.' % resource_id) + raise OperationFailureException("Resource %s not found." % resource_id) class ContentPackResourceGetCommand(ResourceGetCommand): @@ -400,24 +464,31 @@ class ContentPackResourceGetCommand(ResourceGetCommand): retrieved by a reference or by an id. """ - attribute_display_order = ['id', 'pack', 'name', 'description'] + attribute_display_order = ["id", "pack", "name", "description"] - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def get_resource(self, ref_or_id, **kwargs): return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) class ResourceCreateCommand(ResourceCommand): - def __init__(self, resource, *args, **kwargs): - super(ResourceCreateCommand, self).__init__(resource, 'create', - 'Create a new %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceCreateCommand, self).__init__( + resource, + "create", + "Create a new %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s to create.' - % resource.get_display_name().lower())) + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s to create." + % resource.get_display_name().lower() + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -429,34 +500,46 @@ def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + raise Exception("Server did not create instance.") + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) class ResourceUpdateCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceUpdateCommand, self).__init__(resource, 'update', - 'Updating an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceUpdateCommand, self).__init__( + resource, + "update", + "Updating an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s to update.' - % resource.get_display_name().lower())) + self.parser.add_argument(argument, metavar=metavar, help=help) + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s to update." + % resource.get_display_name().lower() + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -465,46 +548,55 @@ def run(self, args, **kwargs): data = load_meta_file(args.file) modified_instance = self.resource.deserialize(data) - if not getattr(modified_instance, 'id', None): + if not getattr(modified_instance, "id", None): modified_instance.id = instance.id else: if modified_instance.id != instance.id: - raise Exception('The value for the %s id in the JSON/YAML file ' - 'does not match the ID provided in the ' - 'command line arguments.' % - self.resource.get_display_name().lower()) + raise Exception( + "The value for the %s id in the JSON/YAML file " + "does not match the ID provided in the " + "command line arguments." % self.resource.get_display_name().lower() + ) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) try: - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except Exception as e: - print('ERROR: %s' % (six.text_type(e))) + print("ERROR: %s" % (six.text_type(e))) raise OperationFailureException(six.text_type(e)) class ContentPackResourceUpdateCommand(ResourceUpdateCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceEnableCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceEnableCommand, self).__init__(resource, 'enable', - 'Enable an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceEnableCommand, self).__init__( + resource, + "enable", + "Enable an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -513,40 +605,48 @@ def run(self, args, **kwargs): data = instance.serialize() - if 'ref' in data: - del data['ref'] + if "ref" in data: + del data["ref"] - data['enabled'] = True + data["enabled"] = True modified_instance = self.resource.deserialize(data) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ContentPackResourceEnableCommand(ResourceEnableCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceDisableCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceDisableCommand, self).__init__(resource, 'disable', - 'Disable an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceDisableCommand, self).__init__( + resource, + "disable", + "Disable an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -555,40 +655,48 @@ def run(self, args, **kwargs): data = instance.serialize() - if 'ref' in data: - del data['ref'] + if "ref" in data: + del data["ref"] - data['enabled'] = False + data["enabled"] = False modified_instance = self.resource.deserialize(data) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ContentPackResourceDisableCommand(ResourceDisableCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceDeleteCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceDeleteCommand, self).__init__(resource, 'delete', - 'Delete an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceDeleteCommand, self).__init__( + resource, + "delete", + "Delete an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -601,10 +709,12 @@ def run_and_print(self, args, **kwargs): try: self.run(args, **kwargs) - print('Resource with id "%s" has been successfully deleted.' % (resource_id)) + print( + 'Resource with id "%s" has been successfully deleted.' % (resource_id) + ) except ResourceNotFoundError: self.print_not_found(resource_id) - raise OperationFailureException('Resource %s not found.' % resource_id) + raise OperationFailureException("Resource %s not found." % resource_id) class ContentPackResourceDeleteCommand(ResourceDeleteCommand): @@ -612,7 +722,7 @@ class ContentPackResourceDeleteCommand(ResourceDeleteCommand): Base command class for deleting a resource which belongs to a content pack. """ - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def load_meta_file(file_path): @@ -621,8 +731,10 @@ def load_meta_file(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return PARSER_FUNCS[file_ext](f) diff --git a/st2client/st2client/commands/rule.py b/st2client/st2client/commands/rule.py index 7f0f5e58db..cbab939e10 100644 --- a/st2client/st2client/commands/rule.py +++ b/st2client/st2client/commands/rule.py @@ -21,99 +21,143 @@ class RuleBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(RuleBranch, self).__init__( - models.Rule, description, app, subparsers, + models.Rule, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': RuleListCommand, - 'get': RuleGetCommand, - 'update': RuleUpdateCommand, - 'delete': RuleDeleteCommand - }) + "list": RuleListCommand, + "get": RuleGetCommand, + "update": RuleUpdateCommand, + "delete": RuleDeleteCommand, + }, + ) - self.commands['enable'] = RuleEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = RuleDisableCommand(self.resource, self.app, self.subparsers) + self.commands["enable"] = RuleEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = RuleDisableCommand( + self.resource, self.app, self.subparsers + ) class RuleListCommand(resource.ResourceTableCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] - display_attributes_iftt = ['ref', 'trigger.ref', 'action.ref', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] + display_attributes_iftt = ["ref", "trigger.ref", "action.ref", "enabled"] def __init__(self, resource, *args, **kwargs): self.default_limit = 50 - super(RuleListCommand, self).__init__(resource, 'list', - 'Get the list of the %s most recent %s.' % - (self.default_limit, - resource.get_plural_display_name().lower()), - *args, **kwargs) + super(RuleListCommand, self).__init__( + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('--iftt', action='store_true', - help='Show trigger and action in display list.') - self.parser.add_argument('-p', '--pack', type=str, - help=('Only return resources belonging to the' - ' provided pack')) - self.group.add_argument('-c', '--action', - help='Action reference to filter the list.') - self.group.add_argument('-g', '--trigger', - help='Trigger type reference to filter the list.') + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "--iftt", + action="store_true", + help="Show trigger and action in display list.", + ) + self.parser.add_argument( + "-p", + "--pack", + type=str, + help=("Only return resources belonging to the" " provided pack"), + ) + self.group.add_argument( + "-c", "--action", help="Action reference to filter the list." + ) + self.group.add_argument( + "-g", "--trigger", help="Trigger type reference to filter the list." + ) self.enabled_filter_group = self.parser.add_mutually_exclusive_group() - self.enabled_filter_group.add_argument('--enabled', action='store_true', - help='Show rules that are enabled.') - self.enabled_filter_group.add_argument('--disabled', action='store_true', - help='Show rules that are disabled.') + self.enabled_filter_group.add_argument( + "--enabled", action="store_true", help="Show rules that are enabled." + ) + self.enabled_filter_group.add_argument( + "--disabled", action="store_true", help="Show rules that are disabled." + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.pack: - kwargs['pack'] = args.pack + kwargs["pack"] = args.pack if args.action: - kwargs['action'] = args.action + kwargs["action"] = args.action if args.trigger: - kwargs['trigger'] = args.trigger + kwargs["trigger"] = args.trigger if args.enabled: - kwargs['enabled'] = True + kwargs["enabled"] = True if args.disabled: - kwargs['enabled'] = False + kwargs["enabled"] = False if args.iftt: # switch attr to display the trigger and action args.attr = self.display_attributes_iftt include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class RuleGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "description", + "enabled", + ] class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -121,15 +165,29 @@ class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand): class RuleEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "description", + "enabled", + ] class RuleDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "description", + "enabled", + ] class RuleDeleteCommand(resource.ContentPackResourceDeleteCommand): diff --git a/st2client/st2client/commands/rule_enforcement.py b/st2client/st2client/commands/rule_enforcement.py index ecebba2b07..dd624d4a72 100644 --- a/st2client/st2client/commands/rule_enforcement.py +++ b/st2client/st2client/commands/rule_enforcement.py @@ -22,24 +22,39 @@ class RuleEnforcementBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(RuleEnforcementBranch, self).__init__( - models.RuleEnforcement, description, app, subparsers, + models.RuleEnforcement, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': RuleEnforcementListCommand, - 'get': RuleEnforcementGetCommand, - }) + "list": RuleEnforcementListCommand, + "get": RuleEnforcementGetCommand, + }, + ) class RuleEnforcementGetCommand(resource.ResourceGetCommand): - display_attributes = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'failure_reason', 'enforced_at'] - attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'failure_reason', 'enforced_at'] - - pk_argument_name = 'id' + display_attributes = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "failure_reason", + "enforced_at", + ] + attribute_display_order = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "failure_reason", + "enforced_at", + ] + + pk_argument_name = "id" @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -48,84 +63,137 @@ def run(self, args, **kwargs): class RuleEnforcementListCommand(resource.ResourceCommand): - display_attributes = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'enforced_at'] - attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'enforced_at'] - - attribute_transform_functions = { - 'enforced_at': format_isodate_for_user_timezone - } + display_attributes = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "enforced_at", + ] + attribute_display_order = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "enforced_at", + ] + + attribute_transform_functions = {"enforced_at": format_isodate_for_user_timezone} def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(RuleEnforcementListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Filter options - self.group.add_argument('--trigger-instance', - help='Trigger instance id to filter the list.') - - self.group.add_argument('--execution', - help='Execution id to filter the list.') - self.group.add_argument('--rule', - help='Rule ref to filter the list.') - - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return enforcements with enforced_at ' - 'greater than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return enforcements with enforced_at ' - 'lower than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) + self.group.add_argument( + "--trigger-instance", help="Trigger instance id to filter the list." + ) + + self.group.add_argument("--execution", help="Execution id to filter the list.") + self.group.add_argument("--rule", help="Rule ref to filter the list.") + + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return enforcements with enforced_at " + "greater than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return enforcements with enforced_at " + "lower than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if args.execution: - kwargs['execution'] = args.execution + kwargs["execution"] = args.execution if args.rule: - kwargs['rule_ref'] = args.rule + kwargs["rule_ref"] = args.rule if args.timestamp_gt: - kwargs['enforced_at_gt'] = args.timestamp_gt + kwargs["enforced_at_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['enforced_at_lt'] = args.timestamp_lt + kwargs["enforced_at_lt"] = args.timestamp_lt return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) diff --git a/st2client/st2client/commands/sensor.py b/st2client/st2client/commands/sensor.py index 0d729c8c02..ca4cc33563 100644 --- a/st2client/st2client/commands/sensor.py +++ b/st2client/st2client/commands/sensor.py @@ -22,35 +22,67 @@ class SensorBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(SensorBranch, self).__init__( - Sensor, description, app, subparsers, + Sensor, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': SensorListCommand, - 'get': SensorGetCommand - }) + commands={"list": SensorListCommand, "get": SensorGetCommand}, + ) - self.commands['enable'] = SensorEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = SensorDisableCommand(self.resource, self.app, self.subparsers) + self.commands["enable"] = SensorEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = SensorDisableCommand( + self.resource, self.app, self.subparsers + ) class SensorListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] class SensorGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'enabled', 'entry_point', - 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "enabled", + "entry_point", + "artifact_uri", + "trigger_types", + ] class SensorEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval', - 'entry_point', 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "poll_interval", + "entry_point", + "artifact_uri", + "trigger_types", + ] class SensorDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval', - 'entry_point', 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "poll_interval", + "entry_point", + "artifact_uri", + "trigger_types", + ] diff --git a/st2client/st2client/commands/service_registry.py b/st2client/st2client/commands/service_registry.py index b609e051a9..6b9bff60b9 100644 --- a/st2client/st2client/commands/service_registry.py +++ b/st2client/st2client/commands/service_registry.py @@ -25,76 +25,87 @@ class ServiceRegistryBranch(commands.Branch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryBranch, self).__init__( - 'service-registry', description, - app, subparsers, parent_parser=parent_parser) + "service-registry", + description, + app, + subparsers, + parent_parser=parent_parser, + ) self.subparsers = self.parser.add_subparsers( - help=('List of commands for managing service registry.')) + help=("List of commands for managing service registry.") + ) # Instantiate commands - args_groups = ['Manage service registry groups', self.app, self.subparsers] - args_members = ['Manage service registry members', self.app, self.subparsers] + args_groups = ["Manage service registry groups", self.app, self.subparsers] + args_members = ["Manage service registry members", self.app, self.subparsers] - self.commands['groups'] = ServiceRegistryGroupsBranch(*args_groups) - self.commands['members'] = ServiceRegistryMembersBranch(*args_members) + self.commands["groups"] = ServiceRegistryGroupsBranch(*args_groups) + self.commands["members"] = ServiceRegistryMembersBranch(*args_members) class ServiceRegistryGroupsBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryGroupsBranch, self).__init__( - ServiceRegistryGroup, description, app, subparsers, + ServiceRegistryGroup, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': ServiceRegistryListGroupsCommand, - 'get': NoopCommand - }) + commands={"list": ServiceRegistryListGroupsCommand, "get": NoopCommand}, + ) - del self.commands['get'] + del self.commands["get"] class ServiceRegistryMembersBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryMembersBranch, self).__init__( - ServiceRegistryMember, description, app, subparsers, + ServiceRegistryMember, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': ServiceRegistryListMembersCommand, - 'get': NoopCommand - }) + commands={"list": ServiceRegistryListMembersCommand, "get": NoopCommand}, + ) - del self.commands['get'] + del self.commands["get"] class ServiceRegistryListGroupsCommand(resource.ResourceListCommand): - display_attributes = ['group_id'] - attribute_display_order = ['group_id'] + display_attributes = ["group_id"] + attribute_display_order = ["group_id"] @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - manager = self.app.client.managers['ServiceRegistryGroups'] + manager = self.app.client.managers["ServiceRegistryGroups"] groups = manager.list() return groups class ServiceRegistryListMembersCommand(resource.ResourceListCommand): - display_attributes = ['group_id', 'member_id', 'capabilities'] - attribute_display_order = ['group_id', 'member_id', 'capabilities'] + display_attributes = ["group_id", "member_id", "capabilities"] + attribute_display_order = ["group_id", "member_id", "capabilities"] def __init__(self, resource, *args, **kwargs): super(ServiceRegistryListMembersCommand, self).__init__( resource, *args, **kwargs ) - self.parser.add_argument('--group-id', dest='group_id', default=None, - help='If provided only retrieve members for the specified group.') + self.parser.add_argument( + "--group-id", + dest="group_id", + default=None, + help="If provided only retrieve members for the specified group.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - groups_manager = self.app.client.managers['ServiceRegistryGroups'] - members_manager = self.app.client.managers['ServiceRegistryMembers'] + groups_manager = self.app.client.managers["ServiceRegistryGroups"] + members_manager = self.app.client.managers["ServiceRegistryMembers"] # If group ID is provided only retrieve members for that group, otherwise retrieve members # for all groups diff --git a/st2client/st2client/commands/timer.py b/st2client/st2client/commands/timer.py index e3fc9e223f..c183367291 100644 --- a/st2client/st2client/commands/timer.py +++ b/st2client/st2client/commands/timer.py @@ -22,30 +22,39 @@ class TimerBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TimerBranch, self).__init__( - Timer, description, app, subparsers, + Timer, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': TimerListCommand, - 'get': TimerGetCommand - }) + commands={"list": TimerListCommand, "get": TimerGetCommand}, + ) class TimerListCommand(resource.ResourceListCommand): - display_attributes = ['id', 'uid', 'pack', 'name', 'type', 'parameters'] + display_attributes = ["id", "uid", "pack", "name", "type", "parameters"] def __init__(self, resource, *args, **kwargs): super(TimerListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-ty', '--timer-type', type=str, dest='timer_type', - help=("List %s type, example: 'core.st2.IntervalTimer', \ - 'core.st2.DateTimer', 'core.st2.CronTimer'." % - resource.get_plural_display_name().lower()), required=False) + self.parser.add_argument( + "-ty", + "--timer-type", + type=str, + dest="timer_type", + help=( + "List %s type, example: 'core.st2.IntervalTimer', \ + 'core.st2.DateTimer', 'core.st2.CronTimer'." + % resource.get_plural_display_name().lower() + ), + required=False, + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if args.timer_type: - kwargs['timer_type'] = args.timer_type + kwargs["timer_type"] = args.timer_type if kwargs: return self.manager.query(**kwargs) @@ -54,5 +63,5 @@ def run(self, args, **kwargs): class TimerGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['type', 'pack', 'name', 'description', 'parameters'] + display_attributes = ["all"] + attribute_display_order = ["type", "pack", "name", "description", "parameters"] diff --git a/st2client/st2client/commands/trace.py b/st2client/st2client/commands/trace.py index b5e59c2cf1..ac8de676c2 100644 --- a/st2client/st2client/commands/trace.py +++ b/st2client/st2client/commands/trace.py @@ -23,53 +23,62 @@ from st2client.utils.date import format_isodate_for_user_timezone -TRACE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'trace_tag', 'action_executions', 'rules', - 'trigger_instances', 'start_timestamp'] +TRACE_ATTRIBUTE_DISPLAY_ORDER = [ + "id", + "trace_tag", + "action_executions", + "rules", + "trigger_instances", + "start_timestamp", +] -TRACE_HEADER_DISPLAY_ORDER = ['id', 'trace_tag', 'start_timestamp'] +TRACE_HEADER_DISPLAY_ORDER = ["id", "trace_tag", "start_timestamp"] -TRACE_COMPONENT_DISPLAY_LABELS = ['id', 'type', 'ref', 'updated_at'] +TRACE_COMPONENT_DISPLAY_LABELS = ["id", "type", "ref", "updated_at"] -TRACE_DISPLAY_ATTRIBUTES = ['all'] +TRACE_DISPLAY_ATTRIBUTES = ["all"] TRIGGER_INSTANCE_DISPLAY_OPTIONS = [ - 'all', - 'trigger-instances', - 'trigger_instances', - 'triggerinstances', - 'triggers' + "all", + "trigger-instances", + "trigger_instances", + "triggerinstances", + "triggers", ] ACTION_EXECUTION_DISPLAY_OPTIONS = [ - 'all', - 'executions', - 'action-executions', - 'action_executions', - 'actionexecutions', - 'actions' + "all", + "executions", + "action-executions", + "action_executions", + "actionexecutions", + "actions", ] class TraceBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TraceBranch, self).__init__( - Trace, description, app, subparsers, + Trace, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': TraceListCommand, - 'get': TraceGetCommand - }) + commands={"list": TraceListCommand, "get": TraceGetCommand}, + ) class SingleTraceDisplayMixin(object): - def print_trace_details(self, trace, args, **kwargs): - options = {'attributes': TRACE_ATTRIBUTE_DISPLAY_ORDER if args.json else - TRACE_HEADER_DISPLAY_ORDER} - options['json'] = args.json - options['yaml'] = args.yaml - options['attribute_transform_functions'] = self.attribute_transform_functions + options = { + "attributes": TRACE_ATTRIBUTE_DISPLAY_ORDER + if args.json + else TRACE_HEADER_DISPLAY_ORDER + } + options["json"] = args.json + options["yaml"] = args.yaml + options["attribute_transform_functions"] = self.attribute_transform_functions formatter = execution_formatter.ExecutionResult @@ -81,35 +90,63 @@ def print_trace_details(self, trace, args, **kwargs): components = [] if any(attr in args.attr for attr in TRIGGER_INSTANCE_DISPLAY_OPTIONS): - components.extend([Resource(**{'id': trigger_instance['object_id'], - 'type': TriggerInstance._alias.lower(), - 'ref': trigger_instance['ref'], - 'updated_at': trigger_instance['updated_at']}) - for trigger_instance in trace.trigger_instances]) - if any(attr in args.attr for attr in ['all', 'rules']): - components.extend([Resource(**{'id': rule['object_id'], - 'type': Rule._alias.lower(), - 'ref': rule['ref'], - 'updated_at': rule['updated_at']}) - for rule in trace.rules]) + components.extend( + [ + Resource( + **{ + "id": trigger_instance["object_id"], + "type": TriggerInstance._alias.lower(), + "ref": trigger_instance["ref"], + "updated_at": trigger_instance["updated_at"], + } + ) + for trigger_instance in trace.trigger_instances + ] + ) + if any(attr in args.attr for attr in ["all", "rules"]): + components.extend( + [ + Resource( + **{ + "id": rule["object_id"], + "type": Rule._alias.lower(), + "ref": rule["ref"], + "updated_at": rule["updated_at"], + } + ) + for rule in trace.rules + ] + ) if any(attr in args.attr for attr in ACTION_EXECUTION_DISPLAY_OPTIONS): - components.extend([Resource(**{'id': execution['object_id'], - 'type': Execution._alias.lower(), - 'ref': execution['ref'], - 'updated_at': execution['updated_at']}) - for execution in trace.action_executions]) + components.extend( + [ + Resource( + **{ + "id": execution["object_id"], + "type": Execution._alias.lower(), + "ref": execution["ref"], + "updated_at": execution["updated_at"], + } + ) + for execution in trace.action_executions + ] + ) if components: components.sort(key=lambda resource: resource.updated_at) - self.print_output(components, table.MultiColumnTable, - attributes=TRACE_COMPONENT_DISPLAY_LABELS, - json=args.json, yaml=args.yaml) + self.print_output( + components, + table.MultiColumnTable, + attributes=TRACE_COMPONENT_DISPLAY_LABELS, + json=args.json, + yaml=args.yaml, + ) class TraceListCommand(resource.ResourceCommand, SingleTraceDisplayMixin): - display_attributes = ['id', 'uid', 'trace_tag', 'start_timestamp'] + display_attributes = ["id", "uid", "trace_tag", "start_timestamp"] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone + "start_timestamp": format_isodate_for_user_timezone } attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER @@ -119,55 +156,90 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(TraceListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_mutually_exclusive_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('-s', '--sort', type=str, dest='sort_order', - default='descending', - help=('Sort %s by start timestamp, ' - 'asc|ascending (earliest first) ' - 'or desc|descending (latest first)' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "-s", + "--sort", + type=str, + dest="sort_order", + default="descending", + help=( + "Sort %s by start timestamp, " + "asc|ascending (earliest first) " + "or desc|descending (latest first)" % self.resource_name + ), + ) # Filter options - self.group.add_argument('-c', '--trace-tag', help='Trace-tag to filter the list.') - self.group.add_argument('-e', '--execution', help='Execution to filter the list.') - self.group.add_argument('-r', '--rule', help='Rule to filter the list.') - self.group.add_argument('-g', '--trigger-instance', - help='TriggerInstance to filter the list.') + self.group.add_argument( + "-c", "--trace-tag", help="Trace-tag to filter the list." + ) + self.group.add_argument( + "-e", "--execution", help="Execution to filter the list." + ) + self.group.add_argument("-r", "--rule", help="Rule to filter the list.") + self.group.add_argument( + "-g", "--trigger-instance", help="TriggerInstance to filter the list." + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trace_tag: - kwargs['trace_tag'] = args.trace_tag + kwargs["trace_tag"] = args.trace_tag if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if args.execution: - kwargs['execution'] = args.execution + kwargs["execution"] = args.execution if args.rule: - kwargs['rule'] = args.rule + kwargs["rule"] = args.rule if args.sort_order: - if args.sort_order in ['asc', 'ascending']: - kwargs['sort_asc'] = True - elif args.sort_order in ['desc', 'descending']: - kwargs['sort_desc'] = True + if args.sort_order in ["asc", "ascending"]: + kwargs["sort_asc"] = True + elif args.sort_order in ["desc", "descending"]: + kwargs["sort_desc"] = True return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): @@ -177,7 +249,7 @@ def run_and_print(self, args, **kwargs): # For a single Trace we must include the components unless # user has overriden the attributes to display if args.attr == self.display_attributes: - args.attr = ['all'] + args.attr = ["all"] self.print_trace_details(trace=instances[0], args=args) if not args.json and not args.yaml: @@ -185,27 +257,36 @@ def run_and_print(self, args, **kwargs): table.SingleRowTable.note_box(self.resource_name, 1) else: if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class TraceGetCommand(resource.ResourceGetCommand, SingleTraceDisplayMixin): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone + "start_timestamp": format_isodate_for_user_timezone } - pk_argument_name = 'id' + pk_argument_name = "id" def __init__(self, resource, *args, **kwargs): super(TraceGetCommand, self).__init__(resource, *args, **kwargs) @@ -213,23 +294,36 @@ def __init__(self, resource, *args, **kwargs): # Causation chains self.causation_group = self.parser.add_mutually_exclusive_group() - self.causation_group.add_argument('-e', '--execution', - help='Execution to show causation chain.') - self.causation_group.add_argument('-r', '--rule', help='Rule to show causation chain.') - self.causation_group.add_argument('-g', '--trigger-instance', - help='TriggerInstance to show causation chain.') + self.causation_group.add_argument( + "-e", "--execution", help="Execution to show causation chain." + ) + self.causation_group.add_argument( + "-r", "--rule", help="Rule to show causation chain." + ) + self.causation_group.add_argument( + "-g", "--trigger-instance", help="TriggerInstance to show causation chain." + ) # display filter group self.display_filter_group = self.parser.add_argument_group() - self.display_filter_group.add_argument('--show-executions', action='store_true', - help='Only show executions.') - self.display_filter_group.add_argument('--show-rules', action='store_true', - help='Only show rules.') - self.display_filter_group.add_argument('--show-trigger-instances', action='store_true', - help='Only show trigger instances.') - self.display_filter_group.add_argument('-n', '--hide-noop-triggers', action='store_true', - help='Hide noop trigger instances.') + self.display_filter_group.add_argument( + "--show-executions", action="store_true", help="Only show executions." + ) + self.display_filter_group.add_argument( + "--show-rules", action="store_true", help="Only show rules." + ) + self.display_filter_group.add_argument( + "--show-trigger-instances", + action="store_true", + help="Only show trigger instances.", + ) + self.display_filter_group.add_argument( + "-n", + "--hide-noop-triggers", + action="store_true", + help="Hide noop trigger instances.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -243,7 +337,7 @@ def run_and_print(self, args, **kwargs): trace = self.run(args, **kwargs) except resource.ResourceNotFoundError: self.print_not_found(args.id) - raise OperationFailureException('Trace %s not found.' % (args.id)) + raise OperationFailureException("Trace %s not found." % (args.id)) # First filter for causation chains trace = self._filter_trace_components(trace=trace, args=args) # next filter for display purposes @@ -266,13 +360,13 @@ def _filter_trace_components(trace, args): # pick the right component type if args.execution: component_id = args.execution - component_type = 'action_execution' + component_type = "action_execution" elif args.rule: component_id = args.rule - component_type = 'rule' + component_type = "rule" elif args.trigger_instance: component_id = args.trigger_instance - component_type = 'trigger_instance' + component_type = "trigger_instance" # Initialize collection to use action_executions = [] @@ -284,13 +378,13 @@ def _filter_trace_components(trace, args): while search_target_found: components_list = [] - if component_type == 'action_execution': + if component_type == "action_execution": components_list = trace.action_executions to_update_list = action_executions - elif component_type == 'rule': + elif component_type == "rule": components_list = trace.rules to_update_list = rules - elif component_type == 'trigger_instance': + elif component_type == "trigger_instance": components_list = trace.trigger_instances to_update_list = trigger_instances # Look for search_target in the right collection and @@ -300,22 +394,25 @@ def _filter_trace_components(trace, args): # init to default value component_caused_by_id = None for component in components_list: - test_id = component['object_id'] + test_id = component["object_id"] if test_id == component_id: - caused_by = component.get('caused_by', {}) - component_id = caused_by.get('id', None) - component_type = caused_by.get('type', None) + caused_by = component.get("caused_by", {}) + component_id = caused_by.get("id", None) + component_type = caused_by.get("type", None) # If provided the component_caused_by_id must match as well. This is mostly # applicable for rules since the same rule may appear multiple times and can # only be distinguished by causing TriggerInstance. - if component_caused_by_id and component_caused_by_id != component_id: + if ( + component_caused_by_id + and component_caused_by_id != component_id + ): continue component_caused_by_id = None to_update_list.append(component) # In some cases the component_id and the causing component are combined to # provide the complete causation chain. Think rule + triggerinstance - if component_id and ':' in component_id: - component_id_split = component_id.split(':') + if component_id and ":" in component_id: + component_id_split = component_id.split(":") component_id = component_id_split[0] component_caused_by_id = component_id_split[1] search_target_found = True @@ -333,19 +430,21 @@ def _apply_display_filters(trace, args): should be displayed. """ # If all the filters are false nothing is to be filtered. - all_component_types = not(args.show_executions or - args.show_rules or - args.show_trigger_instances) + all_component_types = not ( + args.show_executions or args.show_rules or args.show_trigger_instances + ) # check if noop_triggers are to be hidden. This check applies whenever TriggerInstances # are to be shown. - if (all_component_types or args.show_trigger_instances) and args.hide_noop_triggers: + if ( + all_component_types or args.show_trigger_instances + ) and args.hide_noop_triggers: filtered_trigger_instances = [] for trigger_instance in trace.trigger_instances: is_noop_trigger_instance = True for rule in trace.rules: - caused_by_id = rule.get('caused_by', {}).get('id', None) - if caused_by_id == trigger_instance['object_id']: + caused_by_id = rule.get("caused_by", {}).get("id", None) + if caused_by_id == trigger_instance["object_id"]: is_noop_trigger_instance = False if not is_noop_trigger_instance: filtered_trigger_instances.append(trigger_instance) diff --git a/st2client/st2client/commands/trigger.py b/st2client/st2client/commands/trigger.py index 2fd966261c..3a960fddc8 100644 --- a/st2client/st2client/commands/trigger.py +++ b/st2client/st2client/commands/trigger.py @@ -23,29 +23,40 @@ class TriggerTypeBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TriggerTypeBranch, self).__init__( - TriggerType, description, app, subparsers, + TriggerType, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': TriggerTypeListCommand, - 'get': TriggerTypeGetCommand, - 'update': TriggerTypeUpdateCommand, - 'delete': TriggerTypeDeleteCommand - }) + "list": TriggerTypeListCommand, + "get": TriggerTypeGetCommand, + "update": TriggerTypeUpdateCommand, + "delete": TriggerTypeDeleteCommand, + }, + ) # Registers extended commands - self.commands['getspecs'] = TriggerTypeSubTriggerCommand( - self.resource, self.app, self.subparsers, - add_help=False) + self.commands["getspecs"] = TriggerTypeSubTriggerCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class TriggerTypeListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description'] + display_attributes = ["ref", "pack", "description"] class TriggerTypeGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'parameters_schema', 'payload_schema'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "parameters_schema", + "payload_schema", + ] class TriggerTypeUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -57,29 +68,45 @@ class TriggerTypeDeleteCommand(resource.ContentPackResourceDeleteCommand): class TriggerTypeSubTriggerCommand(resource.ResourceCommand): - attribute_display_order = ['id', 'ref', 'context', 'parameters', 'status', - 'start_timestamp', 'result'] + attribute_display_order = [ + "id", + "ref", + "context", + "parameters", + "status", + "start_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(TriggerTypeSubTriggerCommand, self).__init__( - resource, kwargs.pop('name', 'getspecs'), - 'Return Trigger Specifications of a Trigger.', - *args, **kwargs) - - self.parser.add_argument('ref', nargs='?', - metavar='ref', - help='Fully qualified name (pack.trigger_name) ' + - 'of the trigger.') - - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "getspecs"), + "Return Trigger Specifications of a Trigger.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "ref", + nargs="?", + metavar="ref", + help="Fully qualified name (pack.trigger_name) " + "of the trigger.", + ) + + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - trigger_mgr = self.app.client.managers['Trigger'] - return trigger_mgr.query(**{'type': args.ref}) + trigger_mgr = self.app.client.managers["Trigger"] + return trigger_mgr.query(**{"type": args.ref}) @resource.add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): @@ -87,5 +114,6 @@ def run_and_print(self, args, **kwargs): self.parser.print_help() return instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - json=args.json, yaml=args.yaml) + self.print_output( + instances, table.MultiColumnTable, json=args.json, yaml=args.yaml + ) diff --git a/st2client/st2client/commands/triggerinstance.py b/st2client/st2client/commands/triggerinstance.py index 2ac4da73da..12966fea92 100644 --- a/st2client/st2client/commands/triggerinstance.py +++ b/st2client/st2client/commands/triggerinstance.py @@ -25,17 +25,23 @@ class TriggerInstanceResendCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(TriggerInstanceResendCommand, self).__init__( - resource, kwargs.pop('name', 're-emit'), - 'Re-emit a particular trigger instance.', - *args, **kwargs) + resource, + kwargs.pop("name", "re-emit"), + "Re-emit a particular trigger instance.", + *args, + **kwargs, + ) - self.parser.add_argument('id', nargs='?', - metavar='id', - help='ID of trigger instance to re-emit.') self.parser.add_argument( - '-h', '--help', - action='store_true', dest='help', - help='Print usage for the given command.') + "id", nargs="?", metavar="id", help="ID of trigger instance to re-emit." + ) + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given command.", + ) def run(self, args, **kwargs): return self.manager.re_emit(args.id) @@ -43,29 +49,35 @@ def run(self, args, **kwargs): @resource.add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): ret = self.run(args, **kwargs) - if 'message' in ret: - print(ret['message']) + if "message" in ret: + print(ret["message"]) class TriggerInstanceBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TriggerInstanceBranch, self).__init__( - TriggerInstance, description, app, subparsers, - parent_parser=parent_parser, read_only=True, + TriggerInstance, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, commands={ - 'list': TriggerInstanceListCommand, - 'get': TriggerInstanceGetCommand - }) + "list": TriggerInstanceListCommand, + "get": TriggerInstanceGetCommand, + }, + ) - self.commands['re-emit'] = TriggerInstanceResendCommand(self.resource, self.app, - self.subparsers, add_help=False) + self.commands["re-emit"] = TriggerInstanceResendCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class TriggerInstanceListCommand(resource.ResourceViewCommand): - display_attributes = ['id', 'trigger', 'occurrence_time', 'status'] + display_attributes = ["id", "trigger", "occurrence_time", "status"] attribute_transform_functions = { - 'occurrence_time': format_isodate_for_user_timezone + "occurrence_time": format_isodate_for_user_timezone } def __init__(self, resource, *args, **kwargs): @@ -73,83 +85,133 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(TriggerInstanceListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Filter options - self.group.add_argument('--trigger', help='Trigger reference to filter the list.') - - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return trigger instances with occurrence_time ' - 'greater than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return trigger instances with timestamp ' - 'lower than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - - self.group.add_argument('--status', - help='Can be pending, processing, processed or processing_failed.') + self.group.add_argument( + "--trigger", help="Trigger reference to filter the list." + ) + + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return trigger instances with occurrence_time " + "greater than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return trigger instances with timestamp " + "lower than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + + self.group.add_argument( + "--status", + help="Can be pending, processing, processed or processing_failed.", + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trigger: - kwargs['trigger'] = args.trigger + kwargs["trigger"] = args.trigger if args.timestamp_gt: - kwargs['timestamp_gt'] = args.timestamp_gt + kwargs["timestamp_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['timestamp_lt'] = args.timestamp_lt + kwargs["timestamp_lt"] = args.timestamp_lt if args.status: - kwargs['status'] = args.status + kwargs["status"] = args.status include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class TriggerInstanceGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'trigger', 'occurrence_time', 'payload'] + display_attributes = ["all"] + attribute_display_order = ["id", "trigger", "occurrence_time", "payload"] - pk_argument_name = 'id' + pk_argument_name = "id" @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): diff --git a/st2client/st2client/commands/webhook.py b/st2client/st2client/commands/webhook.py index 3a48344500..4b555ac59f 100644 --- a/st2client/st2client/commands/webhook.py +++ b/st2client/st2client/commands/webhook.py @@ -23,37 +23,47 @@ class WebhookBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(WebhookBranch, self).__init__( - Webhook, description, app, subparsers, + Webhook, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': WebhookListCommand, - 'get': WebhookGetCommand - }) + commands={"list": WebhookListCommand, "get": WebhookGetCommand}, + ) class WebhookListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['url', 'type', 'description'] + display_attributes = ["url", "type", "description"] def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) for instance in instances: - instance.url = instance.parameters['url'] + instance.url = instance.parameters["url"] instances = sorted(instances, key=lambda k: k.url) if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) class WebhookGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['type', 'description'] + display_attributes = ["all"] + attribute_display_order = ["type", "description"] - pk_argument_name = 'url' + pk_argument_name = "url" diff --git a/st2client/st2client/commands/workflow.py b/st2client/st2client/commands/workflow.py index 5348f76706..57f9f52a5f 100644 --- a/st2client/st2client/commands/workflow.py +++ b/st2client/st2client/commands/workflow.py @@ -1,4 +1,3 @@ - # Copyright 2020 The StackStorm Authors. # Copyright 2019 Extreme Networks, Inc. # @@ -28,26 +27,25 @@ class WorkflowBranch(commands.Branch): - def __init__(self, description, app, subparsers, parent_parser=None): super(WorkflowBranch, self).__init__( - 'workflow', description, app, subparsers, - parent_parser=parent_parser + "workflow", description, app, subparsers, parent_parser=parent_parser ) # Add subparser to register subcommands for managing workflows. - help_message = 'List of commands for managing workflows.' + help_message = "List of commands for managing workflows." self.subparsers = self.parser.add_subparsers(help=help_message) # Register workflow commands. - self.commands['inspect'] = WorkflowInspectionCommand(self.app, self.subparsers) + self.commands["inspect"] = WorkflowInspectionCommand(self.app, self.subparsers) class WorkflowInspectionCommand(commands.Command): - def __init__(self, *args, **kwargs): - name = 'inspect' - description = 'Inspect workflow definition and return the list of errors if any.' + name = "inspect" + description = ( + "Inspect workflow definition and return the list of errors if any." + ) args = tuple([name, description] + list(args)) super(WorkflowInspectionCommand, self).__init__(*args, **kwargs) @@ -55,27 +53,25 @@ def __init__(self, *args, **kwargs): arg_group = self.parser.add_mutually_exclusive_group() arg_group.add_argument( - '--file', - dest='file', - help='Local file path to the workflow definition.' + "--file", dest="file", help="Local file path to the workflow definition." ) arg_group.add_argument( - '--action', - dest='action', - help='Reference name for the registered action. This option works only if the file ' - 'referenced by the entry point is installed locally under /opt/stackstorm/packs.' + "--action", + dest="action", + help="Reference name for the registered action. This option works only if the file " + "referenced by the entry point is installed locally under /opt/stackstorm/packs.", ) @property def manager(self): - return self.app.client.managers['Workflow'] + return self.app.client.managers["Workflow"] def get_file_content(self, file_path): if not os.path.isfile(file_path): raise Exception('File "%s" does not exist on local system.' % file_path) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: content = f.read() return content @@ -88,13 +84,18 @@ def run(self, args, **kwargs): # is executed locally where the content is stored. if not wf_def_file: action_ref = args.action - action_manager = self.app.client.managers['Action'] + action_manager = self.app.client.managers["Action"] action = action_manager.get_by_ref_or_id(ref_or_id=action_ref) if not action: raise Exception('Unable to identify action "%s".' % action_ref) - wf_def_file = '/opt/stackstorm/packs/' + action.pack + '/actions/' + action.entry_point + wf_def_file = ( + "/opt/stackstorm/packs/" + + action.pack + + "/actions/" + + action.entry_point + ) wf_def = self.get_file_content(wf_def_file) @@ -105,10 +106,10 @@ def run_and_print(self, args, **kwargs): errors = self.run(args, **kwargs) if not isinstance(errors, list): - raise TypeError('The inspection result is not type of list: %s' % errors) + raise TypeError("The inspection result is not type of list: %s" % errors) if not errors: - print('No errors found in workflow definition.') + print("No errors found in workflow definition.") return print(yaml.safe_dump(errors, default_flow_style=False, allow_unicode=True)) diff --git a/st2client/st2client/config.py b/st2client/st2client/config.py index 5de500aec2..c002c7f414 100644 --- a/st2client/st2client/config.py +++ b/st2client/st2client/config.py @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'get_config', - 'set_config' -] +__all__ = ["get_config", "set_config"] # Stores parsed config dictionary CONFIG = {} diff --git a/st2client/st2client/config_parser.py b/st2client/st2client/config_parser.py index e5095df3e2..ca88209f87 100644 --- a/st2client/st2client/config_parser.py +++ b/st2client/st2client/config_parser.py @@ -31,88 +31,38 @@ __all__ = [ - 'CLIConfigParser', - - 'ST2_CONFIG_DIRECTORY', - 'ST2_CONFIG_PATH', - - 'CONFIG_DEFAULT_VALUES' + "CLIConfigParser", + "ST2_CONFIG_DIRECTORY", + "ST2_CONFIG_PATH", + "CONFIG_DEFAULT_VALUES", ] -ST2_CONFIG_DIRECTORY = '~/.st2' +ST2_CONFIG_DIRECTORY = "~/.st2" ST2_CONFIG_DIRECTORY = os.path.abspath(os.path.expanduser(ST2_CONFIG_DIRECTORY)) -ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, 'config')) +ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, "config")) CONFIG_FILE_OPTIONS = { - 'general': { - 'base_url': { - 'type': 'string', - 'default': None - }, - 'api_version': { - 'type': 'string', - 'default': None - }, - 'cacert': { - 'type': 'string', - 'default': None - }, - 'silence_ssl_warnings': { - 'type': 'bool', - 'default': False - }, - 'silence_schema_output': { - 'type': 'bool', - 'default': True - } - }, - 'cli': { - 'debug': { - 'type': 'bool', - 'default': False - }, - 'cache_token': { - 'type': 'boolean', - 'default': True - }, - 'timezone': { - 'type': 'string', - 'default': 'UTC' - } - }, - 'credentials': { - 'username': { - 'type': 'string', - 'default': None - }, - 'password': { - 'type': 'string', - 'default': None - }, - 'api_key': { - 'type': 'string', - 'default': None - } + "general": { + "base_url": {"type": "string", "default": None}, + "api_version": {"type": "string", "default": None}, + "cacert": {"type": "string", "default": None}, + "silence_ssl_warnings": {"type": "bool", "default": False}, + "silence_schema_output": {"type": "bool", "default": True}, }, - 'api': { - 'url': { - 'type': 'string', - 'default': None - } + "cli": { + "debug": {"type": "bool", "default": False}, + "cache_token": {"type": "boolean", "default": True}, + "timezone": {"type": "string", "default": "UTC"}, }, - 'auth': { - 'url': { - 'type': 'string', - 'default': None - } + "credentials": { + "username": {"type": "string", "default": None}, + "password": {"type": "string", "default": None}, + "api_key": {"type": "string", "default": None}, }, - 'stream': { - 'url': { - 'type': 'string', - 'default': None - } - } + "api": {"url": {"type": "string", "default": None}}, + "auth": {"url": {"type": "string", "default": None}}, + "stream": {"url": {"type": "string", "default": None}}, } CONFIG_DEFAULT_VALUES = {} @@ -121,13 +71,18 @@ CONFIG_DEFAULT_VALUES[section] = {} for key, options in six.iteritems(keys): - default_value = options['default'] + default_value = options["default"] CONFIG_DEFAULT_VALUES[section][key] = default_value class CLIConfigParser(object): - def __init__(self, config_file_path, validate_config_exists=True, - validate_config_permissions=True, log=None): + def __init__( + self, + config_file_path, + validate_config_exists=True, + validate_config_permissions=True, + log=None, + ): if validate_config_exists and not os.path.isfile(config_file_path): raise ValueError('Config file "%s" doesn\'t exist') @@ -158,37 +113,40 @@ def parse(self): if bool(os.stat(config_dir_path).st_mode & 0o7): self.LOG.warn( "The StackStorm configuration directory permissions are " - "insecure (too permissive): others have access.") + "insecure (too permissive): others have access." + ) # Make sure the setgid bit is set on the directory if not bool(os.stat(config_dir_path).st_mode & 0o2000): self.LOG.info( "The SGID bit is not set on the StackStorm configuration " - "directory.") + "directory." + ) # Make sure the file permissions == 0o660 if bool(os.stat(self.config_file_path).st_mode & 0o7): self.LOG.warn( "The StackStorm configuration file permissions are " - "insecure: others have access.") + "insecure: others have access." + ) config = ConfigParser() - with io.open(self.config_file_path, 'r', encoding='utf8') as fp: + with io.open(self.config_file_path, "r", encoding="utf8") as fp: config.readfp(fp) for section, keys in six.iteritems(CONFIG_FILE_OPTIONS): for key, options in six.iteritems(keys): - key_type = options['type'] - key_default_value = options['default'] + key_type = options["type"] + key_default_value = options["default"] if config.has_option(section, key): - if key_type in ['str', 'string']: + if key_type in ["str", "string"]: get_func = config.get - elif key_type in ['int', 'integer']: + elif key_type in ["int", "integer"]: get_func = config.getint - elif key_type in ['float']: + elif key_type in ["float"]: get_func = config.getfloat - elif key_type in ['bool', 'boolean']: + elif key_type in ["bool", "boolean"]: get_func = config.getboolean else: msg = 'Invalid type "%s" for option "%s"' % (key_type, key) diff --git a/st2client/st2client/exceptions/base.py b/st2client/st2client/exceptions/base.py index f9cd343665..97c9bb8a09 100644 --- a/st2client/st2client/exceptions/base.py +++ b/st2client/st2client/exceptions/base.py @@ -16,7 +16,8 @@ class StackStormCLIBaseException(Exception): """ - The root of the exception class hierarchy for all - StackStorm CLI exceptions. + The root of the exception class hierarchy for all + StackStorm CLI exceptions. """ + pass diff --git a/st2client/st2client/formatters/__init__.py b/st2client/st2client/formatters/__init__.py index dcaaee3ee1..e0d8e5f718 100644 --- a/st2client/st2client/formatters/__init__.py +++ b/st2client/st2client/formatters/__init__.py @@ -25,10 +25,8 @@ class Formatter(six.with_metaclass(abc.ABCMeta, object)): - @classmethod @abc.abstractmethod def format(cls, subject, *args, **kwargs): - """Override this method to customize output format for the subject. - """ + """Override this method to customize output format for the subject.""" raise NotImplementedError diff --git a/st2client/st2client/formatters/doc.py b/st2client/st2client/formatters/doc.py index ea2218dec4..5f6ca96dce 100644 --- a/st2client/st2client/formatters/doc.py +++ b/st2client/st2client/formatters/doc.py @@ -23,10 +23,7 @@ from st2client import formatters from st2client.utils import jsutil -__all__ = [ - 'JsonFormatter', - 'YAMLFormatter' -] +__all__ = ["JsonFormatter", "YAMLFormatter"] LOG = logging.getLogger(__name__) @@ -34,25 +31,34 @@ class BaseFormatter(formatters.Formatter): @classmethod def format(self, subject, *args, **kwargs): - attributes = kwargs.get('attributes', None) + attributes = kwargs.get("attributes", None) if type(subject) is str: subject = json.loads(subject) - elif not isinstance(subject, (list, tuple)) and not hasattr(subject, '__iter__'): + elif not isinstance(subject, (list, tuple)) and not hasattr( + subject, "__iter__" + ): doc = subject if isinstance(subject, dict) else subject.__dict__ - keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes + keys = ( + list(doc.keys()) + if not attributes or "all" in attributes + else attributes + ) docs = jsutil.get_kvps(doc, keys) else: docs = [] for item in subject: doc = item if isinstance(item, dict) else item.__dict__ - keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes + keys = ( + list(doc.keys()) + if not attributes or "all" in attributes + else attributes + ) docs.append(jsutil.get_kvps(doc, keys)) return docs class JsonFormatter(BaseFormatter): - @classmethod def format(self, subject, *args, **kwargs): docs = BaseFormatter.format(subject, *args, **kwargs) @@ -60,7 +66,6 @@ def format(self, subject, *args, **kwargs): class YAMLFormatter(BaseFormatter): - @classmethod def format(self, subject, *args, **kwargs): docs = BaseFormatter.format(subject, *args, **kwargs) diff --git a/st2client/st2client/formatters/execution.py b/st2client/st2client/formatters/execution.py index 69da8cdb41..b52527d4de 100644 --- a/st2client/st2client/formatters/execution.py +++ b/st2client/st2client/formatters/execution.py @@ -32,32 +32,31 @@ LOG = logging.getLogger(__name__) -PLATFORM_MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 +PLATFORM_MAXINT = 2 ** (struct.Struct("i").size * 8 - 1) - 1 def _print_bordered(text): - lines = text.split('\n') + lines = text.split("\n") width = max(len(s) for s in lines) + 2 - res = ['\n+' + '-' * width + '+'] + res = ["\n+" + "-" * width + "+"] for s in lines: - res.append('| ' + (s + ' ' * width)[:width - 2] + ' |') - res.append('+' + '-' * width + '+') - return '\n'.join(res) + res.append("| " + (s + " " * width)[: width - 2] + " |") + res.append("+" + "-" * width + "+") + return "\n".join(res) class ExecutionResult(formatters.Formatter): - @classmethod def format(cls, entry, *args, **kwargs): - attrs = kwargs.get('attributes', []) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) - key = kwargs.get('key', None) + attrs = kwargs.get("attributes", []) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) + key = kwargs.get("key", None) if key: output = jsutil.get_value(entry.result, key) else: # drop entry to the dict so that jsutil can operate entry = vars(entry) - output = '' + output = "" for attr in attrs: value = jsutil.get_value(entry, attr) value = strutil.strip_carriage_returns(strutil.unescape(value)) @@ -65,8 +64,12 @@ def format(cls, entry, *args, **kwargs): # if the leading character is objectish start and last character is objectish # end but the string isn't supposed to be a object. Try/Except will catch # this for now, but this should be improved. - if (isinstance(value, six.string_types) and len(value) > 0 and - value[0] in ['{', '['] and value[len(value) - 1] in ['}', ']']): + if ( + isinstance(value, six.string_types) + and len(value) > 0 + and value[0] in ["{", "["] + and value[len(value) - 1] in ["}", "]"] + ): try: new_value = ast.literal_eval(value) except: @@ -79,31 +82,40 @@ def format(cls, entry, *args, **kwargs): # 2. Drop the trailing newline # 3. Set width to maxint so pyyaml does not split text. Anything longer # and likely we will see other issues like storage :P. - formatted_value = yaml.safe_dump({attr: value}, - default_flow_style=False, - width=PLATFORM_MAXINT, - indent=2)[len(attr) + 2:-1] - value = ('\n' if isinstance(value, dict) else '') + formatted_value + formatted_value = yaml.safe_dump( + {attr: value}, + default_flow_style=False, + width=PLATFORM_MAXINT, + indent=2, + )[len(attr) + 2 : -1] + value = ("\n" if isinstance(value, dict) else "") + formatted_value value = strutil.dedupe_newlines(value) # transform the value of our attribute so things like 'status' # and 'timestamp' are formatted nicely - transform_function = attribute_transform_functions.get(attr, - lambda value: value) + transform_function = attribute_transform_functions.get( + attr, lambda value: value + ) value = transform_function(value=value) - output += ('\n' if output else '') + '%s: %s' % \ - (DisplayColors.colorize(attr, DisplayColors.BLUE), value) + output += ("\n" if output else "") + "%s: %s" % ( + DisplayColors.colorize(attr, DisplayColors.BLUE), + value, + ) - output_schema = entry.get('action', {}).get('output_schema') - schema_check = get_config()['general']['silence_schema_output'] - if not output_schema and kwargs.get('with_schema'): + output_schema = entry.get("action", {}).get("output_schema") + schema_check = get_config()["general"]["silence_schema_output"] + if not output_schema and kwargs.get("with_schema"): rendered_schema = { - 'output_schema': schema.render_output_schema_from_output(entry['result']) + "output_schema": schema.render_output_schema_from_output( + entry["result"] + ) } - rendered_schema = yaml.safe_dump(rendered_schema, default_flow_style=False) - output += '\n' + rendered_schema = yaml.safe_dump( + rendered_schema, default_flow_style=False + ) + output += "\n" output += _print_bordered( "Based on the action output the following inferred schema was built:" "\n\n" @@ -120,7 +132,11 @@ def format(cls, entry, *args, **kwargs): else: # Assume Python 2 try: - result = strutil.unescape(str(output)).decode('unicode_escape').encode('utf-8') + result = ( + strutil.unescape(str(output)) + .decode("unicode_escape") + .encode("utf-8") + ) except UnicodeDecodeError: # String contains a value which is not an unicode escape sequence, ignore the error result = strutil.unescape(str(output)) diff --git a/st2client/st2client/formatters/table.py b/st2client/st2client/formatters/table.py index 404469ce0e..91cc59e009 100644 --- a/st2client/st2client/formatters/table.py +++ b/st2client/st2client/formatters/table.py @@ -40,40 +40,38 @@ MIN_COL_WIDTH = 5 # Default attribute display order to use if one is not provided -DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'pack', 'description'] +DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "pack", "description"] # Attributes which contain bash escape sequences - we can't split those across multiple lines # since things would break COLORIZED_ATTRIBUTES = { - 'status': { - 'col_width': 24 # Note: len('succeed' + ' (XXXX elapsed)') <= 24 - } + "status": {"col_width": 24} # Note: len('succeed' + ' (XXXX elapsed)') <= 24 } class MultiColumnTable(formatters.Formatter): - def __init__(self): self._table_width = 0 @classmethod def format(cls, entries, *args, **kwargs): - attributes = kwargs.get('attributes', []) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) - widths = kwargs.get('widths', []) + attributes = kwargs.get("attributes", []) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) + widths = kwargs.get("widths", []) widths = widths or [] if not widths and attributes: # Dynamically calculate column size based on the terminal size cols = get_terminal_size_columns() - if attributes[0] == 'id': + if attributes[0] == "id": # consume iterator and save as entries so collection is accessible later. entries = [e for e in entries] # first column contains id, make sure it's not broken up - first_col_width = cls._get_required_column_width(values=[e.id for e in entries], - minimum_width=MIN_ID_COL_WIDTH) - cols = (cols - first_col_width) + first_col_width = cls._get_required_column_width( + values=[e.id for e in entries], minimum_width=MIN_ID_COL_WIDTH + ) + cols = cols - first_col_width col_width = int(math.floor((cols / len(attributes)))) else: col_width = int(math.floor((cols / len(attributes)))) @@ -88,14 +86,16 @@ def format(cls, entries, *args, **kwargs): continue if attribute_name in COLORIZED_ATTRIBUTES: - current_col_width = COLORIZED_ATTRIBUTES[attribute_name]['col_width'] - subtract += (current_col_width - col_width) + current_col_width = COLORIZED_ATTRIBUTES[attribute_name][ + "col_width" + ] + subtract += current_col_width - col_width else: # Make sure we subtract the added width from the last column so we account # for the fixed width columns and make sure table is not wider than the # terminal width. if index == (len(attributes) - 1) and subtract: - current_col_width = (col_width - subtract) + current_col_width = col_width - subtract if current_col_width <= MIN_COL_WIDTH: # Make sure column width is always grater than MIN_COL_WIDTH @@ -105,12 +105,14 @@ def format(cls, entries, *args, **kwargs): widths.append(current_col_width) - if not attributes or 'all' in attributes: + if not attributes or "all" in attributes: entries = list(entries) if entries else [] if len(entries) >= 1: attributes = list(entries[0].__dict__.keys()) - attributes = sorted([attr for attr in attributes if not attr.startswith('_')]) + attributes = sorted( + [attr for attr in attributes if not attr.startswith("_")] + ) else: # There are no entries so we can't infer available attributes attributes = [] @@ -123,8 +125,7 @@ def format(cls, entries, *args, **kwargs): # If only 1 width value is provided then # apply it to all columns else fix at 28. width = widths[0] if len(widths) == 1 else 28 - columns = zip(attributes, - [width for i in range(0, len(attributes))]) + columns = zip(attributes, [width for i in range(0, len(attributes))]) # Format result to table. table = PrettyTable() @@ -132,14 +133,14 @@ def format(cls, entries, *args, **kwargs): table.field_names.append(column[0]) table.max_width[column[0]] = column[1] table.padding_width = 1 - table.align = 'l' - table.valign = 't' + table.align = "l" + table.valign = "t" for entry in entries: # TODO: Improve getting values of nested dict. values = [] for field_name in table.field_names: - if '.' in field_name: - field_names = field_name.split('.') + if "." in field_name: + field_names = field_name.split(".") value = getattr(entry, field_names.pop(0), {}) for name in field_names: value = cls._get_field_value(value, name) @@ -149,8 +150,9 @@ def format(cls, entries, *args, **kwargs): values.append(value) else: value = cls._get_simple_field_value(entry, field_name) - transform_function = attribute_transform_functions.get(field_name, - lambda value: value) + transform_function = attribute_transform_functions.get( + field_name, lambda value: value + ) value = transform_function(value=value) value = strutil.strip_carriage_returns(strutil.unescape(value)) values.append(value) @@ -177,14 +179,14 @@ def _get_simple_field_value(entry, field_name): """ Format a value for a simple field. """ - value = getattr(entry, field_name, '') + value = getattr(entry, field_name, "") if isinstance(value, (list, tuple)): if len(value) == 0: - value = '' + value = "" elif isinstance(value[0], (str, six.text_type)): # List contains simple string values, format it as comma # separated string - value = ', '.join(value) + value = ", ".join(value) return value @@ -192,10 +194,10 @@ def _get_simple_field_value(entry, field_name): def _get_field_value(value, field_name): r_val = value.get(field_name, None) if r_val is None: - return '' + return "" if isinstance(r_val, list) or isinstance(r_val, dict): - return r_val if len(r_val) > 0 else '' + return r_val if len(r_val) > 0 else "" return r_val @staticmethod @@ -203,7 +205,7 @@ def _get_friendly_column_name(name): if not name: return None - friendly_name = name.replace('_', ' ').replace('.', ' ').capitalize() + friendly_name = name.replace("_", " ").replace(".", " ").capitalize() return friendly_name @staticmethod @@ -213,33 +215,34 @@ def _get_required_column_width(values, minimum_width=0): class PropertyValueTable(formatters.Formatter): - @classmethod def format(cls, subject, *args, **kwargs): - attributes = kwargs.get('attributes', None) - attribute_display_order = kwargs.get('attribute_display_order', - DEFAULT_ATTRIBUTE_DISPLAY_ORDER) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) + attributes = kwargs.get("attributes", None) + attribute_display_order = kwargs.get( + "attribute_display_order", DEFAULT_ATTRIBUTE_DISPLAY_ORDER + ) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) - if not attributes or 'all' in attributes: - attributes = sorted([attr for attr in subject.__dict__ - if not attr.startswith('_')]) + if not attributes or "all" in attributes: + attributes = sorted( + [attr for attr in subject.__dict__ if not attr.startswith("_")] + ) for attr in attribute_display_order[::-1]: if attr in attributes: attributes.remove(attr) attributes = [attr] + attributes table = PrettyTable() - table.field_names = ['Property', 'Value'] - table.max_width['Property'] = 20 - table.max_width['Value'] = 60 + table.field_names = ["Property", "Value"] + table.max_width["Property"] = 20 + table.max_width["Value"] = 60 table.padding_width = 1 - table.align = 'l' - table.valign = 't' + table.align = "l" + table.valign = "t" for attribute in attributes: - if '.' in attribute: - field_names = attribute.split('.') + if "." in attribute: + field_names = attribute.split(".") value = cls._get_attribute_value(subject, field_names.pop(0)) for name in field_names: value = cls._get_attribute_value(value, name) @@ -248,8 +251,9 @@ def format(cls, subject, *args, **kwargs): else: value = cls._get_attribute_value(subject, attribute) - transform_function = attribute_transform_functions.get(attribute, - lambda value: value) + transform_function = attribute_transform_functions.get( + attribute, lambda value: value + ) value = transform_function(value=value) if type(value) is dict or type(value) is list: @@ -266,9 +270,9 @@ def _get_attribute_value(subject, attribute): else: r_val = getattr(subject, attribute, None) if r_val is None: - return '' + return "" if isinstance(r_val, list) or isinstance(r_val, dict): - return r_val if len(r_val) > 0 else '' + return r_val if len(r_val) > 0 else "" return r_val @@ -284,19 +288,25 @@ def note_box(entity, limit): else: entity = entity[:-1] - message = "Note: Only one %s is displayed. Use -n/--last flag for more results." \ + message = ( + "Note: Only one %s is displayed. Use -n/--last flag for more results." % entity + ) else: - message = "Note: Only first %s %s are displayed. Use -n/--last flag for more results."\ + message = ( + "Note: Only first %s %s are displayed. Use -n/--last flag for more results." % (limit, entity) + ) # adding default padding message_length = len(message) + 3 m = MultiColumnTable() if m.table_width > message_length: - note = PrettyTable([""], right_padding_width=(m.table_width - message_length)) + note = PrettyTable( + [""], right_padding_width=(m.table_width - message_length) + ) else: note = PrettyTable([""]) note.header = False note.add_row([message]) - sys.stderr.write((str(note) + '\n')) + sys.stderr.write((str(note) + "\n")) return diff --git a/st2client/st2client/models/__init__.py b/st2client/st2client/models/__init__.py index 2862f59d28..8e27a77050 100644 --- a/st2client/st2client/models/__init__.py +++ b/st2client/st2client/models/__init__.py @@ -15,19 +15,19 @@ from __future__ import absolute_import -from st2client.models.core import * # noqa -from st2client.models.auth import * # noqa -from st2client.models.action import * # noqa +from st2client.models.core import * # noqa +from st2client.models.auth import * # noqa +from st2client.models.action import * # noqa from st2client.models.action_alias import * # noqa from st2client.models.aliasexecution import * # noqa from st2client.models.config import * # noqa from st2client.models.inquiry import * # noqa -from st2client.models.keyvalue import * # noqa -from st2client.models.pack import * # noqa -from st2client.models.policy import * # noqa -from st2client.models.reactor import * # noqa -from st2client.models.trace import * # noqa -from st2client.models.webhook import * # noqa -from st2client.models.timer import * # noqa -from st2client.models.service_registry import * # noqa -from st2client.models.rbac import * # noqa +from st2client.models.keyvalue import * # noqa +from st2client.models.pack import * # noqa +from st2client.models.policy import * # noqa +from st2client.models.reactor import * # noqa +from st2client.models.trace import * # noqa +from st2client.models.webhook import * # noqa +from st2client.models.timer import * # noqa +from st2client.models.service_registry import * # noqa +from st2client.models.rbac import * # noqa diff --git a/st2client/st2client/models/action.py b/st2client/st2client/models/action.py index 10692d3dc4..d31b694f80 100644 --- a/st2client/st2client/models/action.py +++ b/st2client/st2client/models/action.py @@ -24,27 +24,33 @@ class RunnerType(core.Resource): - _alias = 'Runner' - _display_name = 'Runner' - _plural = 'RunnerTypes' - _plural_display_name = 'Runners' - _repr_attributes = ['name', 'enabled', 'description'] + _alias = "Runner" + _display_name = "Runner" + _plural = "RunnerTypes" + _plural_display_name = "Runners" + _repr_attributes = ["name", "enabled", "description"] class Action(core.Resource): - _plural = 'Actions' - _repr_attributes = ['name', 'pack', 'enabled', 'runner_type'] - _url_path = 'actions' + _plural = "Actions" + _repr_attributes = ["name", "pack", "enabled", "runner_type"] + _url_path = "actions" class Execution(core.Resource): - _alias = 'Execution' - _display_name = 'Action Execution' - _url_path = 'executions' - _plural = 'ActionExecutions' - _plural_display_name = 'Action executions' - _repr_attributes = ['status', 'action', 'start_timestamp', 'end_timestamp', 'parameters', - 'delay'] + _alias = "Execution" + _display_name = "Action Execution" + _url_path = "executions" + _plural = "ActionExecutions" + _plural_display_name = "Action executions" + _repr_attributes = [ + "status", + "action", + "start_timestamp", + "end_timestamp", + "parameters", + "delay", + ] # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for diff --git a/st2client/st2client/models/action_alias.py b/st2client/st2client/models/action_alias.py index 42162eae3b..1c1a696cff 100644 --- a/st2client/st2client/models/action_alias.py +++ b/st2client/st2client/models/action_alias.py @@ -17,25 +17,22 @@ from st2client.models import core -__all__ = [ - 'ActionAlias', - 'ActionAliasMatch' -] +__all__ = ["ActionAlias", "ActionAliasMatch"] class ActionAlias(core.Resource): - _alias = 'Action-Alias' - _display_name = 'Action Alias' - _plural = 'ActionAliases' - _plural_display_name = 'Action Aliases' - _url_path = 'actionalias' - _repr_attributes = ['name', 'pack', 'action_ref'] + _alias = "Action-Alias" + _display_name = "Action Alias" + _plural = "ActionAliases" + _plural_display_name = "Action Aliases" + _url_path = "actionalias" + _repr_attributes = ["name", "pack", "action_ref"] class ActionAliasMatch(core.Resource): - _alias = 'Action-Alias-Match' - _display_name = 'ActionAlias Match' - _plural = 'ActionAliasMatches' - _plural_display_name = 'Action Alias Matches' - _url_path = 'actionalias' - _repr_attributes = ['command'] + _alias = "Action-Alias-Match" + _display_name = "ActionAlias Match" + _plural = "ActionAliasMatches" + _plural_display_name = "Action Alias Matches" + _url_path = "actionalias" + _repr_attributes = ["command"] diff --git a/st2client/st2client/models/aliasexecution.py b/st2client/st2client/models/aliasexecution.py index 12cfc67cf5..a2d7e62a57 100644 --- a/st2client/st2client/models/aliasexecution.py +++ b/st2client/st2client/models/aliasexecution.py @@ -17,16 +17,21 @@ from st2client.models import core -__all__ = [ - 'ActionAliasExecution' -] +__all__ = ["ActionAliasExecution"] class ActionAliasExecution(core.Resource): - _alias = 'Action-Alias-Execution' - _display_name = 'ActionAlias Execution' - _plural = 'ActionAliasExecutions' - _plural_display_name = 'Runners' - _url_path = 'aliasexecution' - _repr_attributes = ['name', 'format', 'command', 'user', 'source_channel', - 'notification_channel', 'notification_route'] + _alias = "Action-Alias-Execution" + _display_name = "ActionAlias Execution" + _plural = "ActionAliasExecutions" + _plural_display_name = "Runners" + _url_path = "aliasexecution" + _repr_attributes = [ + "name", + "format", + "command", + "user", + "source_channel", + "notification_channel", + "notification_route", + ] diff --git a/st2client/st2client/models/auth.py b/st2client/st2client/models/auth.py index 9fa626a19a..7c909ea172 100644 --- a/st2client/st2client/models/auth.py +++ b/st2client/st2client/models/auth.py @@ -24,14 +24,14 @@ class Token(core.Resource): - _display_name = 'Access Token' - _plural = 'Tokens' - _plural_display_name = 'Access Tokens' - _repr_attributes = ['user', 'expiry', 'metadata'] + _display_name = "Access Token" + _plural = "Tokens" + _plural_display_name = "Access Tokens" + _repr_attributes = ["user", "expiry", "metadata"] class ApiKey(core.Resource): - _display_name = 'API Key' - _plural = 'ApiKeys' - _plural_display_name = 'API Keys' - _repr_attributes = ['id', 'user', 'metadata'] + _display_name = "API Key" + _plural = "ApiKeys" + _plural_display_name = "API Keys" + _repr_attributes = ["id", "user", "metadata"] diff --git a/st2client/st2client/models/config.py b/st2client/st2client/models/config.py index 247b4fcaf9..f9054751ed 100644 --- a/st2client/st2client/models/config.py +++ b/st2client/st2client/models/config.py @@ -19,14 +19,14 @@ class Config(core.Resource): - _display_name = 'Config' - _plural = 'Configs' - _plural_display_name = 'Configs' + _display_name = "Config" + _plural = "Configs" + _plural_display_name = "Configs" class ConfigSchema(core.Resource): - _display_name = 'Config Schema' - _plural = 'ConfigSchema' - _plural_display_name = 'Config Schemas' - _url_path = 'config_schemas' - _repr_attributes = ['id', 'pack', 'attributes'] + _display_name = "Config Schema" + _plural = "ConfigSchema" + _plural_display_name = "Config Schemas" + _url_path = "config_schemas" + _repr_attributes = ["id", "pack", "attributes"] diff --git a/st2client/st2client/models/core.py b/st2client/st2client/models/core.py index 255c91534f..d2a9b694f1 100644 --- a/st2client/st2client/models/core.py +++ b/st2client/st2client/models/core.py @@ -34,12 +34,13 @@ def add_auth_token_to_kwargs_from_env(func): @wraps(func) def decorate(*args, **kwargs): - if not kwargs.get('token') and os.environ.get('ST2_AUTH_TOKEN', None): - kwargs['token'] = os.environ.get('ST2_AUTH_TOKEN') - if not kwargs.get('api_key') and os.environ.get('ST2_API_KEY', None): - kwargs['api_key'] = os.environ.get('ST2_API_KEY') + if not kwargs.get("token") and os.environ.get("ST2_AUTH_TOKEN", None): + kwargs["token"] = os.environ.get("ST2_AUTH_TOKEN") + if not kwargs.get("api_key") and os.environ.get("ST2_API_KEY", None): + kwargs["api_key"] = os.environ.get("ST2_API_KEY") return func(*args, **kwargs) + return decorate @@ -81,8 +82,11 @@ def to_dict(self, exclude_attributes=None): exclude_attributes = exclude_attributes or [] attributes = list(self.__dict__.keys()) - attributes = [attr for attr in attributes if not attr.startswith('__') and - attr not in exclude_attributes] + attributes = [ + attr + for attr in attributes + if not attr.startswith("__") and attr not in exclude_attributes + ] result = {} for attribute in attributes: @@ -102,15 +106,15 @@ def get_display_name(cls): @classmethod def get_plural_name(cls): if not cls._plural: - raise Exception('The %s class is missing class attributes ' - 'in its definition.' % cls.__name__) + raise Exception( + "The %s class is missing class attributes " + "in its definition." % cls.__name__ + ) return cls._plural @classmethod def get_plural_display_name(cls): - return (cls._plural_display_name - if cls._plural_display_name - else cls._plural) + return cls._plural_display_name if cls._plural_display_name else cls._plural @classmethod def get_url_path_name(cls): @@ -120,9 +124,9 @@ def get_url_path_name(cls): return cls.get_plural_name().lower() def serialize(self): - return dict((k, v) - for k, v in six.iteritems(self.__dict__) - if not k.startswith('_')) + return dict( + (k, v) for k, v in six.iteritems(self.__dict__) if not k.startswith("_") + ) @classmethod def deserialize(cls, doc): @@ -140,16 +144,15 @@ def __repr__(self): attributes = [] for attribute in self._repr_attributes: value = getattr(self, attribute, None) - attributes.append('%s=%s' % (attribute, value)) + attributes.append("%s=%s" % (attribute, value)) - attributes = ','.join(attributes) + attributes = ",".join(attributes) class_name = self.__class__.__name__ - result = '<%s %s>' % (class_name, attributes) + result = "<%s %s>" % (class_name, attributes) return result class ResourceManager(object): - def __init__(self, resource, endpoint, cacert=None, debug=False): self.resource = resource self.debug = debug @@ -159,46 +162,47 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): def handle_error(response): try: content = response.json() - fault = content.get('faultstring', '') if content else '' + fault = content.get("faultstring", "") if content else "" if fault: - response.reason += '\nMESSAGE: %s' % fault + response.reason += "\nMESSAGE: %s" % fault except Exception as e: - response.reason += ('\nUnable to retrieve detailed message ' - 'from the HTTP response. %s\n' % six.text_type(e)) + response.reason += ( + "\nUnable to retrieve detailed message " + "from the HTTP response. %s\n" % six.text_type(e) + ) response.raise_for_status() @add_auth_token_to_kwargs_from_env def get_all(self, **kwargs): # TODO: This is ugly, stop abusing kwargs - url = '/%s' % self.resource.get_url_path_name() - limit = kwargs.pop('limit', None) - pack = kwargs.pop('pack', None) - prefix = kwargs.pop('prefix', None) - user = kwargs.pop('user', None) + url = "/%s" % self.resource.get_url_path_name() + limit = kwargs.pop("limit", None) + pack = kwargs.pop("pack", None) + prefix = kwargs.pop("prefix", None) + user = kwargs.pop("user", None) - params = kwargs.pop('params', {}) + params = kwargs.pop("params", {}) if limit: - params['limit'] = limit + params["limit"] = limit if pack: - params['pack'] = pack + params["pack"] = pack if prefix: - params['prefix'] = prefix + params["prefix"] = prefix if user: - params['user'] = user + params["user"] = user response = self.client.get(url=url, params=params, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) - return [self.resource.deserialize(item) - for item in response.json()] + return [self.resource.deserialize(item) for item in response.json()] @add_auth_token_to_kwargs_from_env def get_by_id(self, id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), id) + url = "/%s/%s" % (self.resource.get_url_path_name(), id) response = self.client.get(url, **kwargs) if response.status_code == http_client.NOT_FOUND: return None @@ -214,14 +218,18 @@ def get_property(self, id_, property_name, self_deserialize=True, **kwargs): property_name: Name of the property self_deserialize: #Implies use the deserialize method implemented by this resource. """ - token = kwargs.pop('token', None) - api_key = kwargs.pop('api_key', None) + token = kwargs.pop("token", None) + api_key = kwargs.pop("api_key", None) if kwargs: - url = '/%s/%s/%s/?%s' % (self.resource.get_url_path_name(), id_, property_name, - urllib.parse.urlencode(kwargs)) + url = "/%s/%s/%s/?%s" % ( + self.resource.get_url_path_name(), + id_, + property_name, + urllib.parse.urlencode(kwargs), + ) else: - url = '/%s/%s/%s/' % (self.resource.get_url_path_name(), id_, property_name) + url = "/%s/%s/%s/" % (self.resource.get_url_path_name(), id_, property_name) if token: response = self.client.get(url, token=token) @@ -246,19 +254,21 @@ def get_by_ref_or_id(self, ref_or_id, **kwargs): def _query_details(self, **kwargs): if not kwargs: - raise Exception('Query parameter is not provided.') + raise Exception("Query parameter is not provided.") - token = kwargs.get('token', None) - api_key = kwargs.get('api_key', None) - params = kwargs.get('params', {}) + token = kwargs.get("token", None) + api_key = kwargs.get("api_key", None) + params = kwargs.get("params", {}) for k, v in six.iteritems(kwargs): # Note: That's a special case to support api_key and token kwargs - if k not in ['token', 'api_key', 'params']: + if k not in ["token", "api_key", "params"]: params[k] = v - url = '/%s/?%s' % (self.resource.get_url_path_name(), - urllib.parse.urlencode(params)) + url = "/%s/?%s" % ( + self.resource.get_url_path_name(), + urllib.parse.urlencode(params), + ) if token: response = self.client.get(url, token=token) @@ -284,8 +294,8 @@ def query(self, **kwargs): @add_auth_token_to_kwargs_from_env def query_with_count(self, **kwargs): instances, response = self._query_details(**kwargs) - if response and 'X-Total-Count' in response.headers: - return (instances, int(response.headers['X-Total-Count'])) + if response and "X-Total-Count" in response.headers: + return (instances, int(response.headers["X-Total-Count"])) else: return (instances, None) @@ -296,13 +306,15 @@ def get_by_name(self, name, **kwargs): return None else: if len(instances) > 1: - raise Exception('More than one %s named "%s" are found.' % - (self.resource.__name__.lower(), name)) + raise Exception( + 'More than one %s named "%s" are found.' + % (self.resource.__name__.lower(), name) + ) return instances[0] @add_auth_token_to_kwargs_from_env def create(self, instance, **kwargs): - url = '/%s' % self.resource.get_url_path_name() + url = "/%s" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -311,7 +323,7 @@ def create(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def update(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id) response = self.client.put(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -320,12 +332,14 @@ def update(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def delete(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id) response = self.client.delete(url, **kwargs) - if response.status_code not in [http_client.OK, - http_client.NO_CONTENT, - http_client.NOT_FOUND]: + if response.status_code not in [ + http_client.OK, + http_client.NO_CONTENT, + http_client.NOT_FOUND, + ]: self.handle_error(response) return False @@ -333,11 +347,13 @@ def delete(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def delete_by_id(self, instance_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance_id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance_id) response = self.client.delete(url, **kwargs) - if response.status_code not in [http_client.OK, - http_client.NO_CONTENT, - http_client.NOT_FOUND]: + if response.status_code not in [ + http_client.OK, + http_client.NO_CONTENT, + http_client.NOT_FOUND, + ]: self.handle_error(response) return False try: @@ -357,18 +373,21 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): @add_auth_token_to_kwargs_from_env def match(self, instance, **kwargs): - url = '/%s/match' % self.resource.get_url_path_name() + url = "/%s/match" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) match = response.json() - return (self.resource.deserialize(match['actionalias']), match['representation']) + return ( + self.resource.deserialize(match["actionalias"]), + match["representation"], + ) class ActionAliasExecutionManager(ResourceManager): @add_auth_token_to_kwargs_from_env def match_and_execute(self, instance, **kwargs): - url = '/%s/match_and_execute' % self.resource.get_url_path_name() + url = "/%s/match_and_execute" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: @@ -380,7 +399,10 @@ def match_and_execute(self, instance, **kwargs): class ActionResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def get_entrypoint(self, ref_or_id, **kwargs): - url = '/%s/views/entry_point/%s' % (self.resource.get_url_path_name(), ref_or_id) + url = "/%s/views/entry_point/%s" % ( + self.resource.get_url_path_name(), + ref_or_id, + ) response = self.client.get(url, **kwargs) if response.status_code != http_client.OK: @@ -391,20 +413,30 @@ def get_entrypoint(self, ref_or_id, **kwargs): class ExecutionResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env - def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay=0, **kwargs): - url = '/%s/%s/re_run' % (self.resource.get_url_path_name(), execution_id) + def re_run( + self, + execution_id, + parameters=None, + tasks=None, + no_reset=None, + delay=0, + **kwargs, + ): + url = "/%s/%s/re_run" % (self.resource.get_url_path_name(), execution_id) tasks = tasks or [] no_reset = no_reset or [] if list(set(no_reset) - set(tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) data = { - 'parameters': parameters or {}, - 'tasks': tasks, - 'reset': list(set(tasks) - set(no_reset)), - 'delay': delay + "parameters": parameters or {}, + "tasks": tasks, + "reset": list(set(tasks) - set(no_reset)), + "delay": delay, } response = self.client.post(url, data, **kwargs) @@ -416,10 +448,10 @@ def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay @add_auth_token_to_kwargs_from_env def get_output(self, execution_id, output_type=None, **kwargs): - url = '/%s/%s/output' % (self.resource.get_url_path_name(), execution_id) + url = "/%s/%s/output" % (self.resource.get_url_path_name(), execution_id) if output_type: - url += '?' + urllib.parse.urlencode({'output_type': output_type}) + url += "?" + urllib.parse.urlencode({"output_type": output_type}) response = self.client.get(url, **kwargs) if response.status_code != http_client.OK: @@ -429,8 +461,8 @@ def get_output(self, execution_id, output_type=None, **kwargs): @add_auth_token_to_kwargs_from_env def pause(self, execution_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id) - data = {'status': 'pausing'} + url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id) + data = {"status": "pausing"} response = self.client.put(url, data, **kwargs) @@ -441,8 +473,8 @@ def pause(self, execution_id, **kwargs): @add_auth_token_to_kwargs_from_env def resume(self, execution_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id) - data = {'status': 'resuming'} + url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id) + data = {"status": "resuming"} response = self.client.put(url, data, **kwargs) @@ -453,14 +485,14 @@ def resume(self, execution_id, **kwargs): @add_auth_token_to_kwargs_from_env def get_children(self, execution_id, **kwargs): - url = '/%s/%s/children' % (self.resource.get_url_path_name(), execution_id) + url = "/%s/%s/children" % (self.resource.get_url_path_name(), execution_id) - depth = kwargs.pop('depth', -1) + depth = kwargs.pop("depth", -1) - params = kwargs.pop('params', {}) + params = kwargs.pop("params", {}) if depth: - params['depth'] = depth + params["depth"] = depth response = self.client.get(url=url, params=params, **kwargs) if response.status_code != http_client.OK: @@ -469,19 +501,15 @@ def get_children(self, execution_id, **kwargs): class InquiryResourceManager(ResourceManager): - @add_auth_token_to_kwargs_from_env def respond(self, inquiry_id, inquiry_response, **kwargs): """ Update st2.inquiry.respond action Update st2client respond command to use this? """ - url = '/%s/%s' % (self.resource.get_url_path_name(), inquiry_id) + url = "/%s/%s" % (self.resource.get_url_path_name(), inquiry_id) - payload = { - "id": inquiry_id, - "response": inquiry_response - } + payload = {"id": inquiry_id, "response": inquiry_response} response = self.client.put(url, payload, **kwargs) @@ -494,7 +522,10 @@ def respond(self, inquiry_id, inquiry_response, **kwargs): class TriggerInstanceResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def re_emit(self, trigger_instance_id, **kwargs): - url = '/%s/%s/re_emit' % (self.resource.get_url_path_name(), trigger_instance_id) + url = "/%s/%s/re_emit" % ( + self.resource.get_url_path_name(), + trigger_instance_id, + ) response = self.client.post(url, None, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -508,11 +539,11 @@ class AsyncRequest(Resource): class PackResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def install(self, packs, force=False, skip_dependencies=False, **kwargs): - url = '/%s/install' % (self.resource.get_url_path_name()) + url = "/%s/install" % (self.resource.get_url_path_name()) payload = { - 'packs': packs, - 'force': force, - 'skip_dependencies': skip_dependencies + "packs": packs, + "force": force, + "skip_dependencies": skip_dependencies, } response = self.client.post(url, payload, **kwargs) if response.status_code != http_client.OK: @@ -522,8 +553,8 @@ def install(self, packs, force=False, skip_dependencies=False, **kwargs): @add_auth_token_to_kwargs_from_env def remove(self, packs, **kwargs): - url = '/%s/uninstall' % (self.resource.get_url_path_name()) - response = self.client.post(url, {'packs': packs}, **kwargs) + url = "/%s/uninstall" % (self.resource.get_url_path_name()) + response = self.client.post(url, {"packs": packs}, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) instance = AsyncRequest.deserialize(response.json()) @@ -531,11 +562,11 @@ def remove(self, packs, **kwargs): @add_auth_token_to_kwargs_from_env def search(self, args, ignore_errors=False, **kwargs): - url = '/%s/index/search' % (self.resource.get_url_path_name()) - if 'query' in vars(args): - payload = {'query': args.query} + url = "/%s/index/search" % (self.resource.get_url_path_name()) + if "query" in vars(args): + payload = {"query": args.query} else: - payload = {'pack': args.pack} + payload = {"pack": args.pack} response = self.client.post(url, payload, **kwargs) @@ -552,12 +583,12 @@ def search(self, args, ignore_errors=False, **kwargs): @add_auth_token_to_kwargs_from_env def register(self, packs=None, types=None, **kwargs): - url = '/%s/register' % (self.resource.get_url_path_name()) + url = "/%s/register" % (self.resource.get_url_path_name()) payload = {} if types: - payload['types'] = types + payload["types"] = types if packs: - payload['packs'] = packs + payload["packs"] = packs response = self.client.post(url, payload, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -568,7 +599,7 @@ def register(self, packs=None, types=None, **kwargs): class ConfigManager(ResourceManager): @add_auth_token_to_kwargs_from_env def update(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.pack) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.pack) response = self.client.put(url, instance.values, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -584,16 +615,13 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): @add_auth_token_to_kwargs_from_env def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs): - url = '/webhooks/st2' + url = "/webhooks/st2" headers = {} - data = { - 'trigger': trigger, - 'payload': payload or {} - } + data = {"trigger": trigger, "payload": payload or {}} if trace_tag: - headers['St2-Trace-Tag'] = trace_tag + headers["St2-Trace-Tag"] = trace_tag response = self.client.post(url, data=data, headers=headers, **kwargs) @@ -604,17 +632,20 @@ def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs): @add_auth_token_to_kwargs_from_env def match(self, instance, **kwargs): - url = '/%s/match' % self.resource.get_url_path_name() + url = "/%s/match" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) match = response.json() - return (self.resource.deserialize(match['actionalias']), match['representation']) + return ( + self.resource.deserialize(match["actionalias"]), + match["representation"], + ) class StreamManager(object): def __init__(self, endpoint, cacert=None, debug=False): - self._url = httpclient.get_url_without_trailing_slash(endpoint) + '/stream' + self._url = httpclient.get_url_without_trailing_slash(endpoint) + "/stream" self.debug = debug self.cacert = cacert @@ -631,25 +662,25 @@ def listen(self, events=None, **kwargs): if events and isinstance(events, six.string_types): events = [events] - if 'token' in kwargs: - query_params['x-auth-token'] = kwargs.get('token') + if "token" in kwargs: + query_params["x-auth-token"] = kwargs.get("token") - if 'api_key' in kwargs: - query_params['st2-api-key'] = kwargs.get('api_key') + if "api_key" in kwargs: + query_params["st2-api-key"] = kwargs.get("api_key") - if 'end_event' in kwargs: - query_params['end_event'] = kwargs.get('end_event') + if "end_event" in kwargs: + query_params["end_event"] = kwargs.get("end_event") - if 'end_execution_id' in kwargs: - query_params['end_execution_id'] = kwargs.get('end_execution_id') + if "end_execution_id" in kwargs: + query_params["end_execution_id"] = kwargs.get("end_execution_id") if events: - query_params['events'] = ','.join(events) + query_params["events"] = ",".join(events) if self.cacert is not None: - request_params['verify'] = self.cacert + request_params["verify"] = self.cacert - query_string = '?' + urllib.parse.urlencode(query_params) + query_string = "?" + urllib.parse.urlencode(query_params) url = url + query_string response = requests.get(url, stream=True, **request_params) @@ -667,36 +698,38 @@ class WorkflowManager(object): def __init__(self, endpoint, cacert, debug): self.debug = debug self.cacert = cacert - self.endpoint = endpoint + '/workflows' - self.client = httpclient.HTTPClient(root=self.endpoint, cacert=cacert, debug=debug) + self.endpoint = endpoint + "/workflows" + self.client = httpclient.HTTPClient( + root=self.endpoint, cacert=cacert, debug=debug + ) @staticmethod def handle_error(response): try: content = response.json() - fault = content.get('faultstring', '') if content else '' + fault = content.get("faultstring", "") if content else "" if fault: - response.reason += '\nMESSAGE: %s' % fault + response.reason += "\nMESSAGE: %s" % fault except Exception as e: response.reason += ( - '\nUnable to retrieve detailed message ' - 'from the HTTP response. %s\n' % six.text_type(e) + "\nUnable to retrieve detailed message " + "from the HTTP response. %s\n" % six.text_type(e) ) response.raise_for_status() @add_auth_token_to_kwargs_from_env def inspect(self, definition, **kwargs): - url = '/inspect' + url = "/inspect" if not isinstance(definition, six.string_types): - raise TypeError('Workflow definition is not type of string.') + raise TypeError("Workflow definition is not type of string.") - if 'headers' not in kwargs: - kwargs['headers'] = {} + if "headers" not in kwargs: + kwargs["headers"] = {} - kwargs['headers']['content-type'] = 'text/plain' + kwargs["headers"]["content-type"] = "text/plain" response = self.client.post_raw(url, definition, **kwargs) @@ -709,7 +742,7 @@ def inspect(self, definition, **kwargs): class ServiceRegistryGroupsManager(ResourceManager): @add_auth_token_to_kwargs_from_env def list(self, **kwargs): - url = '/service_registry/groups' + url = "/service_registry/groups" headers = {} response = self.client.get(url, headers=headers, **kwargs) @@ -717,21 +750,20 @@ def list(self, **kwargs): if response.status_code != http_client.OK: self.handle_error(response) - groups = response.json()['groups'] + groups = response.json()["groups"] result = [] for group in groups: - item = self.resource.deserialize({'group_id': group}) + item = self.resource.deserialize({"group_id": group}) result.append(item) return result class ServiceRegistryMembersManager(ResourceManager): - @add_auth_token_to_kwargs_from_env def list(self, group_id, **kwargs): - url = '/service_registry/groups/%s/members' % (group_id) + url = "/service_registry/groups/%s/members" % (group_id) headers = {} response = self.client.get(url, headers=headers, **kwargs) @@ -739,14 +771,14 @@ def list(self, group_id, **kwargs): if response.status_code != http_client.OK: self.handle_error(response) - members = response.json()['members'] + members = response.json()["members"] result = [] for member in members: data = { - 'group_id': group_id, - 'member_id': member['member_id'], - 'capabilities': member['capabilities'] + "group_id": group_id, + "member_id": member["member_id"], + "capabilities": member["capabilities"], } item = self.resource.deserialize(data) result.append(item) diff --git a/st2client/st2client/models/inquiry.py b/st2client/st2client/models/inquiry.py index 5d1a1076f5..93161ee68f 100644 --- a/st2client/st2client/models/inquiry.py +++ b/st2client/st2client/models/inquiry.py @@ -24,15 +24,8 @@ class Inquiry(core.Resource): - _display_name = 'Inquiry' - _plural = 'Inquiries' - _plural_display_name = 'Inquiries' - _url_path = 'inquiries' - _repr_attributes = [ - 'id', - 'schema', - 'roles', - 'users', - 'route', - 'ttl' - ] + _display_name = "Inquiry" + _plural = "Inquiries" + _plural_display_name = "Inquiries" + _url_path = "inquiries" + _repr_attributes = ["id", "schema", "roles", "users", "route", "ttl"] diff --git a/st2client/st2client/models/keyvalue.py b/st2client/st2client/models/keyvalue.py index f7095a4b8f..5bcd1de8de 100644 --- a/st2client/st2client/models/keyvalue.py +++ b/st2client/st2client/models/keyvalue.py @@ -24,11 +24,11 @@ class KeyValuePair(core.Resource): - _alias = 'Key' - _display_name = 'Key Value Pair' - _plural = 'Keys' - _plural_display_name = 'Key Value Pairs' - _repr_attributes = ['name', 'value'] + _alias = "Key" + _display_name = "Key Value Pair" + _plural = "Keys" + _plural_display_name = "Key Value Pairs" + _repr_attributes = ["name", "value"] # Note: This is a temporary hack until we refactor client and make it support non id PKs def get_id(self): diff --git a/st2client/st2client/models/pack.py b/st2client/st2client/models/pack.py index 5d681266ad..7333c1a28e 100644 --- a/st2client/st2client/models/pack.py +++ b/st2client/st2client/models/pack.py @@ -19,8 +19,8 @@ class Pack(core.Resource): - _display_name = 'Pack' - _plural = 'Packs' - _plural_display_name = 'Packs' - _url_path = 'packs' - _repr_attributes = ['name', 'description', 'version', 'author'] + _display_name = "Pack" + _plural = "Packs" + _plural_display_name = "Packs" + _url_path = "packs" + _repr_attributes = ["name", "description", "version", "author"] diff --git a/st2client/st2client/models/policy.py b/st2client/st2client/models/policy.py index 851779d7fd..4b8bb0c813 100644 --- a/st2client/st2client/models/policy.py +++ b/st2client/st2client/models/policy.py @@ -24,13 +24,13 @@ class PolicyType(core.Resource): - _alias = 'Policy-Type' - _display_name = 'Policy type' - _plural = 'PolicyTypes' - _plural_display_name = 'Policy types' - _repr_attributes = ['ref', 'enabled', 'description'] + _alias = "Policy-Type" + _display_name = "Policy type" + _plural = "PolicyTypes" + _plural_display_name = "Policy types" + _repr_attributes = ["ref", "enabled", "description"] class Policy(core.Resource): - _plural = 'Policies' - _repr_attributes = ['name', 'pack', 'enabled', 'policy_type', 'resource_ref'] + _plural = "Policies" + _repr_attributes = ["name", "pack", "enabled", "policy_type", "resource_ref"] diff --git a/st2client/st2client/models/rbac.py b/st2client/st2client/models/rbac.py index 6df4aa4f94..94c765ddf3 100644 --- a/st2client/st2client/models/rbac.py +++ b/st2client/st2client/models/rbac.py @@ -17,25 +17,22 @@ from st2client.models import core -__all__ = [ - 'Role', - 'UserRoleAssignment' -] +__all__ = ["Role", "UserRoleAssignment"] class Role(core.Resource): - _alias = 'role' - _display_name = 'Role' - _plural = 'Roles' - _plural_display_name = 'Roles' - _repr_attributes = ['id', 'name', 'system'] - _url_path = 'rbac/roles' + _alias = "role" + _display_name = "Role" + _plural = "Roles" + _plural_display_name = "Roles" + _repr_attributes = ["id", "name", "system"] + _url_path = "rbac/roles" class UserRoleAssignment(core.Resource): - _alias = 'role-assignment' - _display_name = 'Role Assignment' - _plural = 'RoleAssignments' - _plural_display_name = 'Role Assignments' - _repr_attributes = ['id', 'role', 'user', 'is_remote'] - _url_path = 'rbac/role_assignments' + _alias = "role-assignment" + _display_name = "Role Assignment" + _plural = "RoleAssignments" + _plural_display_name = "Role Assignments" + _repr_attributes = ["id", "role", "user", "is_remote"] + _url_path = "rbac/role_assignments" diff --git a/st2client/st2client/models/reactor.py b/st2client/st2client/models/reactor.py index 140d1aaf50..ef4c054f69 100644 --- a/st2client/st2client/models/reactor.py +++ b/st2client/st2client/models/reactor.py @@ -24,43 +24,49 @@ class Sensor(core.Resource): - _plural = 'Sensortypes' - _repr_attributes = ['name', 'pack'] + _plural = "Sensortypes" + _repr_attributes = ["name", "pack"] class TriggerType(core.Resource): - _alias = 'Trigger' - _display_name = 'Trigger' - _plural = 'Triggertypes' - _plural_display_name = 'Triggers' - _repr_attributes = ['name', 'pack'] + _alias = "Trigger" + _display_name = "Trigger" + _plural = "Triggertypes" + _plural_display_name = "Triggers" + _repr_attributes = ["name", "pack"] class TriggerInstance(core.Resource): - _alias = 'Trigger-Instance' - _display_name = 'TriggerInstance' - _plural = 'Triggerinstances' - _plural_display_name = 'TriggerInstances' - _repr_attributes = ['id', 'trigger', 'occurrence_time', 'payload', 'status'] + _alias = "Trigger-Instance" + _display_name = "TriggerInstance" + _plural = "Triggerinstances" + _plural_display_name = "TriggerInstances" + _repr_attributes = ["id", "trigger", "occurrence_time", "payload", "status"] class Trigger(core.Resource): - _alias = 'TriggerSpecification' - _display_name = 'Trigger Specification' - _plural = 'Triggers' - _plural_display_name = 'Trigger Specifications' - _repr_attributes = ['name', 'pack'] + _alias = "TriggerSpecification" + _display_name = "Trigger Specification" + _plural = "Triggers" + _plural_display_name = "Trigger Specifications" + _repr_attributes = ["name", "pack"] class Rule(core.Resource): - _alias = 'Rule' - _plural = 'Rules' - _repr_attributes = ['name', 'pack', 'trigger', 'criteria', 'enabled'] + _alias = "Rule" + _plural = "Rules" + _repr_attributes = ["name", "pack", "trigger", "criteria", "enabled"] class RuleEnforcement(core.Resource): - _alias = 'Rule-Enforcement' - _plural = 'RuleEnforcements' - _display_name = 'Rule Enforcement' - _plural_display_name = 'Rule Enforcements' - _repr_attributes = ['id', 'trigger_instance_id', 'execution_id', 'rule.ref', 'enforced_at'] + _alias = "Rule-Enforcement" + _plural = "RuleEnforcements" + _display_name = "Rule Enforcement" + _plural_display_name = "Rule Enforcements" + _repr_attributes = [ + "id", + "trigger_instance_id", + "execution_id", + "rule.ref", + "enforced_at", + ] diff --git a/st2client/st2client/models/service_registry.py b/st2client/st2client/models/service_registry.py index 3b3057a3c3..ca95cd73cb 100644 --- a/st2client/st2client/models/service_registry.py +++ b/st2client/st2client/models/service_registry.py @@ -17,32 +17,27 @@ from st2client.models import core -__all__ = [ - 'ServiceRegistry', - - 'ServiceRegistryGroup', - 'ServiceRegistryMember' -] +__all__ = ["ServiceRegistry", "ServiceRegistryGroup", "ServiceRegistryMember"] class ServiceRegistry(core.Resource): - _alias = 'service-registry' - _display_name = 'Service Registry' - _plural = 'Service Registry' - _plural_display_name = 'Service Registry' + _alias = "service-registry" + _display_name = "Service Registry" + _plural = "Service Registry" + _plural_display_name = "Service Registry" class ServiceRegistryGroup(core.Resource): - _alias = 'group' - _display_name = 'Group' - _plural = 'Groups' - _plural_display_name = 'Groups' - _repr_attributes = ['group_id'] + _alias = "group" + _display_name = "Group" + _plural = "Groups" + _plural_display_name = "Groups" + _repr_attributes = ["group_id"] class ServiceRegistryMember(core.Resource): - _alias = 'member' - _display_name = 'Group Member' - _plural = 'Group Members' - _plural_display_name = 'Group Members' - _repr_attributes = ['group_id', 'member_id'] + _alias = "member" + _display_name = "Group Member" + _plural = "Group Members" + _plural_display_name = "Group Members" + _repr_attributes = ["group_id", "member_id"] diff --git a/st2client/st2client/models/timer.py b/st2client/st2client/models/timer.py index 4ba58547f3..fbfbd6cfcd 100644 --- a/st2client/st2client/models/timer.py +++ b/st2client/st2client/models/timer.py @@ -24,7 +24,7 @@ class Timer(core.Resource): - _alias = 'Timer' - _display_name = 'Timer' - _plural = 'Timers' - _plural_display_name = 'Timers' + _alias = "Timer" + _display_name = "Timer" + _plural = "Timers" + _plural_display_name = "Timers" diff --git a/st2client/st2client/models/trace.py b/st2client/st2client/models/trace.py index a03b4a8812..3b7bfe4449 100644 --- a/st2client/st2client/models/trace.py +++ b/st2client/st2client/models/trace.py @@ -19,8 +19,8 @@ class Trace(core.Resource): - _alias = 'Trace' - _display_name = 'Trace' - _plural = 'Traces' - _plural_display_name = 'Traces' - _repr_attributes = ['id', 'trace_tag'] + _alias = "Trace" + _display_name = "Trace" + _plural = "Traces" + _plural_display_name = "Traces" + _repr_attributes = ["id", "trace_tag"] diff --git a/st2client/st2client/models/webhook.py b/st2client/st2client/models/webhook.py index 83d939f061..161d1bdb4c 100644 --- a/st2client/st2client/models/webhook.py +++ b/st2client/st2client/models/webhook.py @@ -24,8 +24,8 @@ class Webhook(core.Resource): - _alias = 'Webhook' - _display_name = 'Webhook' - _plural = 'Webhooks' - _plural_display_name = 'Webhooks' - _repr_attributes = ['parameters', 'type', 'pack', 'name'] + _alias = "Webhook" + _display_name = "Webhook" + _plural = "Webhooks" + _plural_display_name = "Webhooks" + _repr_attributes = ["parameters", "type", "pack", "name"] diff --git a/st2client/st2client/shell.py b/st2client/st2client/shell.py index ac6108d796..7d3359c532 100755 --- a/st2client/st2client/shell.py +++ b/st2client/st2client/shell.py @@ -25,6 +25,7 @@ # Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7 import warnings from cryptography.utils import CryptographyDeprecationWarning + warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import os @@ -66,13 +67,13 @@ from st2client.commands.auth import LoginCommand -__all__ = [ - 'Shell' -] +__all__ = ["Shell"] LOGGER = logging.getLogger(__name__) -CLI_DESCRIPTION = 'CLI for StackStorm event-driven automation platform. https://stackstorm.com' +CLI_DESCRIPTION = ( + "CLI for StackStorm event-driven automation platform. https://stackstorm.com" +) USAGE_STRING = """ Usage: %(prog)s [options] [options] @@ -83,15 +84,19 @@ %(prog)s --debug run core.local cmd=date """.strip() -NON_UTF8_LOCALE = """ +NON_UTF8_LOCALE = ( + """ Locale %s with encoding %s which is not UTF-8 is used. This means that some functionality which relies on outputting unicode characters won't work. You are encouraged to use UTF-8 locale by setting LC_ALL environment variable to en_US.UTF-8 or similar. -""".strip().replace('\n', ' ').replace(' ', ' ') +""".strip() + .replace("\n", " ") + .replace(" ", " ") +) -PACKAGE_METADATA_FILE_PATH = '/opt/stackstorm/st2/package.meta' +PACKAGE_METADATA_FILE_PATH = "/opt/stackstorm/st2/package.meta" def get_stackstorm_version(): @@ -101,7 +106,7 @@ def get_stackstorm_version(): :rtype: ``str`` """ - if 'dev' in __version__: + if "dev" in __version__: version = __version__ if not os.path.isfile(PACKAGE_METADATA_FILE_PATH): @@ -115,11 +120,11 @@ def get_stackstorm_version(): return version try: - git_revision = config.get('server', 'git_sha') + git_revision = config.get("server", "git_sha") except Exception: return version - version = '%s (%s)' % (version, git_revision) + version = "%s (%s)" % (version, git_revision) else: version = __version__ @@ -143,214 +148,237 @@ def __init__(self): # Set up general program options. self.parser.add_argument( - '--version', - action='version', - version='%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}' - .format(version=get_stackstorm_version(), - python_major=sys.version_info.major, - python_minor=sys.version_info.minor, - python_patch=sys.version_info.micro)) + "--version", + action="version", + version="%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}".format( + version=get_stackstorm_version(), + python_major=sys.version_info.major, + python_minor=sys.version_info.minor, + python_patch=sys.version_info.micro, + ), + ) self.parser.add_argument( - '--url', - action='store', - dest='base_url', + "--url", + action="store", + dest="base_url", default=None, - help='Base URL for the API servers. Assumes all servers use the ' - 'same base URL and default ports are used. Get ST2_BASE_URL ' - 'from the environment variables by default.' + help="Base URL for the API servers. Assumes all servers use the " + "same base URL and default ports are used. Get ST2_BASE_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--auth-url', - action='store', - dest='auth_url', + "--auth-url", + action="store", + dest="auth_url", default=None, - help='URL for the authentication service. Get ST2_AUTH_URL ' - 'from the environment variables by default.' + help="URL for the authentication service. Get ST2_AUTH_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--api-url', - action='store', - dest='api_url', + "--api-url", + action="store", + dest="api_url", default=None, - help='URL for the API server. Get ST2_API_URL ' - 'from the environment variables by default.' + help="URL for the API server. Get ST2_API_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--stream-url', - action='store', - dest='stream_url', + "--stream-url", + action="store", + dest="stream_url", default=None, - help='URL for the stream endpoint. Get ST2_STREAM_URL' - 'from the environment variables by default.' + help="URL for the stream endpoint. Get ST2_STREAM_URL" + "from the environment variables by default.", ) self.parser.add_argument( - '--api-version', - action='store', - dest='api_version', + "--api-version", + action="store", + dest="api_version", default=None, - help='API version to use. Get ST2_API_VERSION ' - 'from the environment variables by default.' + help="API version to use. Get ST2_API_VERSION " + "from the environment variables by default.", ) self.parser.add_argument( - '--cacert', - action='store', - dest='cacert', + "--cacert", + action="store", + dest="cacert", default=None, - help='Path to the CA cert bundle for the SSL endpoints. ' - 'Get ST2_CACERT from the environment variables by default. ' - 'If this is not provided, then SSL cert will not be verified.' + help="Path to the CA cert bundle for the SSL endpoints. " + "Get ST2_CACERT from the environment variables by default. " + "If this is not provided, then SSL cert will not be verified.", ) self.parser.add_argument( - '--config-file', - action='store', - dest='config_file', + "--config-file", + action="store", + dest="config_file", default=None, - help='Path to the CLI config file' + help="Path to the CLI config file", ) self.parser.add_argument( - '--print-config', - action='store_true', - dest='print_config', + "--print-config", + action="store_true", + dest="print_config", default=False, - help='Parse the config file and print the values' + help="Parse the config file and print the values", ) self.parser.add_argument( - '--skip-config', - action='store_true', - dest='skip_config', + "--skip-config", + action="store_true", + dest="skip_config", default=False, - help='Don\'t parse and use the CLI config file' + help="Don't parse and use the CLI config file", ) self.parser.add_argument( - '--debug', - action='store_true', - dest='debug', + "--debug", + action="store_true", + dest="debug", default=False, - help='Enable debug mode' + help="Enable debug mode", ) # Set up list of commands and subcommands. - self.subparsers = self.parser.add_subparsers(dest='parser') + self.subparsers = self.parser.add_subparsers(dest="parser") self.subparsers.required = True self.commands = {} - self.commands['run'] = action.ActionRunCommand( - models.Action, self, self.subparsers, name='run', add_help=False) + self.commands["run"] = action.ActionRunCommand( + models.Action, self, self.subparsers, name="run", add_help=False + ) - self.commands['action'] = action.ActionBranch( - 'An activity that happens as a response to the external event.', - self, self.subparsers) + self.commands["action"] = action.ActionBranch( + "An activity that happens as a response to the external event.", + self, + self.subparsers, + ) - self.commands['action-alias'] = action_alias.ActionAliasBranch( - 'Action aliases.', - self, self.subparsers) + self.commands["action-alias"] = action_alias.ActionAliasBranch( + "Action aliases.", self, self.subparsers + ) - self.commands['auth'] = auth.TokenCreateCommand( - models.Token, self, self.subparsers, name='auth') + self.commands["auth"] = auth.TokenCreateCommand( + models.Token, self, self.subparsers, name="auth" + ) - self.commands['login'] = auth.LoginCommand( - models.Token, self, self.subparsers, name='login') + self.commands["login"] = auth.LoginCommand( + models.Token, self, self.subparsers, name="login" + ) - self.commands['whoami'] = auth.WhoamiCommand( - models.Token, self, self.subparsers, name='whoami') + self.commands["whoami"] = auth.WhoamiCommand( + models.Token, self, self.subparsers, name="whoami" + ) - self.commands['api-key'] = auth.ApiKeyBranch( - 'API Keys.', - self, self.subparsers) + self.commands["api-key"] = auth.ApiKeyBranch("API Keys.", self, self.subparsers) - self.commands['execution'] = action.ActionExecutionBranch( - 'An invocation of an action.', - self, self.subparsers) + self.commands["execution"] = action.ActionExecutionBranch( + "An invocation of an action.", self, self.subparsers + ) - self.commands['inquiry'] = inquiry.InquiryBranch( - 'Inquiries provide an opportunity to ask a question ' - 'and wait for a response in a workflow.', - self, self.subparsers) + self.commands["inquiry"] = inquiry.InquiryBranch( + "Inquiries provide an opportunity to ask a question " + "and wait for a response in a workflow.", + self, + self.subparsers, + ) - self.commands['key'] = keyvalue.KeyValuePairBranch( - 'Key value pair is used to store commonly used configuration ' - 'for reuse in sensors, actions, and rules.', - self, self.subparsers) + self.commands["key"] = keyvalue.KeyValuePairBranch( + "Key value pair is used to store commonly used configuration " + "for reuse in sensors, actions, and rules.", + self, + self.subparsers, + ) - self.commands['pack'] = pack.PackBranch( - 'A group of related integration resources: ' - 'actions, rules, and sensors.', - self, self.subparsers) + self.commands["pack"] = pack.PackBranch( + "A group of related integration resources: " "actions, rules, and sensors.", + self, + self.subparsers, + ) - self.commands['policy'] = policy.PolicyBranch( - 'Policy that is enforced on a resource.', - self, self.subparsers) + self.commands["policy"] = policy.PolicyBranch( + "Policy that is enforced on a resource.", self, self.subparsers + ) - self.commands['policy-type'] = policy.PolicyTypeBranch( - 'Type of policy that can be applied to resources.', - self, self.subparsers) + self.commands["policy-type"] = policy.PolicyTypeBranch( + "Type of policy that can be applied to resources.", self, self.subparsers + ) - self.commands['rule'] = rule.RuleBranch( + self.commands["rule"] = rule.RuleBranch( 'A specification to invoke an "action" on a "trigger" selectively ' - 'based on some criteria.', - self, self.subparsers) + "based on some criteria.", + self, + self.subparsers, + ) - self.commands['webhook'] = webhook.WebhookBranch( - 'Webhooks.', - self, self.subparsers) + self.commands["webhook"] = webhook.WebhookBranch( + "Webhooks.", self, self.subparsers + ) - self.commands['timer'] = timer.TimerBranch( - 'Timers.', - self, self.subparsers) + self.commands["timer"] = timer.TimerBranch("Timers.", self, self.subparsers) - self.commands['runner'] = resource.ResourceBranch( + self.commands["runner"] = resource.ResourceBranch( models.RunnerType, - 'Runner is a type of handler for a specific class of actions.', - self, self.subparsers, read_only=True, has_disable=True) + "Runner is a type of handler for a specific class of actions.", + self, + self.subparsers, + read_only=True, + has_disable=True, + ) - self.commands['sensor'] = sensor.SensorBranch( - 'An adapter which allows you to integrate StackStorm with external system.', - self, self.subparsers) + self.commands["sensor"] = sensor.SensorBranch( + "An adapter which allows you to integrate StackStorm with external system.", + self, + self.subparsers, + ) - self.commands['trace'] = trace.TraceBranch( - 'A group of executions, rules and triggerinstances that are related.', - self, self.subparsers) + self.commands["trace"] = trace.TraceBranch( + "A group of executions, rules and triggerinstances that are related.", + self, + self.subparsers, + ) - self.commands['trigger'] = trigger.TriggerTypeBranch( - 'An external event that is mapped to a st2 input. It is the ' - 'st2 invocation point.', - self, self.subparsers) + self.commands["trigger"] = trigger.TriggerTypeBranch( + "An external event that is mapped to a st2 input. It is the " + "st2 invocation point.", + self, + self.subparsers, + ) - self.commands['trigger-instance'] = triggerinstance.TriggerInstanceBranch( - 'Actual instances of triggers received by st2.', - self, self.subparsers) + self.commands["trigger-instance"] = triggerinstance.TriggerInstanceBranch( + "Actual instances of triggers received by st2.", self, self.subparsers + ) - self.commands['rule-enforcement'] = rule_enforcement.RuleEnforcementBranch( - 'Models that represent enforcement of rules.', - self, self.subparsers) + self.commands["rule-enforcement"] = rule_enforcement.RuleEnforcementBranch( + "Models that represent enforcement of rules.", self, self.subparsers + ) - self.commands['workflow'] = workflow.WorkflowBranch( - 'Commands for workflow authoring related operations. ' - 'Only orquesta workflows are supported.', - self, self.subparsers) + self.commands["workflow"] = workflow.WorkflowBranch( + "Commands for workflow authoring related operations. " + "Only orquesta workflows are supported.", + self, + self.subparsers, + ) # Service Registry - self.commands['service-registry'] = service_registry.ServiceRegistryBranch( - 'Service registry group and membership related commands.', - self, self.subparsers) + self.commands["service-registry"] = service_registry.ServiceRegistryBranch( + "Service registry group and membership related commands.", + self, + self.subparsers, + ) # RBAC - self.commands['role'] = rbac.RoleBranch( - 'RBAC roles.', - self, self.subparsers) - self.commands['role-assignment'] = rbac.RoleAssignmentBranch( - 'RBAC role assignments.', - self, self.subparsers) + self.commands["role"] = rbac.RoleBranch("RBAC roles.", self, self.subparsers) + self.commands["role-assignment"] = rbac.RoleAssignmentBranch( + "RBAC role assignments.", self, self.subparsers + ) def run(self, argv): debug = False @@ -369,9 +397,9 @@ def run(self, argv): # Provide autocomplete for shell argcomplete.autocomplete(self.parser) - if '--print-config' in argv: + if "--print-config" in argv: # Hack because --print-config requires no command to be specified - argv = argv + ['action', 'list'] + argv = argv + ["action", "list"] # Parse command line arguments. args = self.parser.parse_args(args=argv) @@ -389,7 +417,7 @@ def run(self, argv): # Setup client and run the command try: - debug = getattr(args, 'debug', False) + debug = getattr(args, "debug", False) if debug: set_log_level_for_all_loggers(level=logging.DEBUG) @@ -399,7 +427,7 @@ def run(self, argv): # TODO: This is not so nice work-around for Python 3 because of a breaking change in # Python 3 - https://bugs.python.org/issue16308 try: - func = getattr(args, 'func') + func = getattr(args, "func") except AttributeError: parser.print_help() sys.exit(2) @@ -414,9 +442,9 @@ def run(self, argv): return 2 except Exception as e: # We allow exception to define custom exit codes - exit_code = getattr(e, 'exit_code', 1) + exit_code = getattr(e, "exit_code", 1) - print('ERROR: %s\n' % e) + print("ERROR: %s\n" % e) if debug: self._print_debug_info(args=args) @@ -426,10 +454,10 @@ def _print_config(self, args): config = self._parse_config_file(args=args) for section, options in six.iteritems(config): - print('[%s]' % (section)) + print("[%s]" % (section)) for name, value in six.iteritems(options): - print('%s = %s' % (name, value)) + print("%s = %s" % (name, value)) def _check_locale_and_print_warning(self): """ @@ -440,23 +468,23 @@ def _check_locale_and_print_warning(self): preferred_encoding = locale.getpreferredencoding() except ValueError: # Ignore unknown locale errors for now - default_locale = 'unknown' - preferred_encoding = 'unknown' + default_locale = "unknown" + preferred_encoding = "unknown" - if preferred_encoding and preferred_encoding.lower() != 'utf-8': - msg = NON_UTF8_LOCALE % (default_locale or 'unknown', preferred_encoding) + if preferred_encoding and preferred_encoding.lower() != "utf-8": + msg = NON_UTF8_LOCALE % (default_locale or "unknown", preferred_encoding) LOGGER.warn(msg) def setup_logging(argv): - debug = '--debug' in argv + debug = "--debug" in argv root = LOGGER root.setLevel(logging.WARNING) handler = logging.StreamHandler(sys.stderr) handler.setLevel(logging.WARNING) - formatter = logging.Formatter('%(asctime)s %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s %(levelname)s - %(message)s") handler.setFormatter(formatter) if not debug: @@ -470,5 +498,5 @@ def main(argv=sys.argv[1:]): return Shell().run(argv) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/st2client/st2client/utils/color.py b/st2client/st2client/utils/color.py index 8b18402136..f1106851e2 100644 --- a/st2client/st2client/utils/color.py +++ b/st2client/st2client/utils/color.py @@ -16,40 +16,36 @@ from __future__ import absolute_import import os -__all__ = [ - 'DisplayColors', - - 'format_status' -] +__all__ = ["DisplayColors", "format_status"] TERMINAL_SUPPORTS_ANSI_CODES = [ - 'xterm', - 'xterm-color', - 'screen', - 'vt100', - 'vt100-color', - 'xterm-256color' + "xterm", + "xterm-color", + "screen", + "vt100", + "vt100-color", + "xterm-256color", ] -DISABLED = os.environ.get('ST2_COLORIZE', '') +DISABLED = os.environ.get("ST2_COLORIZE", "") class DisplayColors(object): - RED = '\033[91m' - PURPLE = '\033[35m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - BLUE = '\033[94m' - BROWN = '\033[33m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + RED = "\033[91m" + PURPLE = "\033[35m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + BROWN = "\033[33m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" @staticmethod - def colorize(value, color=''): + def colorize(value, color=""): # TODO: use list of supported terminals - term = os.environ.get('TERM', None) + term = os.environ.get("TERM", None) if term is None or term.lower() not in TERMINAL_SUPPORTS_ANSI_CODES: # Terminal doesn't support colors @@ -58,33 +54,33 @@ def colorize(value, color=''): if DISABLED or not color: return value - return '%s%s%s' % (color, value, DisplayColors.ENDC) + return "%s%s%s" % (color, value, DisplayColors.ENDC) # Lookup table STATUS_LOOKUP = { - 'succeeded': DisplayColors.GREEN, - 'delayed': DisplayColors.BLUE, - 'failed': DisplayColors.RED, - 'timeout': DisplayColors.BROWN, - 'running': DisplayColors.YELLOW + "succeeded": DisplayColors.GREEN, + "delayed": DisplayColors.BLUE, + "failed": DisplayColors.RED, + "timeout": DisplayColors.BROWN, + "running": DisplayColors.YELLOW, } def format_status(value): # Support status values with elapsed info - split = value.split('(', 1) + split = value.split("(", 1) if len(split) == 2: status = split[0].strip() - remainder = '(' + split[1] + remainder = "(" + split[1] else: status = value - remainder = '' + remainder = "" color = STATUS_LOOKUP.get(status, DisplayColors.YELLOW) result = DisplayColors.colorize(status, color) if remainder: - result = result + ' ' + remainder + result = result + " " + remainder return result diff --git a/st2client/st2client/utils/date.py b/st2client/st2client/utils/date.py index b19e27f3ec..3a76a44c81 100644 --- a/st2client/st2client/utils/date.py +++ b/st2client/st2client/utils/date.py @@ -20,10 +20,7 @@ from st2client.config import get_config -__all__ = [ - 'parse', - 'format_isodate' -] +__all__ = ["parse", "format_isodate"] def add_utc_tz(dt): @@ -39,7 +36,7 @@ def format_dt(dt): """ Format datetime object for human friendly representation. """ - value = dt.strftime('%a, %d %b %Y %H:%M:%S %Z') + value = dt.strftime("%a, %d %b %Y %H:%M:%S %Z") return value @@ -52,7 +49,7 @@ def format_isodate(value, timezone=None): :rtype: ``str`` """ if not value: - return '' + return "" # For some reason pylint thinks it returns a tuple but it returns a datetime object dt = dateutil.parser.parse(str(value)) @@ -70,6 +67,6 @@ def format_isodate_for_user_timezone(value): specific in the config. """ config = get_config() - timezone = config.get('cli', {}).get('timezone', 'UTC') + timezone = config.get("cli", {}).get("timezone", "UTC") result = format_isodate(value=value, timezone=timezone) return result diff --git a/st2client/st2client/utils/httpclient.py b/st2client/st2client/utils/httpclient.py index 089f6b88d6..6af6595ec5 100644 --- a/st2client/st2client/utils/httpclient.py +++ b/st2client/st2client/utils/httpclient.py @@ -27,38 +27,41 @@ def add_ssl_verify_to_kwargs(func): def decorate(*args, **kwargs): - if isinstance(args[0], HTTPClient) and 'https' in getattr(args[0], 'root', ''): - cacert = getattr(args[0], 'cacert', None) - kwargs['verify'] = cacert if cacert is not None else False + if isinstance(args[0], HTTPClient) and "https" in getattr(args[0], "root", ""): + cacert = getattr(args[0], "cacert", None) + kwargs["verify"] = cacert if cacert is not None else False return func(*args, **kwargs) + return decorate def add_auth_token_to_headers(func): def decorate(*args, **kwargs): - headers = kwargs.get('headers', dict()) + headers = kwargs.get("headers", dict()) - token = kwargs.pop('token', None) + token = kwargs.pop("token", None) if token: - headers['X-Auth-Token'] = str(token) - kwargs['headers'] = headers + headers["X-Auth-Token"] = str(token) + kwargs["headers"] = headers - api_key = kwargs.pop('api_key', None) + api_key = kwargs.pop("api_key", None) if api_key: - headers['St2-Api-Key'] = str(api_key) - kwargs['headers'] = headers + headers["St2-Api-Key"] = str(api_key) + kwargs["headers"] = headers return func(*args, **kwargs) + return decorate def add_json_content_type_to_headers(func): def decorate(*args, **kwargs): - headers = kwargs.get('headers', dict()) - content_type = headers.get('content-type', 'application/json') - headers['content-type'] = content_type - kwargs['headers'] = headers + headers = kwargs.get("headers", dict()) + content_type = headers.get("content-type", "application/json") + headers["content-type"] = content_type + kwargs["headers"] = headers return func(*args, **kwargs) + return decorate @@ -71,12 +74,11 @@ def get_url_without_trailing_slash(value): :rtype: ``str`` """ - result = value[:-1] if value.endswith('/') else value + result = value[:-1] if value.endswith("/") else value return result class HTTPClient(object): - def __init__(self, root, cacert=None, debug=False): self.root = get_url_without_trailing_slash(root) self.cacert = cacert @@ -136,30 +138,30 @@ def _response_hook(self, response): print("# -------- begin %d response ----------" % (id(self))) print(response.text) print("# -------- end %d response ------------" % (id(self))) - print('') + print("") return response def _get_curl_line_for_request(self, request): - parts = ['curl'] + parts = ["curl"] # method method = request.method.upper() - if method in ['HEAD']: - parts.extend(['--head']) + if method in ["HEAD"]: + parts.extend(["--head"]) else: - parts.extend(['-X', pquote(method)]) + parts.extend(["-X", pquote(method)]) # headers for key, value in request.headers.items(): - parts.extend(['-H ', pquote('%s: %s' % (key, value))]) + parts.extend(["-H ", pquote("%s: %s" % (key, value))]) # body if request.body: - parts.extend(['--data-binary', pquote(request.body)]) + parts.extend(["--data-binary", pquote(request.body)]) # URL parts.extend([pquote(request.url)]) - curl_line = ' '.join(parts) + curl_line = " ".join(parts) return curl_line diff --git a/st2client/st2client/utils/interactive.py b/st2client/st2client/utils/interactive.py index 35065e5d94..7e6f81b29b 100644 --- a/st2client/st2client/utils/interactive.py +++ b/st2client/st2client/utils/interactive.py @@ -28,8 +28,8 @@ from six.moves import range -POSITIVE_BOOLEAN = {'1', 'y', 'yes', 'true'} -NEGATIVE_BOOLEAN = {'0', 'n', 'no', 'nope', 'nah', 'false'} +POSITIVE_BOOLEAN = {"1", "y", "yes", "true"} +NEGATIVE_BOOLEAN = {"0", "n", "no", "nope", "nah", "false"} class ReaderNotImplemented(OperationFailureException): @@ -58,10 +58,8 @@ class StringReader(object): def __init__(self, name, spec, prefix=None, secret=False, **kw): self.name = name self.spec = spec - self.prefix = prefix or '' - self.options = { - 'is_password': secret - } + self.prefix = prefix or "" + self.options = {"is_password": secret} self._construct_description() self._construct_template() @@ -84,7 +82,7 @@ def read(self): message = self.template.format(self.prefix + self.name, **self.spec) response = prompt(message, **self.options) - result = self.spec.get('default', None) + result = self.spec.get("default", None) if response: result = self._transform_response(response) @@ -92,20 +90,21 @@ def read(self): return result def _construct_description(self): - if 'description' in self.spec: + if "description" in self.spec: + def get_bottom_toolbar_tokens(cli): - return [(token.Token.Toolbar, self.spec['description'])] + return [(token.Token.Toolbar, self.spec["description"])] - self.options['get_bottom_toolbar_tokens'] = get_bottom_toolbar_tokens + self.options["get_bottom_toolbar_tokens"] = get_bottom_toolbar_tokens def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - if 'default' in self.spec: - self.template = u'{0} [{default}]: ' + if "default" in self.spec: + self.template = "{0} [{default}]: " def _construct_validators(self): - self.options['validator'] = MuxValidator([self.validate], self.spec) + self.options["validator"] = MuxValidator([self.validate], self.spec) def _transform_response(self, response): return response @@ -114,25 +113,27 @@ def _transform_response(self, response): class BooleanReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'boolean' + return spec.get("type", None) == "boolean" @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return if input.lower() not in POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN: - raise validation.ValidationError(len(input), - 'Does not look like boolean. Pick from [%s]' - % ', '.join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN)) + raise validation.ValidationError( + len(input), + "Does not look like boolean. Pick from [%s]" + % ", ".join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN), + ) def _construct_template(self): - self.template = u'{0} (boolean)' + self.template = "{0} (boolean)" - if 'default' in self.spec: - self.template += u' [{}]: '.format(self.spec.get('default') and 'y' or 'n') + if "default" in self.spec: + self.template += " [{}]: ".format(self.spec.get("default") and "y" or "n") else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): if response.lower() in POSITIVE_BOOLEAN: @@ -141,14 +142,16 @@ def _transform_response(self, response): return False # Hopefully, it will never happen - raise OperationFailureException('Response neither positive no negative. ' - 'Value have not been properly validated.') + raise OperationFailureException( + "Response neither positive no negative. " + "Value have not been properly validated." + ) class NumberReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'number' + return spec.get("type", None) == "number" @staticmethod def validate(input, spec): @@ -161,12 +164,12 @@ def validate(input, spec): super(NumberReader, NumberReader).validate(input, spec) def _construct_template(self): - self.template = u'{0} (float)' + self.template = "{0} (float)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): return float(response) @@ -175,7 +178,7 @@ def _transform_response(self, response): class IntegerReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'integer' + return spec.get("type", None) == "integer" @staticmethod def validate(input, spec): @@ -188,12 +191,12 @@ def validate(input, spec): super(IntegerReader, IntegerReader).validate(input, spec) def _construct_template(self): - self.template = u'{0} (integer)' + self.template = "{0} (integer)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): return int(response) @@ -205,71 +208,71 @@ def __init__(self, *args, **kwargs): @staticmethod def condition(spec): - return spec.get('secret', None) + return spec.get("secret", None) def _construct_template(self): - self.template = u'{0} (secret)' + self.template = "{0} (secret)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " class EnumReader(StringReader): @staticmethod def condition(spec): - return spec.get('enum', None) + return spec.get("enum", None) @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return if not input.isdigit(): - raise validation.ValidationError(len(input), 'Not a number') + raise validation.ValidationError(len(input), "Not a number") - enum = spec.get('enum') + enum = spec.get("enum") try: enum[int(input)] except IndexError: - raise validation.ValidationError(len(input), 'Out of bounds') + raise validation.ValidationError(len(input), "Out of bounds") def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - enum = self.spec.get('enum') + enum = self.spec.get("enum") for index, value in enumerate(enum): - self.template += u'\n {} - {}'.format(index, value) + self.template += "\n {} - {}".format(index, value) num_options = len(enum) - more = '' + more = "" if num_options > 3: num_options = 3 - more = '...' + more = "..." options = [str(i) for i in range(0, num_options)] - self.template += u'\nChoose from {}{}'.format(', '.join(options), more) + self.template += "\nChoose from {}{}".format(", ".join(options), more) - if 'default' in self.spec: - self.template += u' [{}]: '.format(enum.index(self.spec.get('default'))) + if "default" in self.spec: + self.template += " [{}]: ".format(enum.index(self.spec.get("default"))) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): - return self.spec.get('enum')[int(response)] + return self.spec.get("enum")[int(response)] class ObjectReader(StringReader): - @staticmethod def condition(spec): - return spec.get('type', None) == 'object' + return spec.get("type", None) == "object" def read(self): - prefix = u'{}.'.format(self.name) + prefix = "{}.".format(self.name) - result = InteractiveForm(self.spec.get('properties', {}), - prefix=prefix, reraise=True).initiate_dialog() + result = InteractiveForm( + self.spec.get("properties", {}), prefix=prefix, reraise=True + ).initiate_dialog() return result @@ -277,25 +280,27 @@ def read(self): class ArrayReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'array' + return spec.get("type", None) == "array" @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return - for m in re.finditer(r'[^, ]+', input): + for m in re.finditer(r"[^, ]+", input): index, item = m.start(), m.group() try: - StringReader.validate(item, spec.get('items', {})) + StringReader.validate(item, spec.get("items", {})) except validation.ValidationError as e: raise validation.ValidationError(index, six.text_type(e)) def read(self): - item_type = self.spec.get('items', {}).get('type', 'string') + item_type = self.spec.get("items", {}).get("type", "string") - if item_type not in ['string', 'integer', 'number', 'boolean']: - message = 'Interactive mode does not support arrays of %s type yet' % item_type + if item_type not in ["string", "integer", "number", "boolean"]: + message = ( + "Interactive mode does not support arrays of %s type yet" % item_type + ) raise ReaderNotImplemented(message) result = super(ArrayReader, self).read() @@ -303,37 +308,46 @@ def read(self): return result def _construct_template(self): - self.template = u'{0} (comma-separated list)' + self.template = "{0} (comma-separated list)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=','.join(self.spec.get('default'))) + if "default" in self.spec: + self.template += " [{default}]: ".format( + default=",".join(self.spec.get("default")) + ) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): - return [item.strip() for item in response.split(',')] + return [item.strip() for item in response.split(",")] class ArrayObjectReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'array' and spec.get('items', {}).get('type') == 'object' + return ( + spec.get("type", None) == "array" + and spec.get("items", {}).get("type") == "object" + ) def read(self): results = [] - properties = self.spec.get('items', {}).get('properties', {}) - message = '~~~ Would you like to add another item to "%s" array / list?' % self.name + properties = self.spec.get("items", {}).get("properties", {}) + message = ( + '~~~ Would you like to add another item to "%s" array / list?' % self.name + ) is_continue = True index = 0 while is_continue: - prefix = u'{name}[{index}].'.format(name=self.name, index=index) - results.append(InteractiveForm(properties, - prefix=prefix, - reraise=True).initiate_dialog()) + prefix = "{name}[{index}].".format(name=self.name, index=index) + results.append( + InteractiveForm( + properties, prefix=prefix, reraise=True + ).initiate_dialog() + ) index += 1 - if Question(message, {'default': 'y'}).read() != 'y': + if Question(message, {"default": "y"}).read() != "y": is_continue = False return results @@ -341,53 +355,55 @@ def read(self): class ArrayEnumReader(EnumReader): def __init__(self, name, spec, prefix=None): - self.items = spec.get('items', {}) + self.items = spec.get("items", {}) super(ArrayEnumReader, self).__init__(name, spec, prefix) @staticmethod def condition(spec): - return spec.get('type', None) == 'array' and 'enum' in spec.get('items', {}) + return spec.get("type", None) == "array" and "enum" in spec.get("items", {}) @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return - for m in re.finditer(r'[^, ]+', input): + for m in re.finditer(r"[^, ]+", input): index, item = m.start(), m.group() try: - EnumReader.validate(item, spec.get('items', {})) + EnumReader.validate(item, spec.get("items", {})) except validation.ValidationError as e: raise validation.ValidationError(index, six.text_type(e)) def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - enum = self.items.get('enum') + enum = self.items.get("enum") for index, value in enumerate(enum): - self.template += u'\n {} - {}'.format(index, value) + self.template += "\n {} - {}".format(index, value) num_options = len(enum) - more = '' + more = "" if num_options > 3: num_options = 3 - more = '...' + more = "..." options = [str(i) for i in range(0, num_options)] - self.template += u'\nChoose from {}{}'.format(', '.join(options), more) + self.template += "\nChoose from {}{}".format(", ".join(options), more) - if 'default' in self.spec: - default_choises = [str(enum.index(item)) for item in self.spec.get('default')] - self.template += u' [{}]: '.format(', '.join(default_choises)) + if "default" in self.spec: + default_choises = [ + str(enum.index(item)) for item in self.spec.get("default") + ] + self.template += " [{}]: ".format(", ".join(default_choises)) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): result = [] - for i in (item.strip() for item in response.split(',')): + for i in (item.strip() for item in response.split(",")): if i: - result.append(self.items.get('enum')[int(i)]) + result.append(self.items.get("enum")[int(i)]) return result @@ -403,7 +419,7 @@ class InteractiveForm(object): ArrayObjectReader, ArrayReader, SecretStringReader, - StringReader + StringReader, ] def __init__(self, schema, prefix=None, reraise=False): @@ -419,11 +435,11 @@ def initiate_dialog(self): try: result[field] = self._read_field(field) except ReaderNotImplemented as e: - print('%s. Skipping...' % six.text_type(e)) + print("%s. Skipping..." % six.text_type(e)) except DialogInterrupted: if self.reraise: raise - print('Dialog interrupted.') + print("Dialog interrupted.") return result @@ -438,7 +454,7 @@ def _read_field(self, field): break if not reader: - raise ReaderNotImplemented('No reader for the field spec') + raise ReaderNotImplemented("No reader for the field spec") try: return reader.read() diff --git a/st2client/st2client/utils/jsutil.py b/st2client/st2client/utils/jsutil.py index 7aaf20dfe0..1d98ab8f46 100644 --- a/st2client/st2client/utils/jsutil.py +++ b/st2client/st2client/utils/jsutil.py @@ -48,7 +48,7 @@ def _get_value_simple(doc, key): Returns the extracted value from the key specified (if found) Returns None if the key can not be found """ - split_key = key.split('.') + split_key = key.split(".") if not split_key: return None @@ -82,8 +82,9 @@ def get_value(doc, key): raise ValueError("key is None or empty: '{}'".format(key)) if not isinstance(doc, dict): - raise ValueError("doc is not an instance of dict: type={} value='{}'".format(type(doc), - doc)) + raise ValueError( + "doc is not an instance of dict: type={} value='{}'".format(type(doc), doc) + ) # jsonpath_rw can be very slow when processing expressions. # In the case of a simple expression we've created a "fast path" that avoids # the complexity introduced by running jsonpath_rw code. @@ -113,12 +114,12 @@ def get_kvps(doc, keys): value = get_value(doc, key) if value is not None: nested = new_doc - while '.' in key: - attr = key[:key.index('.')] + while "." in key: + attr = key[: key.index(".")] if attr not in nested: nested[attr] = {} nested = nested[attr] - key = key[key.index('.') + 1:] + key = key[key.index(".") + 1 :] nested[key] = value return new_doc diff --git a/st2client/st2client/utils/logging.py b/st2client/st2client/utils/logging.py index dd8b8b9e44..8328a5c55e 100644 --- a/st2client/st2client/utils/logging.py +++ b/st2client/st2client/utils/logging.py @@ -18,9 +18,9 @@ import logging __all__ = [ - 'LogLevelFilter', - 'set_log_level_for_all_handlers', - 'set_log_level_for_all_loggers' + "LogLevelFilter", + "set_log_level_for_all_handlers", + "set_log_level_for_all_loggers", ] diff --git a/st2client/st2client/utils/misc.py b/st2client/st2client/utils/misc.py index e8623b3070..62c7b1a61f 100644 --- a/st2client/st2client/utils/misc.py +++ b/st2client/st2client/utils/misc.py @@ -18,9 +18,7 @@ import six -__all__ = [ - 'merge_dicts' -] +__all__ = ["merge_dicts"] def merge_dicts(d1, d2): diff --git a/st2client/st2client/utils/schema.py b/st2client/st2client/utils/schema.py index 33142daa71..2cf7d5b231 100644 --- a/st2client/st2client/utils/schema.py +++ b/st2client/st2client/utils/schema.py @@ -17,36 +17,30 @@ TYPE_TABLE = { - dict: 'object', - list: 'array', - int: 'integer', - str: 'string', - float: 'number', - bool: 'boolean', - type(None): 'null', + dict: "object", + list: "array", + int: "integer", + str: "string", + float: "number", + bool: "boolean", + type(None): "null", } if sys.version_info[0] < 3: - TYPE_TABLE[unicode] = 'string' # noqa # pylint: disable=E0602 + TYPE_TABLE[unicode] = "string" # noqa # pylint: disable=E0602 def _dict_to_schema(item): schema = {} for key, value in item.iteritems(): if isinstance(value, dict): - schema[key] = { - 'type': 'object', - 'parameters': _dict_to_schema(value) - } + schema[key] = {"type": "object", "parameters": _dict_to_schema(value)} else: - schema[key] = { - 'type': TYPE_TABLE[type(value)] - } + schema[key] = {"type": TYPE_TABLE[type(value)]} return schema def render_output_schema_from_output(output): - """Given an action output produce a reasonable schema to match. - """ + """Given an action output produce a reasonable schema to match.""" return _dict_to_schema(output) diff --git a/st2client/st2client/utils/strutil.py b/st2client/st2client/utils/strutil.py index d6bc23d9cc..0bb970ff3e 100644 --- a/st2client/st2client/utils/strutil.py +++ b/st2client/st2client/utils/strutil.py @@ -24,9 +24,9 @@ def unescape(s): This function unescapes those chars. """ if isinstance(s, six.string_types): - s = s.replace('\\n', '\n') - s = s.replace('\\r', '\r') - s = s.replace('\\"', '\"') + s = s.replace("\\n", "\n") + s = s.replace("\\r", "\r") + s = s.replace('\\"', '"') return s @@ -39,14 +39,14 @@ def dedupe_newlines(s): """ if isinstance(s, six.string_types): - s = s.replace('\n\n', '\n') + s = s.replace("\n\n", "\n") return s def strip_carriage_returns(s): if isinstance(s, six.string_types): - s = s.replace('\\r', '') - s = s.replace('\r', '') + s = s.replace("\\r", "") + s = s.replace("\r", "") return s diff --git a/st2client/st2client/utils/terminal.py b/st2client/st2client/utils/terminal.py index 555753fc95..6ce28a4d74 100644 --- a/st2client/st2client/utils/terminal.py +++ b/st2client/st2client/utils/terminal.py @@ -24,11 +24,7 @@ DEFAULT_TERMINAL_SIZE_COLUMNS = 150 -__all__ = [ - 'DEFAULT_TERMINAL_SIZE_COLUMNS', - - 'get_terminal_size_columns' -] +__all__ = ["DEFAULT_TERMINAL_SIZE_COLUMNS", "get_terminal_size_columns"] def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): @@ -48,7 +44,7 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): # This way it's consistent with upstream implementation. In the past, our implementation # checked those variables at the end as a fall back. try: - columns = os.environ['COLUMNS'] + columns = os.environ["COLUMNS"] return int(columns) except (KeyError, ValueError): pass @@ -56,8 +52,9 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): def ioctl_GWINSZ(fd): import fcntl import termios + # Return a tuple (lines, columns) - return struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) + return struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234")) # 2. try stdin, stdout, stderr for fd in (0, 1, 2): @@ -78,10 +75,12 @@ def ioctl_GWINSZ(fd): # 4. try `stty size` try: - process = subprocess.Popen(['stty', 'size'], - shell=False, - stdout=subprocess.PIPE, - stderr=open(os.devnull, 'w')) + process = subprocess.Popen( + ["stty", "size"], + shell=False, + stdout=subprocess.PIPE, + stderr=open(os.devnull, "w"), + ) result = process.communicate() if process.returncode == 0: return tuple(int(x) for x in result[0].split())[1] @@ -101,23 +100,23 @@ def __exit__(self, type, value, traceback): return self.close() def add_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name)) + self._write("\t[{:^20}] {}".format(format_status(status), name)) def update_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True) + self._write("\t[{:^20}] {}".format(format_status(status), name), override=True) def finish_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True) + self._write("\t[{:^20}] {}".format(format_status(status), name), override=True) def close(self): if self.dirty: - self._write('\n') + self._write("\n") def _write(self, string, override=False): if override: - sys.stdout.write('\r') + sys.stdout.write("\r") else: - sys.stdout.write('\n') + sys.stdout.write("\n") sys.stdout.write(string) sys.stdout.flush() diff --git a/st2client/st2client/utils/types.py b/st2client/st2client/utils/types.py index 5c25990a6e..ad70f078b9 100644 --- a/st2client/st2client/utils/types.py +++ b/st2client/st2client/utils/types.py @@ -20,17 +20,14 @@ from __future__ import absolute_import import collections -__all__ = [ - 'OrderedSet' -] +__all__ = ["OrderedSet"] class OrderedSet(collections.MutableSet): - def __init__(self, iterable=None): self.end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.map = {} # key --> [key, prev, next] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] if iterable is not None: self |= iterable @@ -68,15 +65,15 @@ def __reversed__(self): def pop(self, last=True): if not self: - raise KeyError('set is empty') + raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] self.discard(key) return key def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): diff --git a/st2client/tests/base.py b/st2client/tests/base.py index 307f00f74b..80c14efef6 100644 --- a/st2client/tests/base.py +++ b/st2client/tests/base.py @@ -27,26 +27,22 @@ LOG = logging.getLogger(__name__) -FAKE_ENDPOINT = 'http://127.0.0.1:8268' +FAKE_ENDPOINT = "http://127.0.0.1:8268" RESOURCES = [ { "id": "123", "name": "abc", }, - { - "id": "456", - "name": "def" - } + {"id": "456", "name": "def"}, ] class FakeResource(models.Resource): - _plural = 'FakeResources' + _plural = "FakeResources" class FakeResponse(object): - def __init__(self, text, status_code, reason, *args): self.text = text self.status_code = status_code @@ -64,8 +60,7 @@ def raise_for_status(self): class FakeClient(object): def __init__(self): self.managers = { - 'FakeResource': models.ResourceManager(FakeResource, - FAKE_ENDPOINT) + "FakeResource": models.ResourceManager(FakeResource, FAKE_ENDPOINT) } @@ -75,23 +70,32 @@ def __init__(self): class BaseCLITestCase(unittest2.TestCase): - capture_output = True # if True, stdout and stderr are saved to self.stdout and self.stderr + capture_output = ( + True # if True, stdout and stderr are saved to self.stdout and self.stderr + ) stdout = six.moves.StringIO() stderr = six.moves.StringIO() - DEFAULT_SKIP_CONFIG = '1' + DEFAULT_SKIP_CONFIG = "1" def setUp(self): super(BaseCLITestCase, self).setUp() # Setup environment - for var in ['ST2_BASE_URL', 'ST2_AUTH_URL', 'ST2_API_URL', 'ST2_STREAM_URL', - 'ST2_AUTH_TOKEN', 'ST2_CONFIG_FILE', 'ST2_API_KEY']: + for var in [ + "ST2_BASE_URL", + "ST2_AUTH_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_AUTH_TOKEN", + "ST2_CONFIG_FILE", + "ST2_API_KEY", + ]: if var in os.environ: del os.environ[var] - os.environ['ST2_CLI_SKIP_CONFIG'] = self.DEFAULT_SKIP_CONFIG + os.environ["ST2_CLI_SKIP_CONFIG"] = self.DEFAULT_SKIP_CONFIG if self.capture_output: # Make sure we reset it for each test class instance @@ -134,5 +138,5 @@ def _reset_output_streams(self): self.stderr.truncate() # Verify it has been reset correctly - self.assertEqual(self.stdout.getvalue(), '') - self.assertEqual(self.stderr.getvalue(), '') + self.assertEqual(self.stdout.getvalue(), "") + self.assertEqual(self.stderr.getvalue(), "") diff --git a/st2client/tests/fixtures/loader.py b/st2client/tests/fixtures/loader.py index a471d8e710..049a82b7a6 100644 --- a/st2client/tests/fixtures/loader.py +++ b/st2client/tests/fixtures/loader.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -24,8 +25,8 @@ import yaml -ALLOWED_EXTS = ['.json', '.yaml', '.yml', '.txt'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".json", ".yaml", ".yml", ".txt"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} def get_fixtures_base_path(): @@ -44,12 +45,14 @@ def load_content(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) parser_func = PARSER_FUNCS.get(file_ext, None) - with open(file_path, 'r') as fd: + with open(file_path, "r") as fd: return parser_func(fd) if parser_func else fd.read() @@ -75,7 +78,7 @@ def load_fixtures(fixtures_dict=None): for fixture_type, fixtures in six.iteritems(fixtures_dict): loaded_fixtures = {} for fixture in fixtures: - fixture_path = fixtures_base_path + '/' + fixture + fixture_path = fixtures_base_path + "/" + fixture fixture_dict = load_content(fixture_path) loaded_fixtures[fixture] = fixture_dict all_fixtures[fixture_type] = loaded_fixtures diff --git a/st2client/tests/unit/test_action.py b/st2client/tests/unit/test_action.py index e02c1ea1ca..1bb8be3810 100644 --- a/st2client/tests/unit/test_action.py +++ b/st2client/tests/unit/test_action.py @@ -34,7 +34,7 @@ "float": {"type": "number"}, "json": {"type": "object"}, "list": {"type": "array"}, - "str": {"type": "string"} + "str": {"type": "string"}, }, "name": "mock-runner1", } @@ -46,7 +46,7 @@ "parameters": {}, "enabled": True, "entry_point": "", - "pack": "mockety" + "pack": "mockety", } RUNNER2 = { @@ -65,475 +65,583 @@ "float": {"type": "number"}, "json": {"type": "object"}, "list": {"type": "array"}, - "str": {"type": "string"} + "str": {"type": "string"}, }, "enabled": True, "entry_point": "", - "pack": "mockety" + "pack": "mockety", } LIVE_ACTION = { - 'action': 'mockety.mock', - 'status': 'complete', - 'result': {'stdout': 'non-empty'} + "action": "mockety.mock", + "status": "complete", + "result": {"stdout": "non-empty"}, } def get_by_name(name, **kwargs): - if name == 'mock-runner1': + if name == "mock-runner1": return models.RunnerType(**RUNNER1) - if name == 'mock-runner2': + if name == "mock-runner2": return models.RunnerType(**RUNNER2) def get_by_ref(**kwargs): - ref = kwargs.get('ref_or_id', None) + ref = kwargs.get("ref_or_id", None) if not ref: raise Exception('Actions must be referred to by "ref".') - if ref == 'mockety.mock1': + if ref == "mockety.mock1": return models.Action(**ACTION1) - if ref == 'mockety.mock2': + if ref == "mockety.mock2": return models.Action(**ACTION2) class ActionCommandTestCase(base.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(ActionCommandTestCase, self).__init__(*args, **kwargs) self.shell = shell.Shell() @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_bool_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'bool=false']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'bool': False}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "bool=false"]) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"bool": False}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_integer_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'int=30']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'int': 30}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "int=30"]) + expected = {"action": "mockety.mock1", "user": None, "parameters": {"int": 30}} + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_float_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'float=3.01']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'float': 3.01}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "float=3.01"]) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"float": 3.01}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_json_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'json={"a":1}']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'json': {'a': 1}}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", 'json={"a":1}']) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"json": {"a": 1}}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_array_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'list=one,two,three']) + self.shell.run(["run", "mockety.mock1", "list=one,two,three"]) expected = { - 'action': 'mockety.mock1', - 'user': None, - 'parameters': { - 'list': [ - 'one', - 'two', - 'three' - ] - } + "action": "mockety.mock1", + "user": None, + "parameters": {"list": ["one", "two", "three"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_array_object_conversion(self): self.shell.run( [ - 'run', - 'mockety.mock1', - 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]' + "run", + "mockety.mock1", + 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]', ] ) expected = { - 'action': 'mockety.mock1', - 'user': None, - 'parameters': { - 'list': [ - { - 'foo': 1, - 'ponies': 'rainbows' - }, - { - 'pluto': False, - 'earth': True - } + "action": "mockety.mock1", + "user": None, + "parameters": { + "list": [ + {"foo": 1, "ponies": "rainbows"}, + {"pluto": False, "earth": True}, ] - } + }, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_bool_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'bool=false']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'bool': False}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "bool=false"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"bool": False}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_integer_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'int=30']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'int': 30}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "int=30"]) + expected = {"action": "mockety.mock2", "user": None, "parameters": {"int": 30}} + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_float_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'float=3.01']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'float': 3.01}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "float=3.01"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"float": 3.01}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_json_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'json={"a":1}']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'json': {'a': 1}}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", 'json={"a":1}']) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"json": {"a": 1}}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'list=one,two,three']) + self.shell.run(["run", "mockety.mock2", "list=one,two,three"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'one', - 'two', - 'three' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["one", "two", "three"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion_single_element_str(self): - self.shell.run(['run', 'mockety.mock2', 'list=one']) + self.shell.run(["run", "mockety.mock2", "list=one"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'one' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["one"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion_single_element_int(self): - self.shell.run(['run', 'mockety.mock2', 'list=1']) + self.shell.run(["run", "mockety.mock2", "list=1"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 1 - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": [1]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_object_conversion(self): self.shell.run( [ - 'run', - 'mockety.mock2', - 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]' + "run", + "mockety.mock2", + 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]', ] ) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - { - 'foo': 1, - 'ponies': 'rainbows' - }, - { - 'pluto': False, - 'earth': True - } + "action": "mockety.mock2", + "user": None, + "parameters": { + "list": [ + {"foo": 1, "ponies": "rainbows"}, + {"pluto": False, "earth": True}, ] - } + }, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_dict_conversion_flag(self): - """Ensure that the automatic conversion to dict based on colons only occurs with the flag - """ + """Ensure that the automatic conversion to dict based on colons only occurs with the flag""" self.shell.run( - [ - 'run', - 'mockety.mock2', - 'list=key1:value1,key2:value2', - '--auto-dict' - ] + ["run", "mockety.mock2", "list=key1:value1,key2:value2", "--auto-dict"] ) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - { - 'key1': 'value1', - 'key2': 'value2' - } - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": [{"key1": "value1", "key2": "value2"}]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) - self.shell.run( - [ - 'run', - 'mockety.mock2', - 'list=key1:value1,key2:value2' - ] - ) + self.shell.run(["run", "mockety.mock2", "list=key1:value1,key2:value2"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'key1:value1', - 'key2:value2' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["key1:value1", "key2:value2"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_value_with_equal_sign(self): - self.shell.run(['run', 'mockety.mock2', 'key=foo=bar&ponies=unicorns']) - expected = {'action': 'mockety.mock2', 'user': None, - 'parameters': {'key': 'foo=bar&ponies=unicorns'}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "key=foo=bar&ponies=unicorns"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"key": "foo=bar&ponies=unicorns"}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_cancel_single_execution(self): - self.shell.run(['execution', 'cancel', '123']) - httpclient.HTTPClient.delete.assert_called_with('/executions/123') + self.shell.run(["execution", "cancel", "123"]) + httpclient.HTTPClient.delete.assert_called_with("/executions/123") @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_cancel_multiple_executions(self): - self.shell.run(['execution', 'cancel', '123', '456', '789']) - calls = [mock.call('/executions/123'), - mock.call('/executions/456'), - mock.call('/executions/789')] + self.shell.run(["execution", "cancel", "123", "456", "789"]) + calls = [ + mock.call("/executions/123"), + mock.call("/executions/456"), + mock.call("/executions/789"), + ] httpclient.HTTPClient.delete.assert_has_calls(calls) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_pause_single_execution(self): - self.shell.run(['execution', 'pause', '123']) - expected = {'status': 'pausing'} - httpclient.HTTPClient.put.assert_called_with('/executions/123', expected) + self.shell.run(["execution", "pause", "123"]) + expected = {"status": "pausing"} + httpclient.HTTPClient.put.assert_called_with("/executions/123", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_pause_multiple_executions(self): - self.shell.run(['execution', 'pause', '123', '456', '789']) - expected = {'status': 'pausing'} - calls = [mock.call('/executions/123', expected), - mock.call('/executions/456', expected), - mock.call('/executions/789', expected)] + self.shell.run(["execution", "pause", "123", "456", "789"]) + expected = {"status": "pausing"} + calls = [ + mock.call("/executions/123", expected), + mock.call("/executions/456", expected), + mock.call("/executions/789", expected), + ] httpclient.HTTPClient.put.assert_has_calls(calls) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_resume_single_execution(self): - self.shell.run(['execution', 'resume', '123']) - expected = {'status': 'resuming'} - httpclient.HTTPClient.put.assert_called_with('/executions/123', expected) + self.shell.run(["execution", "resume", "123"]) + expected = {"status": "resuming"} + httpclient.HTTPClient.put.assert_called_with("/executions/123", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_resume_multiple_executions(self): - self.shell.run(['execution', 'resume', '123', '456', '789']) - expected = {'status': 'resuming'} - calls = [mock.call('/executions/123', expected), - mock.call('/executions/456', expected), - mock.call('/executions/789', expected)] + self.shell.run(["execution", "resume", "123", "456", "789"]) + expected = {"status": "resuming"} + calls = [ + mock.call("/executions/123", expected), + mock.call("/executions/456", expected), + mock.call("/executions/789", expected), + ] httpclient.HTTPClient.put.assert_has_calls(calls) diff --git a/st2client/tests/unit/test_action_alias.py b/st2client/tests/unit/test_action_alias.py index a360fd5139..753b4e71a8 100644 --- a/st2client/tests/unit/test_action_alias.py +++ b/st2client/tests/unit/test_action_alias.py @@ -29,9 +29,7 @@ "execution": { "id": "mock-id", }, - "actionalias": { - "ref": "mock-ref" - } + "actionalias": {"ref": "mock-ref"}, } ] } @@ -43,20 +41,26 @@ def __init__(self, *args, **kwargs): self.shell = shell.Shell() @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT), - 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT), 200, "OK" + ) + ), + ) def test_match_and_execute(self): - ret = self.shell.run(['action-alias', 'execute', "run whoami on localhost"]) + ret = self.shell.run(["action-alias", "execute", "run whoami on localhost"]) self.assertEqual(ret, 0) expected_args = { - 'command': 'run whoami on localhost', - 'user': '', - 'source_channel': 'cli' + "command": "run whoami on localhost", + "user": "", + "source_channel": "cli", } - httpclient.HTTPClient.post.assert_called_with('/aliasexecution/match_and_execute', - expected_args) + httpclient.HTTPClient.post.assert_called_with( + "/aliasexecution/match_and_execute", expected_args + ) mock_stdout = self.stdout.getvalue() diff --git a/st2client/tests/unit/test_app.py b/st2client/tests/unit/test_app.py index eb1a67242e..217d3875ad 100644 --- a/st2client/tests/unit/test_app.py +++ b/st2client/tests/unit/test_app.py @@ -26,33 +26,33 @@ class BaseCLIAppTestCase(unittest2.TestCase): - @mock.patch('os.path.isfile', mock.Mock()) + @mock.patch("os.path.isfile", mock.Mock()) def test_cli_config_file_path(self): app = BaseCLIApp() args = mock.Mock() # 1. Absolute path - args.config_file = '/tmp/full/abs/path/config.ini' + args.config_file = "/tmp/full/abs/path/config.ini" result = app._get_config_file_path(args=args) self.assertEqual(result, args.config_file) - args.config_file = '/home/user/st2/config.ini' + args.config_file = "/home/user/st2/config.ini" result = app._get_config_file_path(args=args) self.assertEqual(result, args.config_file) # 2. Path relative to user home directory, should get expanded - args.config_file = '~/.st2/config.ini' + args.config_file = "~/.st2/config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.path.expanduser('~' + USER), '.st2/config.ini') + expected = os.path.join(os.path.expanduser("~" + USER), ".st2/config.ini") self.assertEqual(result, expected) # 3. Relative path (should get converted to absolute one) - args.config_file = 'config.ini' + args.config_file = "config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.getcwd(), 'config.ini') + expected = os.path.join(os.getcwd(), "config.ini") self.assertEqual(result, expected) - args.config_file = '.st2/config.ini' + args.config_file = ".st2/config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.getcwd(), '.st2/config.ini') + expected = os.path.join(os.getcwd(), ".st2/config.ini") self.assertEqual(result, expected) diff --git a/st2client/tests/unit/test_auth.py b/st2client/tests/unit/test_auth.py index cd838712cb..e59b31dfaf 100644 --- a/st2client/tests/unit/test_auth.py +++ b/st2client/tests/unit/test_auth.py @@ -29,24 +29,27 @@ from st2client import shell from st2client.models.core import add_auth_token_to_kwargs_from_env from st2client.commands.resource import add_auth_token_to_kwargs_from_cli -from st2client.utils.httpclient import add_auth_token_to_headers, add_json_content_type_to_headers +from st2client.utils.httpclient import ( + add_auth_token_to_headers, + add_json_content_type_to_headers, +) LOG = logging.getLogger(__name__) if six.PY3: RULE = { - 'name': 'drule', - 'description': 'i am THE rule.', - 'pack': 'cli', - 'id': uuid.uuid4().hex + "name": "drule", + "description": "i am THE rule.", + "pack": "cli", + "id": uuid.uuid4().hex, } else: RULE = { - 'id': uuid.uuid4().hex, - 'description': 'i am THE rule.', - 'name': 'drule', - 'pack': 'cli', + "id": uuid.uuid4().hex, + "description": "i am THE rule.", + "name": "drule", + "pack": "cli", } @@ -59,9 +62,9 @@ class TestLoginBase(base.BaseCLITestCase): on duplicate code in each test class """ - DOTST2_PATH = os.path.expanduser('~/.st2/') - CONFIG_FILE_NAME = 'st2.conf' - PARENT_DIR = 'testconfig' + DOTST2_PATH = os.path.expanduser("~/.st2/") + CONFIG_FILE_NAME = "st2.conf" + PARENT_DIR = "testconfig" TMP_DIR = tempfile.mkdtemp() CONFIG_CONTENTS = """ [credentials] @@ -73,11 +76,11 @@ def __init__(self, *args, **kwargs): super(TestLoginBase, self).__init__(*args, **kwargs) # We're overriding the default behavior for CLI test cases here - self.DEFAULT_SKIP_CONFIG = '0' + self.DEFAULT_SKIP_CONFIG = "0" self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() self.CONFIG_DIR = os.path.join(self.TMP_DIR, self.PARENT_DIR) @@ -94,9 +97,9 @@ def setUp(self): if os.path.isfile(self.CONFIG_FILE): os.remove(self.CONFIG_FILE) - with open(self.CONFIG_FILE, 'w') as cfg: - for line in self.CONFIG_CONTENTS.split('\n'): - cfg.write('%s\n' % line.strip()) + with open(self.CONFIG_FILE, "w") as cfg: + for line in self.CONFIG_CONTENTS.split("\n"): + cfg.write("%s\n" % line.strip()) os.chmod(self.CONFIG_FILE, 0o660) @@ -107,7 +110,7 @@ def tearDown(self): os.remove(self.CONFIG_FILE) # Clean up tokens - for file in [f for f in os.listdir(self.DOTST2_PATH) if 'token-' in f]: + for file in [f for f in os.listdir(self.DOTST2_PATH) if "token-" in f]: os.remove(self.DOTST2_PATH + file) # Clean up config directory @@ -116,181 +119,208 @@ def tearDown(self): class TestLoginPasswordAndConfig(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) def runTest(self): - '''Test 'st2 login' functionality by specifying a password and a configuration file - ''' - - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password', - 'Password1!'] + """Test 'st2 login' functionality by specifying a password and a configuration file""" + + expected_username = self.TOKEN["user"] + args = [ + "--config", + self.CONFIG_FILE, + "login", + expected_username, + "--password", + "Password1!", + ] self.shell.run(args) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('password', line) - self.assertNotIn('olduser', line) + self.assertNotIn("password", line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) class TestLoginIntPwdAndConfig(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' functionality with interactive password entry - ''' + """Test 'st2 login' functionality with interactive password entry""" - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username] + expected_username = self.TOKEN["user"] + args = ["--config", self.CONFIG_FILE, "login", expected_username] - mock_gp.getpass.return_value = 'Password1!' + mock_gp.getpass.return_value = "Password1!" self.shell.run(args) expected_kwargs = { - 'headers': {'content-type': 'application/json'}, - 'auth': ('st2admin', 'Password1!') + "headers": {"content-type": "application/json"}, + "auth": ("st2admin", "Password1!"), } - requests.post.assert_called_with('http://127.0.0.1:9100/tokens', '{}', **expected_kwargs) + requests.post.assert_called_with( + "http://127.0.0.1:9100/tokens", "{}", **expected_kwargs + ) # Check file permissions self.assertEqual(os.stat(self.CONFIG_FILE).st_mode & 0o777, 0o660) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('password', line) - self.assertNotIn('olduser', line) + self.assertNotIn("password", line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) # Validate token is sent on subsequent requests to st2 API - args = ['--config', self.CONFIG_FILE, 'pack', 'list'] + args = ["--config", self.CONFIG_FILE, "pack", "list"] self.shell.run(args) expected_kwargs = { - 'headers': { - 'X-Auth-Token': self.TOKEN['token'] - }, - 'params': { - 'include_attributes': 'ref,name,description,version,author' - } + "headers": {"X-Auth-Token": self.TOKEN["token"]}, + "params": {"include_attributes": "ref,name,description,version,author"}, } - requests.get.assert_called_with('http://127.0.0.1:9101/v1/packs', **expected_kwargs) + requests.get.assert_called_with( + "http://127.0.0.1:9101/v1/packs", **expected_kwargs + ) class TestLoginWritePwdOkay(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' functionality with --write-password flag set - ''' - - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password', - 'Password1!', '--write-password'] + """Test 'st2 login' functionality with --write-password flag set""" + + expected_username = self.TOKEN["user"] + args = [ + "--config", + self.CONFIG_FILE, + "login", + expected_username, + "--password", + "Password1!", + "--write-password", + ] self.shell.run(args) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('olduser', line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) class TestLoginUncaughtException(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' ability to detect unhandled exceptions - ''' + """Test 'st2 login' ability to detect unhandled exceptions""" - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username] + expected_username = self.TOKEN["user"] + args = ["--config", self.CONFIG_FILE, "login", expected_username] mock_gp.getpass = mock.MagicMock(side_effect=Exception) self.shell.run(args) retcode = self.shell.run(args) - self.assertIn('Failed to log in as %s' % expected_username, self.stdout.getvalue()) - self.assertNotIn('Logged in as', self.stdout.getvalue()) + self.assertIn( + "Failed to log in as %s" % expected_username, self.stdout.getvalue() + ) + self.assertNotIn("Logged in as", self.stdout.getvalue()) self.assertEqual(retcode, 1) @@ -301,26 +331,26 @@ class TestAuthToken(base.BaseCLITestCase): def __init__(self, *args, **kwargs): super(TestAuthToken, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): super(TestAuthToken, self).setUp() # Setup environment. - os.environ['ST2_BASE_URL'] = 'http://127.0.0.1' + os.environ["ST2_BASE_URL"] = "http://127.0.0.1" def tearDown(self): super(TestAuthToken, self).tearDown() # Clean up environment. - if 'ST2_AUTH_TOKEN' in os.environ: - del os.environ['ST2_AUTH_TOKEN'] - if 'ST2_API_KEY' in os.environ: - del os.environ['ST2_API_KEY'] - if 'ST2_BASE_URL' in os.environ: - del os.environ['ST2_BASE_URL'] + if "ST2_AUTH_TOKEN" in os.environ: + del os.environ["ST2_AUTH_TOKEN"] + if "ST2_API_KEY" in os.environ: + del os.environ["ST2_API_KEY"] + if "ST2_BASE_URL" in os.environ: + del os.environ["ST2_BASE_URL"] @add_auth_token_to_kwargs_from_cli @add_auth_token_to_kwargs_from_env @@ -329,27 +359,27 @@ def _mock_run(self, args, **kwargs): def test_decorate_auth_token_by_cli(self): token = uuid.uuid4().hex - args = self.parser.parse_args(args=['-t', token]) - self.assertDictEqual(self._mock_run(args), {'token': token}) - args = self.parser.parse_args(args=['--token', token]) - self.assertDictEqual(self._mock_run(args), {'token': token}) + args = self.parser.parse_args(args=["-t", token]) + self.assertDictEqual(self._mock_run(args), {"token": token}) + args = self.parser.parse_args(args=["--token", token]) + self.assertDictEqual(self._mock_run(args), {"token": token}) def test_decorate_api_key_by_cli(self): token = uuid.uuid4().hex - args = self.parser.parse_args(args=['--api-key', token]) - self.assertDictEqual(self._mock_run(args), {'api_key': token}) + args = self.parser.parse_args(args=["--api-key", token]) + self.assertDictEqual(self._mock_run(args), {"api_key": token}) def test_decorate_auth_token_by_env(self): token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token args = self.parser.parse_args(args=[]) - self.assertDictEqual(self._mock_run(args), {'token': token}) + self.assertDictEqual(self._mock_run(args), {"token": token}) def test_decorate_api_key_by_env(self): token = uuid.uuid4().hex - os.environ['ST2_API_KEY'] = token + os.environ["ST2_API_KEY"] = token args = self.parser.parse_args(args=[]) - self.assertDictEqual(self._mock_run(args), {'api_key': token}) + self.assertDictEqual(self._mock_run(args), {"api_key": token}) def test_decorate_without_auth_token(self): args = self.parser.parse_args(args=[]) @@ -362,187 +392,215 @@ def _mock_http(self, url, **kwargs): def test_decorate_auth_token_to_http_headers(self): token = uuid.uuid4().hex - kwargs = self._mock_http('/', token=token) - expected = {'content-type': 'application/json', 'X-Auth-Token': token} - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", token=token) + expected = {"content-type": "application/json", "X-Auth-Token": token} + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) def test_decorate_api_key_to_http_headers(self): token = uuid.uuid4().hex - kwargs = self._mock_http('/', api_key=token) - expected = {'content-type': 'application/json', 'St2-Api-Key': token} - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", api_key=token) + expected = {"content-type": "application/json", "St2-Api-Key": token} + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) def test_decorate_without_auth_token_to_http_headers(self): - kwargs = self._mock_http('/', auth=('stanley', 'stanley')) - expected = {'content-type': 'application/json'} - self.assertIn('auth', kwargs) - self.assertEqual(kwargs['auth'], ('stanley', 'stanley')) - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", auth=("stanley", "stanley")) + expected = {"content-type": "application/json"} + self.assertIn("auth", kwargs) + self.assertEqual(kwargs["auth"], ("stanley", "stanley")) + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_resource_list(self): - url = ('http://127.0.0.1:9101/v1/rules/' - '?include_attributes=ref,pack,description,enabled&limit=50') - url = url.replace(',', '%2C') + url = ( + "http://127.0.0.1:9101/v1/rules/" + "?include_attributes=ref,pack,description,enabled&limit=50" + ) + url = url.replace(",", "%2C") # Test without token. - self.shell.run(['rule', 'list']) + self.shell.run(["rule", "list"]) kwargs = {} requests.get.assert_called_with(url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'list', '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "list", "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'list']) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "list"]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_get(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref # Test without token. - self.shell.run(['rule', 'get', rule_ref]) + self.shell.run(["rule", "get", rule_ref]) kwargs = {} requests.get.assert_called_with(url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'get', rule_ref, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "get", rule_ref, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'get', rule_ref]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "get", rule_ref]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_post(self): - url = 'http://127.0.0.1:9101/v1/rules' - data = {'name': RULE['name'], 'description': RULE['description']} + url = "http://127.0.0.1:9101/v1/rules" + data = {"name": RULE["name"], "description": RULE["description"]} - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(data, indent=4)) # Test without token. - self.shell.run(['rule', 'create', path]) - kwargs = {'headers': {'content-type': 'application/json'}} + self.shell.run(["rule", "create", path]) + kwargs = {"headers": {"content-type": "application/json"}} requests.post.assert_called_with(url, json.dumps(data), **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'create', path, '-t', token]) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + self.shell.run(["rule", "create", path, "-t", token]) + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.post.assert_called_with(url, json.dumps(data), **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'create', path]) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "create", path]) + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.post.assert_called_with(url, json.dumps(data), **kwargs) finally: os.close(fd) os.unlink(path) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_put(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - - get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref - put_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id'] - data = {'name': RULE['name'], 'description': RULE['description'], 'pack': RULE['pack']} + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + + get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref + put_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"] + data = { + "name": RULE["name"], + "description": RULE["description"], + "pack": RULE["pack"], + } - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(data, indent=4)) # Test without token. - self.shell.run(['rule', 'update', rule_ref, path]) + self.shell.run(["rule", "update", rule_ref, path]) kwargs = {} requests.get.assert_called_with(get_url, **kwargs) - kwargs = {'headers': {'content-type': 'application/json'}} + kwargs = {"headers": {"content-type": "application/json"}} requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'update', rule_ref, path, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "update", rule_ref, path, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'update', rule_ref, path]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "update", rule_ref, path]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) # Note: We parse the payload because data might not be in the same # order as the fixture - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) finally: os.close(fd) os.unlink(path) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) @mock.patch.object( - requests, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'OK'))) + requests, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "OK")), + ) def test_decorate_resource_delete(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref - del_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id'] + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref + del_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"] # Test without token. - self.shell.run(['rule', 'delete', rule_ref]) + self.shell.run(["rule", "delete", rule_ref]) kwargs = {} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'delete', rule_ref, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "delete", rule_ref, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'delete', rule_ref]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "delete", rule_ref]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) diff --git a/st2client/tests/unit/test_client.py b/st2client/tests/unit/test_client.py index 2e9fd95095..2d0a380ab5 100644 --- a/st2client/tests/unit/test_client.py +++ b/st2client/tests/unit/test_client.py @@ -25,25 +25,25 @@ LOG = logging.getLogger(__name__) -NONRESOURCES = ['workflows'] +NONRESOURCES = ["workflows"] class TestClientEndpoints(unittest2.TestCase): - def tearDown(self): for var in [ - 'ST2_BASE_URL', - 'ST2_API_URL', - 'ST2_STREAM_URL', - 'ST2_DATASTORE_URL', - 'ST2_AUTH_TOKEN' + "ST2_BASE_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_DATASTORE_URL", + "ST2_AUTH_TOKEN", ]: if var in os.environ: del os.environ[var] def test_managers(self): - property_names = [k for k, v in six.iteritems(Client.__dict__) - if isinstance(v, property)] + property_names = [ + k for k, v in six.iteritems(Client.__dict__) if isinstance(v, property) + ] client = Client() @@ -55,96 +55,109 @@ def test_managers(self): self.assertIsInstance(manager, models.ResourceManager) def test_default(self): - base_url = 'http://127.0.0.1' - api_url = 'http://127.0.0.1:9101/v1' - stream_url = 'http://127.0.0.1:9102/v1' + base_url = "http://127.0.0.1" + api_url = "http://127.0.0.1:9101/v1" + stream_url = "http://127.0.0.1:9102/v1" client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_env(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - os.environ['ST2_BASE_URL'] = base_url - os.environ['ST2_API_URL'] = api_url - os.environ['ST2_STREAM_URL'] = stream_url - self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url) - self.assertEqual(os.environ.get('ST2_API_URL'), api_url) - self.assertEqual(os.environ.get('ST2_STREAM_URL'), stream_url) + os.environ["ST2_BASE_URL"] = base_url + os.environ["ST2_API_URL"] = api_url + os.environ["ST2_STREAM_URL"] = stream_url + self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url) + self.assertEqual(os.environ.get("ST2_API_URL"), api_url) + self.assertEqual(os.environ.get("ST2_STREAM_URL"), stream_url) client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_env_base_only(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.stackstorm.com:9101/v1' - stream_url = 'http://www.stackstorm.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.stackstorm.com:9101/v1" + stream_url = "http://www.stackstorm.com:9102/v1" - os.environ['ST2_BASE_URL'] = base_url - self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url) - self.assertEqual(os.environ.get('ST2_API_URL'), None) - self.assertEqual(os.environ.get('ST2_STREAM_URL'), None) + os.environ["ST2_BASE_URL"] = base_url + self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url) + self.assertEqual(os.environ.get("ST2_API_URL"), None) + self.assertEqual(os.environ.get("ST2_STREAM_URL"), None) client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_args(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url) endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_cacert_arg(self): # Valid value, boolean True - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True + ) self.assertEqual(client.cacert, True) # Valid value, boolean False - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False + ) self.assertEqual(client.cacert, False) # Valid value, existing path to a CA bundle cacert = os.path.abspath(__file__) - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert + ) self.assertEqual(client.cacert, cacert) # Invalid value, path to the bundle doesn't exist cacert = os.path.abspath(__file__) expected_msg = 'CA cert file "doesntexist" does not exist' - self.assertRaisesRegexp(ValueError, expected_msg, Client, base_url=base_url, - api_url=api_url, stream_url=stream_url, cacert='doesntexist') + self.assertRaisesRegexp( + ValueError, + expected_msg, + Client, + base_url=base_url, + api_url=api_url, + stream_url=stream_url, + cacert="doesntexist", + ) def test_args_base_only(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.stackstorm.com:9101/v1' - stream_url = 'http://www.stackstorm.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.stackstorm.com:9101/v1" + stream_url = "http://www.stackstorm.com:9102/v1" client = Client(base_url=base_url) endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) diff --git a/st2client/tests/unit/test_client_actions.py b/st2client/tests/unit/test_client_actions.py index 82b12b788d..141e7c8ece 100644 --- a/st2client/tests/unit/test_client_actions.py +++ b/st2client/tests/unit/test_client_actions.py @@ -31,22 +31,17 @@ EXECUTION = { "id": 12345, - "action": { - "ref": "mock.foobar" - }, + "action": {"ref": "mock.foobar"}, "status": "failed", - "result": "non-empty" + "result": "non-empty", } ENTRYPOINT = ( "version: 1.0" - "description: A basic workflow that runs an arbitrary linux command." - "input:" " - cmd" " - timeout" - "tasks:" " task1:" " action: core.local cmd=<% ctx(cmd) %> timeout=<% ctx(timeout) %>" @@ -55,51 +50,63 @@ " publish:" " - stdout: <% result().stdout %>" " - stderr: <% result().stderr %>" - "output:" " - stdout: <% ctx(stdout) %>" ) class TestActionResourceManager(unittest2.TestCase): - @classmethod def setUpClass(cls): super(TestActionResourceManager, cls).setUpClass() cls.client = client.Client() @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK") + ), + ) def test_get_action_entry_point_by_ref(self): - actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['action']['ref']) + actual_entrypoint = self.client.actions.get_entrypoint( + EXECUTION["action"]["ref"] + ) actual_entrypoint = json.loads(actual_entrypoint) - endpoint = '/actions/views/entry_point/%s' % EXECUTION['action']['ref'] + endpoint = "/actions/views/entry_point/%s" % EXECUTION["action"]["ref"] httpclient.HTTPClient.get.assert_called_with(endpoint) self.assertEqual(ENTRYPOINT, actual_entrypoint) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK") + ), + ) def test_get_action_entry_point_by_id(self): - actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['id']) + actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION["id"]) actual_entrypoint = json.loads(actual_entrypoint) - endpoint = '/actions/views/entry_point/%s' % EXECUTION['id'] + endpoint = "/actions/views/entry_point/%s" % EXECUTION["id"] httpclient.HTTPClient.get.assert_called_with(endpoint) self.assertEqual(ENTRYPOINT, actual_entrypoint) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse( - json.dumps({}), 404, '404 Client Error: Not Found' - ))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps({}), 404, "404 Client Error: Not Found" + ) + ), + ) def test_get_non_existent_action_entry_point(self): - with self.assertRaisesRegexp(Exception, '404 Client Error: Not Found'): - self.client.actions.get_entrypoint('nonexistentpack.nonexistentaction') + with self.assertRaisesRegexp(Exception, "404 Client Error: Not Found"): + self.client.actions.get_entrypoint("nonexistentpack.nonexistentaction") - endpoint = '/actions/views/entry_point/%s' % 'nonexistentpack.nonexistentaction' + endpoint = "/actions/views/entry_point/%s" % "nonexistentpack.nonexistentaction" httpclient.HTTPClient.get.assert_called_with(endpoint) diff --git a/st2client/tests/unit/test_client_executions.py b/st2client/tests/unit/test_client_executions.py index 0470347ee2..a9dc19e2c3 100644 --- a/st2client/tests/unit/test_client_executions.py +++ b/st2client/tests/unit/test_client_executions.py @@ -34,9 +34,7 @@ RUNNER = { "enabled": True, "name": "marathon", - "runner_parameters": { - "var1": {"type": "string"} - } + "runner_parameters": {"var1": {"type": "string"}}, } ACTION = { @@ -46,185 +44,227 @@ "parameters": {}, "enabled": True, "entry_point": "", - "pack": "mocke" + "pack": "mocke", } EXECUTION = { "id": 12345, - "action": { - "ref": "mock.foobar" - }, + "action": {"ref": "mock.foobar"}, "status": "failed", - "result": "non-empty" + "result": "non-empty", } class TestExecutionResourceManager(unittest2.TestCase): - @classmethod def setUpClass(cls): super(TestExecutionResourceManager, cls).setUpClass() cls.client = client.Client() @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_no_params(self): - self.client.executions.re_run(EXECUTION['id'], tasks=['foobar']) + self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"]) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] - data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': {}, - 'delay': 0 - } + data = {"tasks": ["foobar"], "reset": ["foobar"], "parameters": {}, "delay": 0} httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_params(self): - params = { - 'var1': 'testing...' - } + params = {"var1": "testing..."} self.client.executions.re_run( - EXECUTION['id'], - tasks=['foobar'], - parameters=params + EXECUTION["id"], tasks=["foobar"], parameters=params ) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': params, - 'delay': 0 + "tasks": ["foobar"], + "reset": ["foobar"], + "parameters": params, + "delay": 0, } httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_delay(self): - self.client.executions.re_run(EXECUTION['id'], tasks=['foobar'], delay=100) + self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"], delay=100) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': {}, - 'delay': 100 + "tasks": ["foobar"], + "reset": ["foobar"], + "parameters": {}, + "delay": 100, } httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_pause(self): - self.client.executions.pause(EXECUTION['id']) + self.client.executions.pause(EXECUTION["id"]) - endpoint = '/executions/%s' % EXECUTION['id'] + endpoint = "/executions/%s" % EXECUTION["id"] - data = { - 'status': 'pausing' - } + data = {"status": "pausing"} httpclient.HTTPClient.put.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_resume(self): - self.client.executions.resume(EXECUTION['id']) + self.client.executions.resume(EXECUTION["id"]) - endpoint = '/executions/%s' % EXECUTION['id'] + endpoint = "/executions/%s" % EXECUTION["id"] - data = { - 'status': 'resuming' - } + data = {"status": "resuming"} httpclient.HTTPClient.put.assert_called_with(endpoint, data) @mock.patch.object( - models.core.Resource, 'get_url_path_name', - mock.MagicMock(return_value='executions')) + models.core.Resource, + "get_url_path_name", + mock.MagicMock(return_value="executions"), + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK") + ), + ) def test_get_children(self): - self.client.executions.get_children(EXECUTION['id']) + self.client.executions.get_children(EXECUTION["id"]) - endpoint = '/executions/%s/children' % EXECUTION['id'] + endpoint = "/executions/%s/children" % EXECUTION["id"] - data = { - 'depth': -1 - } + data = {"depth": -1} httpclient.HTTPClient.get.assert_called_with(url=endpoint, params=data) @mock.patch.object( - models.ResourceManager, 'get_all', - mock.MagicMock(return_value=[models.Execution(**EXECUTION)])) - @mock.patch.object(warnings, 'warn') - def test_st2client_liveactions_has_been_deprecated_and_emits_warning(self, mock_warn): + models.ResourceManager, + "get_all", + mock.MagicMock(return_value=[models.Execution(**EXECUTION)]), + ) + @mock.patch.object(warnings, "warn") + def test_st2client_liveactions_has_been_deprecated_and_emits_warning( + self, mock_warn + ): self.assertEqual(mock_warn.call_args, None) self.client.liveactions.get_all() - expected_msg = 'st2client.liveactions has been renamed' + expected_msg = "st2client.liveactions has been renamed" self.assertTrue(len(mock_warn.call_args_list) >= 1) self.assertIn(expected_msg, mock_warn.call_args_list[0][0][0]) self.assertEqual(mock_warn.call_args_list[0][0][1], DeprecationWarning) diff --git a/st2client/tests/unit/test_command_actionrun.py b/st2client/tests/unit/test_command_actionrun.py index 763ac649a6..1e312e0786 100644 --- a/st2client/tests/unit/test_command_actionrun.py +++ b/st2client/tests/unit/test_command_actionrun.py @@ -21,73 +21,79 @@ import mock from st2client.commands.action import ActionRunCommand -from st2client.models.action import (Action, RunnerType) +from st2client.models.action import Action, RunnerType class ActionRunCommandTest(unittest2.TestCase): - def test_get_params_types(self): runner = RunnerType() runner_params = { - 'foo': {'immutable': True, 'required': True}, - 'bar': {'description': 'Some param.', 'type': 'string'} + "foo": {"immutable": True, "required": True}, + "bar": {"description": "Some param.", "type": "string"}, } runner.runner_parameters = runner_params orig_runner_params = copy.deepcopy(runner.runner_parameters) action = Action() action.parameters = { - 'foo': {'immutable': False}, # Should not be allowed by API. - 'stuff': {'description': 'Some param.', 'type': 'string', 'required': True} + "foo": {"immutable": False}, # Should not be allowed by API. + "stuff": {"description": "Some param.", "type": "string", "required": True}, } orig_action_params = copy.deepcopy(action.parameters) params, rqd, opt, imm = ActionRunCommand._get_params_types(runner, action) self.assertEqual(len(list(params.keys())), 3) - self.assertIn('foo', imm, '"foo" param should be in immutable set.') - self.assertNotIn('foo', rqd, '"foo" param should not be in required set.') - self.assertNotIn('foo', opt, '"foo" param should not be in optional set.') + self.assertIn("foo", imm, '"foo" param should be in immutable set.') + self.assertNotIn("foo", rqd, '"foo" param should not be in required set.') + self.assertNotIn("foo", opt, '"foo" param should not be in optional set.') - self.assertIn('bar', opt, '"bar" param should be in optional set.') - self.assertNotIn('bar', rqd, '"bar" param should not be in required set.') - self.assertNotIn('bar', imm, '"bar" param should not be in immutable set.') + self.assertIn("bar", opt, '"bar" param should be in optional set.') + self.assertNotIn("bar", rqd, '"bar" param should not be in required set.') + self.assertNotIn("bar", imm, '"bar" param should not be in immutable set.') - self.assertIn('stuff', rqd, '"stuff" param should be in required set.') - self.assertNotIn('stuff', opt, '"stuff" param should not be in optional set.') - self.assertNotIn('stuff', imm, '"stuff" param should not be in immutable set.') - self.assertEqual(runner.runner_parameters, orig_runner_params, 'Runner params modified.') - self.assertEqual(action.parameters, orig_action_params, 'Action params modified.') + self.assertIn("stuff", rqd, '"stuff" param should be in required set.') + self.assertNotIn("stuff", opt, '"stuff" param should not be in optional set.') + self.assertNotIn("stuff", imm, '"stuff" param should not be in immutable set.') + self.assertEqual( + runner.runner_parameters, orig_runner_params, "Runner params modified." + ) + self.assertEqual( + action.parameters, orig_action_params, "Action params modified." + ) def test_opt_in_dict_auto_convert(self): - """Test ability for user to opt-in to dict convert functionality - """ + """Test ability for user to opt-in to dict convert functionality""" runner = RunnerType() runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_array': {'type': 'array'}, + "param_array": {"type": "array"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.parameters = [ - 'param_array=foo:bar,foo2:bar2', + "param_array=foo:bar,foo2:bar2", ] mockarg.auto_dict = False - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) - self.assertEqual(param['param_array'], ['foo:bar', 'foo2:bar2']) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) + self.assertEqual(param["param_array"], ["foo:bar", "foo2:bar2"]) mockarg.auto_dict = True - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) - self.assertEqual(param['param_array'], [{'foo': 'bar', 'foo2': 'bar2'}]) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) + self.assertEqual(param["param_array"], [{"foo": "bar", "foo2": "bar2"}]) # set auto_dict back to default mockarg.auto_dict = False @@ -104,60 +110,65 @@ def test_get_params_from_args(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_string': {'type': 'string'}, - 'param_integer': {'type': 'integer'}, - 'param_number': {'type': 'number'}, - 'param_object': {'type': 'object'}, - 'param_boolean': {'type': 'boolean'}, - 'param_array': {'type': 'array'}, - 'param_array_of_dicts': {'type': 'array', 'properties': { - 'foo': {'type': 'string'}, - 'bar': {'type': 'integer'}, - 'baz': {'type': 'number'}, - 'qux': {'type': 'object'}, - 'quux': {'type': 'boolean'}} + "param_string": {"type": "string"}, + "param_integer": {"type": "integer"}, + "param_number": {"type": "number"}, + "param_object": {"type": "object"}, + "param_boolean": {"type": "boolean"}, + "param_array": {"type": "array"}, + "param_array_of_dicts": { + "type": "array", + "properties": { + "foo": {"type": "string"}, + "bar": {"type": "integer"}, + "baz": {"type": "number"}, + "qux": {"type": "object"}, + "quux": {"type": "boolean"}, + }, }, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True mockarg.parameters = [ - 'param_string=hoge', - 'param_integer=123', - 'param_number=1.23', - 'param_object=hoge=1,fuga=2', - 'param_boolean=False', - 'param_array=foo,bar,baz', - 'param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True', - 'param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False' + "param_string=hoge", + "param_integer=123", + "param_number=1.23", + "param_object=hoge=1,fuga=2", + "param_boolean=False", + "param_array=foo,bar,baz", + "param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True", + "param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False", ] - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) self.assertIsInstance(param, dict) - self.assertEqual(param['param_string'], 'hoge') - self.assertEqual(param['param_integer'], 123) - self.assertEqual(param['param_number'], 1.23) - self.assertEqual(param['param_object'], {'hoge': '1', 'fuga': '2'}) - self.assertFalse(param['param_boolean']) - self.assertEqual(param['param_array'], ['foo', 'bar', 'baz']) + self.assertEqual(param["param_string"], "hoge") + self.assertEqual(param["param_integer"], 123) + self.assertEqual(param["param_number"], 1.23) + self.assertEqual(param["param_object"], {"hoge": "1", "fuga": "2"}) + self.assertFalse(param["param_boolean"]) + self.assertEqual(param["param_array"], ["foo", "bar", "baz"]) # checking the result of parsing for array of objects - self.assertIsInstance(param['param_array_of_dicts'], list) - self.assertEqual(len(param['param_array_of_dicts']), 2) - for param in param['param_array_of_dicts']: + self.assertIsInstance(param["param_array_of_dicts"], list) + self.assertEqual(len(param["param_array_of_dicts"]), 2) + for param in param["param_array_of_dicts"]: self.assertIsInstance(param, dict) - self.assertIsInstance(param['foo'], str) - self.assertIsInstance(param['bar'], int) - self.assertIsInstance(param['baz'], float) - self.assertIsInstance(param['qux'], dict) - self.assertIsInstance(param['quux'], bool) + self.assertIsInstance(param["foo"], str) + self.assertIsInstance(param["bar"], int) + self.assertIsInstance(param["baz"], float) + self.assertIsInstance(param["qux"], dict) + self.assertIsInstance(param["quux"], bool) # set auto_dict back to default mockarg.auto_dict = False @@ -167,36 +178,38 @@ def test_get_params_from_args_read_content_from_file(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_object': {'type': 'object'}, + "param_object": {"type": "object"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") # 1. File doesn't exist mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True - mockarg.parameters = [ - '@param_object=doesnt-exist.json' - ] + mockarg.parameters = ["@param_object=doesnt-exist.json"] - self.assertRaisesRegex(ValueError, "doesn't exist", - command._get_action_parameters_from_args, action=action, - runner=runner, args=mockarg) + self.assertRaisesRegex( + ValueError, + "doesn't exist", + command._get_action_parameters_from_args, + action=action, + runner=runner, + args=mockarg, + ) # 2. Valid file path (we simply read this file) mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True - mockarg.parameters = [ - '@param_string=%s' % (__file__) - ] + mockarg.parameters = ["@param_string=%s" % (__file__)] - params = command._get_action_parameters_from_args(action=action, - runner=runner, args=mockarg) + params = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) self.assertTrue(isinstance(params["param_string"], six.text_type)) self.assertTrue(params["param_string"].startswith("# Copyright")) @@ -212,37 +225,39 @@ def test_get_params_from_args_with_multiple_declarations(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_string': {'type': 'string'}, - 'param_array': {'type': 'array'}, - 'param_array_of_dicts': {'type': 'array'}, + "param_string": {"type": "string"}, + "param_array": {"type": "array"}, + "param_array_of_dicts": {"type": "array"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True mockarg.parameters = [ - 'param_string=hoge', # This value will be overwritten with the next declaration. - 'param_string=fuga', - 'param_array=foo', - 'param_array=bar', - 'param_array_of_dicts=foo:1,bar:2', - 'param_array_of_dicts=hoge:A,fuga:B' + "param_string=hoge", # This value will be overwritten with the next declaration. + "param_string=fuga", + "param_array=foo", + "param_array=bar", + "param_array_of_dicts=foo:1,bar:2", + "param_array_of_dicts=hoge:A,fuga:B", ] - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) # checks to accept multiple declaration only if the array type - self.assertEqual(param['param_string'], 'fuga') - self.assertEqual(param['param_array'], ['foo', 'bar']) - self.assertEqual(param['param_array_of_dicts'], [ - {'foo': '1', 'bar': '2'}, - {'hoge': 'A', 'fuga': 'B'} - ]) + self.assertEqual(param["param_string"], "fuga") + self.assertEqual(param["param_array"], ["foo", "bar"]) + self.assertEqual( + param["param_array_of_dicts"], + [{"foo": "1", "bar": "2"}, {"hoge": "A", "fuga": "B"}], + ) # set auto_dict back to default mockarg.auto_dict = False diff --git a/st2client/tests/unit/test_commands.py b/st2client/tests/unit/test_commands.py index 0748a4aeec..de84f7883f 100644 --- a/st2client/tests/unit/test_commands.py +++ b/st2client/tests/unit/test_commands.py @@ -32,97 +32,117 @@ from st2client.commands import resource from st2client.commands.resource import ResourceViewCommand -__all__ = [ - 'TestResourceCommand', - 'ResourceViewCommandTestCase' -] +__all__ = ["TestResourceCommand", "ResourceViewCommandTestCase"] LOG = logging.getLogger(__name__) class TestResourceCommand(unittest2.TestCase): - def __init__(self, *args, **kwargs): super(TestResourceCommand, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() self.subparsers = self.parser.add_subparsers() self.branch = resource.ResourceBranch( - base.FakeResource, 'Test Command', base.FakeApp(), self.subparsers) + base.FakeResource, "Test Command", base.FakeApp(), self.subparsers + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_command_list(self): - args = self.parser.parse_args(['fakeresource', 'list']) - self.assertEqual(args.func, self.branch.commands['list'].run_and_print) - instances = self.branch.commands['list'].run(args) + args = self.parser.parse_args(["fakeresource", "list"]) + self.assertEqual(args.func, self.branch.commands["list"].run_and_print) + instances = self.branch.commands["list"].run(args) actual = [instance.serialize() for instance in instances] expected = json.loads(json.dumps(base.RESOURCES)) self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_list_failed(self): - args = self.parser.parse_args(['fakeresource', 'list']) - self.assertRaises(Exception, self.branch.commands['list'].run, args) + args = self.parser.parse_args(["fakeresource", "list"]) + self.assertRaises(Exception, self.branch.commands["list"].run, args) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=None)) + models.ResourceManager, "get_by_name", mock.MagicMock(return_value=None) + ) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0]))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0])), + ) def test_command_get_by_id(self): - args = self.parser.parse_args(['fakeresource', 'get', '123']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - instance = self.branch.commands['get'].run(args) + args = self.parser.parse_args(["fakeresource", "get", "123"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + instance = self.branch.commands["get"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_get(self): - args = self.parser.parse_args(['fakeresource', 'get', 'abc']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - instance = self.branch.commands['get'].run(args) + args = self.parser.parse_args(["fakeresource", "get", "abc"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + instance = self.branch.commands["get"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_command_get_404(self): - args = self.parser.parse_args(['fakeresource', 'get', 'cba']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - self.assertRaises(resource.ResourceNotFoundError, - self.branch.commands['get'].run, - args) + args = self.parser.parse_args(["fakeresource", "get", "cba"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + self.assertRaises( + resource.ResourceNotFoundError, self.branch.commands["get"].run, args + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_get_failed(self): - args = self.parser.parse_args(['fakeresource', 'get', 'cba']) - self.assertRaises(Exception, self.branch.commands['get'].run, args) + args = self.parser.parse_args(["fakeresource", "get", "cba"]) + self.assertRaises(Exception, self.branch.commands["get"].run, args) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_create(self): - instance = base.FakeResource(name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args(['fakeresource', 'create', path]) - self.assertEqual(args.func, - self.branch.commands['create'].run_and_print) - instance = self.branch.commands['create'].run(args) + args = self.parser.parse_args(["fakeresource", "create", path]) + self.assertEqual(args.func, self.branch.commands["create"].run_and_print) + instance = self.branch.commands["create"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @@ -131,40 +151,49 @@ def test_command_create(self): os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_create_failed(self): - instance = base.FakeResource(name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args(['fakeresource', 'create', path]) - self.assertRaises(Exception, - self.branch.commands['create'].run, - args) + args = self.parser.parse_args(["fakeresource", "create", path]) + self.assertRaises(Exception, self.branch.commands["create"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_update(self): - instance = base.FakeResource(id='123', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="123", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertEqual(args.func, - self.branch.commands['update'].run_and_print) - instance = self.branch.commands['update'].run(args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertEqual(args.func, self.branch.commands["update"].run_and_print) + instance = self.branch.commands["update"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @@ -173,122 +202,142 @@ def test_command_update(self): os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_update_failed(self): - instance = base.FakeResource(id='123', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="123", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertRaises(Exception, - self.branch.commands['update'].run, - args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertRaises(Exception, self.branch.commands["update"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) def test_command_update_id_mismatch(self): - instance = base.FakeResource(id='789', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="789", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertRaises(Exception, - self.branch.commands['update'].run, - args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertRaises(Exception, self.branch.commands["update"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")), + ) def test_command_delete(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'abc']) - self.assertEqual(args.func, - self.branch.commands['delete'].run_and_print) - self.branch.commands['delete'].run(args) + args = self.parser.parse_args(["fakeresource", "delete", "abc"]) + self.assertEqual(args.func, self.branch.commands["delete"].run_and_print) + self.branch.commands["delete"].run(args) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_command_delete_404(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'cba']) - self.assertEqual(args.func, - self.branch.commands['delete'].run_and_print) - self.assertRaises(resource.ResourceNotFoundError, - self.branch.commands['delete'].run, - args) + args = self.parser.parse_args(["fakeresource", "delete", "cba"]) + self.assertEqual(args.func, self.branch.commands["delete"].run_and_print) + self.assertRaises( + resource.ResourceNotFoundError, self.branch.commands["delete"].run, args + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_delete_failed(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'cba']) - self.assertRaises(Exception, self.branch.commands['delete'].run, args) + args = self.parser.parse_args(["fakeresource", "delete", "cba"]) + self.assertRaises(Exception, self.branch.commands["delete"].run, args) class ResourceViewCommandTestCase(unittest2.TestCase): - def setUp(self): ResourceViewCommand.display_attributes = [] def test_get_include_attributes(self): - cls = namedtuple('Args', 'attr') + cls = namedtuple("Args", "attr") args = cls(attr=[]) result = ResourceViewCommand._get_include_attributes(args=args) self.assertEqual(result, []) - args = cls(attr=['result']) + args = cls(attr=["result"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result']) + self.assertEqual(result, ["result"]) - args = cls(attr=['result', 'trigger_instance']) + args = cls(attr=["result", "trigger_instance"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result', 'trigger_instance']) + self.assertEqual(result, ["result", "trigger_instance"]) - args = cls(attr=['result.stdout']) + args = cls(attr=["result.stdout"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout']) + self.assertEqual(result, ["result.stdout"]) - args = cls(attr=['result.stdout', 'result.stderr']) + args = cls(attr=["result.stdout", "result.stderr"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout', 'result.stderr']) + self.assertEqual(result, ["result.stdout", "result.stderr"]) - args = cls(attr=['result.stdout', 'trigger_instance.id']) + args = cls(attr=["result.stdout", "trigger_instance.id"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout', 'trigger_instance.id']) + self.assertEqual(result, ["result.stdout", "trigger_instance.id"]) - ResourceViewCommand.display_attributes = ['id', 'status'] + ResourceViewCommand.display_attributes = ["id", "status"] args = cls(attr=[]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(set(result), set(['id', 'status'])) + self.assertEqual(set(result), set(["id", "status"])) - args = cls(attr=['trigger_instance']) + args = cls(attr=["trigger_instance"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(set(result), set(['trigger_instance'])) + self.assertEqual(set(result), set(["trigger_instance"])) - args = cls(attr=['all']) + args = cls(attr=["all"]) result = ResourceViewCommand._get_include_attributes(args=args) self.assertEqual(result, None) @@ -303,20 +352,19 @@ class CommandsHelpStringTestCase(BaseCLITestCase): # TODO: Automatically iterate all the available commands COMMANDS = [ # action - ['action', 'list'], - ['action', 'get'], - ['action', 'create'], - ['action', 'update'], - ['action', 'delete'], - ['action', 'enable'], - ['action', 'disable'], - ['action', 'execute'], - + ["action", "list"], + ["action", "get"], + ["action", "create"], + ["action", "update"], + ["action", "delete"], + ["action", "enable"], + ["action", "disable"], + ["action", "execute"], # execution - ['execution', 'cancel'], - ['execution', 'pause'], - ['execution', 'resume'], - ['execution', 'tail'] + ["execution", "cancel"], + ["execution", "pause"], + ["execution", "resume"], + ["execution", "tail"], ] def test_help_command_line_arg_works_for_supported_commands(self): @@ -324,7 +372,7 @@ def test_help_command_line_arg_works_for_supported_commands(self): for command in self.COMMANDS: # First test longhang notation - argv = command + ['--help'] + argv = command + ["--help"] try: result = shell.run(argv) @@ -335,16 +383,16 @@ def test_help_command_line_arg_works_for_supported_commands(self): stdout = self.stdout.getvalue() - self.assertIn('usage:', stdout) - self.assertIn(' '.join(command), stdout) + self.assertIn("usage:", stdout) + self.assertIn(" ".join(command), stdout) # self.assertIn('positional arguments:', stdout) - self.assertIn('optional arguments:', stdout) + self.assertIn("optional arguments:", stdout) # Reset stdout and stderr after each iteration self._reset_output_streams() # Then shorthand notation - argv = command + ['-h'] + argv = command + ["-h"] try: result = shell.run(argv) @@ -355,14 +403,14 @@ def test_help_command_line_arg_works_for_supported_commands(self): stdout = self.stdout.getvalue() - self.assertIn('usage:', stdout) - self.assertIn(' '.join(command), stdout) + self.assertIn("usage:", stdout) + self.assertIn(" ".join(command), stdout) # self.assertIn('positional arguments:', stdout) - self.assertIn('optional arguments:', stdout) + self.assertIn("optional arguments:", stdout) # Verify that the actual help usage string was triggered and not the invalid # "too few arguments" which would indicate command doesn't actually correctly handle # --help flag - self.assertNotIn('too few arguments', stdout) + self.assertNotIn("too few arguments", stdout) self._reset_output_streams() diff --git a/st2client/tests/unit/test_config_parser.py b/st2client/tests/unit/test_config_parser.py index 35a125ebeb..9cea63ee5a 100644 --- a/st2client/tests/unit/test_config_parser.py +++ b/st2client/tests/unit/test_config_parser.py @@ -26,80 +26,77 @@ from st2client.config_parser import CONFIG_DEFAULT_VALUES BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini') -CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini') -CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, '../fixtures/test_unicode.ini') +CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini") +CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini") +CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, "../fixtures/test_unicode.ini") class CLIConfigParserTestCase(unittest2.TestCase): def test_constructor(self): - parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False) + parser = CLIConfigParser( + config_file_path="doesnotexist", validate_config_exists=False + ) self.assertTrue(parser) - self.assertRaises(ValueError, CLIConfigParser, config_file_path='doestnotexist', - validate_config_exists=True) + self.assertRaises( + ValueError, + CLIConfigParser, + config_file_path="doestnotexist", + validate_config_exists=True, + ) def test_parse(self): # File doesn't exist - parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False) + parser = CLIConfigParser( + config_file_path="doesnotexist", validate_config_exists=False + ) result = parser.parse() self.assertEqual(CONFIG_DEFAULT_VALUES, result) # File exists - all the options specified expected = { - 'general': { - 'base_url': 'http://127.0.0.1', - 'api_version': 'v1', - 'cacert': 'cacartpath', - 'silence_ssl_warnings': False, - 'silence_schema_output': True + "general": { + "base_url": "http://127.0.0.1", + "api_version": "v1", + "cacert": "cacartpath", + "silence_ssl_warnings": False, + "silence_schema_output": True, }, - 'cli': { - 'debug': True, - 'cache_token': False, - 'timezone': 'UTC' - }, - 'credentials': { - 'username': 'test1', - 'password': 'test1', - 'api_key': None - }, - 'api': { - 'url': 'http://127.0.0.1:9101/v1' - }, - 'auth': { - 'url': 'http://127.0.0.1:9100/' - }, - 'stream': { - 'url': 'http://127.0.0.1:9102/v1/stream' - } + "cli": {"debug": True, "cache_token": False, "timezone": "UTC"}, + "credentials": {"username": "test1", "password": "test1", "api_key": None}, + "api": {"url": "http://127.0.0.1:9101/v1"}, + "auth": {"url": "http://127.0.0.1:9100/"}, + "stream": {"url": "http://127.0.0.1:9102/v1/stream"}, } - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_FULL, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_FULL, validate_config_exists=False + ) result = parser.parse() self.assertEqual(expected, result) # File exists - missing options, test defaults - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_PARTIAL, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_PARTIAL, validate_config_exists=False + ) result = parser.parse() - self.assertTrue(result['cli']['cache_token'], True) + self.assertTrue(result["cli"]["cache_token"], True) def test_get_config_for_unicode_char(self): - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_UNICODE, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_UNICODE, validate_config_exists=False + ) config = parser.parse() if six.PY3: - self.assertEqual(config['credentials']['password'], '密码') + self.assertEqual(config["credentials"]["password"], "密码") else: - self.assertEqual(config['credentials']['password'], u'\u5bc6\u7801') + self.assertEqual(config["credentials"]["password"], "\u5bc6\u7801") class CLIConfigPermissionsTestCase(unittest2.TestCase): def setUp(self): - self.TEMP_FILE_PATH = os.path.join('st2config', '.st2', 'config') + self.TEMP_FILE_PATH = os.path.join("st2config", ".st2", "config") self.TEMP_CONFIG_DIR = os.path.dirname(self.TEMP_FILE_PATH) if os.path.exists(self.TEMP_FILE_PATH): @@ -135,7 +132,9 @@ def test_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o660) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -159,7 +158,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o640) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -175,7 +176,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o600) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -200,7 +203,9 @@ def test_warn_on_bad_config_permissions(self): self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -209,17 +214,20 @@ def test_warn_on_bad_config_permissions(self): self.assertEqual( "The SGID bit is not set on the StackStorm configuration directory.", - parser.LOG.info.call_args_list[0][0][0]) + parser.LOG.info.call_args_list[0][0][0], + ) self.assertEqual(parser.LOG.warn.call_count, 2) self.assertEqual( "The StackStorm configuration directory permissions are insecure " "(too permissive): others have access.", - parser.LOG.warn.call_args_list[0][0][0]) + parser.LOG.warn.call_args_list[0][0][0], + ) self.assertEqual( "The StackStorm configuration file permissions are insecure: others have access.", - parser.LOG.warn.call_args_list[1][0][0]) + parser.LOG.warn.call_args_list[1][0][0], + ) # Make sure we left the file alone self.assertTrue(os.path.exists(self.TEMP_FILE_PATH)) @@ -239,9 +247,11 @@ def test_disable_permissions_warnings(self): self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, - validate_config_exists=True, - validate_config_permissions=False) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, + validate_config_exists=True, + validate_config_permissions=False, + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 diff --git a/st2client/tests/unit/test_execution_tail_command.py b/st2client/tests/unit/test_execution_tail_command.py index 15500767f2..08957ddbf1 100644 --- a/st2client/tests/unit/test_execution_tail_command.py +++ b/st2client/tests/unit/test_execution_tail_command.py @@ -27,247 +27,180 @@ from st2client.commands.action import LIVEACTION_STATUS_TIMED_OUT from st2client.shell import Shell -__all__ = [ - 'ActionExecutionTailCommandTestCase' -] +__all__ = ["ActionExecutionTailCommandTestCase"] # Mock objects -MOCK_LIVEACTION_1_RUNNING = { - 'id': 'idfoo1', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_1_RUNNING = {"id": "idfoo1", "status": LIVEACTION_STATUS_RUNNING} -MOCK_LIVEACTION_1_SUCCEEDED = { - 'id': 'idfoo1', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_1_SUCCEEDED = {"id": "idfoo1", "status": LIVEACTION_STATUS_SUCCEEDED} -MOCK_LIVEACTION_2_FAILED = { - 'id': 'idfoo2', - 'status': LIVEACTION_STATUS_FAILED -} +MOCK_LIVEACTION_2_FAILED = {"id": "idfoo2", "status": LIVEACTION_STATUS_FAILED} # Mock liveaction objects for ActionChain workflow -MOCK_LIVEACTION_3_RUNNING = { - 'id': 'idfoo3', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_3_RUNNING = {"id": "idfoo3", "status": LIVEACTION_STATUS_RUNNING} MOCK_LIVEACTION_3_CHILD_1_RUNNING = { - 'id': 'idchild1', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_1' - } - }, - 'status': LIVEACTION_STATUS_RUNNING + "id": "idchild1", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}}, + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED = { - 'id': 'idchild1', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_1' - } - }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "id": "idchild1", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}}, + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1 = { - 'execution_id': 'idchild1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line ac 4\n' + "execution_id": "idchild1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line ac 4\n", } MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2 = { - 'execution_id': 'idchild1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line ac 5\n' + "execution_id": "idchild1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line ac 5\n", } MOCK_LIVEACTION_3_CHILD_2_RUNNING = { - 'id': 'idchild2', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_2' - } - }, - 'status': LIVEACTION_STATUS_RUNNING + "id": "idchild2", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}}, + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_3_CHILD_2_FAILED = { - 'id': 'idchild2', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_2' - } - }, - 'status': LIVEACTION_STATUS_FAILED + "id": "idchild2", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}}, + "status": LIVEACTION_STATUS_FAILED, } MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1 = { - 'execution_id': 'idchild2', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line ac 100\n' + "execution_id": "idchild2", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line ac 100\n", } -MOCK_LIVEACTION_3_SUCCEDED = { - 'id': 'idfoo3', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_3_SUCCEDED = {"id": "idfoo3", "status": LIVEACTION_STATUS_SUCCEEDED} # Mock objects for Orquesta workflow execution -MOCK_LIVEACTION_4_RUNNING = { - 'id': 'idfoo4', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_4_RUNNING = {"id": "idfoo4", "status": LIVEACTION_STATUS_RUNNING} MOCK_LIVEACTION_4_CHILD_1_RUNNING = { - 'id': 'idorquestachild1', - 'context': { - 'orquesta': { - 'task_name': 'task_1' - }, - 'parent': { - 'execution_id': 'idfoo4' - } + "id": "idorquestachild1", + "context": { + "orquesta": {"task_name": "task_1"}, + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_1_1_RUNNING = { - 'id': 'idorquestachild1_1', - 'context': { - 'orquesta': { - 'task_name': 'task_1' - }, - 'parent': { - 'execution_id': 'idorquestachild1' - } + "id": "idorquestachild1_1", + "context": { + "orquesta": {"task_name": "task_1"}, + "parent": {"execution_id": "idorquestachild1"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED = { - 'id': 'idorquestachild1', - 'context': { - 'orquesta': { - 'task_name': 'task_1', + "id": "idorquestachild1", + "context": { + "orquesta": { + "task_name": "task_1", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED = { - 'id': 'idorquestachild1_1', - 'context': { - 'orquesta': { - 'task_name': 'task_1', + "id": "idorquestachild1_1", + "context": { + "orquesta": { + "task_name": "task_1", }, - 'parent': { - 'execution_id': 'idorquestachild1' - } + "parent": {"execution_id": "idorquestachild1"}, }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1 = { - 'execution_id': 'idorquestachild1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 4\n' + "execution_id": "idorquestachild1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 4\n", } MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2 = { - 'execution_id': 'idorquestachild1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line orquesta 5\n' + "execution_id": "idorquestachild1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line orquesta 5\n", } MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1 = { - 'execution_id': 'idorquestachild1_1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 4\n' + "execution_id": "idorquestachild1_1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 4\n", } MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2 = { - 'execution_id': 'idorquestachild1_1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line orquesta 5\n' + "execution_id": "idorquestachild1_1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line orquesta 5\n", } MOCK_LIVEACTION_4_CHILD_2_RUNNING = { - 'id': 'idorquestachild2', - 'context': { - 'orquesta': { - 'task_name': 'task_2', + "id": "idorquestachild2", + "context": { + "orquesta": { + "task_name": "task_2", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT = { - 'id': 'idorquestachild2', - 'context': { - 'orquesta': { - 'task_name': 'task_2', + "id": "idorquestachild2", + "context": { + "orquesta": { + "task_name": "task_2", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_TIMED_OUT + "status": LIVEACTION_STATUS_TIMED_OUT, } MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1 = { - 'execution_id': 'idorquestachild2', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 100\n' + "execution_id": "idorquestachild2", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 100\n", } -MOCK_LIVEACTION_4_SUCCEDED = { - 'id': 'idfoo4', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_4_SUCCEDED = {"id": "idfoo4", "status": LIVEACTION_STATUS_SUCCEEDED} # Mock objects for simple actions MOCK_OUTPUT_1 = { - 'execution_id': 'idfoo3', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line 1\n' + "execution_id": "idfoo3", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line 1\n", } MOCK_OUTPUT_2 = { - 'execution_id': 'idfoo3', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line 2\n' + "execution_id": "idfoo3", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line 2\n", } @@ -279,42 +212,55 @@ def __init__(self, *args, **kwargs): self.shell = Shell() @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_SUCCEEDED), - 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_1_SUCCEEDED), 200, "OK" + ) + ), + ) def test_tail_simple_execution_already_finished_succeeded(self): - argv = ['execution', 'tail', 'idfoo1'] + argv = ["execution", "tail", "idfoo1"] self.assertEqual(self.shell.run(argv), 0) stdout = self.stdout.getvalue() stderr = self.stderr.getvalue() - self.assertIn('Execution idfoo1 has completed (status=succeeded)', stdout) - self.assertEqual(stderr, '') + self.assertIn("Execution idfoo1 has completed (status=succeeded)", stdout) + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_2_FAILED), - 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_2_FAILED), 200, "OK" + ) + ), + ) def test_tail_simple_execution_already_finished_failed(self): - argv = ['execution', 'tail', 'idfoo2'] + argv = ["execution", "tail", "idfoo2"] self.assertEqual(self.shell.run(argv), 0) stdout = self.stdout.getvalue() stderr = self.stderr.getvalue() - self.assertIn('Execution idfoo2 has completed (status=failed)', stdout) - self.assertEqual(stderr, '') + self.assertIn("Execution idfoo2 has completed (status=failed)", stdout) + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_1_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo1'] + argv = ["execution", "tail", "idfoo1"] - MOCK_EVENTS = [ - MOCK_LIVEACTION_1_SUCCEEDED - ] + MOCK_EVENTS = [MOCK_LIVEACTION_1_SUCCEEDED] mock_cls = mock.Mock() mock_cls.listen = mock.Mock() @@ -333,21 +279,26 @@ def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manage Execution idfoo1 has completed (status=succeeded). """ self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_simple_execution_running_with_data(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo3'] + argv = ["execution", "tail", "idfoo3"] MOCK_EVENTS = [ MOCK_LIVEACTION_3_RUNNING, MOCK_OUTPUT_1, MOCK_OUTPUT_2, - MOCK_LIVEACTION_3_SUCCEDED + MOCK_LIVEACTION_3_SUCCEDED, ] mock_cls = mock.Mock() @@ -372,41 +323,39 @@ def test_tail_simple_execution_running_with_data(self, mock_stream_manager): Execution idfoo3 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_action_chain_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo3'] + argv = ["execution", "tail", "idfoo3"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_3_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_3_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_3_CHILD_2_FAILED, - # Parent workflow task finished - MOCK_LIVEACTION_3_SUCCEDED + MOCK_LIVEACTION_3_SUCCEDED, ] mock_cls = mock.Mock() @@ -440,41 +389,39 @@ def test_tail_action_chain_workflow_execution(self, mock_stream_manager): Execution idfoo3 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_orquesta_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_4_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_4_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, - # Parent workflow task finished - MOCK_LIVEACTION_4_SUCCEDED + MOCK_LIVEACTION_4_SUCCEDED, ] mock_cls = mock.Mock() @@ -508,64 +455,55 @@ def test_tail_orquesta_workflow_execution(self, mock_stream_manager): Execution idfoo4 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_4_RUNNING, - # Child task 1 started running (sub workflow) MOCK_LIVEACTION_4_CHILD_1_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_4_CHILD_1_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1, MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2, - # Another execution has started, this output should not be included MOCK_LIVEACTION_3_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED, - # Parent workflow task finished MOCK_LIVEACTION_3_SUCCEDED, # End another execution - # Child task 1 has finished MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED, - # Child task 1 finished (sub workflow) MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, - # Parent workflow task finished - MOCK_LIVEACTION_4_SUCCEDED + MOCK_LIVEACTION_4_SUCCEDED, ] mock_cls = mock.Mock() @@ -604,32 +542,33 @@ def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manage """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_child_execution_directly(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Other executions should not interfere # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Child task 1 finished (sub workflow) MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 finished - MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT + MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, ] mock_cls = mock.Mock() @@ -654,4 +593,4 @@ def test_tail_child_execution_directly(self, mock_stream_manager): """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") diff --git a/st2client/tests/unit/test_formatters.py b/st2client/tests/unit/test_formatters.py index b3733faba5..fe0370aea1 100644 --- a/st2client/tests/unit/test_formatters.py +++ b/st2client/tests/unit/test_formatters.py @@ -39,38 +39,43 @@ LOG = logging.getLogger(__name__) FIXTURES_MANIFEST = { - 'executions': ['execution.json', - 'execution_result_has_carriage_return.json', - 'execution_unicode.json', - 'execution_double_backslash.json', - 'execution_with_stack_trace.json', - 'execution_with_schema.json'], - 'results': ['execution_get_default.txt', - 'execution_get_detail.txt', - 'execution_get_result_by_key.txt', - 'execution_result_has_carriage_return.txt', - 'execution_result_has_carriage_return_py3.txt', - 'execution_get_attributes.txt', - 'execution_list_attr_start_timestamp.txt', - 'execution_list_empty_response_start_timestamp_attr.txt', - 'execution_unescape_newline.txt', - 'execution_unicode.txt', - 'execution_double_backslash.txt', - 'execution_unicode_py3.txt', - 'execution_get_has_schema.txt'] + "executions": [ + "execution.json", + "execution_result_has_carriage_return.json", + "execution_unicode.json", + "execution_double_backslash.json", + "execution_with_stack_trace.json", + "execution_with_schema.json", + ], + "results": [ + "execution_get_default.txt", + "execution_get_detail.txt", + "execution_get_result_by_key.txt", + "execution_result_has_carriage_return.txt", + "execution_result_has_carriage_return_py3.txt", + "execution_get_attributes.txt", + "execution_list_attr_start_timestamp.txt", + "execution_list_empty_response_start_timestamp_attr.txt", + "execution_unescape_newline.txt", + "execution_unicode.txt", + "execution_double_backslash.txt", + "execution_unicode_py3.txt", + "execution_get_has_schema.txt", + ], } FIXTURES = loader.load_fixtures(fixtures_dict=FIXTURES_MANIFEST) -EXECUTION = FIXTURES['executions']['execution.json'] -UNICODE = FIXTURES['executions']['execution_unicode.json'] -DOUBLE_BACKSLASH = FIXTURES['executions']['execution_double_backslash.json'] -OUTPUT_SCHEMA = FIXTURES['executions']['execution_with_schema.json'] -NEWLINE = FIXTURES['executions']['execution_with_stack_trace.json'] -HAS_CARRIAGE_RETURN = FIXTURES['executions']['execution_result_has_carriage_return.json'] +EXECUTION = FIXTURES["executions"]["execution.json"] +UNICODE = FIXTURES["executions"]["execution_unicode.json"] +DOUBLE_BACKSLASH = FIXTURES["executions"]["execution_double_backslash.json"] +OUTPUT_SCHEMA = FIXTURES["executions"]["execution_with_schema.json"] +NEWLINE = FIXTURES["executions"]["execution_with_stack_trace.json"] +HAS_CARRIAGE_RETURN = FIXTURES["executions"][ + "execution_result_has_carriage_return.json" +] class TestExecutionResultFormatter(unittest2.TestCase): - def __init__(self, *args, **kwargs): super(TestExecutionResultFormatter, self).__init__(*args, **kwargs) self.shell = shell.Shell() @@ -88,212 +93,278 @@ def tearDown(self): os.unlink(self.path) def _redirect_console(self, path): - sys.stdout = open(path, 'w') - sys.stderr = open(path, 'w') + sys.stdout = open(path, "w") + sys.stderr = open(path, "w") def _undo_console_redirect(self): sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ def test_console_redirect(self): - message = 'Hello, World!' + message = "Hello, World!" print(message) self._undo_console_redirect() - with open(self.path, 'r') as fd: - content = fd.read().replace('\n', '') + with open(self.path, "r") as fd: + content = fd.read().replace("\n", "") self.assertEqual(content, message) def test_execution_get_default(self): - argv = ['execution', 'get', EXECUTION['id']] + argv = ["execution", "get", EXECUTION["id"]] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_default.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_default.txt"]) def test_execution_get_attributes(self): - argv = ['execution', 'get', EXECUTION['id'], '--attr', 'status', 'end_timestamp'] + argv = [ + "execution", + "get", + EXECUTION["id"], + "--attr", + "status", + "end_timestamp", + ] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_attributes.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_attributes.txt"]) def test_execution_get_default_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-j'] + argv = ["execution", "get", EXECUTION["id"], "-j"] content = self._get_execution(argv) - self.assertEqual(json.loads(content), - jsutil.get_kvps(EXECUTION, ['id', 'action.ref', 'context.user', - 'start_timestamp', 'end_timestamp', 'status', - 'parameters', 'result'])) + self.assertEqual( + json.loads(content), + jsutil.get_kvps( + EXECUTION, + [ + "id", + "action.ref", + "context.user", + "start_timestamp", + "end_timestamp", + "status", + "parameters", + "result", + ], + ), + ) def test_execution_get_detail(self): - argv = ['execution', 'get', EXECUTION['id'], '-d'] + argv = ["execution", "get", EXECUTION["id"], "-d"] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_detail.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_detail.txt"]) def test_execution_with_schema(self): - argv = ['execution', 'get', OUTPUT_SCHEMA['id']] + argv = ["execution", "get", OUTPUT_SCHEMA["id"]] content = self._get_schema_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_has_schema.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_has_schema.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(NEWLINE), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(NEWLINE), 200, "OK", {}) + ), + ) def test_execution_unescape_newline(self): - """Ensure client renders newline characters - """ + """Ensure client renders newline characters""" - argv = ['execution', 'get', NEWLINE['id']] + argv = ["execution", "get", NEWLINE["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() - self.assertEqual(content, FIXTURES['results']['execution_unescape_newline.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_unescape_newline.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(UNICODE), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(UNICODE), 200, "OK", {}) + ), + ) def test_execution_unicode(self): - """Ensure client renders unicode escape sequences - """ + """Ensure client renders unicode escape sequences""" - argv = ['execution', 'get', UNICODE['id']] + argv = ["execution", "get", UNICODE["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() if six.PY2: - self.assertEqual(content, FIXTURES['results']['execution_unicode.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_unicode.txt"]) else: - content = content.replace(r'\xE2\x80\xA1', r'\u2021') - self.assertEqual(content, FIXTURES['results']['execution_unicode_py3.txt']) + content = content.replace(r"\xE2\x80\xA1", r"\u2021") + self.assertEqual(content, FIXTURES["results"]["execution_unicode_py3.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, "OK", {}) + ), + ) def test_execution_double_backslash_not_unicode_escape_sequence(self): - argv = ['execution', 'get', DOUBLE_BACKSLASH['id']] + argv = ["execution", "get", DOUBLE_BACKSLASH["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() - self.assertEqual(content, FIXTURES['results']['execution_double_backslash.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_double_backslash.txt"]) def test_execution_get_detail_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-d', '-j'] + argv = ["execution", "get", EXECUTION["id"], "-d", "-j"] content = self._get_execution(argv) content_dict = json.loads(content) # Sufficient to check if output contains all expected keys. The entire result will not # match as content will contain characters which improve rendering. for k in six.iterkeys(EXECUTION): - if k in ['liveaction', 'callback']: + if k in ["liveaction", "callback"]: continue if k in content: continue - self.assertTrue(False, 'Missing key %s. %s != %s' % (k, EXECUTION, content_dict)) + self.assertTrue( + False, "Missing key %s. %s != %s" % (k, EXECUTION, content_dict) + ) def test_execution_get_result_by_key(self): - argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout'] + argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout"] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_result_by_key.txt']) + self.assertEqual( + content, FIXTURES["results"]["execution_get_result_by_key.txt"] + ) def test_execution_get_result_by_key_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout', '-j'] + argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout", "-j"] content = self._get_execution(argv) - self.assertDictEqual(json.loads(content), - jsutil.get_kvps(EXECUTION, ['result.localhost.stdout'])) + self.assertDictEqual( + json.loads(content), jsutil.get_kvps(EXECUTION, ["result.localhost.stdout"]) + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(HAS_CARRIAGE_RETURN), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(HAS_CARRIAGE_RETURN), 200, "OK", {} + ) + ), + ) def test_execution_get_detail_with_carriage_return(self): - argv = ['execution', 'get', HAS_CARRIAGE_RETURN['id'], '-d'] + argv = ["execution", "get", HAS_CARRIAGE_RETURN["id"], "-d"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() if six.PY2: self.assertEqual( - content, FIXTURES['results']['execution_result_has_carriage_return.txt']) + content, FIXTURES["results"]["execution_result_has_carriage_return.txt"] + ) else: self.assertEqual( content, - FIXTURES['results']['execution_result_has_carriage_return_py3.txt']) + FIXTURES["results"]["execution_result_has_carriage_return_py3.txt"], + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK", {}) + ), + ) def test_execution_list_attribute_provided(self): # Client shouldn't throw if "-a" flag is provided when listing executions - argv = ['execution', 'list', '-a', 'start_timestamp'] + argv = ["execution", "list", "-a", "start_timestamp"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() self.assertEqual( - content, FIXTURES['results']['execution_list_attr_start_timestamp.txt']) + content, FIXTURES["results"]["execution_list_attr_start_timestamp.txt"] + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK", {})), + ) def test_execution_list_attribute_provided_empty_response(self): # Client shouldn't throw if "-a" flag is provided, but there are no executions - argv = ['execution', 'list', '-a', 'start_timestamp'] + argv = ["execution", "list", "-a", "start_timestamp"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() self.assertEqual( - content, FIXTURES['results']['execution_list_empty_response_start_timestamp_attr.txt']) + content, + FIXTURES["results"][ + "execution_list_empty_response_start_timestamp_attr.txt" + ], + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK", {}) + ), + ) def _get_execution(self, argv): self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() return content @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, "OK", {}) + ), + ) def _get_schema_execution(self, argv): self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() return content def test_SinlgeRowTable_notebox_one(self): - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only one action execution is displayed. Use -n/--last flag for " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only one action execution is displayed. Use -n/--last flag for " "more results." + ) print(self.table.note_box("action executions", 1)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) def test_SinlgeRowTable_notebox_zero(self): - with mock.patch('sys.stderr', new=BytesIO()) as fackety_fake: - contents = (fackety_fake.getvalue()) - self.assertEqual(contents, b'') + with mock.patch("sys.stderr", new=BytesIO()) as fackety_fake: + contents = fackety_fake.getvalue() + self.assertEqual(contents, b"") def test_SinlgeRowTable_notebox_default(self): - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only first 50 action executions are displayed. Use -n/--last flag " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only first 50 action executions are displayed. Use -n/--last flag " "for more results." + ) print(self.table.note_box("action executions", 50)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only first 15 action executions are displayed. Use -n/--last flag " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only first 15 action executions are displayed. Use -n/--last flag " "for more results." + ) print(self.table.note_box("action executions", 15)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) diff --git a/st2client/tests/unit/test_inquiry.py b/st2client/tests/unit/test_inquiry.py index 138f1da899..4132fda0d1 100644 --- a/st2client/tests/unit/test_inquiry.py +++ b/st2client/tests/unit/test_inquiry.py @@ -31,12 +31,12 @@ def _randomize_inquiry_id(inquiry): newinquiry = copy.deepcopy(inquiry) - newinquiry['id'] = str(uuid.uuid4()) + newinquiry["id"] = str(uuid.uuid4()) # ID can't have '1440' in it, otherwise our `count()` fails # when inspecting the inquiry list output for test: # test_list_inquiries_limit() - while '1440' in newinquiry['id']: - newinquiry['id'] = str(uuid.uuid4()) + while "1440" in newinquiry["id"]: + newinquiry["id"] = str(uuid.uuid4()) return newinquiry @@ -45,8 +45,7 @@ def _generate_inquiries(count): class TestInquiryBase(base.BaseCLITestCase): - """Base class for "inquiry" CLI tests - """ + """Base class for "inquiry" CLI tests""" capture_output = True @@ -54,8 +53,8 @@ def __init__(self, *args, **kwargs): super(TestInquiryBase, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): @@ -72,14 +71,12 @@ def tearDown(self): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, } -RESPONSE_DEFAULT = { - "continue": True -} +RESPONSE_DEFAULT = {"continue": True} SCHEMA_MULTIPLE = { "title": "response_data", @@ -88,30 +85,24 @@ def tearDown(self): "name": { "type": "string", "description": "What is your name?", - "required": True + "required": True, }, "pin": { "type": "integer", "description": "What is your PIN?", - "required": True + "required": True, }, "paradox": { "type": "boolean", "description": "This statement is False.", - "required": True - } + "required": True, + }, }, } -RESPONSE_MULTIPLE = { - "name": "matt", - "pin": 1234, - "paradox": True -} +RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True} -RESPONSE_BAD = { - "foo": "bar" -} +RESPONSE_BAD = {"foo": "bar"} INQUIRY_1 = { "id": "abcdef", @@ -119,7 +110,7 @@ def tearDown(self): "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } INQUIRY_MULTIPLE = { @@ -128,145 +119,200 @@ def tearDown(self): "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } class TestInquirySubcommands(TestInquiryBase): - @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, 'OK'))) + requests, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK") + ), + ) def test_get_inquiry(self): - """Test retrieval of a single inquiry - """ - inquiry_id = 'abcdef' - args = ['inquiry', 'get', inquiry_id] + """Test retrieval of a single inquiry""" + inquiry_id = "abcdef" + args = ["inquiry", "get", inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 404, 'NOT FOUND'))) + requests, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps({}), 404, "NOT FOUND") + ), + ) def test_get_inquiry_not_found(self): - """Test retrieval of a inquiry that doesn't exist - """ - inquiry_id = 'asdbv' - args = ['inquiry', 'get', inquiry_id] + """Test retrieval of a inquiry that doesn't exist""" + inquiry_id = "asdbv" + args = ["inquiry", "get", inquiry_id] retcode = self.shell.run(args) - self.assertEqual('Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue()) + self.assertEqual( + 'Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue() + ) self.assertEqual(retcode, 2) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps([INQUIRY_1]), 200, 'OK', {'X-Total-Count': '1'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps([INQUIRY_1]), 200, "OK", {"X-Total-Count": "1"} + ) + ) + ), + ) def test_list_inquiries(self): - """Test retrieval of a list of Inquiries - """ - args = ['inquiry', 'list'] + """Test retrieval of a list of Inquiries""" + args = ["inquiry", "list"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) - self.assertEqual(self.stdout.getvalue().count('1440'), 1) + self.assertEqual(self.stdout.getvalue().count("1440"), 1) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(_generate_inquiries(50)), 200, 'OK', {'X-Total-Count': '55'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps(_generate_inquiries(50)), + 200, + "OK", + {"X-Total-Count": "55"}, + ) + ) + ), + ) def test_list_inquiries_limit(self): - """Test retrieval of a list of Inquiries while using the "limit" option - """ - args = ['inquiry', 'list', '-n', '50'] + """Test retrieval of a list of Inquiries while using the "limit" option""" + args = ["inquiry", "list", "-n", "50"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) - self.assertEqual(self.stdout.getvalue().count('1440'), 50) - self.assertIn('Note: Only first 50 inquiries are displayed.', self.stderr.getvalue()) + self.assertEqual(self.stdout.getvalue().count("1440"), 50) + self.assertIn( + "Note: Only first 50 inquiries are displayed.", self.stderr.getvalue() + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps([]), 200, 'OK', {'X-Total-Count': '0'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps([]), 200, "OK", {"X-Total-Count": "0"}) + ) + ), + ) def test_list_empty_inquiries(self): - """Test empty list of Inquiries - """ - args = ['inquiry', 'list'] + """Test empty list of Inquiries""" + args = ["inquiry", "list"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK' - )))) - @mock.patch('st2client.commands.inquiry.InteractiveForm') + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), + 200, + "OK", + ) + ) + ), + ) + @mock.patch("st2client.commands.inquiry.InteractiveForm") def test_respond(self, mock_form): - """Test interactive response - """ + """Test interactive response""" form_instance = mock_form.return_value form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT - args = ['inquiry', 'respond', 'abcdef'] + args = ["inquiry", "respond", "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK' - )))) + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), + 200, + "OK", + ) + ) + ), + ) def test_respond_response_flag(self): - """Test response without interactive mode - """ - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, 'abcdef'] + """Test response without interactive mode""" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({}), 400, '400 Client Error: Bad Request' - )))) + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps({}), 400, "400 Client Error: Bad Request") + ) + ), + ) def test_respond_invalid(self): - """Test invalid response - """ - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_BAD, 'abcdef'] + """Test invalid response""" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_BAD, "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: 400 Client Error: Bad Request', self.stdout.getvalue().strip()) + self.assertEqual( + "ERROR: 400 Client Error: Bad Request", self.stdout.getvalue().strip() + ) def test_respond_nonexistent_inquiry(self): - """Test responding to an inquiry that doesn't exist - """ - inquiry_id = '134234' - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, inquiry_id] + """Test responding to an inquiry that doesn't exist""" + inquiry_id = "134234" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, - self.stdout.getvalue().strip()) + self.assertEqual( + 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, + self.stdout.getvalue().strip(), + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({}), 404, '404 Client Error: Not Found' - )))) - @mock.patch('st2client.commands.inquiry.InteractiveForm') + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps({}), 404, "404 Client Error: Not Found") + ) + ), + ) + @mock.patch("st2client.commands.inquiry.InteractiveForm") def test_respond_nonexistent_inquiry_interactive(self, mock_form): """Test interactively responding to an inquiry that doesn't exist @@ -274,11 +320,13 @@ def test_respond_nonexistent_inquiry_interactive(self, mock_form): responding with PUT, in order to retrieve the desired schema for this inquiry. So, we want to test that interaction separately. """ - inquiry_id = '253432' + inquiry_id = "253432" form_instance = mock_form.return_value form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT - args = ['inquiry', 'respond', inquiry_id] + args = ["inquiry", "respond", inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, - self.stdout.getvalue().strip()) + self.assertEqual( + 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, + self.stdout.getvalue().strip(), + ) diff --git a/st2client/tests/unit/test_interactive.py b/st2client/tests/unit/test_interactive.py index 24f0080232..dce4c6748d 100644 --- a/st2client/tests/unit/test_interactive.py +++ b/st2client/tests/unit/test_interactive.py @@ -31,37 +31,32 @@ class TestInteractive(unittest2.TestCase): - def assertPromptMessage(self, prompt_mock, message, msg=None): self.assertEqual(prompt_mock.call_args[0], (message,), msg) def assertPromptDescription(self, prompt_mock, message, msg=None): - toolbar_factory = prompt_mock.call_args[1]['get_bottom_toolbar_tokens'] + toolbar_factory = prompt_mock.call_args[1]["get_bottom_toolbar_tokens"] self.assertEqual(toolbar_factory(None)[0][1], message, msg) def assertPromptValidate(self, prompt_mock, value): - validator = prompt_mock.call_args[1]['validator'] + validator = prompt_mock.call_args[1]["validator"] validator.validate(Document(text=six.text_type(value))) def assertPromptPassword(self, prompt_mock, value, msg=None): - self.assertEqual(prompt_mock.call_args[1]['is_password'], value, msg) + self.assertEqual(prompt_mock.call_args[1]["is_password"], value, msg) def test_interactive_form(self): reader = mock.MagicMock() Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=True) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - Reader.condition.assert_called_once_with(schema['string']) + Reader.condition.assert_called_once_with(schema["string"]) reader.read.assert_called_once_with() def test_interactive_form_no_match(self): @@ -69,35 +64,27 @@ def test_interactive_form_no_match(self): Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=False) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - Reader.condition.assert_called_once_with(schema['string']) + Reader.condition.assert_called_once_with(schema["string"]) reader.read.assert_not_called() - @mock.patch('sys.stdout', new_callable=StringIO) + @mock.patch("sys.stdout", new_callable=StringIO) def test_interactive_form_interrupted(self, stdout_mock): reader = mock.MagicMock() Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=True) reader.read = mock.MagicMock(side_effect=KeyboardInterrupt) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - self.assertEqual(stdout_mock.getvalue(), 'Dialog interrupted.\n') + self.assertEqual(stdout_mock.getvalue(), "Dialog interrupted.\n") def test_interactive_form_interrupted_reraised(self): reader = mock.MagicMock() @@ -105,285 +92,278 @@ def test_interactive_form_interrupted_reraised(self): Reader.condition = mock.MagicMock(return_value=True) reader.read = mock.MagicMock(side_effect=KeyboardInterrupt) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): - self.assertRaises(interactive.DialogInterrupted, - interactive.InteractiveForm(schema, reraise=True).initiate_dialog) + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): + self.assertRaises( + interactive.DialogInterrupted, + interactive.InteractiveForm(schema, reraise=True).initiate_dialog, + ) - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_stringreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 'hey' - } - Reader = interactive.StringReader('some', spec) + spec = {"description": "some description", "default": "hey"} + Reader = interactive.StringReader("some", spec) - prompt_mock.return_value = 'stuff' + prompt_mock.return_value = "stuff" result = Reader.read() - self.assertEqual(result, 'stuff') - self.assertPromptMessage(prompt_mock, 'some [hey]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'stuff') + self.assertEqual(result, "stuff") + self.assertPromptMessage(prompt_mock, "some [hey]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "stuff") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'hey') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "hey") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_booleanreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': False - } - Reader = interactive.BooleanReader('some', spec) + spec = {"description": "some description", "default": False} + Reader = interactive.BooleanReader("some", spec) - prompt_mock.return_value = 'y' + prompt_mock.return_value = "y" result = Reader.read() self.assertEqual(result, True) - self.assertPromptMessage(prompt_mock, 'some (boolean) [n]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'y') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (boolean) [n]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "y") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, False) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_numberreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 3.2 - } - Reader = interactive.NumberReader('some', spec) + spec = {"description": "some description", "default": 3.2} + Reader = interactive.NumberReader("some", spec) - prompt_mock.return_value = '5.3' + prompt_mock.return_value = "5.3" result = Reader.read() self.assertEqual(result, 5.3) - self.assertPromptMessage(prompt_mock, 'some (float) [3.2]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '5.3') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (float) [3.2]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "5.3") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, 3.2) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_integerreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 3 - } - Reader = interactive.IntegerReader('some', spec) + spec = {"description": "some description", "default": 3} + Reader = interactive.IntegerReader("some", spec) - prompt_mock.return_value = '5' + prompt_mock.return_value = "5" result = Reader.read() self.assertEqual(result, 5) - self.assertPromptMessage(prompt_mock, 'some (integer) [3]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '5') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, '5.3') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (integer) [3]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "5") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "5.3", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, 3) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_secretstringreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 'hey' - } - Reader = interactive.SecretStringReader('some', spec) + spec = {"description": "some description", "default": "hey"} + Reader = interactive.SecretStringReader("some", spec) - prompt_mock.return_value = 'stuff' + prompt_mock.return_value = "stuff" result = Reader.read() - self.assertEqual(result, 'stuff') - self.assertPromptMessage(prompt_mock, 'some (secret) [hey]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'stuff') + self.assertEqual(result, "stuff") + self.assertPromptMessage(prompt_mock, "some (secret) [hey]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "stuff") self.assertPromptPassword(prompt_mock, True) - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'hey') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "hey") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_enumreader(self, prompt_mock): spec = { - 'enum': ['some', 'thing', 'else'], - 'description': 'some description', - 'default': 'thing' + "enum": ["some", "thing", "else"], + "description": "some description", + "default": "thing", } - Reader = interactive.EnumReader('some', spec) + Reader = interactive.EnumReader("some", spec) - prompt_mock.return_value = '2' + prompt_mock.return_value = "2" result = Reader.read() - self.assertEqual(result, 'else') - message = 'some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: ' + self.assertEqual(result, "else") + message = "some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, '5') - - prompt_mock.return_value = '' + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "5", + ) + + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'thing') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "thing") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': ['a', 'b'] - } - Reader = interactive.ArrayReader('some', spec) + spec = {"description": "some description", "default": ["a", "b"]} + Reader = interactive.ArrayReader("some", spec) - prompt_mock.return_value = 'some,thing,else' + prompt_mock.return_value = "some,thing,else" result = Reader.read() - self.assertEqual(result, ['some', 'thing', 'else']) - self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'some,thing,else') + self.assertEqual(result, ["some", "thing", "else"]) + self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "some,thing,else") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, ['a', 'b']) - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, ["a", "b"]) + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayreader_ends_with_comma(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': ['a', 'b'] - } - Reader = interactive.ArrayReader('some', spec) + spec = {"description": "some description", "default": ["a", "b"]} + Reader = interactive.ArrayReader("some", spec) - prompt_mock.return_value = 'some,thing,else,' + prompt_mock.return_value = "some,thing,else," result = Reader.read() - self.assertEqual(result, ['some', 'thing', 'else', '']) - self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'some,thing,else,') + self.assertEqual(result, ["some", "thing", "else", ""]) + self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "some,thing,else,") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayenumreader(self, prompt_mock): spec = { - 'items': { - 'enum': ['a', 'b', 'c', 'd', 'e'] - }, - 'description': 'some description', - 'default': ['a', 'b'] + "items": {"enum": ["a", "b", "c", "d", "e"]}, + "description": "some description", + "default": ["a", "b"], } - Reader = interactive.ArrayEnumReader('some', spec) + Reader = interactive.ArrayEnumReader("some", spec) - prompt_mock.return_value = '0,2,4' + prompt_mock.return_value = "0,2,4" result = Reader.read() - self.assertEqual(result, ['a', 'c', 'e']) - message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: ' + self.assertEqual(result, ["a", "c", "e"]) + message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0,2,4') + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0,2,4") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, ['a', 'b']) - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, ["a", "b"]) + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayenumreader_ends_with_comma(self, prompt_mock): spec = { - 'items': { - 'enum': ['a', 'b', 'c', 'd', 'e'] - }, - 'description': 'some description', - 'default': ['a', 'b'] + "items": {"enum": ["a", "b", "c", "d", "e"]}, + "description": "some description", + "default": ["a", "b"], } - Reader = interactive.ArrayEnumReader('some', spec) + Reader = interactive.ArrayEnumReader("some", spec) - prompt_mock.return_value = '0,2,4,' + prompt_mock.return_value = "0,2,4," result = Reader.read() - self.assertEqual(result, ['a', 'c', 'e']) - message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: ' + self.assertEqual(result, ["a", "c", "e"]) + message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0,2,4,') + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0,2,4,") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayobjectreader(self, prompt_mock): spec = { - 'items': { - 'type': 'object', - 'properties': { - 'foo': { - 'type': 'string', - 'description': 'some description', + "items": { + "type": "object", + "properties": { + "foo": { + "type": "string", + "description": "some description", + }, + "bar": { + "type": "string", + "description": "some description", }, - 'bar': { - 'type': 'string', - 'description': 'some description', - } - } + }, }, - 'description': 'some description', + "description": "some description", } - Reader = interactive.ArrayObjectReader('some', spec) + Reader = interactive.ArrayObjectReader("some", spec) # To emulate continuing setting, this flag variable is needed self.is_continued = False def side_effect(msg, **kwargs): - if re.match(r'^~~~ Would you like to add another item to.*', msg): + if re.match(r"^~~~ Would you like to add another item to.*", msg): # prompt requires the input to judge continuing setting, or not if not self.is_continued: # continuing the configuration only once self.is_continued = True - return '' + return "" else: # finishing to configuration - return 'n' + return "n" else: # prompt requires the input of property value in the object - return 'value' + return "value" prompt_mock.side_effect = side_effect results = Reader.read() self.assertEqual(len(results), 2) self.assertTrue(all([len(list(x.keys())) == 2 for x in results])) - self.assertTrue(all(['foo' in x and 'bar' in x for x in results])) + self.assertTrue(all(["foo" in x and "bar" in x for x in results])) diff --git a/st2client/tests/unit/test_keyvalue.py b/st2client/tests/unit/test_keyvalue.py index bb5bf09d60..52c240a052 100644 --- a/st2client/tests/unit/test_keyvalue.py +++ b/st2client/tests/unit/test_keyvalue.py @@ -29,77 +29,70 @@ LOG = logging.getLogger(__name__) KEYVALUE = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system' + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", } KEYVALUE_USER = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'user': 'stanley' + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "user": "stanley", } KEYVALUE_SECRET = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'secret': True + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "secret": True, } KEYVALUE_PRE_ENCRYPTED = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'AAABBBCCC1234', - 'scope': 'system', - 'encrypted': True, - 'secret': True + "id": "kv_name", + "name": "kv_name.", + "value": "AAABBBCCC1234", + "scope": "system", + "encrypted": True, + "secret": True, } KEYVALUE_TTL = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'ttl': 100 + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "ttl": 100, } KEYVALUE_OBJECT = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': {'obj': [1, True, 23.4, 'abc']}, - 'scope': 'system', + "id": "kv_name", + "name": "kv_name.", + "value": {"obj": [1, True, 23.4, "abc"]}, + "scope": "system", } KEYVALUE_ALL = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'AAAAABBBBBCCCCCCDDDDD11122345', - 'scope': 'system', - 'user': 'stanley', - 'secret': True, - 'encrypted': True, - 'ttl': 100 + "id": "kv_name", + "name": "kv_name.", + "value": "AAAAABBBBBCCCCCCDDDDD11122345", + "scope": "system", + "user": "stanley", + "secret": True, + "encrypted": True, + "ttl": 100, } -KEYVALUE_MISSING_NAME = { - 'id': 'kv_name', - 'value': 'super cool value' -} +KEYVALUE_MISSING_NAME = {"id": "kv_name", "value": "super cool value"} -KEYVALUE_MISSING_VALUE = { - 'id': 'kv_name', - 'name': 'kv_name.' -} +KEYVALUE_MISSING_VALUE = {"id": "kv_name", "name": "kv_name."} class TestKeyValueBase(base.BaseCLITestCase): - """Base class for "key" CLI tests - """ + """Base class for "key" CLI tests""" capture_output = True @@ -107,8 +100,8 @@ def __init__(self, *args, **kwargs): super(TestKeyValueBase, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): @@ -119,44 +112,49 @@ def tearDown(self): class TestKeyValueSet(TestKeyValueBase): - @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, - 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK" + ) + ), + ) def test_set_keyvalue(self): - """Test setting key/value pair with optional pre_encrypted field - """ - args = ['key', 'set', '--encrypted', 'kv_name', 'AAABBBCCC1234'] + """Test setting key/value pair with optional pre_encrypted field""" + args = ["key", "set", "--encrypted", "kv_name", "AAABBBCCC1234"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) def test_encrypt_and_encrypted_flags_are_mutually_exclusive(self): - args = ['key', 'set', '--encrypt', '--encrypted', 'kv_name', 'AAABBBCCC1234'] + args = ["key", "set", "--encrypt", "--encrypted", "kv_name", "AAABBBCCC1234"] - self.assertRaisesRegexp(SystemExit, '2', self.shell.run, args) + self.assertRaisesRegexp(SystemExit, "2", self.shell.run, args) self.stderr.seek(0) stderr = self.stderr.read() - expected_msg = ('error: argument --encrypted: not allowed with argument -e/--encrypt') + expected_msg = ( + "error: argument --encrypted: not allowed with argument -e/--encrypt" + ) self.assertIn(expected_msg, stderr) class TestKeyValueLoad(TestKeyValueBase): - @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")), + ) def test_load_keyvalue_json(self): - """Test loading of key/value pair in JSON format - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair in JSON format""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -164,17 +162,18 @@ def test_load_keyvalue_json(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")), + ) def test_load_keyvalue_yaml(self): - """Test loading of key/value pair in YAML format - """ - fd, path = tempfile.mkstemp(suffix='.yaml') + """Test loading of key/value pair in YAML format""" + fd, path = tempfile.mkstemp(suffix=".yaml") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(yaml.safe_dump(KEYVALUE)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -182,17 +181,20 @@ def test_load_keyvalue_yaml(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, "OK") + ), + ) def test_load_keyvalue_user(self): - """Test loading of key/value pair with the optional user field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional user field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_USER, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -200,17 +202,20 @@ def test_load_keyvalue_user(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, "OK") + ), + ) def test_load_keyvalue_secret(self): - """Test loading of key/value pair with the optional secret field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional secret field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_SECRET, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -218,18 +223,22 @@ def test_load_keyvalue_secret(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, - 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK" + ) + ), + ) def test_load_keyvalue_already_encrypted(self): - """Test loading of key/value pair with the pre-encrypted value - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the pre-encrypted value""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_PRE_ENCRYPTED, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -237,17 +246,20 @@ def test_load_keyvalue_already_encrypted(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, "OK") + ), + ) def test_load_keyvalue_ttl(self): - """Test loading of key/value pair with the optional ttl field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional ttl field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_TTL, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -255,23 +267,26 @@ def test_load_keyvalue_ttl(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK") + ), + ) def test_load_keyvalue_object(self): - """Test loading of key/value pair where the value is an object - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair where the value is an object""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_OBJECT, indent=4)) # test converting with short option - args = ['key', 'load', '-c', path] + args = ["key", "load", "-c", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) # test converting with long option - args = ['key', 'load', '--convert', path] + args = ["key", "load", "--convert", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -279,19 +294,23 @@ def test_load_keyvalue_object(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK") + ), + ) def test_load_keyvalue_object_fail(self): """Test failure to load key/value pair where the value is an object - and the -c/--convert option is not passed + and the -c/--convert option is not passed """ - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_OBJECT, indent=4)) # test converting with short option - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertNotEqual(retcode, 0) finally: @@ -299,17 +318,20 @@ def test_load_keyvalue_object_fail(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK") + ), + ) def test_load_keyvalue_all(self): - """Test loading of key/value pair with all optional fields - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with all optional fields""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_ALL, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -317,21 +339,23 @@ def test_load_keyvalue_all(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), - 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK") + ), + ) def test_load_keyvalue_array(self): - """Test loading an array of key/value pairs - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading an array of key/value pairs""" + fd, path = tempfile.mkstemp(suffix=".json") try: array = [KEYVALUE, KEYVALUE_ALL] json_str = json.dumps(array, indent=4) LOG.info(json_str) - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json_str) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -339,14 +363,13 @@ def test_load_keyvalue_array(self): os.unlink(path) def test_load_keyvalue_missing_name(self): - """Test loading of a key/value pair with the required field 'name' missing - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of a key/value pair with the required field 'name' missing""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_MISSING_NAME, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -354,14 +377,13 @@ def test_load_keyvalue_missing_name(self): os.unlink(path) def test_load_keyvalue_missing_value(self): - """Test loading of a key/value pair with the required field 'value' missing - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of a key/value pair with the required field 'value' missing""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_MISSING_VALUE, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -369,19 +391,17 @@ def test_load_keyvalue_missing_value(self): os.unlink(path) def test_load_keyvalue_missing_file(self): - """Test loading of a key/value pair with a missing file - """ - path = '/some/file/that/doesnt/exist.json' - args = ['key', 'load', path] + """Test loading of a key/value pair with a missing file""" + path = "/some/file/that/doesnt/exist.json" + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) def test_load_keyvalue_bad_file_extension(self): - """Test loading of a key/value pair with a bad file extension - """ - fd, path = tempfile.mkstemp(suffix='.badext') + """Test loading of a key/value pair with a bad file extension""" + fd, path = tempfile.mkstemp(suffix=".badext") try: - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -392,11 +412,11 @@ def test_load_keyvalue_empty_file(self): """ Loading K/V from an empty file shouldn't throw an error """ - fd, path = tempfile.mkstemp(suffix='.yaml') + fd, path = tempfile.mkstemp(suffix=".yaml") try: - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) - self.assertIn('No matching items found', self.stdout.getvalue()) + self.assertIn("No matching items found", self.stdout.getvalue()) self.assertEqual(retcode, 0) finally: os.close(fd) diff --git a/st2client/tests/unit/test_models.py b/st2client/tests/unit/test_models.py index dd7f35d6b8..8a137afa13 100644 --- a/st2client/tests/unit/test_models.py +++ b/st2client/tests/unit/test_models.py @@ -29,22 +29,24 @@ class TestSerialization(unittest2.TestCase): - def test_resource_serialize(self): - instance = base.FakeResource(id='123', name='abc') + instance = base.FakeResource(id="123", name="abc") self.assertDictEqual(instance.serialize(), base.RESOURCES[0]) def test_resource_deserialize(self): instance = base.FakeResource.deserialize(base.RESOURCES[0]) - self.assertEqual(instance.id, '123') - self.assertEqual(instance.name, 'abc') + self.assertEqual(instance.id, "123") + self.assertEqual(instance.name, "abc") class TestResourceManager(unittest2.TestCase): - @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_all(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) resources = mgr.get_all() @@ -53,8 +55,12 @@ def test_resource_get_all(self): self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_all_with_limit(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) resources = mgr.get_all(limit=50) @@ -63,135 +69,197 @@ def test_resource_get_all_with_limit(self): self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_all_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_all) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_get_by_id(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_id('123') + resource = mgr.get_by_id("123") actual = resource.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_get_by_id_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_id('123') + resource = mgr.get_by_id("123") self.assertIsNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_by_id_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_by_id) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_query(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources = mgr.query(name='abc') + resources = mgr.query(name="abc") actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {'X-Total-Count': '50'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {"X-Total-Count": "50"} + ) + ), + ) def test_resource_query_with_count(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources, count = mgr.query_with_count(name='abc') + resources, count = mgr.query_with_count(name="abc") actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) self.assertEqual(count, 50) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_query_with_limit(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources = mgr.query(name='abc', limit=50) + resources = mgr.query(name="abc", limit=50) actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND', - {'X-Total-Count': '30'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + "", 404, "NOT FOUND", {"X-Total-Count": "30"} + ) + ), + ) def test_resource_query_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) # No X-Total-Count - resources = mgr.query(name='abc') + resources = mgr.query(name="abc") self.assertListEqual(resources, []) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND', - {'X-Total-Count': '30'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + "", 404, "NOT FOUND", {"X-Total-Count": "30"} + ) + ), + ) def test_resource_query_with_count_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources, count = mgr.query_with_count(name='abc') + resources, count = mgr.query_with_count(name="abc") self.assertListEqual(resources, []) self.assertIsNone(count) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_query_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - self.assertRaises(Exception, mgr.query, name='abc') + self.assertRaises(Exception, mgr.query, name="abc") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_get_by_name(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) # No X-Total-Count - resource = mgr.get_by_name('abc') + resource = mgr.get_by_name("abc") actual = resource.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_get_by_name_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_name('abc') + resource = mgr.get_by_name("abc") self.assertIsNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_by_name_ambiguous(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - self.assertRaises(Exception, mgr.get_by_name, 'abc') + self.assertRaises(Exception, mgr.get_by_name, "abc") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_by_name_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_by_name) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_create(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize('{"name": "abc"}') @@ -199,16 +267,24 @@ def test_resource_create(self): self.assertIsNotNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_create_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize('{"name": "abc"}') self.assertRaises(Exception, mgr.create, instance) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_update(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) text = '{"id": "123", "name": "cba"}' @@ -217,8 +293,12 @@ def test_resource_update(self): self.assertIsNotNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_update_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) text = '{"id": "123", "name": "cba"}' @@ -226,39 +306,57 @@ def test_resource_update_failed(self): self.assertRaises(Exception, mgr.update, instance) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")), + ) def test_resource_delete(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - instance = mgr.get_by_name('abc') + instance = mgr.get_by_name("abc") mgr.delete(instance) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_delete_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize(base.RESOURCES[0]) mgr.delete(instance) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_delete_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - instance = mgr.get_by_name('abc') + instance = mgr.get_by_name("abc") self.assertRaises(Exception, mgr.delete, instance) - @mock.patch('requests.get') - @mock.patch('sseclient.SSEClient') + @mock.patch("requests.get") + @mock.patch("sseclient.SSEClient") def test_stream_resource_listen(self, mock_sseclient, mock_requests): mock_msg = mock.Mock() mock_msg.data = json.dumps(base.RESOURCES) @@ -267,14 +365,16 @@ def test_stream_resource_listen(self, mock_sseclient, mock_requests): def side_effect_checking_verify_parameter_is(): return [mock_msg] - mock_sseclient.return_value.events.side_effect = side_effect_checking_verify_parameter_is - mgr = models.StreamManager('https://example.com', cacert='/path/ca.crt') + mock_sseclient.return_value.events.side_effect = ( + side_effect_checking_verify_parameter_is + ) + mgr = models.StreamManager("https://example.com", cacert="/path/ca.crt") - resp = mgr.listen(events=['foo', 'bar']) + resp = mgr.listen(events=["foo", "bar"]) self.assertEqual(list(resp), [base.RESOURCES]) - call_args = tuple(['https://example.com/stream?events=foo%2Cbar']) - call_kwargs = {'stream': True, 'verify': '/path/ca.crt'} + call_args = tuple(["https://example.com/stream?events=foo%2Cbar"]) + call_kwargs = {"stream": True, "verify": "/path/ca.crt"} self.assertEqual(mock_requests.call_args_list[0][0], call_args) self.assertEqual(mock_requests.call_args_list[0][1], call_kwargs) @@ -283,15 +383,16 @@ def side_effect_checking_verify_parameter_is(): def side_effect_checking_verify_parameter_is_not(): return [mock_msg] - mock_sseclient.return_value.events.side_effect = \ + mock_sseclient.return_value.events.side_effect = ( side_effect_checking_verify_parameter_is_not - mgr = models.StreamManager('https://example.com') + ) + mgr = models.StreamManager("https://example.com") resp = mgr.listen() self.assertEqual(list(resp), [base.RESOURCES]) - call_args = tuple(['https://example.com/stream?']) - call_kwargs = {'stream': True} + call_args = tuple(["https://example.com/stream?"]) + call_kwargs = {"stream": True} self.assertEqual(mock_requests.call_args_list[1][0], call_args) self.assertEqual(mock_requests.call_args_list[1][1], call_kwargs) diff --git a/st2client/tests/unit/test_shell.py b/st2client/tests/unit/test_shell.py index 8383526615..bce176b4ad 100644 --- a/st2client/tests/unit/test_shell.py +++ b/st2client/tests/unit/test_shell.py @@ -38,8 +38,8 @@ LOG = logging.getLogger(__name__) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini') -CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini') +CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini") +CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini") MOCK_CONFIG = """ [credentials] @@ -77,352 +77,383 @@ def test_commands_usage_and_help_strings(self): self.stderr.seek(0) stderr = self.stderr.read() - self.assertIn('Usage: ', stderr) - self.assertIn('For example:', stderr) - self.assertIn('CLI for StackStorm', stderr) - self.assertIn('positional arguments:', stderr) + self.assertIn("Usage: ", stderr) + self.assertIn("For example:", stderr) + self.assertIn("CLI for StackStorm", stderr) + self.assertIn("positional arguments:", stderr) self.stdout.truncate() self.stderr.truncate() # --help should result in the same output try: - self.assertEqual(self.shell.run(['--help']), 0) + self.assertEqual(self.shell.run(["--help"]), 0) except SystemExit as e: self.assertEqual(e.code, 0) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('Usage: ', stdout) - self.assertIn('For example:', stdout) - self.assertIn('CLI for StackStorm', stdout) - self.assertIn('positional arguments:', stdout) + self.assertIn("Usage: ", stdout) + self.assertIn("For example:", stdout) + self.assertIn("CLI for StackStorm", stdout) + self.assertIn("positional arguments:", stdout) self.stdout.truncate() self.stderr.truncate() # Sub command with no args try: - self.assertEqual(self.shell.run(['action']), 2) + self.assertEqual(self.shell.run(["action"]), 2) except SystemExit as e: self.assertEqual(e.code, 2) self.stderr.seek(0) stderr = self.stderr.read() - self.assertIn('usage', stderr) + self.assertIn("usage", stderr) if six.PY2: - self.assertIn('{list,get,create,update', stderr) - self.assertIn('error: too few arguments', stderr) + self.assertIn("{list,get,create,update", stderr) + self.assertIn("error: too few arguments", stderr) def test_endpoints_default(self): - base_url = 'http://127.0.0.1' - auth_url = 'http://127.0.0.1:9100' - api_url = 'http://127.0.0.1:9101/v1' - stream_url = 'http://127.0.0.1:9102/v1' - args = ['trigger', 'list'] + base_url = "http://127.0.0.1" + auth_url = "http://127.0.0.1:9100" + api_url = "http://127.0.0.1:9101/v1" + stream_url = "http://127.0.0.1:9102/v1" + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_base_url_from_cli(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:9100' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' - args = ['--url', base_url, 'trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:9100" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" + args = ["--url", base_url, "trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_base_url_from_env(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:9100' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' - os.environ['ST2_BASE_URL'] = base_url - args = ['trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:9100" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" + os.environ["ST2_BASE_URL"] = base_url + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_override_from_cli(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:8888' - api_url = 'http://www.stackstorm1.com:9101/v1' - stream_url = 'http://www.stackstorm1.com:9102/v1' - args = ['--url', base_url, - '--auth-url', auth_url, - '--api-url', api_url, - '--stream-url', stream_url, - 'trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:8888" + api_url = "http://www.stackstorm1.com:9101/v1" + stream_url = "http://www.stackstorm1.com:9102/v1" + args = [ + "--url", + base_url, + "--auth-url", + auth_url, + "--api-url", + api_url, + "--stream-url", + stream_url, + "trigger", + "list", + ] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_override_from_env(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:8888' - api_url = 'http://www.stackstorm1.com:9101/v1' - stream_url = 'http://www.stackstorm1.com:9102/v1' - os.environ['ST2_BASE_URL'] = base_url - os.environ['ST2_AUTH_URL'] = auth_url - os.environ['ST2_API_URL'] = api_url - os.environ['ST2_STREAM_URL'] = stream_url - args = ['trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:8888" + api_url = "http://www.stackstorm1.com:9101/v1" + stream_url = "http://www.stackstorm1.com:9102/v1" + os.environ["ST2_BASE_URL"] = base_url + os.environ["ST2_AUTH_URL"] = auth_url + os.environ["ST2_API_URL"] = api_url + os.environ["ST2_STREAM_URL"] = stream_url + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_exit_code_on_success(self): - argv = ['trigger', 'list'] + argv = ["trigger", "list"] self.assertEqual(self.shell.run(argv), 0) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(None, 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(None, 500, "INTERNAL SERVER ERROR") + ), + ) def test_exit_code_on_error(self): - argv = ['trigger', 'list'] + argv = ["trigger", "list"] self.assertEqual(self.shell.run(argv), 1) def _validate_parser(self, args_list, is_subcommand=True): for args in args_list: ns = self.shell.parser.parse_args(args) - func = (self.shell.commands[args[0]].run_and_print - if not is_subcommand - else self.shell.commands[args[0]].commands[args[1]].run_and_print) + func = ( + self.shell.commands[args[0]].run_and_print + if not is_subcommand + else self.shell.commands[args[0]].commands[args[1]].run_and_print + ) self.assertEqual(ns.func, func) def test_action(self): args_list = [ - ['action', 'list'], - ['action', 'get', 'abc'], - ['action', 'create', '/tmp/action.json'], - ['action', 'update', '123', '/tmp/action.json'], - ['action', 'delete', 'abc'], - ['action', 'execute', '-h'], - ['action', 'execute', 'remote', '-h'], - ['action', 'execute', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'], - ['action', 'execute', 'remote-fib', 'hosts=192.168.1.1', '3', '8'] + ["action", "list"], + ["action", "get", "abc"], + ["action", "create", "/tmp/action.json"], + ["action", "update", "123", "/tmp/action.json"], + ["action", "delete", "abc"], + ["action", "execute", "-h"], + ["action", "execute", "remote", "-h"], + [ + "action", + "execute", + "remote", + "hosts=192.168.1.1", + "user=st2", + 'cmd="ls -l"', + ], + ["action", "execute", "remote-fib", "hosts=192.168.1.1", "3", "8"], ] self._validate_parser(args_list) def test_action_execution(self): args_list = [ - ['execution', 'list'], - ['execution', 'list', '-a', 'all'], - ['execution', 'list', '--attr=all'], - ['execution', 'get', '123'], - ['execution', 'get', '123', '-d'], - ['execution', 'get', '123', '-k', 'localhost.stdout'], - ['execution', 're-run', '123'], - ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z'], - ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z', '--no-reset', 'x'], - ['execution', 're-run', '123', 'a=1', 'b=x', 'c=True'], - ['execution', 'cancel', '123'], - ['execution', 'cancel', '123', '456'], - ['execution', 'pause', '123'], - ['execution', 'pause', '123', '456'], - ['execution', 'resume', '123'], - ['execution', 'resume', '123', '456'] + ["execution", "list"], + ["execution", "list", "-a", "all"], + ["execution", "list", "--attr=all"], + ["execution", "get", "123"], + ["execution", "get", "123", "-d"], + ["execution", "get", "123", "-k", "localhost.stdout"], + ["execution", "re-run", "123"], + ["execution", "re-run", "123", "--tasks", "x", "y", "z"], + ["execution", "re-run", "123", "--tasks", "x", "y", "z", "--no-reset", "x"], + ["execution", "re-run", "123", "a=1", "b=x", "c=True"], + ["execution", "cancel", "123"], + ["execution", "cancel", "123", "456"], + ["execution", "pause", "123"], + ["execution", "pause", "123", "456"], + ["execution", "resume", "123"], + ["execution", "resume", "123", "456"], ] self._validate_parser(args_list) # Test mutually exclusive argument groups - self.assertRaises(SystemExit, self._validate_parser, - [['execution', 'get', '123', '-d', '-k', 'localhost.stdout']]) + self.assertRaises( + SystemExit, + self._validate_parser, + [["execution", "get", "123", "-d", "-k", "localhost.stdout"]], + ) def test_key(self): args_list = [ - ['key', 'list'], - ['key', 'list', '-n', '2'], - ['key', 'get', 'abc'], - ['key', 'set', 'abc', '123'], - ['key', 'delete', 'abc'], - ['key', 'load', '/tmp/keys.json'] + ["key", "list"], + ["key", "list", "-n", "2"], + ["key", "get", "abc"], + ["key", "set", "abc", "123"], + ["key", "delete", "abc"], + ["key", "load", "/tmp/keys.json"], ] self._validate_parser(args_list) def test_policy(self): args_list = [ - ['policy', 'list'], - ['policy', 'list', '-p', 'core'], - ['policy', 'list', '--pack', 'core'], - ['policy', 'list', '-r', 'core.local'], - ['policy', 'list', '--resource-ref', 'core.local'], - ['policy', 'list', '-pt', 'action.type1'], - ['policy', 'list', '--policy-type', 'action.type1'], - ['policy', 'list', '-r', 'core.local', '-pt', 'action.type1'], - ['policy', 'list', '--resource-ref', 'core.local', '--policy-type', 'action.type1'], - ['policy', 'get', 'abc'], - ['policy', 'create', '/tmp/policy.json'], - ['policy', 'update', '123', '/tmp/policy.json'], - ['policy', 'delete', 'abc'] + ["policy", "list"], + ["policy", "list", "-p", "core"], + ["policy", "list", "--pack", "core"], + ["policy", "list", "-r", "core.local"], + ["policy", "list", "--resource-ref", "core.local"], + ["policy", "list", "-pt", "action.type1"], + ["policy", "list", "--policy-type", "action.type1"], + ["policy", "list", "-r", "core.local", "-pt", "action.type1"], + [ + "policy", + "list", + "--resource-ref", + "core.local", + "--policy-type", + "action.type1", + ], + ["policy", "get", "abc"], + ["policy", "create", "/tmp/policy.json"], + ["policy", "update", "123", "/tmp/policy.json"], + ["policy", "delete", "abc"], ] self._validate_parser(args_list) def test_policy_type(self): args_list = [ - ['policy-type', 'list'], - ['policy-type', 'list', '-r', 'action'], - ['policy-type', 'list', '--resource-type', 'action'], - ['policy-type', 'get', 'abc'] + ["policy-type", "list"], + ["policy-type", "list", "-r", "action"], + ["policy-type", "list", "--resource-type", "action"], + ["policy-type", "get", "abc"], ] self._validate_parser(args_list) def test_pack(self): args_list = [ - ['pack', 'list'], - ['pack', 'get', 'abc'], - ['pack', 'search', 'abc'], - ['pack', 'show', 'abc'], - ['pack', 'remove', 'abc'], - ['pack', 'remove', 'abc', '--detail'], - ['pack', 'install', 'abc'], - ['pack', 'install', 'abc', '--force'], - ['pack', 'install', 'abc', '--detail'], - ['pack', 'config', 'abc'] + ["pack", "list"], + ["pack", "get", "abc"], + ["pack", "search", "abc"], + ["pack", "show", "abc"], + ["pack", "remove", "abc"], + ["pack", "remove", "abc", "--detail"], + ["pack", "install", "abc"], + ["pack", "install", "abc", "--force"], + ["pack", "install", "abc", "--detail"], + ["pack", "config", "abc"], ] self._validate_parser(args_list) - @mock.patch('st2client.base.ST2_CONFIG_PATH', '/home/does/not/exist') + @mock.patch("st2client.base.ST2_CONFIG_PATH", "/home/does/not/exist") def test_print_config_default_config_no_config(self): - os.environ['ST2_CONFIG_FILE'] = '/home/does/not/exist' - argv = ['--print-config'] + os.environ["ST2_CONFIG_FILE"] = "/home/does/not/exist" + argv = ["--print-config"] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = None', stdout) - self.assertIn('cache_token = True', stdout) + self.assertIn("username = None", stdout) + self.assertIn("cache_token = True", stdout) def test_print_config_custom_config_as_env_variable(self): - os.environ['ST2_CONFIG_FILE'] = CONFIG_FILE_PATH_FULL - argv = ['--print-config'] + os.environ["ST2_CONFIG_FILE"] = CONFIG_FILE_PATH_FULL + argv = ["--print-config"] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = test1', stdout) - self.assertIn('cache_token = False', stdout) + self.assertIn("username = test1", stdout) + self.assertIn("cache_token = False", stdout) def test_print_config_custom_config_as_command_line_argument(self): - argv = ['--print-config', '--config-file=%s' % (CONFIG_FILE_PATH_FULL)] + argv = ["--print-config", "--config-file=%s" % (CONFIG_FILE_PATH_FULL)] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = test1', stdout) - self.assertIn('cache_token = False', stdout) + self.assertIn("username = test1", stdout) + self.assertIn("cache_token = False", stdout) def test_run(self): args_list = [ - ['run', '-h'], - ['run', 'abc', '-h'], - ['run', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'], - ['run', 'remote-fib', 'hosts=192.168.1.1', '3', '8'] + ["run", "-h"], + ["run", "abc", "-h"], + ["run", "remote", "hosts=192.168.1.1", "user=st2", 'cmd="ls -l"'], + ["run", "remote-fib", "hosts=192.168.1.1", "3", "8"], ] self._validate_parser(args_list, is_subcommand=False) def test_runner(self): - args_list = [ - ['runner', 'list'], - ['runner', 'get', 'abc'] - ] + args_list = [["runner", "list"], ["runner", "get", "abc"]] self._validate_parser(args_list) def test_rule(self): args_list = [ - ['rule', 'list'], - ['rule', 'list', '-n', '1'], - ['rule', 'get', 'abc'], - ['rule', 'create', '/tmp/rule.json'], - ['rule', 'update', '123', '/tmp/rule.json'], - ['rule', 'delete', 'abc'] + ["rule", "list"], + ["rule", "list", "-n", "1"], + ["rule", "get", "abc"], + ["rule", "create", "/tmp/rule.json"], + ["rule", "update", "123", "/tmp/rule.json"], + ["rule", "delete", "abc"], ] self._validate_parser(args_list) def test_trigger(self): args_list = [ - ['trigger', 'list'], - ['trigger', 'get', 'abc'], - ['trigger', 'create', '/tmp/trigger.json'], - ['trigger', 'update', '123', '/tmp/trigger.json'], - ['trigger', 'delete', 'abc'] + ["trigger", "list"], + ["trigger", "get", "abc"], + ["trigger", "create", "/tmp/trigger.json"], + ["trigger", "update", "123", "/tmp/trigger.json"], + ["trigger", "delete", "abc"], ] self._validate_parser(args_list) def test_workflow(self): args_list = [ - ['workflow', 'inspect', '--file', '/path/to/workflow/definition'], - ['workflow', 'inspect', '--action', 'mock.foobar'] + ["workflow", "inspect", "--file", "/path/to/workflow/definition"], + ["workflow", "inspect", "--action", "mock.foobar"], ] self._validate_parser(args_list) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.8.0') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.8.0") def test_get_version_no_package_metadata_file_stable_version(self): # stable version, package metadata file doesn't exist on disk - no git revision should be # included shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.8.0, on Python', stderr) + self.assertIn("v2.8.0, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.8.0') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.8.0") def test_get_version_package_metadata_file_exists_stable_version(self): # stable version, package metadata file exists on disk - no git revision should be included package_metadata_path = self._write_mock_package_metadata_file() st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path shell = Shell() - shell.run(argv=['--version']) + shell.run(argv=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.8.0, on Python', stderr) + self.assertIn("v2.8.0, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.9dev') - @mock.patch('st2client.shell.PACKAGE_METADATA_FILE_PATH', '/tmp/doesnt/exist.1') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.9dev") + @mock.patch("st2client.shell.PACKAGE_METADATA_FILE_PATH", "/tmp/doesnt/exist.1") def test_get_version_no_package_metadata_file_dev_version(self): # dev version, package metadata file doesn't exist on disk - no git revision should be # included since package metadata file doesn't exist on disk shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.9dev, on Python', stderr) + self.assertIn("v2.9dev, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.9dev') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.9dev") def test_get_version_package_metadata_file_exists_dev_version(self): # dev version, package metadata file exists on disk - git revision should be included # since package metadata file exists on disk and contains server.git_sha attribute @@ -430,55 +461,67 @@ def test_get_version_package_metadata_file_exists_dev_version(self): st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.9dev (abcdefg), on Python', stderr) + self.assertIn("v2.9dev (abcdefg), on Python", stderr) - @mock.patch('locale.getdefaultlocale', mock.Mock(return_value=['en_US'])) - @mock.patch('locale.getpreferredencoding', mock.Mock(return_value='iso')) + @mock.patch("locale.getdefaultlocale", mock.Mock(return_value=["en_US"])) + @mock.patch("locale.getpreferredencoding", mock.Mock(return_value="iso")) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) - @mock.patch('st2client.shell.LOGGER') + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) + @mock.patch("st2client.shell.LOGGER") def test_non_unicode_encoding_locale_warning_is_printed(self, mock_logger): shell = Shell() - shell.run(argv=['trigger', 'list']) + shell.run(argv=["trigger", "list"]) call_args = mock_logger.warn.call_args[0][0] - self.assertIn('Locale en_US with encoding iso which is not UTF-8 is used.', call_args) + self.assertIn( + "Locale en_US with encoding iso which is not UTF-8 is used.", call_args + ) - @mock.patch('locale.getdefaultlocale', mock.Mock(side_effect=ValueError('bar'))) - @mock.patch('locale.getpreferredencoding', mock.Mock(side_effect=ValueError('bar'))) + @mock.patch("locale.getdefaultlocale", mock.Mock(side_effect=ValueError("bar"))) + @mock.patch("locale.getpreferredencoding", mock.Mock(side_effect=ValueError("bar"))) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) - @mock.patch('st2client.shell.LOGGER') + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) + @mock.patch("st2client.shell.LOGGER") def test_failed_to_get_locale_encoding_warning_is_printed(self, mock_logger): shell = Shell() - shell.run(argv=['trigger', 'list']) + shell.run(argv=["trigger", "list"]) call_args = mock_logger.warn.call_args[0][0] - self.assertTrue('Locale unknown with encoding unknown which is not UTF-8 is used.' in - call_args) + self.assertTrue( + "Locale unknown with encoding unknown which is not UTF-8 is used." + in call_args + ) def _write_mock_package_metadata_file(self): _, package_metadata_path = tempfile.mkstemp() - with open(package_metadata_path, 'w') as fp: + with open(package_metadata_path, "w") as fp: fp.write(MOCK_PACKAGE_METADATA) return package_metadata_path - @unittest2.skipIf(True, 'skipping until checks are re-enabled') + @unittest2.skipIf(True, "skipping until checks are re-enabled") @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse("{}", 200, 'OK'))) + requests, "get", mock.MagicMock(return_value=base.FakeResponse("{}", 200, "OK")) + ) def test_dont_warn_multiple_times(self): mock_temp_dir_path = tempfile.mkdtemp() - mock_config_dir_path = os.path.join(mock_temp_dir_path, 'testconfig') - mock_config_path = os.path.join(mock_config_dir_path, 'config') + mock_config_dir_path = os.path.join(mock_temp_dir_path, "testconfig") + mock_config_path = os.path.join(mock_config_dir_path, "config") # Make the temporary config directory os.makedirs(mock_config_dir_path) @@ -495,38 +538,46 @@ def test_dont_warn_multiple_times(self): shell.LOG = mock.Mock() # Test without token. - shell.run(['--config-file', mock_config_path, 'action', 'list']) + shell.run(["--config-file", mock_config_path, "action", "list"]) self.assertEqual(shell.LOG.warn.call_count, 2) self.assertEqual( shell.LOG.warn.call_args_list[0][0][0][:63], - 'The StackStorm configuration directory permissions are insecure') + "The StackStorm configuration directory permissions are insecure", + ) self.assertEqual( shell.LOG.warn.call_args_list[1][0][0][:58], - 'The StackStorm configuration file permissions are insecure') + "The StackStorm configuration file permissions are insecure", + ) self.assertEqual(shell.LOG.info.call_count, 2) self.assertEqual( - shell.LOG.info.call_args_list[0][0][0], "The SGID bit is not " - "set on the StackStorm configuration directory.") + shell.LOG.info.call_args_list[0][0][0], + "The SGID bit is not " "set on the StackStorm configuration directory.", + ) self.assertEqual( - shell.LOG.info.call_args_list[1][0][0], 'Skipping parsing CLI config') + shell.LOG.info.call_args_list[1][0][0], "Skipping parsing CLI config" + ) class CLITokenCachingTestCase(unittest2.TestCase): def setUp(self): super(CLITokenCachingTestCase, self).setUp() self._mock_temp_dir_path = tempfile.mkdtemp() - self._mock_config_directory_path = os.path.join(self._mock_temp_dir_path, 'testconfig') - self._mock_config_path = os.path.join(self._mock_config_directory_path, 'config') + self._mock_config_directory_path = os.path.join( + self._mock_temp_dir_path, "testconfig" + ) + self._mock_config_path = os.path.join( + self._mock_config_directory_path, "config" + ) os.makedirs(self._mock_config_directory_path) - self._p1 = mock.patch('st2client.base.ST2_CONFIG_DIRECTORY', - self._mock_config_directory_path) - self._p2 = mock.patch('st2client.base.ST2_CONFIG_PATH', - self._mock_config_path) + self._p1 = mock.patch( + "st2client.base.ST2_CONFIG_DIRECTORY", self._mock_config_directory_path + ) + self._p2 = mock.patch("st2client.base.ST2_CONFIG_PATH", self._mock_config_path) self._p1.start() self._p2.start() @@ -536,46 +587,46 @@ def tearDown(self): self._p2.stop() for var in [ - 'ST2_BASE_URL', - 'ST2_API_URL', - 'ST2_STREAM_URL', - 'ST2_DATASTORE_URL', - 'ST2_AUTH_TOKEN' + "ST2_BASE_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_DATASTORE_URL", + "ST2_AUTH_TOKEN", ]: if var in os.environ: del os.environ[var] def _write_mock_config(self): - with open(self._mock_config_path, 'w') as fp: + with open(self._mock_config_path, "w") as fp: fp.write(MOCK_CONFIG) def test_get_cached_auth_token_invalid_permissions(self): shell = Shell() client = Client() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) # 1. Current user doesn't have read access to the config directory os.chmod(self._mock_config_directory_path, 0o000) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to retrieve cached token from .*? read access to the parent ' - 'directory') + expected_msg = ( + "Unable to retrieve cached token from .*? read access to the parent " + "directory" + ) self.assertRegexpMatches(log_message, expected_msg) # 2. Read access on the directory, but not on the cached token file @@ -583,14 +634,17 @@ def test_get_cached_auth_token_invalid_permissions(self): os.chmod(cached_token_path, 0o000) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to retrieve cached token from .*? read access to this file') + expected_msg = ( + "Unable to retrieve cached token from .*? read access to this file" + ) self.assertRegexpMatches(log_message, expected_msg) # 3. Other users also have read access to the file @@ -598,31 +652,29 @@ def test_get_cached_auth_token_invalid_permissions(self): os.chmod(cached_token_path, 0o444) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'yayvalid') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "yayvalid") self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Permissions .*? for cached token file .*? are too permissive.*') + expected_msg = "Permissions .*? for cached token file .*? are too permissive.*" self.assertRegexpMatches(log_message, expected_msg) def test_cache_auth_token_invalid_permissions(self): shell = Shell() - username = 'testu' + username = "testu" cached_token_path = shell._get_cached_token_path_for_user(username=username) expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30) - token_db = TokenDB(user=username, token='fyeah', expiry=expiry) + token_db = TokenDB(user=username, token="fyeah", expiry=expiry) cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) # 1. Current user has no write access to the parent directory @@ -634,8 +686,10 @@ def test_cache_auth_token_invalid_permissions(self): self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to write token to .*? doesn\'t have write access to the parent ' - 'directory') + expected_msg = ( + "Unable to write token to .*? doesn't have write access to the parent " + "directory" + ) self.assertRegexpMatches(log_message, expected_msg) # 2. Current user has no write access to the cached token file @@ -648,86 +702,93 @@ def test_cache_auth_token_invalid_permissions(self): self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to write token to .*? doesn\'t have write access to this file') + expected_msg = ( + "Unable to write token to .*? doesn't have write access to this file" + ) self.assertRegexpMatches(log_message, expected_msg) def test_get_cached_auth_token_no_token_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) def test_get_cached_auth_token_corrupted_token_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - with open(cached_token_path, 'w') as fp: - fp.write('CORRRRRUPTED!') - - expected_msg = 'File (.+) with cached token is corrupted or invalid' - self.assertRaisesRegexp(ValueError, expected_msg, shell._get_cached_auth_token, - client=client, username=username, password=password) + with open(cached_token_path, "w") as fp: + fp.write("CORRRRRUPTED!") + + expected_msg = "File (.+) with cached token is corrupted or invalid" + self.assertRaisesRegexp( + ValueError, + expected_msg, + shell._get_cached_auth_token, + client=client, + username=username, + password=password, + ) def test_get_cached_auth_token_expired_token_in_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'expired', - 'expire_timestamp': (int(time.time()) - 10) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "expired", "expire_timestamp": (int(time.time()) - 10)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) def test_get_cached_auth_token_valid_token_in_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'yayvalid') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "yayvalid") def test_cache_auth_token_success(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) - token_db = TokenDB(user=username, token='fyeah', expiry=expiry) + token_db = TokenDB(user=username, token="fyeah", expiry=expiry) shell._cache_auth_token(token_obj=token_db) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'fyeah') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "fyeah") def test_automatic_auth_skipped_on_auth_command(self): self._write_mock_config() @@ -735,7 +796,7 @@ def test_automatic_auth_skipped_on_auth_command(self): shell = Shell() shell._get_auth_token = mock.Mock() - argv = ['auth', 'testu', '-p', 'testp'] + argv = ["auth", "testu", "-p", "testp"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) @@ -746,8 +807,8 @@ def test_automatic_auth_skipped_if_token_provided_as_env_variable(self): shell = Shell() shell._get_auth_token = mock.Mock() - os.environ['ST2_AUTH_TOKEN'] = 'fooo' - argv = ['action', 'list'] + os.environ["ST2_AUTH_TOKEN"] = "fooo" + argv = ["action", "list"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) @@ -758,12 +819,12 @@ def test_automatic_auth_skipped_if_token_provided_as_cli_argument(self): shell = Shell() shell._get_auth_token = mock.Mock() - argv = ['action', 'list', '--token=bar'] + argv = ["action", "list", "--token=bar"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) - argv = ['action', 'list', '-t', 'bar'] + argv = ["action", "list", "-t", "bar"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) diff --git a/st2client/tests/unit/test_ssl.py b/st2client/tests/unit/test_ssl.py index 5ed8bfbf28..5db836482b 100644 --- a/st2client/tests/unit/test_ssl.py +++ b/st2client/tests/unit/test_ssl.py @@ -27,17 +27,18 @@ LOG = logging.getLogger(__name__) -USERNAME = 'stanley' -PASSWORD = 'ShhhDontTell' -HEADERS = {'content-type': 'application/json'} -AUTH_URL = 'https://127.0.0.1:9100/tokens' -GET_RULES_URL = ('http://127.0.0.1:9101/v1/rules/' - '?include_attributes=ref,pack,description,enabled&limit=50') -GET_RULES_URL = GET_RULES_URL.replace(',', '%2C') +USERNAME = "stanley" +PASSWORD = "ShhhDontTell" +HEADERS = {"content-type": "application/json"} +AUTH_URL = "https://127.0.0.1:9100/tokens" +GET_RULES_URL = ( + "http://127.0.0.1:9101/v1/rules/" + "?include_attributes=ref,pack,description,enabled&limit=50" +) +GET_RULES_URL = GET_RULES_URL.replace(",", "%2C") class TestHttps(base.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(TestHttps, self).__init__(*args, **kwargs) self.shell = shell.Shell() @@ -46,11 +47,11 @@ def setUp(self): super(TestHttps, self).setUp() # Setup environment. - os.environ['ST2_BASE_URL'] = 'http://127.0.0.1' - os.environ['ST2_AUTH_URL'] = 'https://127.0.0.1:9100' + os.environ["ST2_BASE_URL"] = "http://127.0.0.1" + os.environ["ST2_AUTH_URL"] = "https://127.0.0.1:9100" - if 'ST2_CACERT' in os.environ: - del os.environ['ST2_CACERT'] + if "ST2_CACERT" in os.environ: + del os.environ["ST2_CACERT"] # Create a temp file to mock a cert file. self.cacert_fd, self.cacert_path = tempfile.mkstemp() @@ -59,58 +60,78 @@ def tearDown(self): super(TestHttps, self).tearDown() # Clean up environment. - if 'ST2_CACERT' in os.environ: - del os.environ['ST2_CACERT'] - if 'ST2_BASE_URL' in os.environ: - del os.environ['ST2_BASE_URL'] + if "ST2_CACERT" in os.environ: + del os.environ["ST2_CACERT"] + if "ST2_BASE_URL" in os.environ: + del os.environ["ST2_BASE_URL"] # Clean up temp files. os.close(self.cacert_fd) os.unlink(self.cacert_path) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_without_cacert(self): - self.shell.run(['auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': False, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + self.shell.run(["auth", USERNAME, "-p", PASSWORD]) + kwargs = {"verify": False, "headers": HEADERS, "auth": (USERNAME, PASSWORD)} requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_with_cacert_from_cli(self): - self.shell.run(['--cacert', self.cacert_path, 'auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + self.shell.run(["--cacert", self.cacert_path, "auth", USERNAME, "-p", PASSWORD]) + kwargs = { + "verify": self.cacert_path, + "headers": HEADERS, + "auth": (USERNAME, PASSWORD), + } requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_with_cacert_from_env(self): - os.environ['ST2_CACERT'] = self.cacert_path - self.shell.run(['auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + os.environ["ST2_CACERT"] = self.cacert_path + self.shell.run(["auth", USERNAME, "-p", PASSWORD]) + kwargs = { + "verify": self.cacert_path, + "headers": HEADERS, + "auth": (USERNAME, PASSWORD), + } requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK")), + ) def test_decorate_http_without_cacert(self): - self.shell.run(['rule', 'list']) + self.shell.run(["rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_http_with_cacert_from_cli(self): - self.shell.run(['--cacert', self.cacert_path, 'rule', 'list']) + self.shell.run(["--cacert", self.cacert_path, "rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_http_with_cacert_from_env(self): - os.environ['ST2_CACERT'] = self.cacert_path - self.shell.run(['rule', 'list']) + os.environ["ST2_CACERT"] = self.cacert_path + self.shell.run(["rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) diff --git a/st2client/tests/unit/test_trace_commands.py b/st2client/tests/unit/test_trace_commands.py index 99d60598a4..ea3b552d47 100644 --- a/st2client/tests/unit/test_trace_commands.py +++ b/st2client/tests/unit/test_trace_commands.py @@ -23,23 +23,38 @@ class TraceCommandTestCase(base.BaseCLITestCase): - def test_trace_get_filter_trace_components_executions(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', 'e1') - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", "e1") + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 1) @@ -48,22 +63,38 @@ def test_trace_get_filter_trace_components_executions(self): def test_trace_get_filter_trace_components_rules(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', None) - setattr(args, 'rule', 'r1') - setattr(args, 'trigger_instance', None) - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", None) + setattr(args, "rule", "r1") + setattr(args, "trigger_instance", None) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 0) @@ -72,22 +103,38 @@ def test_trace_get_filter_trace_components_rules(self): def test_trace_get_filter_trace_components_trigger_instances(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', None) - setattr(args, 'rule', None) - setattr(args, 'trigger_instance', 't1') - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", None) + setattr(args, "rule", None) + setattr(args, "trigger_instance", "t1") + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 0) @@ -96,15 +143,15 @@ def test_trace_get_filter_trace_components_trigger_instances(self): def test_trace_get_apply_display_filters_show_executions(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', True) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", True) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertTrue(trace.action_executions) @@ -113,15 +160,15 @@ def test_trace_get_apply_display_filters_show_executions(self): def test_trace_get_apply_display_filters_show_rules(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', True) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", True) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertFalse(trace.action_executions) @@ -130,15 +177,15 @@ def test_trace_get_apply_display_filters_show_rules(self): def test_trace_get_apply_display_filters_show_trigger_instances(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', True) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", True) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertFalse(trace.action_executions) @@ -147,15 +194,15 @@ def test_trace_get_apply_display_filters_show_trigger_instances(self): def test_trace_get_apply_display_filters_show_multiple(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', True) - setattr(args, 'show_rules', True) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", True) + setattr(args, "show_rules", True) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertTrue(trace.action_executions) @@ -164,15 +211,15 @@ def test_trace_get_apply_display_filters_show_multiple(self): def test_trace_get_apply_display_filters_show_all(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertEqual(len(trace.action_executions), 1) @@ -181,19 +228,35 @@ def test_trace_get_apply_display_filters_show_all(self): def test_trace_get_apply_display_filters_hide_noop(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', True) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", True) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertEqual(len(trace.action_executions), 1) diff --git a/st2client/tests/unit/test_util_date.py b/st2client/tests/unit/test_util_date.py index e29b840ed7..2cdeab95fc 100644 --- a/st2client/tests/unit/test_util_date.py +++ b/st2client/tests/unit/test_util_date.py @@ -30,31 +30,31 @@ def test_format_dt(self): dt = datetime.datetime(2015, 10, 20, 8, 0, 0) dt = add_utc_tz(dt) result = format_dt(dt) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") def test_format_isodate(self): # No timezone, defaults to UTC - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") # Timezone provided - value = 'Tue, 20 Oct 2015 08:00:00 UTC' - result = format_isodate(value=value, timezone='Europe/Ljubljana') - self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST') + value = "Tue, 20 Oct 2015 08:00:00 UTC" + result = format_isodate(value=value, timezone="Europe/Ljubljana") + self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST") - @mock.patch('st2client.utils.date.get_config') + @mock.patch("st2client.utils.date.get_config") def test_format_isodate_for_user_timezone(self, mock_get_config): # No timezone, defaults to UTC mock_get_config.return_value = {} - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate_for_user_timezone(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") # Timezone provided - mock_get_config.return_value = {'cli': {'timezone': 'Europe/Ljubljana'}} + mock_get_config.return_value = {"cli": {"timezone": "Europe/Ljubljana"}} - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate_for_user_timezone(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST') + self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST") diff --git a/st2client/tests/unit/test_util_json.py b/st2client/tests/unit/test_util_json.py index f44a4b9bf9..2333128c2e 100644 --- a/st2client/tests/unit/test_util_json.py +++ b/st2client/tests/unit/test_util_json.py @@ -25,76 +25,67 @@ LOG = logging.getLogger(__name__) DOC = { - 'a01': 1, - 'b01': 2, - 'c01': { - 'c11': 3, - 'd12': 4, - 'c13': { - 'c21': 5, - 'c22': 6 - }, - 'c14': [7, 8, 9] - } + "a01": 1, + "b01": 2, + "c01": {"c11": 3, "d12": 4, "c13": {"c21": 5, "c22": 6}, "c14": [7, 8, 9]}, } DOC_IP_ADDRESS = { - 'ips': { - "192.168.1.1": { - "hostname": "router.domain.tld" - }, - "192.168.1.10": { - "hostname": "server.domain.tld" - } + "ips": { + "192.168.1.1": {"hostname": "router.domain.tld"}, + "192.168.1.10": {"hostname": "server.domain.tld"}, } } class TestGetValue(unittest2.TestCase): - def test_dot_notation(self): - self.assertEqual(jsutil.get_value(DOC, 'a01'), 1) - self.assertEqual(jsutil.get_value(DOC, 'c01.c11'), 3) - self.assertEqual(jsutil.get_value(DOC, 'c01.c13.c22'), 6) - self.assertEqual(jsutil.get_value(DOC, 'c01.c13'), {'c21': 5, 'c22': 6}) - self.assertListEqual(jsutil.get_value(DOC, 'c01.c14'), [7, 8, 9]) + self.assertEqual(jsutil.get_value(DOC, "a01"), 1) + self.assertEqual(jsutil.get_value(DOC, "c01.c11"), 3) + self.assertEqual(jsutil.get_value(DOC, "c01.c13.c22"), 6) + self.assertEqual(jsutil.get_value(DOC, "c01.c13"), {"c21": 5, "c22": 6}) + self.assertListEqual(jsutil.get_value(DOC, "c01.c14"), [7, 8, 9]) def test_dot_notation_with_val_error(self): self.assertRaises(ValueError, jsutil.get_value, DOC, None) - self.assertRaises(ValueError, jsutil.get_value, DOC, '') - self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), 'a01') + self.assertRaises(ValueError, jsutil.get_value, DOC, "") + self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), "a01") def test_dot_notation_with_key_error(self): - self.assertIsNone(jsutil.get_value(DOC, 'd01')) - self.assertIsNone(jsutil.get_value(DOC, 'a01.a11')) - self.assertIsNone(jsutil.get_value(DOC, 'c01.c11.c21.c31')) - self.assertIsNone(jsutil.get_value(DOC, 'c01.c14.c31')) + self.assertIsNone(jsutil.get_value(DOC, "d01")) + self.assertIsNone(jsutil.get_value(DOC, "a01.a11")) + self.assertIsNone(jsutil.get_value(DOC, "c01.c11.c21.c31")) + self.assertIsNone(jsutil.get_value(DOC, "c01.c14.c31")) def test_ip_address(self): - self.assertEqual(jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'), - {"hostname": "router.domain.tld"}) + self.assertEqual( + jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'), + {"hostname": "router.domain.tld"}, + ) def test_chars_nums_dashes_underscores_calls_simple(self): - for char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_': + for char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_": with mock.patch("st2client.utils.jsutil._get_value_simple") as mock_simple: jsutil.get_value(DOC, char) mock_simple.assert_called_with(DOC, char) def test_symbols_calls_complex(self): - for char in '`~!@#$%^&&*()=+{}[]|\\;:\'"<>,./?': - with mock.patch("st2client.utils.jsutil._get_value_complex") as mock_complex: + for char in "`~!@#$%^&&*()=+{}[]|\\;:'\"<>,./?": + with mock.patch( + "st2client.utils.jsutil._get_value_complex" + ) as mock_complex: jsutil.get_value(DOC, char) mock_complex.assert_called_with(DOC, char) @mock.patch("st2client.utils.jsutil._get_value_simple") def test_single_key_calls_simple(self, mock__get_value_simple): - jsutil.get_value(DOC, 'a01') - mock__get_value_simple.assert_called_with(DOC, 'a01') + jsutil.get_value(DOC, "a01") + mock__get_value_simple.assert_called_with(DOC, "a01") @mock.patch("st2client.utils.jsutil._get_value_simple") def test_dot_notation_calls_simple(self, mock__get_value_simple): - jsutil.get_value(DOC, 'c01.c11') - mock__get_value_simple.assert_called_with(DOC, 'c01.c11') + jsutil.get_value(DOC, "c01.c11") + mock__get_value_simple.assert_called_with(DOC, "c01.c11") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_ip_address_calls_complex(self, mock__get_value_complex): @@ -103,54 +94,64 @@ def test_ip_address_calls_complex(self, mock__get_value_complex): @mock.patch("st2client.utils.jsutil._get_value_complex") def test_beginning_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, '.c01.c11') - mock__get_value_complex.assert_called_with(DOC, '.c01.c11') + jsutil.get_value(DOC, ".c01.c11") + mock__get_value_complex.assert_called_with(DOC, ".c01.c11") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_ending_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, 'c01.c11.') - mock__get_value_complex.assert_called_with(DOC, 'c01.c11.') + jsutil.get_value(DOC, "c01.c11.") + mock__get_value_complex.assert_called_with(DOC, "c01.c11.") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_double_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, 'c01..c11') - mock__get_value_complex.assert_called_with(DOC, 'c01..c11') + jsutil.get_value(DOC, "c01..c11") + mock__get_value_complex.assert_called_with(DOC, "c01..c11") class TestGetKeyValuePairs(unittest2.TestCase): - def test_select_kvps(self): - self.assertEqual(jsutil.get_kvps(DOC, ['a01']), - {'a01': 1}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11']), - {'c01': {'c11': 3}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13.c22']), - {'c01': {'c13': {'c22': 6}}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13']), - {'c01': {'c13': {'c21': 5, 'c22': 6}}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14']), - {'c01': {'c14': [7, 8, 9]}}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c21']), - {'a01': 1, 'c01': {'c11': 3, 'c13': {'c21': 5}}}) - self.assertEqual(jsutil.get_kvps(DOC_IP_ADDRESS, - ['ips."192.168.1.1"', - 'ips."192.168.1.10".hostname']), - {'ips': - {'"192': - {'168': - {'1': - {'1"': {'hostname': 'router.domain.tld'}, - '10"': {'hostname': 'server.domain.tld'}}}}}}) + self.assertEqual(jsutil.get_kvps(DOC, ["a01"]), {"a01": 1}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11"]), {"c01": {"c11": 3}}) + self.assertEqual( + jsutil.get_kvps(DOC, ["c01.c13.c22"]), {"c01": {"c13": {"c22": 6}}} + ) + self.assertEqual( + jsutil.get_kvps(DOC, ["c01.c13"]), {"c01": {"c13": {"c21": 5, "c22": 6}}} + ) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14"]), {"c01": {"c14": [7, 8, 9]}}) + self.assertEqual( + jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c21"]), + {"a01": 1, "c01": {"c11": 3, "c13": {"c21": 5}}}, + ) + self.assertEqual( + jsutil.get_kvps( + DOC_IP_ADDRESS, ['ips."192.168.1.1"', 'ips."192.168.1.10".hostname'] + ), + { + "ips": { + '"192': { + "168": { + "1": { + '1"': {"hostname": "router.domain.tld"}, + '10"': {"hostname": "server.domain.tld"}, + } + } + } + } + }, + ) def test_select_kvps_with_val_error(self): self.assertRaises(ValueError, jsutil.get_kvps, DOC, [None]) - self.assertRaises(ValueError, jsutil.get_kvps, DOC, ['']) - self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ['a01']) + self.assertRaises(ValueError, jsutil.get_kvps, DOC, [""]) + self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ["a01"]) def test_select_kvps_with_key_error(self): - self.assertEqual(jsutil.get_kvps(DOC, ['d01']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01.a11']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11.c21.c31']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14.c31']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c23']), - {'a01': 1, 'c01': {'c11': 3}}) + self.assertEqual(jsutil.get_kvps(DOC, ["d01"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["a01.a11"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11.c21.c31"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14.c31"]), {}) + self.assertEqual( + jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c23"]), + {"a01": 1, "c01": {"c11": 3}}, + ) diff --git a/st2client/tests/unit/test_util_misc.py b/st2client/tests/unit/test_util_misc.py index 6a2cf3a8fc..2e33156adc 100644 --- a/st2client/tests/unit/test_util_misc.py +++ b/st2client/tests/unit/test_util_misc.py @@ -21,37 +21,37 @@ class MiscUtilTestCase(unittest2.TestCase): def test_merge_dicts(self): - d1 = {'a': 1} - d2 = {'a': 2} - expected = {'a': 2} + d1 = {"a": 1} + d2 = {"a": 2} + expected = {"a": 2} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1} - d2 = {'b': 1} - expected = {'a': 1, 'b': 1} + d1 = {"a": 1} + d2 = {"b": 1} + expected = {"a": 1, "b": 1} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1} - d2 = {'a': 3, 'b': 1} - expected = {'a': 3, 'b': 1} + d1 = {"a": 1} + d2 = {"a": 3, "b": 1} + expected = {"a": 3, "b": 1} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1, 'm': None} - d2 = {'a': None, 'b': 1, 'c': None} - expected = {'a': 1, 'b': 1, 'c': None, 'm': None} + d1 = {"a": 1, "m": None} + d2 = {"a": None, "b": 1, "c": None} + expected = {"a": 1, "b": 1, "c": None, "m": None} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1, 'b': {'a': 1, 'b': 2, 'c': 3}} - d2 = {'b': {'b': 100}} - expected = {'a': 1, 'b': {'a': 1, 'b': 100, 'c': 3}} + d1 = {"a": 1, "b": {"a": 1, "b": 2, "c": 3}} + d2 = {"b": {"b": 100}} + expected = {"a": 1, "b": {"a": 1, "b": 100, "c": 3}} result = merge_dicts(d1, d2) self.assertEqual(result, expected) diff --git a/st2client/tests/unit/test_util_strutil.py b/st2client/tests/unit/test_util_strutil.py index 2d442013de..585e88c389 100644 --- a/st2client/tests/unit/test_util_strutil.py +++ b/st2client/tests/unit/test_util_strutil.py @@ -26,17 +26,17 @@ class StrUtilTestCase(unittest2.TestCase): def test_unescape(self): in_str = 'Action execution result double escape \\"stuffs\\".\\r\\n' - expected = 'Action execution result double escape \"stuffs\".\r\n' + expected = 'Action execution result double escape "stuffs".\r\n' out_str = strutil.unescape(in_str) self.assertEqual(out_str, expected) def test_unicode_string(self): - in_str = '\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55' + in_str = "\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55" out_str = strutil.unescape(in_str) self.assertEqual(out_str, in_str) def test_strip_carriage_returns(self): - in_str = 'Windows editors introduce\r\nlike a noob in 2017.' + in_str = "Windows editors introduce\r\nlike a noob in 2017." out_str = strutil.strip_carriage_returns(in_str) - exp_str = 'Windows editors introduce\nlike a noob in 2017.' + exp_str = "Windows editors introduce\nlike a noob in 2017." self.assertEqual(out_str, exp_str) diff --git a/st2client/tests/unit/test_util_terminal.py b/st2client/tests/unit/test_util_terminal.py index c9b6d82b27..29a8386b0b 100644 --- a/st2client/tests/unit/test_util_terminal.py +++ b/st2client/tests/unit/test_util_terminal.py @@ -23,20 +23,20 @@ from st2client.utils.terminal import DEFAULT_TERMINAL_SIZE_COLUMNS from st2client.utils.terminal import get_terminal_size_columns -__all__ = [ - 'TerminalUtilsTestCase' -] +__all__ = ["TerminalUtilsTestCase"] class TerminalUtilsTestCase(unittest2.TestCase): def setUp(self): super(TerminalUtilsTestCase, self).setUp() - if 'COLUMNS' in os.environ: - del os.environ['COLUMNS'] + if "COLUMNS" in os.environ: + del os.environ["COLUMNS"] - @mock.patch.dict(os.environ, {'LINES': '111', 'COLUMNS': '222'}) - def test_get_terminal_size_columns_columns_environment_variable_has_precedence(self): + @mock.patch.dict(os.environ, {"LINES": "111", "COLUMNS": "222"}) + def test_get_terminal_size_columns_columns_environment_variable_has_precedence( + self, + ): # Verify that COLUMNS environment variables has precedence over other approaches columns = get_terminal_size_columns() @@ -44,16 +44,16 @@ def test_get_terminal_size_columns_columns_environment_variable_has_precedence(s # make sure that os.environ['COLUMNS'] isn't set so it can't override/screw-up this test @mock.patch.dict(os.environ, {}) - @mock.patch('fcntl.ioctl', mock.Mock(return_value='dummy')) - @mock.patch('struct.unpack', mock.Mock(return_value=(333, 444))) + @mock.patch("fcntl.ioctl", mock.Mock(return_value="dummy")) + @mock.patch("struct.unpack", mock.Mock(return_value=(333, 444))) def test_get_terminal_size_columns_stdout_is_used(self): columns = get_terminal_size_columns() self.assertEqual(columns, 444) - @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a'))) - @mock.patch('subprocess.Popen') + @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a"))) + @mock.patch("subprocess.Popen") def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen): - mock_communicate = mock.Mock(return_value=['555 666']) + mock_communicate = mock.Mock(return_value=["555 666"]) mock_process = mock.Mock() mock_process.returncode = 0 @@ -64,8 +64,8 @@ def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen): columns = get_terminal_size_columns() self.assertEqual(columns, 666) - @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a'))) - @mock.patch('subprocess.Popen', mock.Mock(side_effect=Exception('b'))) + @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a"))) + @mock.patch("subprocess.Popen", mock.Mock(side_effect=Exception("b"))) def test_get_terminal_size_default_values_are_used(self): columns = get_terminal_size_columns() diff --git a/st2client/tests/unit/test_workflow.py b/st2client/tests/unit/test_workflow.py index 3896a27bc2..79d580f85d 100644 --- a/st2client/tests/unit/test_workflow.py +++ b/st2client/tests/unit/test_workflow.py @@ -31,13 +31,13 @@ LOG = logging.getLogger(__name__) MOCK_ACTION = { - 'ref': 'mock.foobar', - 'runner_type': 'mock-runner', - 'pack': 'mock', - 'name': 'foobar', - 'parameters': {}, - 'enabled': True, - 'entry_point': 'workflows/foobar.yaml' + "ref": "mock.foobar", + "runner_type": "mock-runner", + "pack": "mock", + "name": "foobar", + "parameters": {}, + "enabled": True, + "entry_point": "workflows/foobar.yaml", } MOCK_WF_DEF = """ @@ -56,73 +56,88 @@ def get_by_ref(**kwargs): class WorkflowCommandTestCase(st2cli_tests.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(WorkflowCommandTestCase, self).__init__(*args, **kwargs) self.shell = shell.Shell() @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_file(self): - fd, path = tempfile.mkstemp(suffix='.yaml') + fd, path = tempfile.mkstemp(suffix=".yaml") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(MOCK_WF_DEF) - retcode = self.shell.run(['workflow', 'inspect', '--file', path]) + retcode = self.shell.run(["workflow", "inspect", "--file", path]) self.assertEqual(retcode, 0) httpclient.HTTPClient.post_raw.assert_called_with( - '/inspect', - MOCK_WF_DEF, - headers={'content-type': 'text/plain'} + "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"} ) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_bad_file(self): - retcode = self.shell.run(['workflow', 'inspect', '--file', '/tmp/foobar']) + retcode = self.shell.run(["workflow", "inspect", "--file", "/tmp/foobar"]) self.assertEqual(retcode, 1) - self.assertIn('does not exist', self.stdout.getvalue()) + self.assertIn("does not exist", self.stdout.getvalue()) self.assertFalse(httpclient.HTTPClient.post_raw.called) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - workflow.WorkflowInspectionCommand, 'get_file_content', - mock.MagicMock(return_value=MOCK_WF_DEF)) + workflow.WorkflowInspectionCommand, + "get_file_content", + mock.MagicMock(return_value=MOCK_WF_DEF), + ) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_action(self): - retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar']) + retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"]) self.assertEqual(retcode, 0) httpclient.HTTPClient.post_raw.assert_called_with( - '/inspect', - MOCK_WF_DEF, - headers={'content-type': 'text/plain'} + "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"} ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=None)) + models.ResourceManager, "get_by_ref_or_id", mock.MagicMock(return_value=None) + ) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_bad_action(self): - retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar']) + retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"]) self.assertEqual(retcode, 1) - self.assertIn('Unable to identify action', self.stdout.getvalue()) + self.assertIn("Unable to identify action", self.stdout.getvalue()) self.assertFalse(httpclient.HTTPClient.post_raw.called) diff --git a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py index db5cbdcf67..b8de86a661 100755 --- a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py +++ b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py @@ -35,16 +35,20 @@ def migrate_datastore(): try: for kvp in key_value_items: - kvp_id = getattr(kvp, 'id', None) - secret = getattr(kvp, 'secret', False) - scope = getattr(kvp, 'scope', SYSTEM_SCOPE) - new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name, - expire_timestamp=kvp.expire_timestamp, - value=kvp.value, secret=secret, - scope=scope) + kvp_id = getattr(kvp, "id", None) + secret = getattr(kvp, "secret", False) + scope = getattr(kvp, "scope", SYSTEM_SCOPE) + new_kvp_db = KeyValuePairDB( + id=kvp_id, + name=kvp.name, + expire_timestamp=kvp.expire_timestamp, + value=kvp.value, + secret=secret, + scope=scope, + ) KeyValuePair.add_or_update(new_kvp_db) except: - print('ERROR: Failed migrating datastore item with name: %s' % kvp.name) + print("ERROR: Failed migrating datastore item with name: %s" % kvp.name) tb.print_exc() raise @@ -58,10 +62,10 @@ def main(): # Migrate rules. try: migrate_datastore() - print('SUCCESS: Datastore items migrated successfully.') + print("SUCCESS: Datastore items migrated successfully.") exit_code = 0 except: - print('ABORTED: Datastore migration aborted on first failure.') + print("ABORTED: Datastore migration aborted on first failure.") exit_code = 1 # Disconnect from db. @@ -69,5 +73,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py index 24275f80dc..a1a500ad96 100755 --- a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py +++ b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py @@ -32,9 +32,9 @@ def migrate_datastore(): try: for kvp in key_value_items: - kvp_id = getattr(kvp, 'id', None) - secret = getattr(kvp, 'secret', False) - scope = getattr(kvp, 'scope', SYSTEM_SCOPE) + kvp_id = getattr(kvp, "id", None) + secret = getattr(kvp, "secret", False) + scope = getattr(kvp, "scope", SYSTEM_SCOPE) if scope == USER_SCOPE: scope = FULL_USER_SCOPE @@ -42,13 +42,17 @@ def migrate_datastore(): if scope == SYSTEM_SCOPE: scope = FULL_SYSTEM_SCOPE - new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name, - expire_timestamp=kvp.expire_timestamp, - value=kvp.value, secret=secret, - scope=scope) + new_kvp_db = KeyValuePairDB( + id=kvp_id, + name=kvp.name, + expire_timestamp=kvp.expire_timestamp, + value=kvp.value, + secret=secret, + scope=scope, + ) KeyValuePair.add_or_update(new_kvp_db) except: - print('ERROR: Failed migrating datastore item with name: %s' % kvp.name) + print("ERROR: Failed migrating datastore item with name: %s" % kvp.name) tb.print_exc() raise @@ -62,10 +66,10 @@ def main(): # Migrate rules. try: migrate_datastore() - print('SUCCESS: Datastore items migrated successfully.') + print("SUCCESS: Datastore items migrated successfully.") exit_code = 0 except: - print('ABORTED: Datastore migration aborted on first failure.') + print("ABORTED: Datastore migration aborted on first failure.") exit_code = 1 # Disconnect from db. @@ -73,5 +77,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py index bb4ee666b9..9d09789413 100755 --- a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py +++ b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py @@ -39,12 +39,14 @@ def main(): try: handler = scheduler_handler.get_handler() handler._cleanup_policy_delayed() - LOG.info('SUCCESS: Completed clean up of executions with deprecated policy-delayed status.') + LOG.info( + "SUCCESS: Completed clean up of executions with deprecated policy-delayed status." + ) exit_code = 0 except Exception as e: LOG.error( - 'ABORTED: Clean up of executions with deprecated policy-delayed status aborted on ' - 'first failure. %s' % e.message + "ABORTED: Clean up of executions with deprecated policy-delayed status aborted on " + "first failure. %s" % e.message ) exit_code = 1 @@ -53,5 +55,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/paramiko_ssh_evenlets_tester.py b/st2common/bin/paramiko_ssh_evenlets_tester.py index 49a42545f8..af30196de1 100755 --- a/st2common/bin/paramiko_ssh_evenlets_tester.py +++ b/st2common/bin/paramiko_ssh_evenlets_tester.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import argparse @@ -34,49 +35,54 @@ def main(user, pkey, password, hosts_str, cmd, file_path, dir_path, delete_dir): if file_path: if not os.path.exists(file_path): - raise Exception('File not found.') - results = client.put(file_path, '/home/lakshmi/test_file', mode="0660") - pp.pprint('Copy results: \n%s' % results) - results = client.run('ls -rlth') - pp.pprint('ls results: \n%s' % results) + raise Exception("File not found.") + results = client.put(file_path, "/home/lakshmi/test_file", mode="0660") + pp.pprint("Copy results: \n%s" % results) + results = client.run("ls -rlth") + pp.pprint("ls results: \n%s" % results) if dir_path: if not os.path.exists(dir_path): - raise Exception('File not found.') - results = client.put(dir_path, '/home/lakshmi/', mode="0660") - pp.pprint('Copy results: \n%s' % results) - results = client.run('ls -rlth') - pp.pprint('ls results: \n%s' % results) + raise Exception("File not found.") + results = client.put(dir_path, "/home/lakshmi/", mode="0660") + pp.pprint("Copy results: \n%s" % results) + results = client.run("ls -rlth") + pp.pprint("ls results: \n%s" % results) if cmd: results = client.run(cmd) - pp.pprint('cmd results: \n%s' % results) + pp.pprint("cmd results: \n%s" % results) if delete_dir: results = client.delete_dir(delete_dir, force=True) - pp.pprint('Delete results: \n%s' % results) + pp.pprint("Delete results: \n%s" % results) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Parallel SSH tester.') - parser.add_argument('--hosts', required=True, - help='List of hosts to connect to') - parser.add_argument('--private-key', required=False, - help='Private key to use.') - parser.add_argument('--password', required=False, - help='Password.') - parser.add_argument('--user', required=True, - help='SSH user name.') - parser.add_argument('--cmd', required=False, - help='Command to run on host.') - parser.add_argument('--file', required=False, - help='Path of file to copy to remote host.') - parser.add_argument('--dir', required=False, - help='Path of dir to copy to remote host.') - parser.add_argument('--delete-dir', required=False, - help='Path of dir to delete on remote host.') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Parallel SSH tester.") + parser.add_argument("--hosts", required=True, help="List of hosts to connect to") + parser.add_argument("--private-key", required=False, help="Private key to use.") + parser.add_argument("--password", required=False, help="Password.") + parser.add_argument("--user", required=True, help="SSH user name.") + parser.add_argument("--cmd", required=False, help="Command to run on host.") + parser.add_argument( + "--file", required=False, help="Path of file to copy to remote host." + ) + parser.add_argument( + "--dir", required=False, help="Path of dir to copy to remote host." + ) + parser.add_argument( + "--delete-dir", required=False, help="Path of dir to delete on remote host." + ) args = parser.parse_args() - main(user=args.user, pkey=args.private_key, password=args.password, - hosts_str=args.hosts, cmd=args.cmd, - file_path=args.file, dir_path=args.dir, delete_dir=args.delete_dir) + main( + user=args.user, + pkey=args.private_key, + password=args.password, + hosts_str=args.hosts, + cmd=args.cmd, + file_path=args.file, + dir_path=args.dir, + delete_dir=args.delete_dir, + ) diff --git a/st2common/dist_utils.py b/st2common/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2common/dist_utils.py +++ b/st2common/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2common/setup.py b/st2common/setup.py index f68679af8c..908884260d 100644 --- a/st2common/setup.py +++ b/st2common/setup.py @@ -23,10 +23,10 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -ST2_COMPONENT = 'st2common' +ST2_COMPONENT = "st2common" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -INIT_FILE = os.path.join(BASE_DIR, 'st2common/__init__.py') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +INIT_FILE = os.path.join(BASE_DIR, "st2common/__init__.py") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -34,41 +34,43 @@ setup( name=ST2_COMPONENT, version=get_version_string(INIT_FILE), - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2-bootstrap-rmq', - 'bin/st2-cleanup-db', - 'bin/st2-register-content', - 'bin/st2-purge-executions', - 'bin/st2-purge-trigger-instances', - 'bin/st2-run-pack-tests', - 'bin/st2ctl', - 'bin/st2-generate-symmetric-crypto-key', - 'bin/st2-self-check', - 'bin/st2-track-result', - 'bin/st2-validate-pack-config', - 'bin/st2-pack-install', - 'bin/st2-pack-download', - 'bin/st2-pack-setup-virtualenv' + "bin/st2-bootstrap-rmq", + "bin/st2-cleanup-db", + "bin/st2-register-content", + "bin/st2-purge-executions", + "bin/st2-purge-trigger-instances", + "bin/st2-run-pack-tests", + "bin/st2ctl", + "bin/st2-generate-symmetric-crypto-key", + "bin/st2-self-check", + "bin/st2-track-result", + "bin/st2-validate-pack-config", + "bin/st2-pack-install", + "bin/st2-pack-download", + "bin/st2-pack-setup-virtualenv", ], entry_points={ - 'st2common.metrics.driver': [ - 'statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver', - 'noop = st2common.metrics.drivers.noop_driver:NoopDriver', - 'echo = st2common.metrics.drivers.echo_driver:EchoDriver' + "st2common.metrics.driver": [ + "statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver", + "noop = st2common.metrics.drivers.noop_driver:NoopDriver", + "echo = st2common.metrics.drivers.echo_driver:EchoDriver", ], - 'st2common.rbac.backend': [ - 'noop = st2common.rbac.backends.noop:NoOpRBACBackend' + "st2common.rbac.backend": [ + "noop = st2common.rbac.backends.noop:NoOpRBACBackend" ], - } + }, ) diff --git a/st2common/st2common/__init__.py b/st2common/st2common/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2common/st2common/__init__.py +++ b/st2common/st2common/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2common/st2common/bootstrap/actionsregistrar.py b/st2common/st2common/bootstrap/actionsregistrar.py index c21788fb14..f5265bac48 100644 --- a/st2common/st2common/bootstrap/actionsregistrar.py +++ b/st2common/st2common/bootstrap/actionsregistrar.py @@ -30,10 +30,7 @@ import st2common.util.action_db as action_utils import st2common.validators.api.action as action_validator -__all__ = [ - 'ActionsRegistrar', - 'register_actions' -] +__all__ = ["ActionsRegistrar", "register_actions"] LOG = logging.getLogger(__name__) @@ -53,15 +50,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='actions') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="actions" + ) for pack, actions_dir in six.iteritems(content): if not actions_dir: - LOG.debug('Pack %s does not contain actions.', pack) + LOG.debug("Pack %s does not contain actions.", pack) continue try: - LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir) + LOG.debug( + "Registering actions from pack %s:, dir: %s", pack, actions_dir + ) actions = self._get_actions_from_pack(actions_dir) count = self._register_actions_from_pack(pack=pack, actions=actions) registered_count += count @@ -69,7 +69,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all actions from pack: %s', actions_dir) + LOG.exception( + "Failed registering all actions from pack: %s", actions_dir + ) return registered_count @@ -80,10 +82,11 @@ def register_from_pack(self, pack_dir): :return: Number of actions registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - actions_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='actions') + actions_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="actions" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -92,16 +95,18 @@ def register_from_pack(self, pack_dir): if not actions_dir: return registered_count - LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir) + LOG.debug("Registering actions from pack %s:, dir: %s", pack, actions_dir) try: actions = self._get_actions_from_pack(actions_dir=actions_dir) - registered_count = self._register_actions_from_pack(pack=pack, actions=actions) + registered_count = self._register_actions_from_pack( + pack=pack, actions=actions + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all actions from pack: %s', actions_dir) + LOG.exception("Failed registering all actions from pack: %s", actions_dir) return registered_count @@ -109,29 +114,33 @@ def _get_actions_from_pack(self, actions_dir): actions = self.get_resources_from_pack(resources_dir=actions_dir) # Exclude global actions configuration file - config_files = ['actions/config' + ext for ext in self.ALLOWED_EXTENSIONS] + config_files = ["actions/config" + ext for ext in self.ALLOWED_EXTENSIONS] for config_file in config_files: - actions = [file_path for file_path in actions if config_file not in file_path] + actions = [ + file_path for file_path in actions if config_file not in file_path + ] return actions def _register_action(self, pack, action): content = self._meta_loader.load(action) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=action, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=action, use_pack_cache=True + ) + content["metadata_file"] = metadata_file action_api = ActionAPI(**content) @@ -141,25 +150,29 @@ def _register_action(self, pack, action): # We throw a more user-friendly exception on invalid parameter name msg = six.text_type(e) - is_invalid_parameter_name = 'does not match any of the regexes: ' in msg + is_invalid_parameter_name = "does not match any of the regexes: " in msg if is_invalid_parameter_name: - match = re.search('\'(.+?)\' does not match any of the regexes', msg) + match = re.search("'(.+?)' does not match any of the regexes", msg) if match: parameter_name = match.groups()[0] else: - parameter_name = 'unknown' + parameter_name = "unknown" - new_msg = ('Parameter name "%s" is invalid. Valid characters for parameter name ' - 'are [a-zA-Z0-0_].' % (parameter_name)) - new_msg += '\n\n' + msg + new_msg = ( + 'Parameter name "%s" is invalid. Valid characters for parameter name ' + "are [a-zA-Z0-0_]." % (parameter_name) + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) raise e # Use in-memory cached RunnerTypeDB objects to reduce load on the database if self._use_runners_cache: - runner_type_db = self._runner_type_db_cache.get(action_api.runner_type, None) + runner_type_db = self._runner_type_db_cache.get( + action_api.runner_type, None + ) if not runner_type_db: runner_type_db = action_validator.get_runner_model(action_api) @@ -170,36 +183,47 @@ def _register_action(self, pack, action): action_validator.validate_action(action_api, runner_type_db=runner_type_db) model = ActionAPI.to_model(action_api) - action_ref = ResourceReference.to_string_reference(pack=pack, name=str(content['name'])) + action_ref = ResourceReference.to_string_reference( + pack=pack, name=str(content["name"]) + ) existing = action_utils.get_action_by_ref(action_ref) if not existing: - LOG.debug('Action %s not found. Creating new one with: %s', action_ref, content) + LOG.debug( + "Action %s not found. Creating new one with: %s", action_ref, content + ) else: - LOG.debug('Action %s found. Will be updated from: %s to: %s', - action_ref, existing, model) + LOG.debug( + "Action %s found. Will be updated from: %s to: %s", + action_ref, + existing, + model, + ) model.id = existing.id try: model = Action.add_or_update(model) - extra = {'action_db': model} - LOG.audit('Action updated. Action %s from %s.', model, action, extra=extra) + extra = {"action_db": model} + LOG.audit("Action updated. Action %s from %s.", model, action, extra=extra) except Exception: - LOG.exception('Failed to write action to db %s.', model.name) + LOG.exception("Failed to write action to db %s.", model.name) raise def _register_actions_from_pack(self, pack, actions): registered_count = 0 for action in actions: try: - LOG.debug('Loading action from %s.', action) + LOG.debug("Loading action from %s.", action) self._register_action(pack=pack, action=action) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register action "%s" from pack "%s": %s' % (action, pack, - six.text_type(e))) + msg = 'Failed to register action "%s" from pack "%s": %s' % ( + action, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register action: %s', action) + LOG.exception("Unable to register action: %s", action) continue else: registered_count += 1 @@ -207,16 +231,18 @@ def _register_actions_from_pack(self, pack, actions): return registered_count -def register_actions(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_actions( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = ActionsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = ActionsRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/aliasesregistrar.py b/st2common/st2common/bootstrap/aliasesregistrar.py index dbc9c3b0fc..c9d4ef7017 100644 --- a/st2common/st2common/bootstrap/aliasesregistrar.py +++ b/st2common/st2common/bootstrap/aliasesregistrar.py @@ -27,10 +27,7 @@ from st2common.persistence.actionalias import ActionAlias from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'AliasesRegistrar', - 'register_aliases' -] +__all__ = ["AliasesRegistrar", "register_aliases"] LOG = logging.getLogger(__name__) @@ -50,15 +47,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='aliases') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="aliases" + ) for pack, aliases_dir in six.iteritems(content): if not aliases_dir: - LOG.debug('Pack %s does not contain aliases.', pack) + LOG.debug("Pack %s does not contain aliases.", pack) continue try: - LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir) + LOG.debug( + "Registering aliases from pack %s:, dir: %s", pack, aliases_dir + ) aliases = self._get_aliases_from_pack(aliases_dir) count = self._register_aliases_from_pack(pack=pack, aliases=aliases) registered_count += count @@ -66,7 +66,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all aliases from pack: %s', aliases_dir) + LOG.exception( + "Failed registering all aliases from pack: %s", aliases_dir + ) return registered_count @@ -77,10 +79,11 @@ def register_from_pack(self, pack_dir): :return: Number of aliases registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - aliases_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='aliases') + aliases_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="aliases" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -89,16 +92,18 @@ def register_from_pack(self, pack_dir): if not aliases_dir: return registered_count - LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir) + LOG.debug("Registering aliases from pack %s:, dir: %s", pack, aliases_dir) try: aliases = self._get_aliases_from_pack(aliases_dir=aliases_dir) - registered_count = self._register_aliases_from_pack(pack=pack, aliases=aliases) + registered_count = self._register_aliases_from_pack( + pack=pack, aliases=aliases + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all aliases from pack: %s', aliases_dir) + LOG.exception("Failed registering all aliases from pack: %s", aliases_dir) return registered_count return registered_count @@ -106,7 +111,9 @@ def register_from_pack(self, pack_dir): def _get_aliases_from_pack(self, aliases_dir): return self.get_resources_from_pack(resources_dir=aliases_dir) - def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=False): + def _get_action_alias_db( + self, pack, action_alias, ignore_metadata_file_error=False + ): """ Retrieve ActionAliasDB object. @@ -115,25 +122,27 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa :type ignore_metadata_file_error: ``bool`` """ content = self._meta_loader.load(action_alias) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory try: - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=action_alias, - use_pack_cache=True) + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=action_alias, use_pack_cache=True + ) except ValueError as e: if not ignore_metadata_file_error: raise e else: - content['metadata_file'] = metadata_file + content["metadata_file"] = metadata_file action_alias_api = ActionAliasAPI(**content) action_alias_api.validate() @@ -142,28 +151,35 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa return action_alias_db def _register_action_alias(self, pack, action_alias): - action_alias_db = self._get_action_alias_db(pack=pack, - action_alias=action_alias) + action_alias_db = self._get_action_alias_db( + pack=pack, action_alias=action_alias + ) try: action_alias_db.id = ActionAlias.get_by_name(action_alias_db.name).id except StackStormDBObjectNotFoundError: - LOG.debug('ActionAlias %s not found. Creating new one.', action_alias) + LOG.debug("ActionAlias %s not found. Creating new one.", action_alias) action_ref = action_alias_db.action_ref action_db = Action.get_by_ref(action_ref) if not action_db: - LOG.warning('Action %s not found in DB. Did you forget to register the action?', - action_ref) + LOG.warning( + "Action %s not found in DB. Did you forget to register the action?", + action_ref, + ) try: action_alias_db = ActionAlias.add_or_update(action_alias_db) - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias updated. Action alias %s from %s.', action_alias_db, - action_alias, extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias updated. Action alias %s from %s.", + action_alias_db, + action_alias, + extra=extra, + ) except Exception: - LOG.exception('Failed to create action alias %s.', action_alias_db.name) + LOG.exception("Failed to create action alias %s.", action_alias_db.name) raise def _register_aliases_from_pack(self, pack, aliases): @@ -171,15 +187,18 @@ def _register_aliases_from_pack(self, pack, aliases): for alias in aliases: try: - LOG.debug('Loading alias from %s.', alias) + LOG.debug("Loading alias from %s.", alias) self._register_action_alias(pack, alias) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register alias "%s" from pack "%s": %s' % (alias, pack, - six.text_type(e))) + msg = 'Failed to register alias "%s" from pack "%s": %s' % ( + alias, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register alias: %s', alias) + LOG.exception("Unable to register alias: %s", alias) continue else: registered_count += 1 @@ -187,8 +206,9 @@ def _register_aliases_from_pack(self, pack, aliases): return registered_count -def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_aliases( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) @@ -196,8 +216,9 @@ def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True, if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = AliasesRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = AliasesRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/base.py b/st2common/st2common/bootstrap/base.py index 1757a2fa8e..1070a3af38 100644 --- a/st2common/st2common/bootstrap/base.py +++ b/st2common/st2common/bootstrap/base.py @@ -32,9 +32,7 @@ from st2common.util.pack import get_pack_ref_from_metadata from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ResourceRegistrar' -] +__all__ = ["ResourceRegistrar"] LOG = logging.getLogger(__name__) @@ -44,16 +42,15 @@ # a long running process. REGISTERED_PACKS_CACHE = {} -EXCLUDE_FILE_PATTERNS = [ - '*.pyc', - '.git/*' -] +EXCLUDE_FILE_PATTERNS = ["*.pyc", ".git/*"] class ResourceRegistrar(object): ALLOWED_EXTENSIONS = [] - def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False): + def __init__( + self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False + ): """ :param use_pack_cache: True to cache which packs have been registered in memory and making sure packs are only registered once. @@ -81,10 +78,10 @@ def get_resources_from_pack(self, resources_dir): for ext in self.ALLOWED_EXTENSIONS: resources_glob = resources_dir - if resources_dir.endswith('/'): + if resources_dir.endswith("/"): resources_glob = resources_dir + ext else: - resources_glob = resources_dir + '/*' + ext + resources_glob = resources_dir + "/*" + ext resource_files = glob.glob(resources_glob) resources.extend(resource_files) @@ -121,7 +118,7 @@ def register_pack(self, pack_name, pack_dir): # This pack has already been registered during this register content run return - LOG.debug('Registering pack: %s' % (pack_name)) + LOG.debug("Registering pack: %s" % (pack_name)) REGISTERED_PACKS_CACHE[pack_name] = True try: @@ -148,19 +145,26 @@ def _register_pack(self, pack_name, pack_dir): # Display a warning if pack contains deprecated config.yaml file. Support for those files # will be fully removed in v2.4.0. - config_path = os.path.join(pack_dir, 'config.yaml') + config_path = os.path.join(pack_dir, "config.yaml") if os.path.isfile(config_path): - LOG.error('Pack "%s" contains a deprecated config.yaml file (%s). ' - 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 ' - 'in favor of config.schema.yaml config schema files and config files in ' - '/opt/stackstorm/configs/ directory. Support for config.yaml files has ' - 'been removed in the release (v2.4.0) so please migrate. For more ' - 'information please refer to %s ' % (pack_db.name, config_path, - 'https://docs.stackstorm.com/reference/pack_configs.html')) + LOG.error( + 'Pack "%s" contains a deprecated config.yaml file (%s). ' + 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 ' + "in favor of config.schema.yaml config schema files and config files in " + "/opt/stackstorm/configs/ directory. Support for config.yaml files has " + "been removed in the release (v2.4.0) so please migrate. For more " + "information please refer to %s " + % ( + pack_db.name, + config_path, + "https://docs.stackstorm.com/reference/pack_configs.html", + ) + ) # 2. Register corresponding pack config schema - config_schema_db = self._register_pack_config_schema_db(pack_name=pack_name, - pack_dir=pack_dir) + config_schema_db = self._register_pack_config_schema_db( + pack_name=pack_name, pack_dir=pack_dir + ) return pack_db, config_schema_db @@ -173,25 +177,28 @@ def _register_pack_db(self, pack_name, pack_dir): # 2hich are in sub-directories) # 2. If attribute is not available, but pack name is and pack name meets the valid name # criteria, we use that - content['ref'] = get_pack_ref_from_metadata(metadata=content, - pack_directory_name=pack_name) + content["ref"] = get_pack_ref_from_metadata( + metadata=content, pack_directory_name=pack_name + ) # Include a list of pack files - pack_file_list = get_file_list(directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS) - content['files'] = pack_file_list - content['path'] = pack_dir + pack_file_list = get_file_list( + directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS + ) + content["files"] = pack_file_list + content["path"] = pack_dir pack_api = PackAPI(**content) pack_api.validate() pack_db = PackAPI.to_model(pack_api) try: - pack_db.id = Pack.get_by_ref(content['ref']).id + pack_db.id = Pack.get_by_ref(content["ref"]).id except StackStormDBObjectNotFoundError: - LOG.debug('Pack %s not found. Creating new one.', pack_name) + LOG.debug("Pack %s not found. Creating new one.", pack_name) pack_db = Pack.add_or_update(pack_db) - LOG.debug('Pack %s registered.' % (pack_name)) + LOG.debug("Pack %s registered." % (pack_name)) return pack_db def _register_pack_config_schema_db(self, pack_name, pack_dir): @@ -204,11 +211,13 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir): values = self._meta_loader.load(config_schema_path) if not values: - raise ValueError('Config schema "%s" is empty and invalid.' % (config_schema_path)) + raise ValueError( + 'Config schema "%s" is empty and invalid.' % (config_schema_path) + ) content = {} - content['pack'] = pack_name - content['attributes'] = values + content["pack"] = pack_name + content["attributes"] = values config_schema_api = ConfigSchemaAPI(**content) config_schema_api = config_schema_api.validate() @@ -217,8 +226,10 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir): try: config_schema_db.id = ConfigSchema.get_by_pack(pack_name).id except StackStormDBObjectNotFoundError: - LOG.debug('Config schema for pack %s not found. Creating new one.', pack_name) + LOG.debug( + "Config schema for pack %s not found. Creating new one.", pack_name + ) config_schema_db = ConfigSchema.add_or_update(config_schema_db) - LOG.debug('Config schema for pack %s registered.' % (pack_name)) + LOG.debug("Config schema for pack %s registered." % (pack_name)) return config_schema_db diff --git a/st2common/st2common/bootstrap/configsregistrar.py b/st2common/st2common/bootstrap/configsregistrar.py index fc7e05eb98..3cbc5283fc 100644 --- a/st2common/st2common/bootstrap/configsregistrar.py +++ b/st2common/st2common/bootstrap/configsregistrar.py @@ -28,9 +28,7 @@ from st2common.persistence.pack import Config from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ConfigsRegistrar' -] +__all__ = ["ConfigsRegistrar"] LOG = logging.getLogger(__name__) @@ -44,11 +42,18 @@ class ConfigsRegistrar(ResourceRegistrar): ALLOWED_EXTENSIONS = ALLOWED_EXTS - def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False, - validate_configs=True): - super(ConfigsRegistrar, self).__init__(use_pack_cache=use_pack_cache, - use_runners_cache=use_runners_cache, - fail_on_failure=fail_on_failure) + def __init__( + self, + use_pack_cache=True, + use_runners_cache=False, + fail_on_failure=False, + validate_configs=True, + ): + super(ConfigsRegistrar, self).__init__( + use_pack_cache=use_pack_cache, + use_runners_cache=use_runners_cache, + fail_on_failure=fail_on_failure, + ) self._validate_configs = validate_configs @@ -68,21 +73,29 @@ def register_from_packs(self, base_dirs): if not os.path.isfile(config_path): # Config for that pack doesn't exist - LOG.debug('No config found for pack "%s" (file "%s" is not present).', pack_name, - config_path) + LOG.debug( + 'No config found for pack "%s" (file "%s" is not present).', + pack_name, + config_path, + ) continue try: self._register_config_for_pack(pack=pack_name, config_path=config_path) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register config "%s" for pack "%s": %s' % (config_path, - pack_name, - six.text_type(e))) + msg = 'Failed to register config "%s" for pack "%s": %s' % ( + config_path, + pack_name, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Failed to register config for pack "%s": %s', pack_name, - six.text_type(e)) + LOG.exception( + 'Failed to register config for pack "%s": %s', + pack_name, + six.text_type(e), + ) else: registered_count += 1 @@ -92,7 +105,7 @@ def register_from_pack(self, pack_dir): """ Register config for a provided pack. """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack_name = os.path.split(pack_dir) # Register pack first @@ -106,8 +119,8 @@ def register_from_pack(self, pack_dir): return 1 def _get_config_path_for_pack(self, pack_name): - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % (pack_name)) + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % (pack_name)) return config_path @@ -115,8 +128,8 @@ def _register_config_for_pack(self, pack, config_path): content = {} values = self._meta_loader.load(config_path) - content['pack'] = pack - content['values'] = values + content["pack"] = pack + content["values"] = values config_api = ConfigAPI(**content) config_api.validate(validate_against_schema=self._validate_configs) @@ -136,17 +149,22 @@ def save_model(config_api): try: config_db = Config.add_or_update(config_db) - extra = {'config_db': config_db} + extra = {"config_db": config_db} LOG.audit('Config for pack "%s" is updated.', config_db.pack, extra=extra) except Exception: - LOG.exception('Failed to save config for pack %s.', pack) + LOG.exception("Failed to save config for pack %s.", pack) raise return config_db -def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False, validate_configs=True): +def register_configs( + packs_base_paths=None, + pack_dir=None, + use_pack_cache=True, + fail_on_failure=False, + validate_configs=True, +): if packs_base_paths: assert isinstance(packs_base_paths, list) @@ -154,9 +172,11 @@ def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True, if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = ConfigsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure, - validate_configs=validate_configs) + registrar = ConfigsRegistrar( + use_pack_cache=use_pack_cache, + fail_on_failure=fail_on_failure, + validate_configs=validate_configs, + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/policiesregistrar.py b/st2common/st2common/bootstrap/policiesregistrar.py index b963eaf097..4f6f247694 100644 --- a/st2common/st2common/bootstrap/policiesregistrar.py +++ b/st2common/st2common/bootstrap/policiesregistrar.py @@ -30,11 +30,7 @@ from st2common.util import loader -__all__ = [ - 'PolicyRegistrar', - 'register_policy_types', - 'register_policies' -] +__all__ = ["PolicyRegistrar", "register_policy_types", "register_policies"] LOG = logging.getLogger(__name__) @@ -55,15 +51,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='policies') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="policies" + ) for pack, policies_dir in six.iteritems(content): if not policies_dir: - LOG.debug('Pack %s does not contain policies.', pack) + LOG.debug("Pack %s does not contain policies.", pack) continue try: - LOG.debug('Registering policies from pack %s:, dir: %s', pack, policies_dir) + LOG.debug( + "Registering policies from pack %s:, dir: %s", pack, policies_dir + ) policies = self._get_policies_from_pack(policies_dir) count = self._register_policies_from_pack(pack=pack, policies=policies) registered_count += count @@ -71,7 +70,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all policies from pack: %s', policies_dir) + LOG.exception( + "Failed registering all policies from pack: %s", policies_dir + ) return registered_count @@ -82,11 +83,12 @@ def register_from_pack(self, pack_dir): :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - policies_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='policies') + policies_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="policies" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -95,16 +97,18 @@ def register_from_pack(self, pack_dir): if not policies_dir: return registered_count - LOG.debug('Registering policies from pack %s, dir: %s', pack, policies_dir) + LOG.debug("Registering policies from pack %s, dir: %s", pack, policies_dir) try: policies = self._get_policies_from_pack(policies_dir=policies_dir) - registered_count = self._register_policies_from_pack(pack=pack, policies=policies) + registered_count = self._register_policies_from_pack( + pack=pack, policies=policies + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all policies from pack: %s', policies_dir) + LOG.exception("Failed registering all policies from pack: %s", policies_dir) return registered_count return registered_count @@ -117,15 +121,18 @@ def _register_policies_from_pack(self, pack, policies): for policy in policies: try: - LOG.debug('Loading policy from %s.', policy) + LOG.debug("Loading policy from %s.", policy) self._register_policy(pack=pack, policy=policy) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register policy "%s" from pack "%s": %s' % (policy, pack, - six.text_type(e))) + msg = 'Failed to register policy "%s" from pack "%s": %s' % ( + policy, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register policy: %s', policy) + LOG.exception("Unable to register policy: %s", policy) continue else: registered_count += 1 @@ -134,20 +141,22 @@ def _register_policies_from_pack(self, pack, policies): def _register_policy(self, pack, policy): content = self._meta_loader.load(policy) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=policy, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=policy, use_pack_cache=True + ) + content["metadata_file"] = metadata_file policy_api = PolicyAPI(**content) policy_api = policy_api.validate() @@ -160,21 +169,21 @@ def _register_policy(self, pack, policy): try: policy_db = Policy.add_or_update(policy_db) - extra = {'policy_db': policy_db} + extra = {"policy_db": policy_db} LOG.audit('Policy "%s" is updated.', policy_db.ref, extra=extra) except Exception: - LOG.exception('Failed to create policy %s.', policy_api.name) + LOG.exception("Failed to create policy %s.", policy_api.name) raise def register_policy_types(module): registered_count = 0 mod_path = os.path.dirname(os.path.realpath(sys.modules[module.__name__].__file__)) - path = os.path.join(mod_path, 'policies/meta') + path = os.path.join(mod_path, "policies/meta") files = [] for ext in ALLOWED_EXTS: - exp = '%s/*%s' % (path, ext) + exp = "%s/*%s" % (path, ext) files += glob.glob(exp) for f in files: @@ -189,11 +198,13 @@ def register_policy_types(module): if existing_entry: policy_type_db.id = existing_entry.id except StackStormDBObjectNotFoundError: - LOG.debug('Policy type "%s" is not found. Creating new entry.', - policy_type_db.ref) + LOG.debug( + 'Policy type "%s" is not found. Creating new entry.', + policy_type_db.ref, + ) policy_type_db = PolicyType.add_or_update(policy_type_db) - extra = {'policy_type_db': policy_type_db} + extra = {"policy_type_db": policy_type_db} LOG.audit('Policy type "%s" is updated.', policy_type_db.ref, extra=extra) registered_count += 1 @@ -203,16 +214,18 @@ def register_policy_types(module): return registered_count -def register_policies(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_policies( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = PolicyRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = PolicyRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/rulesregistrar.py b/st2common/st2common/bootstrap/rulesregistrar.py index c50b0d5eae..505f3e5337 100644 --- a/st2common/st2common/bootstrap/rulesregistrar.py +++ b/st2common/st2common/bootstrap/rulesregistrar.py @@ -25,14 +25,14 @@ from st2common.models.api.rule import RuleAPI from st2common.models.system.common import ResourceReference from st2common.persistence.rule import Rule -from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count +from st2common.services.triggers import ( + cleanup_trigger_db_for_rule, + increment_trigger_ref_count, +) from st2common.exceptions.db import StackStormDBObjectNotFoundError import st2common.content.utils as content_utils -__all__ = [ - 'RulesRegistrar', - 'register_rules' -] +__all__ = ["RulesRegistrar", "register_rules"] LOG = logging.getLogger(__name__) @@ -49,14 +49,15 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='rules') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="rules" + ) for pack, rules_dir in six.iteritems(content): if not rules_dir: - LOG.debug('Pack %s does not contain rules.', pack) + LOG.debug("Pack %s does not contain rules.", pack) continue try: - LOG.debug('Registering rules from pack: %s', pack) + LOG.debug("Registering rules from pack: %s", pack) rules = self._get_rules_from_pack(rules_dir) count = self._register_rules_from_pack(pack, rules) registered_count += count @@ -64,7 +65,7 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all rules from pack: %s', rules_dir) + LOG.exception("Failed registering all rules from pack: %s", rules_dir) return registered_count @@ -75,10 +76,11 @@ def register_from_pack(self, pack_dir): :return: Number of rules registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - rules_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='rules') + rules_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="rules" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -87,7 +89,7 @@ def register_from_pack(self, pack_dir): if not rules_dir: return registered_count - LOG.debug('Registering rules from pack %s:, dir: %s', pack, rules_dir) + LOG.debug("Registering rules from pack %s:, dir: %s", pack, rules_dir) try: rules = self._get_rules_from_pack(rules_dir=rules_dir) @@ -96,7 +98,7 @@ def register_from_pack(self, pack_dir): if self._fail_on_failure: raise e - LOG.exception('Failed registering all rules from pack: %s', rules_dir) + LOG.exception("Failed registering all rules from pack: %s", rules_dir) return registered_count @@ -108,21 +110,23 @@ def _register_rules_from_pack(self, pack, rules): # TODO: Refactor this monstrosity for rule in rules: - LOG.debug('Loading rule from %s.', rule) + LOG.debug("Loading rule from %s.", rule) try: content = self._meta_loader.load(rule) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=rule, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=rule, use_pack_cache=True + ) + content["metadata_file"] = metadata_file rule_api = RuleAPI(**content) rule_api.validate() @@ -134,35 +138,48 @@ def _register_rules_from_pack(self, pack, rules): # delete so we don't have duplicates. if pack_field != DEFAULT_PACK_NAME: try: - rule_ref = ResourceReference.to_string_reference(name=content['name'], - pack=DEFAULT_PACK_NAME) - LOG.debug('Looking for rule %s in pack %s', content['name'], - DEFAULT_PACK_NAME) + rule_ref = ResourceReference.to_string_reference( + name=content["name"], pack=DEFAULT_PACK_NAME + ) + LOG.debug( + "Looking for rule %s in pack %s", + content["name"], + DEFAULT_PACK_NAME, + ) existing = Rule.get_by_ref(rule_ref) - LOG.debug('Existing = %s', existing) + LOG.debug("Existing = %s", existing) if existing: - LOG.debug('Found rule in pack default: %s; Deleting.', rule_ref) + LOG.debug( + "Found rule in pack default: %s; Deleting.", rule_ref + ) Rule.delete(existing) except: - LOG.exception('Exception deleting rule from %s pack.', DEFAULT_PACK_NAME) + LOG.exception( + "Exception deleting rule from %s pack.", DEFAULT_PACK_NAME + ) try: - rule_ref = ResourceReference.to_string_reference(name=content['name'], - pack=content['pack']) + rule_ref = ResourceReference.to_string_reference( + name=content["name"], pack=content["pack"] + ) existing = Rule.get_by_ref(rule_ref) if existing: rule_db.id = existing.id - LOG.debug('Found existing rule: %s with id: %s', rule_ref, existing.id) + LOG.debug( + "Found existing rule: %s with id: %s", rule_ref, existing.id + ) except StackStormDBObjectNotFoundError: - LOG.debug('Rule %s not found. Creating new one.', rule) + LOG.debug("Rule %s not found. Creating new one.", rule) try: rule_db = Rule.add_or_update(rule_db) increment_trigger_ref_count(rule_api=rule_api) - extra = {'rule_db': rule_db} - LOG.audit('Rule updated. Rule %s from %s.', rule_db, rule, extra=extra) + extra = {"rule_db": rule_db} + LOG.audit( + "Rule updated. Rule %s from %s.", rule_db, rule, extra=extra + ) except Exception: - LOG.exception('Failed to create rule %s.', rule_api.name) + LOG.exception("Failed to create rule %s.", rule_api.name) # If there was an existing rule then the ref count was updated in # to_model so it needs to be adjusted down here. Also, update could @@ -171,27 +188,32 @@ def _register_rules_from_pack(self, pack, rules): cleanup_trigger_db_for_rule(existing) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register rule "%s" from pack "%s": %s' % (rule, pack, - six.text_type(e))) + msg = 'Failed to register rule "%s" from pack "%s": %s' % ( + rule, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Failed registering rule from %s.', rule) + LOG.exception("Failed registering rule from %s.", rule) else: registered_count += 1 return registered_count -def register_rules(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_rules( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = RulesRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = RulesRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/ruletypesregistrar.py b/st2common/st2common/bootstrap/ruletypesregistrar.py index 735294cd23..90d4018a40 100644 --- a/st2common/st2common/bootstrap/ruletypesregistrar.py +++ b/st2common/st2common/bootstrap/ruletypesregistrar.py @@ -22,41 +22,36 @@ from st2common.persistence.rule import RuleType from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'register_rule_types', - 'RULE_TYPES' -] +__all__ = ["register_rule_types", "RULE_TYPES"] LOG = logging.getLogger(__name__) RULE_TYPES = [ { - 'name': RULE_TYPE_STANDARD, - 'description': 'standard rule that is always applicable.', - 'enabled': True, - 'parameters': { - } + "name": RULE_TYPE_STANDARD, + "description": "standard rule that is always applicable.", + "enabled": True, + "parameters": {}, }, { - 'name': RULE_TYPE_BACKSTOP, - 'description': 'Rule that applies when no other rule has matched for a specific Trigger.', - 'enabled': True, - 'parameters': { - } + "name": RULE_TYPE_BACKSTOP, + "description": "Rule that applies when no other rule has matched for a specific Trigger.", + "enabled": True, + "parameters": {}, }, ] def register_rule_types(): - LOG.debug('Start : register default RuleTypes.') + LOG.debug("Start : register default RuleTypes.") registered_count = 0 for rule_type in RULE_TYPES: rule_type = copy.deepcopy(rule_type) try: - rule_type_db = RuleType.get_by_name(rule_type['name']) + rule_type_db = RuleType.get_by_name(rule_type["name"]) update = True except StackStormDBObjectNotFoundError: rule_type_db = None @@ -72,16 +67,16 @@ def register_rule_types(): try: rule_type_db = RuleType.add_or_update(rule_type_model) - extra = {'rule_type_db': rule_type_db} + extra = {"rule_type_db": rule_type_db} if update: - LOG.audit('RuleType updated. RuleType %s', rule_type_db, extra=extra) + LOG.audit("RuleType updated. RuleType %s", rule_type_db, extra=extra) else: - LOG.audit('RuleType created. RuleType %s', rule_type_db, extra=extra) + LOG.audit("RuleType created. RuleType %s", rule_type_db, extra=extra) except Exception: - LOG.exception('Unable to register RuleType %s.', rule_type['name']) + LOG.exception("Unable to register RuleType %s.", rule_type["name"]) else: registered_count += 1 - LOG.debug('End : register default RuleTypes.') + LOG.debug("End : register default RuleTypes.") return registered_count diff --git a/st2common/st2common/bootstrap/runnersregistrar.py b/st2common/st2common/bootstrap/runnersregistrar.py index 3aa93da9b1..bb99389433 100644 --- a/st2common/st2common/bootstrap/runnersregistrar.py +++ b/st2common/st2common/bootstrap/runnersregistrar.py @@ -26,7 +26,7 @@ from st2common.util.action_db import get_runnertype_by_name __all__ = [ - 'register_runner_types', + "register_runner_types", ] @@ -37,7 +37,7 @@ def register_runners(experimental=False, fail_on_failure=True): """ Register runners """ - LOG.debug('Start : register runners') + LOG.debug("Start : register runners") runner_count = 0 manager = ExtensionManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False) @@ -46,28 +46,30 @@ def register_runners(experimental=False, fail_on_failure=True): for name in extension_names: LOG.debug('Found runner "%s"' % (name)) - manager = DriverManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name) + manager = DriverManager( + namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name + ) runner_metadata = manager.driver.get_metadata() runner_count += register_runner(runner_metadata, experimental) - LOG.debug('End : register runners') + LOG.debug("End : register runners") return runner_count def register_runner(runner_type, experimental): # For backward compatibility reasons, we also register runners under the old names - runner_names = [runner_type['name']] + runner_type.get('aliases', []) + runner_names = [runner_type["name"]] + runner_type.get("aliases", []) for runner_name in runner_names: - runner_type['name'] = runner_name - runner_experimental = runner_type.get('experimental', False) + runner_type["name"] = runner_name + runner_experimental = runner_type.get("experimental", False) if runner_experimental and not experimental: LOG.debug('Skipping experimental runner "%s"' % (runner_name)) continue # Remove additional, non db-model attributes - non_db_attributes = ['experimental', 'aliases'] + non_db_attributes = ["experimental", "aliases"] for attribute in non_db_attributes: if attribute in runner_type: del runner_type[attribute] @@ -81,13 +83,13 @@ def register_runner(runner_type, experimental): # Note: We don't want to overwrite "enabled" attribute which is already in the database # (aka we don't want to re-enable runner which has been disabled by the user) - if runner_type_db and runner_type_db['enabled'] != runner_type['enabled']: - runner_type['enabled'] = runner_type_db['enabled'] + if runner_type_db and runner_type_db["enabled"] != runner_type["enabled"]: + runner_type["enabled"] = runner_type_db["enabled"] # If package is not provided, assume it's the same as module name for backward # compatibility reasons - if not runner_type.get('runner_package', None): - runner_type['runner_package'] = runner_type['runner_module'] + if not runner_type.get("runner_package", None): + runner_type["runner_package"] = runner_type["runner_module"] runner_type_api = RunnerTypeAPI(**runner_type) runner_type_api.validate() @@ -100,13 +102,17 @@ def register_runner(runner_type, experimental): runner_type_db = RunnerType.add_or_update(runner_type_model) - extra = {'runner_type_db': runner_type_db} + extra = {"runner_type_db": runner_type_db} if update: - LOG.audit('RunnerType updated. RunnerType %s', runner_type_db, extra=extra) + LOG.audit( + "RunnerType updated. RunnerType %s", runner_type_db, extra=extra + ) else: - LOG.audit('RunnerType created. RunnerType %s', runner_type_db, extra=extra) + LOG.audit( + "RunnerType created. RunnerType %s", runner_type_db, extra=extra + ) except Exception: - LOG.exception('Unable to register runner type %s.', runner_type['name']) + LOG.exception("Unable to register runner type %s.", runner_type["name"]) return 0 return 1 diff --git a/st2common/st2common/bootstrap/sensorsregistrar.py b/st2common/st2common/bootstrap/sensorsregistrar.py index 5181270d79..8a91e23eea 100644 --- a/st2common/st2common/bootstrap/sensorsregistrar.py +++ b/st2common/st2common/bootstrap/sensorsregistrar.py @@ -26,10 +26,7 @@ from st2common.models.api.sensor import SensorTypeAPI from st2common.persistence.sensor import SensorType -__all__ = [ - 'SensorsRegistrar', - 'register_sensors' -] +__all__ = ["SensorsRegistrar", "register_sensors"] LOG = logging.getLogger(__name__) @@ -51,15 +48,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='sensors') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="sensors" + ) for pack, sensors_dir in six.iteritems(content): if not sensors_dir: - LOG.debug('Pack %s does not contain sensors.', pack) + LOG.debug("Pack %s does not contain sensors.", pack) continue try: - LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir) + LOG.debug( + "Registering sensors from pack %s:, dir: %s", pack, sensors_dir + ) sensors = self._get_sensors_from_pack(sensors_dir) count = self._register_sensors_from_pack(pack=pack, sensors=sensors) registered_count += count @@ -67,8 +67,11 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all sensors from pack "%s": %s', + sensors_dir, + six.text_type(e), + ) return registered_count @@ -79,10 +82,11 @@ def register_from_pack(self, pack_dir): :return: Number of sensors registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - sensors_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='sensors') + sensors_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="sensors" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -91,17 +95,22 @@ def register_from_pack(self, pack_dir): if not sensors_dir: return registered_count - LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir) + LOG.debug("Registering sensors from pack %s:, dir: %s", pack, sensors_dir) try: sensors = self._get_sensors_from_pack(sensors_dir=sensors_dir) - registered_count = self._register_sensors_from_pack(pack=pack, sensors=sensors) + registered_count = self._register_sensors_from_pack( + pack=pack, sensors=sensors + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all sensors from pack "%s": %s', + sensors_dir, + six.text_type(e), + ) return registered_count @@ -115,11 +124,16 @@ def _register_sensors_from_pack(self, pack, sensors): self._register_sensor_from_pack(pack=pack, sensor=sensor) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register sensor "%s" from pack "%s": %s' % (sensor, pack, - six.text_type(e))) + msg = 'Failed to register sensor "%s" from pack "%s": %s' % ( + sensor, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.debug('Failed to register sensor "%s": %s', sensor, six.text_type(e)) + LOG.debug( + 'Failed to register sensor "%s": %s', sensor, six.text_type(e) + ) else: LOG.debug('Sensor "%s" successfully registered', sensor) registered_count += 1 @@ -129,33 +143,35 @@ def _register_sensors_from_pack(self, pack, sensors): def _register_sensor_from_pack(self, pack, sensor): sensor_metadata_file_path = sensor - LOG.debug('Loading sensor from %s.', sensor_metadata_file_path) + LOG.debug("Loading sensor from %s.", sensor_metadata_file_path) content = self._meta_loader.load(file_path=sensor_metadata_file_path) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) - entry_point = content.get('entry_point', None) + entry_point = content.get("entry_point", None) if not entry_point: - raise ValueError('Sensor definition missing entry_point') + raise ValueError("Sensor definition missing entry_point") # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=sensor, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=sensor, use_pack_cache=True + ) + content["metadata_file"] = metadata_file sensors_dir = os.path.dirname(sensor_metadata_file_path) sensor_file_path = os.path.join(sensors_dir, entry_point) - artifact_uri = 'file://%s' % (sensor_file_path) - content['artifact_uri'] = artifact_uri - content['entry_point'] = entry_point + artifact_uri = "file://%s" % (sensor_file_path) + content["artifact_uri"] = artifact_uri + content["entry_point"] = entry_point sensor_api = SensorTypeAPI(**content) sensor_model = SensorTypeAPI.to_model(sensor_api) @@ -163,28 +179,33 @@ def _register_sensor_from_pack(self, pack, sensor): sensor_types = SensorType.query(pack=sensor_model.pack, name=sensor_model.name) if len(sensor_types) >= 1: sensor_type = sensor_types[0] - LOG.debug('Found existing sensor id:%s with name:%s. Will update it.', - sensor_type.id, sensor_type.name) + LOG.debug( + "Found existing sensor id:%s with name:%s. Will update it.", + sensor_type.id, + sensor_type.name, + ) sensor_model.id = sensor_type.id try: sensor_model = SensorType.add_or_update(sensor_model) except: - LOG.exception('Failed creating sensor model for %s', sensor) + LOG.exception("Failed creating sensor model for %s", sensor) return sensor_model -def register_sensors(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_sensors( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = SensorsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = SensorsRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/triggersregistrar.py b/st2common/st2common/bootstrap/triggersregistrar.py index 4f95a6d0a3..180c9cb885 100644 --- a/st2common/st2common/bootstrap/triggersregistrar.py +++ b/st2common/st2common/bootstrap/triggersregistrar.py @@ -24,10 +24,7 @@ import st2common.content.utils as content_utils from st2common.models.utils import sensor_type_utils -__all__ = [ - 'TriggersRegistrar', - 'register_triggers' -] +__all__ = ["TriggersRegistrar", "register_triggers"] LOG = logging.getLogger(__name__) @@ -47,15 +44,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='triggers') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="triggers" + ) for pack, triggers_dir in six.iteritems(content): if not triggers_dir: - LOG.debug('Pack %s does not contain triggers.', pack) + LOG.debug("Pack %s does not contain triggers.", pack) continue try: - LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir) + LOG.debug( + "Registering triggers from pack %s:, dir: %s", pack, triggers_dir + ) triggers = self._get_triggers_from_pack(triggers_dir) count = self._register_triggers_from_pack(pack=pack, triggers=triggers) registered_count += count @@ -63,8 +63,11 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all triggers from pack "%s": %s', + triggers_dir, + six.text_type(e), + ) return registered_count @@ -75,10 +78,11 @@ def register_from_pack(self, pack_dir): :return: Number of triggers registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - triggers_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='triggers') + triggers_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="triggers" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -87,17 +91,22 @@ def register_from_pack(self, pack_dir): if not triggers_dir: return registered_count - LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir) + LOG.debug("Registering triggers from pack %s:, dir: %s", pack, triggers_dir) try: triggers = self._get_triggers_from_pack(triggers_dir=triggers_dir) - registered_count = self._register_triggers_from_pack(pack=pack, triggers=triggers) + registered_count = self._register_triggers_from_pack( + pack=pack, triggers=triggers + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all triggers from pack "%s": %s', + triggers_dir, + six.text_type(e), + ) return registered_count @@ -107,20 +116,27 @@ def _get_triggers_from_pack(self, triggers_dir): def _register_triggers_from_pack(self, pack, triggers): registered_count = 0 - pack_base_path = content_utils.get_pack_base_path(pack_name=pack, - include_trailing_slash=True) + pack_base_path = content_utils.get_pack_base_path( + pack_name=pack, include_trailing_slash=True + ) for trigger in triggers: try: - self._register_trigger_from_pack(pack_base_path=pack_base_path, pack=pack, - trigger=trigger) + self._register_trigger_from_pack( + pack_base_path=pack_base_path, pack=pack, trigger=trigger + ) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register trigger "%s" from pack "%s": %s' % (trigger, pack, - six.text_type(e))) + msg = 'Failed to register trigger "%s" from pack "%s": %s' % ( + trigger, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.debug('Failed to register trigger "%s": %s', trigger, six.text_type(e)) + LOG.debug( + 'Failed to register trigger "%s": %s', trigger, six.text_type(e) + ) else: LOG.debug('Trigger "%s" successfully registered', trigger) registered_count += 1 @@ -130,37 +146,41 @@ def _register_triggers_from_pack(self, pack, triggers): def _register_trigger_from_pack(self, pack_base_path, pack, trigger): trigger_metadata_file_path = trigger - LOG.debug('Loading trigger from %s.', trigger_metadata_file_path) + LOG.debug("Loading trigger from %s.", trigger_metadata_file_path) content = self._meta_loader.load(file_path=trigger_metadata_file_path) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = trigger.replace(pack_base_path, '') - content['metadata_file'] = metadata_file + metadata_file = trigger.replace(pack_base_path, "") + content["metadata_file"] = metadata_file trigger_types = [content] result = sensor_type_utils.create_trigger_types(trigger_types=trigger_types) return result[0] if result else None -def register_triggers(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_triggers( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = TriggersRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = TriggersRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/callback/base.py b/st2common/st2common/callback/base.py index ae1b55e501..a48fcbecb9 100644 --- a/st2common/st2common/callback/base.py +++ b/st2common/st2common/callback/base.py @@ -21,7 +21,7 @@ __all__ = [ - 'AsyncActionExecutionCallbackHandler', + "AsyncActionExecutionCallbackHandler", ] @@ -30,7 +30,6 @@ @six.add_metaclass(abc.ABCMeta) class AsyncActionExecutionCallbackHandler(object): - @staticmethod @abc.abstractmethod def callback(liveaction): diff --git a/st2common/st2common/cmd/download_pack.py b/st2common/st2common/cmd/download_pack.py index b22a3f0467..5ef0fb72b9 100644 --- a/st2common/st2common/cmd/download_pack.py +++ b/st2common/st2common/cmd/download_pack.py @@ -24,23 +24,34 @@ from st2common.util.pack_management import download_pack from st2common.util.pack_management import get_and_set_proxy_config -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to install (download).'), - cfg.BoolOpt('verify-ssl', default=True, - help=('Verify SSL certificate of the Git repo from which the pack is ' - 'installed.')), - cfg.BoolOpt('force', default=False, - help='True to force pack download and ignore download ' - 'lock file if it exists.'), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to install (download).", + ), + cfg.BoolOpt( + "verify-ssl", + default=True, + help=( + "Verify SSL certificate of the Git repo from which the pack is " + "installed." + ), + ), + cfg.BoolOpt( + "force", + default=False, + help="True to force pack download and ignore download " + "lock file if it exists.", + ), ] do_register_cli_opts(cli_opts) @@ -49,8 +60,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack verify_ssl = cfg.CONF.verify_ssl @@ -60,8 +75,13 @@ def main(argv): for pack in packs: LOG.info('Installing pack "%s"' % (pack)) - result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force, - proxy_config=proxy_config, force_permissions=True) + result = download_pack( + pack=pack, + verify_ssl=verify_ssl, + force=force, + proxy_config=proxy_config, + force_permissions=True, + ) # Raw pack name excluding the version pack_name = result[1] diff --git a/st2common/st2common/cmd/generate_api_spec.py b/st2common/st2common/cmd/generate_api_spec.py index 1b0a65ec8f..7ff7757b71 100644 --- a/st2common/st2common/cmd/generate_api_spec.py +++ b/st2common/st2common/cmd/generate_api_spec.py @@ -25,9 +25,7 @@ from st2common.script_setup import setup as common_setup from st2common.script_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -37,7 +35,7 @@ def setup(): def generate_spec(): - spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2') + spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2") print(spec_string) @@ -52,7 +50,7 @@ def main(): generate_spec() ret = 0 except Exception: - LOG.error('Failed to generate openapi.yaml file', exc_info=True) + LOG.error("Failed to generate openapi.yaml file", exc_info=True) ret = 1 finally: teartown() diff --git a/st2common/st2common/cmd/install_pack.py b/st2common/st2common/cmd/install_pack.py index 861d0d4041..42c2267012 100644 --- a/st2common/st2common/cmd/install_pack.py +++ b/st2common/st2common/cmd/install_pack.py @@ -25,23 +25,34 @@ from st2common.util.pack_management import get_and_set_proxy_config from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to install.'), - cfg.BoolOpt('verify-ssl', default=True, - help=('Verify SSL certificate of the Git repo from which the pack is ' - 'downloaded.')), - cfg.BoolOpt('force', default=False, - help='True to force pack installation and ignore install ' - 'lock file if it exists.'), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to install.", + ), + cfg.BoolOpt( + "verify-ssl", + default=True, + help=( + "Verify SSL certificate of the Git repo from which the pack is " + "downloaded." + ), + ), + cfg.BoolOpt( + "force", + default=False, + help="True to force pack installation and ignore install " + "lock file if it exists.", + ), ] do_register_cli_opts(cli_opts) @@ -50,8 +61,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack verify_ssl = cfg.CONF.verify_ssl @@ -62,8 +77,13 @@ def main(argv): for pack in packs: # 1. Download the pack LOG.info('Installing pack "%s"' % (pack)) - result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force, - proxy_config=proxy_config, force_permissions=True) + result = download_pack( + pack=pack, + verify_ssl=verify_ssl, + force=force, + proxy_config=proxy_config, + force_permissions=True, + ) # Raw pack name excluding the version pack_name = result[1] @@ -78,9 +98,13 @@ def main(argv): # 2. Setup pack virtual environment LOG.info('Setting up virtualenv for pack "%s"' % (pack_name)) - setup_pack_virtualenv(pack_name=pack_name, update=False, logger=LOG, - proxy_config=proxy_config, - no_download=True) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + logger=LOG, + proxy_config=proxy_config, + no_download=True, + ) LOG.info('Successfully set up virtualenv for pack "%s"' % (pack_name)) return 0 diff --git a/st2common/st2common/cmd/purge_executions.py b/st2common/st2common/cmd/purge_executions.py index dcf7b47b40..27225d661c 100755 --- a/st2common/st2common/cmd/purge_executions.py +++ b/st2common/st2common/cmd/purge_executions.py @@ -38,25 +38,30 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.garbage_collection.executions import purge_executions -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('timestamp', default=None, - help='Will delete execution and liveaction models older than ' + - 'this UTC timestamp. ' + - 'Example value: 2015-03-13T19:01:27.255542Z.'), - cfg.StrOpt('action-ref', default='', - help='action-ref to delete executions for.'), - cfg.BoolOpt('purge-incomplete', default=False, - help='Purge all models irrespective of their ``status``.' + - 'By default, only executions in completed states such as "succeeeded" ' + - ', "failed", "canceled" and "timed_out" are deleted.'), + cfg.StrOpt( + "timestamp", + default=None, + help="Will delete execution and liveaction models older than " + + "this UTC timestamp. " + + "Example value: 2015-03-13T19:01:27.255542Z.", + ), + cfg.StrOpt( + "action-ref", default="", help="action-ref to delete executions for." + ), + cfg.BoolOpt( + "purge-incomplete", + default=False, + help="Purge all models irrespective of their ``status``." + + 'By default, only executions in completed states such as "succeeeded" ' + + ', "failed", "canceled" and "timed_out" are deleted.', + ), ] do_register_cli_opts(cli_opts) @@ -71,15 +76,19 @@ def main(): purge_incomplete = cfg.CONF.purge_incomplete if not timestamp: - LOG.error('Please supply a timestamp for purging models. Aborting.') + LOG.error("Please supply a timestamp for purging models. Aborting.") return 1 else: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=pytz.UTC) try: - purge_executions(logger=LOG, timestamp=timestamp, action_ref=action_ref, - purge_incomplete=purge_incomplete) + purge_executions( + logger=LOG, + timestamp=timestamp, + action_ref=action_ref, + purge_incomplete=purge_incomplete, + ) except Exception as e: LOG.exception(six.text_type(e)) return FAILURE_EXIT_CODE diff --git a/st2common/st2common/cmd/purge_trigger_instances.py b/st2common/st2common/cmd/purge_trigger_instances.py index e0908e9f8d..529b786678 100755 --- a/st2common/st2common/cmd/purge_trigger_instances.py +++ b/st2common/st2common/cmd/purge_trigger_instances.py @@ -38,19 +38,20 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.garbage_collection.trigger_instances import purge_trigger_instances -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('timestamp', default=None, - help='Will delete trigger instances older than ' + - 'this UTC timestamp. ' + - 'Example value: 2015-03-13T19:01:27.255542Z') + cfg.StrOpt( + "timestamp", + default=None, + help="Will delete trigger instances older than " + + "this UTC timestamp. " + + "Example value: 2015-03-13T19:01:27.255542Z", + ) ] do_register_cli_opts(cli_opts) @@ -63,10 +64,10 @@ def main(): timestamp = cfg.CONF.timestamp if not timestamp: - LOG.error('Please supply a timestamp for purging models. Aborting.') + LOG.error("Please supply a timestamp for purging models. Aborting.") return 1 else: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=pytz.UTC) # Purge models. diff --git a/st2common/st2common/cmd/setup_pack_virtualenv.py b/st2common/st2common/cmd/setup_pack_virtualenv.py index 626bb389af..514b1cf2e0 100644 --- a/st2common/st2common/cmd/setup_pack_virtualenv.py +++ b/st2common/st2common/cmd/setup_pack_virtualenv.py @@ -22,23 +22,31 @@ from st2common.util.pack_management import get_and_set_proxy_config from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to setup the virtual environment for.'), - cfg.BoolOpt('update', default=False, - help=('Check this option if the virtual environment already exists and if you ' - 'only want to perform an update and installation of new dependencies. If ' - 'you don\'t check this option, the virtual environment will be destroyed ' - 'then re-created. If you check this and the virtual environment doesn\'t ' - 'exist, it will create it..')), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to setup the virtual environment for.", + ), + cfg.BoolOpt( + "update", + default=False, + help=( + "Check this option if the virtual environment already exists and if you " + "only want to perform an update and installation of new dependencies. If " + "you don't check this option, the virtual environment will be destroyed " + "then re-created. If you check this and the virtual environment doesn't " + "exist, it will create it.." + ), + ), ] do_register_cli_opts(cli_opts) @@ -47,8 +55,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack update = cfg.CONF.update @@ -58,9 +70,13 @@ def main(argv): for pack in packs: # Setup pack virtual environment LOG.info('Setting up virtualenv for pack "%s"' % (pack)) - setup_pack_virtualenv(pack_name=pack, update=update, logger=LOG, - proxy_config=proxy_config, - no_download=True) + setup_pack_virtualenv( + pack_name=pack, + update=update, + logger=LOG, + proxy_config=proxy_config, + no_download=True, + ) LOG.info('Successfully set up virtualenv for pack "%s"' % (pack)) return 0 diff --git a/st2common/st2common/cmd/validate_api_spec.py b/st2common/st2common/cmd/validate_api_spec.py index 743b3e467a..4f317db4a4 100644 --- a/st2common/st2common/cmd/validate_api_spec.py +++ b/st2common/st2common/cmd/validate_api_spec.py @@ -33,19 +33,20 @@ import six -__all__ = [ - 'main' -] +__all__ = ["main"] cfg.CONF.register_cli_opt( - cfg.StrOpt('spec-file', short='f', required=False, - default='st2common/st2common/openapi.yaml') + cfg.StrOpt( + "spec-file", + short="f", + required=False, + default="st2common/st2common/openapi.yaml", + ) ) cfg.CONF.register_cli_opt( - cfg.BoolOpt('generate', short='-c', required=False, - default=False) + cfg.BoolOpt("generate", short="-c", required=False, default=False) ) LOG = logging.getLogger(__name__) @@ -56,12 +57,12 @@ def setup(): def _validate_definitions(spec): - defs = spec.get('definitions', None) + defs = spec.get("definitions", None) error = False verbose = cfg.CONF.verbose for (model, definition) in six.iteritems(defs): - api_model = definition.get('x-api-model', None) + api_model = definition.get("x-api-model", None) if not api_model: msg = ( @@ -69,7 +70,7 @@ def _validate_definitions(spec): ) if verbose: - LOG.info('Supplied definition for model %s: \n\n%s.', model, definition) + LOG.info("Supplied definition for model %s: \n\n%s.", model, definition) error = True LOG.error(msg) @@ -82,18 +83,20 @@ def validate_spec(): generate_spec = cfg.CONF.generate if not os.path.exists(spec_file) and not generate_spec: - msg = ('No spec file found in location %s. ' % spec_file + - 'Provide a valid spec file or ' + - 'pass --generate-api-spec to genrate a spec.') + msg = ( + "No spec file found in location %s. " % spec_file + + "Provide a valid spec file or " + + "pass --generate-api-spec to genrate a spec." + ) raise Exception(msg) if generate_spec: if not spec_file: - raise Exception('Supply a path to write to spec file to.') + raise Exception("Supply a path to write to spec file to.") - spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2') + spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2") - with open(spec_file, 'w') as f: + with open(spec_file, "w") as f: f.write(spec_string) f.flush() @@ -112,13 +115,15 @@ def main(): try: # 1. Validate there are no duplicates keys in the YAML file - spec_loader.load_spec('st2common', 'openapi.yaml.j2', allow_duplicate_keys=False) + spec_loader.load_spec( + "st2common", "openapi.yaml.j2", allow_duplicate_keys=False + ) # 2. Validate schema (currently broken) # validate_spec() ret = 0 except Exception: - LOG.error('Failed to validate openapi.yaml file', exc_info=True) + LOG.error("Failed to validate openapi.yaml file", exc_info=True) ret = 1 finally: teartown() diff --git a/st2common/st2common/cmd/validate_config.py b/st2common/st2common/cmd/validate_config.py index 2bd5b58d0d..6b7bedd32f 100644 --- a/st2common/st2common/cmd/validate_config.py +++ b/st2common/st2common/cmd/validate_config.py @@ -31,9 +31,7 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.util.pack import validate_config_against_schema -__all__ = [ - 'main' -] +__all__ = ["main"] def _do_register_cli_opts(opts, ignore_errors=False): @@ -47,10 +45,18 @@ def _do_register_cli_opts(opts, ignore_errors=False): def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('schema-path', default=None, required=True, - help='Path to the config schema to use for validation.'), - cfg.StrOpt('config-path', default=None, required=True, - help='Path to the config file to validate.'), + cfg.StrOpt( + "schema-path", + default=None, + required=True, + help="Path to the config schema to use for validation.", + ), + cfg.StrOpt( + "config-path", + default=None, + required=True, + help="Path to the config file to validate.", + ), ] do_register_cli_opts(cli_opts) @@ -65,18 +71,24 @@ def main(): print('Validating config "%s" against schema in "%s"' % (config_path, schema_path)) - with open(schema_path, 'r') as fp: + with open(schema_path, "r") as fp: config_schema = yaml.safe_load(fp.read()) - with open(config_path, 'r') as fp: + with open(config_path, "r") as fp: config_object = yaml.safe_load(fp.read()) try: - validate_config_against_schema(config_schema=config_schema, config_object=config_object, - config_path=config_path) + validate_config_against_schema( + config_schema=config_schema, + config_object=config_object, + config_path=config_path, + ) except Exception as e: - print('Failed to validate pack config.\n%s' % six.text_type(e)) + print("Failed to validate pack config.\n%s" % six.text_type(e)) return FAILURE_EXIT_CODE - print('Config "%s" successfully validated against schema in %s.' % (config_path, schema_path)) + print( + 'Config "%s" successfully validated against schema in %s.' + % (config_path, schema_path) + ) return SUCCESS_EXIT_CODE diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py index 8ad77fa626..e7b30a9a7c 100644 --- a/st2common/st2common/config.py +++ b/st2common/st2common/config.py @@ -25,12 +25,7 @@ from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL from st2common.constants.action import LIVEACTION_COMPLETED_STATES -__all__ = [ - 'do_register_opts', - 'do_register_cli_opts', - - 'parse_args' -] +__all__ = ["do_register_opts", "do_register_cli_opts", "parse_args"] def do_register_opts(opts, group=None, ignore_errors=False): @@ -57,447 +52,550 @@ def do_register_cli_opts(opt, ignore_errors=False): def register_opts(ignore_errors=False): rbac_opts = [ + cfg.BoolOpt("enable", default=False, help="Enable RBAC."), + cfg.StrOpt("backend", default="noop", help="RBAC backend to use."), cfg.BoolOpt( - 'enable', default=False, - help='Enable RBAC.'), - cfg.StrOpt( - 'backend', default='noop', - help='RBAC backend to use.'), - cfg.BoolOpt( - 'sync_remote_groups', default=False, - help='True to synchronize remote groups returned by the auth backed for each ' - 'StackStorm user with local StackStorm roles based on the group to role ' - 'mapping definition files.'), + "sync_remote_groups", + default=False, + help="True to synchronize remote groups returned by the auth backed for each " + "StackStorm user with local StackStorm roles based on the group to role " + "mapping definition files.", + ), cfg.BoolOpt( - 'permission_isolation', default=False, - help='Isolate resources by user. For now, these resources only include rules and ' - 'executions. All resources can only be viewed or executed by the owning user ' - 'except the admin and system_user who can view or run everything.') + "permission_isolation", + default=False, + help="Isolate resources by user. For now, these resources only include rules and " + "executions. All resources can only be viewed or executed by the owning user " + "except the admin and system_user who can view or run everything.", + ), ] - do_register_opts(rbac_opts, 'rbac', ignore_errors) + do_register_opts(rbac_opts, "rbac", ignore_errors) system_user_opts = [ + cfg.StrOpt("user", default="stanley", help="Default system user."), cfg.StrOpt( - 'user', default='stanley', - help='Default system user.'), - cfg.StrOpt( - 'ssh_key_file', default='/home/stanley/.ssh/stanley_rsa', - help='SSH private key for the system user.') + "ssh_key_file", + default="/home/stanley/.ssh/stanley_rsa", + help="SSH private key for the system user.", + ), ] - do_register_opts(system_user_opts, 'system_user', ignore_errors) + do_register_opts(system_user_opts, "system_user", ignore_errors) schema_opts = [ - cfg.IntOpt( - 'version', default=4, - help='Version of JSON schema to use.'), + cfg.IntOpt("version", default=4, help="Version of JSON schema to use."), cfg.StrOpt( - 'draft', default='http://json-schema.org/draft-04/schema#', - help='URL to the JSON schema draft.') + "draft", + default="http://json-schema.org/draft-04/schema#", + help="URL to the JSON schema draft.", + ), ] - do_register_opts(schema_opts, 'schema', ignore_errors) + do_register_opts(schema_opts, "schema", ignore_errors) system_opts = [ - cfg.BoolOpt( - 'debug', default=False, - help='Enable debug mode.'), + cfg.BoolOpt("debug", default=False, help="Enable debug mode."), cfg.StrOpt( - 'base_path', default='/opt/stackstorm', - help='Base path to all st2 artifacts.'), + "base_path", + default="/opt/stackstorm", + help="Base path to all st2 artifacts.", + ), cfg.BoolOpt( - 'validate_trigger_parameters', default=True, - help='True to validate parameters for non-system trigger types when creating' - 'a rule. By default, only parameters for system triggers are validated.'), + "validate_trigger_parameters", + default=True, + help="True to validate parameters for non-system trigger types when creating" + "a rule. By default, only parameters for system triggers are validated.", + ), cfg.BoolOpt( - 'validate_trigger_payload', default=True, - help='True to validate payload for non-system trigger types when dispatching a trigger ' - 'inside the sensor. By default, only payload for system triggers is validated.'), + "validate_trigger_payload", + default=True, + help="True to validate payload for non-system trigger types when dispatching a trigger " + "inside the sensor. By default, only payload for system triggers is validated.", + ), cfg.BoolOpt( - 'validate_output_schema', default=False, - help='True to validate action and runner output against schema.') + "validate_output_schema", + default=False, + help="True to validate action and runner output against schema.", + ), ] - do_register_opts(system_opts, 'system', ignore_errors) + do_register_opts(system_opts, "system", ignore_errors) - system_packs_base_path = os.path.join(cfg.CONF.system.base_path, 'packs') - system_runners_base_path = os.path.join(cfg.CONF.system.base_path, 'runners') + system_packs_base_path = os.path.join(cfg.CONF.system.base_path, "packs") + system_runners_base_path = os.path.join(cfg.CONF.system.base_path, "runners") content_opts = [ cfg.StrOpt( - 'pack_group', default='st2packs', - help='User group that can write to packs directory.'), - cfg.StrOpt( - 'system_packs_base_path', default=system_packs_base_path, - help='Path to the directory which contains system packs.'), - cfg.StrOpt( - 'system_runners_base_path', default=system_runners_base_path, - help='Path to the directory which contains system runners. ' - 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'), - cfg.StrOpt( - 'packs_base_paths', default=None, - help='Paths which will be searched for integration packs.'), - cfg.StrOpt( - 'runners_base_paths', default=None, - help='Paths which will be searched for runners. ' - 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'), + "pack_group", + default="st2packs", + help="User group that can write to packs directory.", + ), + cfg.StrOpt( + "system_packs_base_path", + default=system_packs_base_path, + help="Path to the directory which contains system packs.", + ), + cfg.StrOpt( + "system_runners_base_path", + default=system_runners_base_path, + help="Path to the directory which contains system runners. " + "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0", + ), + cfg.StrOpt( + "packs_base_paths", + default=None, + help="Paths which will be searched for integration packs.", + ), + cfg.StrOpt( + "runners_base_paths", + default=None, + help="Paths which will be searched for runners. " + "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0", + ), cfg.ListOpt( - 'index_url', default=['https://index.stackstorm.org/v1/index.json'], - help='A URL pointing to the pack index. StackStorm Exchange is used by ' - 'default. Use a comma-separated list for multiple indexes if you ' - 'want to get other packs discovered with "st2 pack search".'), + "index_url", + default=["https://index.stackstorm.org/v1/index.json"], + help="A URL pointing to the pack index. StackStorm Exchange is used by " + "default. Use a comma-separated list for multiple indexes if you " + 'want to get other packs discovered with "st2 pack search".', + ), ] - do_register_opts(content_opts, 'content', ignore_errors) + do_register_opts(content_opts, "content", ignore_errors) webui_opts = [ cfg.StrOpt( - 'webui_base_url', default='https://%s' % socket.getfqdn(), - help='Base https URL to access st2 Web UI. This is used to construct history URLs ' - 'that are sent out when chatops is used to kick off executions.') + "webui_base_url", + default="https://%s" % socket.getfqdn(), + help="Base https URL to access st2 Web UI. This is used to construct history URLs " + "that are sent out when chatops is used to kick off executions.", + ) ] - do_register_opts(webui_opts, 'webui', ignore_errors) + do_register_opts(webui_opts, "webui", ignore_errors) db_opts = [ - cfg.StrOpt( - 'host', default='127.0.0.1', - help='host of db server'), + cfg.StrOpt("host", default="127.0.0.1", help="host of db server"), + cfg.IntOpt("port", default=27017, help="port of db server"), + cfg.StrOpt("db_name", default="st2", help="name of database"), + cfg.StrOpt("username", help="username for db login"), + cfg.StrOpt("password", help="password for db login"), cfg.IntOpt( - 'port', default=27017, - help='port of db server'), - cfg.StrOpt( - 'db_name', default='st2', - help='name of database'), - cfg.StrOpt( - 'username', - help='username for db login'), - cfg.StrOpt( - 'password', - help='password for db login'), + "connection_timeout", + default=3 * 1000, + help="Connection and server selection timeout (in ms).", + ), cfg.IntOpt( - 'connection_timeout', default=3 * 1000, - help='Connection and server selection timeout (in ms).'), + "connection_retry_max_delay_m", + default=3, + help="Connection retry total time (minutes).", + ), cfg.IntOpt( - 'connection_retry_max_delay_m', default=3, - help='Connection retry total time (minutes).'), + "connection_retry_backoff_max_s", + default=10, + help="Connection retry backoff max (seconds).", + ), cfg.IntOpt( - 'connection_retry_backoff_max_s', default=10, - help='Connection retry backoff max (seconds).'), - cfg.IntOpt( - 'connection_retry_backoff_mul', default=1, - help='Backoff multiplier (seconds).'), + "connection_retry_backoff_mul", + default=1, + help="Backoff multiplier (seconds).", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Create the connection to mongodb using SSL'), - cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against MongoDB.'), - cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the localconnection'), - cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided'), - cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from MongoDB.'), + "ssl", default=False, help="Create the connection to mongodb using SSL" + ), + cfg.StrOpt( + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against MongoDB.", + ), + cfg.StrOpt( + "ssl_certfile", + default=None, + help="Certificate file used to identify the localconnection", + ), + cfg.StrOpt( + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided", + ), + cfg.StrOpt( + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from MongoDB.", + ), cfg.BoolOpt( - 'ssl_match_hostname', default=True, - help='If True and `ssl_cert_reqs` is not None, enables hostname verification'), - cfg.StrOpt( - 'authentication_mechanism', default=None, - help='Specifies database authentication mechanisms. ' - 'By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, ' - 'MONGODB-CR (MongoDB Challenge Response protocol) for older servers.') + "ssl_match_hostname", + default=True, + help="If True and `ssl_cert_reqs` is not None, enables hostname verification", + ), + cfg.StrOpt( + "authentication_mechanism", + default=None, + help="Specifies database authentication mechanisms. " + "By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, " + "MONGODB-CR (MongoDB Challenge Response protocol) for older servers.", + ), ] - do_register_opts(db_opts, 'database', ignore_errors) + do_register_opts(db_opts, "database", ignore_errors) messaging_opts = [ # It would be nice to be able to deprecate url and completely switch to using # url. However, this will be a breaking change and will have impact so allowing both. cfg.StrOpt( - 'url', default='amqp://guest:guest@127.0.0.1:5672//', - help='URL of the messaging server.'), + "url", + default="amqp://guest:guest@127.0.0.1:5672//", + help="URL of the messaging server.", + ), cfg.ListOpt( - 'cluster_urls', default=[], - help='URL of all the nodes in a messaging service cluster.'), + "cluster_urls", + default=[], + help="URL of all the nodes in a messaging service cluster.", + ), cfg.IntOpt( - 'connection_retries', default=10, - help='How many times should we retry connection before failing.'), + "connection_retries", + default=10, + help="How many times should we retry connection before failing.", + ), cfg.IntOpt( - 'connection_retry_wait', default=10000, - help='How long should we wait between connection retries.'), + "connection_retry_wait", + default=10000, + help="How long should we wait between connection retries.", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Use SSL / TLS to connect to the messaging server. Same as ' - 'appending "?ssl=true" at the end of the connection URL string.'), - cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against RabbitMQ.'), - cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the local connection (client).'), - cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided.'), - cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from RabbitMQ.'), - cfg.StrOpt( - 'login_method', default=None, - help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).') + "ssl", + default=False, + help="Use SSL / TLS to connect to the messaging server. Same as " + 'appending "?ssl=true" at the end of the connection URL string.', + ), + cfg.StrOpt( + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against RabbitMQ.", + ), + cfg.StrOpt( + "ssl_certfile", + default=None, + help="Certificate file used to identify the local connection (client).", + ), + cfg.StrOpt( + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided.", + ), + cfg.StrOpt( + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from RabbitMQ.", + ), + cfg.StrOpt( + "login_method", + default=None, + help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).", + ), ] - do_register_opts(messaging_opts, 'messaging', ignore_errors) + do_register_opts(messaging_opts, "messaging", ignore_errors) syslog_opts = [ + cfg.StrOpt("host", default="127.0.0.1", help="Host for the syslog server."), + cfg.IntOpt("port", default=514, help="Port for the syslog server."), + cfg.StrOpt("facility", default="local7", help="Syslog facility level."), cfg.StrOpt( - 'host', default='127.0.0.1', - help='Host for the syslog server.'), - cfg.IntOpt( - 'port', default=514, - help='Port for the syslog server.'), - cfg.StrOpt( - 'facility', default='local7', - help='Syslog facility level.'), - cfg.StrOpt( - 'protocol', default='udp', - help='Transport protocol to use (udp / tcp).') + "protocol", default="udp", help="Transport protocol to use (udp / tcp)." + ), ] - do_register_opts(syslog_opts, 'syslog', ignore_errors) + do_register_opts(syslog_opts, "syslog", ignore_errors) log_opts = [ - cfg.ListOpt( - 'excludes', default='', - help='Exclusion list of loggers to omit.'), + cfg.ListOpt("excludes", default="", help="Exclusion list of loggers to omit."), cfg.BoolOpt( - 'redirect_stderr', default=False, - help='Controls if stderr should be redirected to the logs.'), + "redirect_stderr", + default=False, + help="Controls if stderr should be redirected to the logs.", + ), cfg.BoolOpt( - 'mask_secrets', default=True, - help='True to mask secrets in the log files.'), + "mask_secrets", default=True, help="True to mask secrets in the log files." + ), cfg.ListOpt( - 'mask_secrets_blacklist', default=[], - help='Blacklist of additional attribute names to mask in the log messages.') + "mask_secrets_blacklist", + default=[], + help="Blacklist of additional attribute names to mask in the log messages.", + ), ] - do_register_opts(log_opts, 'log', ignore_errors) + do_register_opts(log_opts, "log", ignore_errors) # Common API options api_opts = [ - cfg.StrOpt( - 'host', default='127.0.0.1', - help='StackStorm API server host'), - cfg.IntOpt( - 'port', default=9101, - help='StackStorm API server port'), + cfg.StrOpt("host", default="127.0.0.1", help="StackStorm API server host"), + cfg.IntOpt("port", default=9101, help="StackStorm API server port"), cfg.ListOpt( - 'allow_origin', default=['http://127.0.0.1:3000'], - help='List of origins allowed for api, auth and stream'), + "allow_origin", + default=["http://127.0.0.1:3000"], + help="List of origins allowed for api, auth and stream", + ), cfg.BoolOpt( - 'mask_secrets', default=True, - help='True to mask secrets in the API responses') + "mask_secrets", + default=True, + help="True to mask secrets in the API responses", + ), ] - do_register_opts(api_opts, 'api', ignore_errors) + do_register_opts(api_opts, "api", ignore_errors) # Key Value store options keyvalue_opts = [ cfg.BoolOpt( - 'enable_encryption', default=True, - help='Allow encryption of values in key value stored qualified as "secret".'), - cfg.StrOpt( - 'encryption_key_path', default='', - help='Location of the symmetric encryption key for encrypting values in kvstore. ' - 'This key should be in JSON and should\'ve been generated using ' - 'st2-generate-symmetric-crypto-key tool.') + "enable_encryption", + default=True, + help='Allow encryption of values in key value stored qualified as "secret".', + ), + cfg.StrOpt( + "encryption_key_path", + default="", + help="Location of the symmetric encryption key for encrypting values in kvstore. " + "This key should be in JSON and should've been generated using " + "st2-generate-symmetric-crypto-key tool.", + ), ] - do_register_opts(keyvalue_opts, group='keyvalue') + do_register_opts(keyvalue_opts, group="keyvalue") # Common auth options auth_opts = [ cfg.StrOpt( - 'api_url', default=None, - help='Base URL to the API endpoint excluding the version'), - cfg.BoolOpt( - 'enable', default=True, - help='Enable authentication middleware.'), + "api_url", + default=None, + help="Base URL to the API endpoint excluding the version", + ), + cfg.BoolOpt("enable", default=True, help="Enable authentication middleware."), cfg.IntOpt( - 'token_ttl', default=(24 * 60 * 60), - help='Access token ttl in seconds.'), + "token_ttl", default=(24 * 60 * 60), help="Access token ttl in seconds." + ), # This TTL is used for tokens which belong to StackStorm services cfg.IntOpt( - 'service_token_ttl', default=(24 * 60 * 60), - help='Service token ttl in seconds.') + "service_token_ttl", + default=(24 * 60 * 60), + help="Service token ttl in seconds.", + ), ] - do_register_opts(auth_opts, 'auth', ignore_errors) + do_register_opts(auth_opts, "auth", ignore_errors) # Runner options default_python_bin_path = sys.executable base_dir = os.path.dirname(os.path.realpath(default_python_bin_path)) - default_virtualenv_bin_path = os.path.join(base_dir, 'virtualenv') + default_virtualenv_bin_path = os.path.join(base_dir, "virtualenv") action_runner_opts = [ # Common runner options cfg.StrOpt( - 'logging', default='/etc/st2/logging.actionrunner.conf', - help='location of the logging.conf file'), - + "logging", + default="/etc/st2/logging.actionrunner.conf", + help="location of the logging.conf file", + ), # Python runner options cfg.StrOpt( - 'python_binary', default=default_python_bin_path, - help='Python binary which will be used by Python actions.'), - cfg.StrOpt( - 'virtualenv_binary', default=default_virtualenv_bin_path, - help='Virtualenv binary which should be used to create pack virtualenvs.'), - cfg.StrOpt( - 'python_runner_log_level', default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, - help='Default log level to use for Python runner actions. Can be overriden on ' - 'invocation basis using "log_level" runner parameter.'), + "python_binary", + default=default_python_bin_path, + help="Python binary which will be used by Python actions.", + ), + cfg.StrOpt( + "virtualenv_binary", + default=default_virtualenv_bin_path, + help="Virtualenv binary which should be used to create pack virtualenvs.", + ), + cfg.StrOpt( + "python_runner_log_level", + default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + help="Default log level to use for Python runner actions. Can be overriden on " + 'invocation basis using "log_level" runner parameter.', + ), cfg.ListOpt( - 'virtualenv_opts', default=['--system-site-packages'], + "virtualenv_opts", + default=["--system-site-packages"], help='List of virtualenv options to be passsed to "virtualenv" command that ' - 'creates pack virtualenv.'), + "creates pack virtualenv.", + ), cfg.ListOpt( - 'pip_opts', default=[], + "pip_opts", + default=[], help='List of pip options to be passed to "pip install" command when installing pack ' - 'dependencies into pack virtual environment.'), + "dependencies into pack virtual environment.", + ), cfg.BoolOpt( - 'stream_output', default=True, - help='True to store and stream action output (stdout and stderr) in real-time.'), + "stream_output", + default=True, + help="True to store and stream action output (stdout and stderr) in real-time.", + ), cfg.IntOpt( - 'stream_output_buffer_size', default=-1, - help=('Buffer size to use for real time action output streaming. 0 means unbuffered ' - '1 means line buffered, -1 means system default, which usually means fully ' - 'buffered and any other positive value means use a buffer of (approximately) ' - 'that size')) + "stream_output_buffer_size", + default=-1, + help=( + "Buffer size to use for real time action output streaming. 0 means unbuffered " + "1 means line buffered, -1 means system default, which usually means fully " + "buffered and any other positive value means use a buffer of (approximately) " + "that size" + ), + ), ] - do_register_opts(action_runner_opts, group='actionrunner') + do_register_opts(action_runner_opts, group="actionrunner") dispatcher_pool_opts = [ cfg.IntOpt( - 'workflows_pool_size', default=40, - help='Internal pool size for dispatcher used by workflow actions.'), + "workflows_pool_size", + default=40, + help="Internal pool size for dispatcher used by workflow actions.", + ), cfg.IntOpt( - 'actions_pool_size', default=60, - help='Internal pool size for dispatcher used by regular actions.') + "actions_pool_size", + default=60, + help="Internal pool size for dispatcher used by regular actions.", + ), ] - do_register_opts(dispatcher_pool_opts, group='actionrunner') + do_register_opts(dispatcher_pool_opts, group="actionrunner") ssh_runner_opts = [ cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.IntOpt( - 'max_parallel_actions', default=50, - help='Max number of parallel remote SSH actions that should be run. ' - 'Works only with Paramiko SSH runner.'), + "max_parallel_actions", + default=50, + help="Max number of parallel remote SSH actions that should be run. " + "Works only with Paramiko SSH runner.", + ), cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.'), - cfg.StrOpt( - 'ssh_config_file_path', default='~/.ssh/config', - help='Path to the ssh config file.'), + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), + cfg.StrOpt( + "ssh_config_file_path", + default="~/.ssh/config", + help="Path to the ssh config file.", + ), cfg.IntOpt( - 'ssh_connect_timeout', default=60, - help='Max time in seconds to establish the SSH connection.') + "ssh_connect_timeout", + default=60, + help="Max time in seconds to establish the SSH connection.", + ), ] - do_register_opts(ssh_runner_opts, group='ssh_runner') + do_register_opts(ssh_runner_opts, group="ssh_runner") # Common options (used by action runner and sensor container) action_sensor_opts = [ cfg.BoolOpt( - 'enable', default=True, - help='Whether to enable or disable the ability to post a trigger on action.'), + "enable", + default=True, + help="Whether to enable or disable the ability to post a trigger on action.", + ), cfg.ListOpt( - 'emit_when', default=LIVEACTION_COMPLETED_STATES, - help='List of execution statuses for which a trigger will be emitted. ') + "emit_when", + default=LIVEACTION_COMPLETED_STATES, + help="List of execution statuses for which a trigger will be emitted. ", + ), ] - do_register_opts(action_sensor_opts, group='action_sensor') + do_register_opts(action_sensor_opts, group="action_sensor") # Common options for content pack_lib_opts = [ cfg.BoolOpt( - 'enable_common_libs', default=False, - help='Enable/Disable support for pack common libs. ' - 'Setting this config to ``True`` would allow you to ' - 'place common library code for sensors and actions in lib/ folder ' - 'in packs and use them in python sensors and actions. ' - 'See https://docs.stackstorm.com/reference/' - 'sharing_code_sensors_actions.html ' - 'for details.') + "enable_common_libs", + default=False, + help="Enable/Disable support for pack common libs. " + "Setting this config to ``True`` would allow you to " + "place common library code for sensors and actions in lib/ folder " + "in packs and use them in python sensors and actions. " + "See https://docs.stackstorm.com/reference/" + "sharing_code_sensors_actions.html " + "for details.", + ) ] - do_register_opts(pack_lib_opts, group='packs') + do_register_opts(pack_lib_opts, group="packs") # Coordination options coord_opts = [ - cfg.StrOpt( - 'url', default=None, - help='Endpoint for the coordination server.'), + cfg.StrOpt("url", default=None, help="Endpoint for the coordination server."), cfg.IntOpt( - 'lock_timeout', default=60, - help='TTL for the lock if backend suports it.'), + "lock_timeout", default=60, help="TTL for the lock if backend suports it." + ), cfg.BoolOpt( - 'service_registry', default=False, - help='True to register StackStorm services in a service registry.'), + "service_registry", + default=False, + help="True to register StackStorm services in a service registry.", + ), ] - do_register_opts(coord_opts, 'coordination', ignore_errors) + do_register_opts(coord_opts, "coordination", ignore_errors) # XXX: This is required for us to support deprecated config group results_tracker query_opts = [ cfg.IntOpt( - 'thread_pool_size', - help='Number of threads to use to query external workflow systems.'), + "thread_pool_size", + help="Number of threads to use to query external workflow systems.", + ), cfg.FloatOpt( - 'query_interval', - help='Time interval between subsequent queries for a context ' - 'to external workflow system.') + "query_interval", + help="Time interval between subsequent queries for a context " + "to external workflow system.", + ), ] - do_register_opts(query_opts, group='results_tracker', ignore_errors=ignore_errors) + do_register_opts(query_opts, group="results_tracker", ignore_errors=ignore_errors) # Common stream options stream_opts = [ cfg.IntOpt( - 'heartbeat', default=25, - help='Send empty message every N seconds to keep connection open') + "heartbeat", + default=25, + help="Send empty message every N seconds to keep connection open", + ) ] - do_register_opts(stream_opts, group='stream', ignore_errors=ignore_errors) + do_register_opts(stream_opts, group="stream", ignore_errors=ignore_errors) # Common CLI options cli_opts = [ cfg.BoolOpt( - 'debug', default=False, - help='Enable debug mode. By default this will set all log levels to DEBUG.'), + "debug", + default=False, + help="Enable debug mode. By default this will set all log levels to DEBUG.", + ), cfg.BoolOpt( - 'profile', default=False, - help='Enable profile mode. In the profile mode all the MongoDB queries and ' - 'related profile data are logged.'), + "profile", + default=False, + help="Enable profile mode. In the profile mode all the MongoDB queries and " + "related profile data are logged.", + ), cfg.BoolOpt( - 'use-debugger', default=True, - help='Enables debugger. Note that using this option changes how the ' - 'eventlet library is used to support async IO. This could result in ' - 'failures that do not occur under normal operation.') + "use-debugger", + default=True, + help="Enables debugger. Note that using this option changes how the " + "eventlet library is used to support async IO. This could result in " + "failures that do not occur under normal operation.", + ), ] do_register_cli_opts(cli_opts, ignore_errors=ignore_errors) @@ -505,92 +603,121 @@ def register_opts(ignore_errors=False): # Metrics Options stream options metrics_opts = [ cfg.StrOpt( - 'driver', default='noop', - help='Driver type for metrics collection.'), + "driver", default="noop", help="Driver type for metrics collection." + ), cfg.StrOpt( - 'host', default='127.0.0.1', - help='Destination server to connect to if driver requires connection.'), + "host", + default="127.0.0.1", + help="Destination server to connect to if driver requires connection.", + ), cfg.IntOpt( - 'port', default=8125, - help='Destination port to connect to if driver requires connection.'), - cfg.StrOpt( - 'prefix', default=None, - help='Optional prefix which is prepended to all the metric names. Comes handy when ' - 'you want to submit metrics from various environment to the same metric ' - 'backend instance.'), + "port", + default=8125, + help="Destination port to connect to if driver requires connection.", + ), + cfg.StrOpt( + "prefix", + default=None, + help="Optional prefix which is prepended to all the metric names. Comes handy when " + "you want to submit metrics from various environment to the same metric " + "backend instance.", + ), cfg.FloatOpt( - 'sample_rate', default=1, - help='Randomly sample and only send metrics for X% of metric operations to the ' - 'backend. Default value of 1 means no sampling is done and all the metrics are ' - 'sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.') - + "sample_rate", + default=1, + help="Randomly sample and only send metrics for X% of metric operations to the " + "backend. Default value of 1 means no sampling is done and all the metrics are " + "sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.", + ), ] - do_register_opts(metrics_opts, group='metrics', ignore_errors=ignore_errors) + do_register_opts(metrics_opts, group="metrics", ignore_errors=ignore_errors) # Common timers engine options timer_logging_opts = [ cfg.StrOpt( - 'logging', default=None, - help='Location of the logging configuration file. ' - 'NOTE: Deprecated in favor of timersengine.logging'), + "logging", + default=None, + help="Location of the logging configuration file. " + "NOTE: Deprecated in favor of timersengine.logging", + ), ] timers_engine_logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.timersengine.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.timersengine.conf", + help="Location of the logging configuration file.", + ) ] - do_register_opts(timer_logging_opts, group='timer', ignore_errors=ignore_errors) - do_register_opts(timers_engine_logging_opts, group='timersengine', ignore_errors=ignore_errors) + do_register_opts(timer_logging_opts, group="timer", ignore_errors=ignore_errors) + do_register_opts( + timers_engine_logging_opts, group="timersengine", ignore_errors=ignore_errors + ) # NOTE: We default old style deprecated "timer" options to None so our code # works correclty and "timersengine" has precedence over "timers" # NOTE: "timer" section will be removed in v3.1 timer_opts = [ cfg.StrOpt( - 'local_timezone', default=None, - help='Timezone pertaining to the location where st2 is run. ' - 'NOTE: Deprecated in favor of timersengine.local_timezone'), + "local_timezone", + default=None, + help="Timezone pertaining to the location where st2 is run. " + "NOTE: Deprecated in favor of timersengine.local_timezone", + ), cfg.BoolOpt( - 'enable', default=None, - help='Specify to enable timer service. ' - 'NOTE: Deprecated in favor of timersengine.enable'), + "enable", + default=None, + help="Specify to enable timer service. " + "NOTE: Deprecated in favor of timersengine.enable", + ), ] timers_engine_opts = [ cfg.StrOpt( - 'local_timezone', default='America/Los_Angeles', - help='Timezone pertaining to the location where st2 is run.'), - cfg.BoolOpt( - 'enable', default=True, - help='Specify to enable timer service.') + "local_timezone", + default="America/Los_Angeles", + help="Timezone pertaining to the location where st2 is run.", + ), + cfg.BoolOpt("enable", default=True, help="Specify to enable timer service."), ] - do_register_opts(timer_opts, group='timer', ignore_errors=ignore_errors) - do_register_opts(timers_engine_opts, group='timersengine', ignore_errors=ignore_errors) + do_register_opts(timer_opts, group="timer", ignore_errors=ignore_errors) + do_register_opts( + timers_engine_opts, group="timersengine", ignore_errors=ignore_errors + ) # Workflow engine options workflow_engine_opts = [ cfg.IntOpt( - 'retry_stop_max_msec', default=60000, - help='Max time to stop retrying.'), + "retry_stop_max_msec", default=60000, help="Max time to stop retrying." + ), cfg.IntOpt( - 'retry_wait_fixed_msec', default=1000, - help='Interval inbetween retries.'), + "retry_wait_fixed_msec", default=1000, help="Interval inbetween retries." + ), cfg.FloatOpt( - 'retry_max_jitter_msec', default=1000, - help='Max jitter interval to smooth out retries.'), + "retry_max_jitter_msec", + default=1000, + help="Max jitter interval to smooth out retries.", + ), cfg.IntOpt( - 'gc_max_idle_sec', default=0, - help='Max seconds to allow workflow execution be idled before it is identified as ' - 'orphaned and cancelled by the garbage collector. A value of zero means the ' - 'feature is disabled. This is disabled by default.') + "gc_max_idle_sec", + default=0, + help="Max seconds to allow workflow execution be idled before it is identified as " + "orphaned and cancelled by the garbage collector. A value of zero means the " + "feature is disabled. This is disabled by default.", + ), ] - do_register_opts(workflow_engine_opts, group='workflow_engine', ignore_errors=ignore_errors) + do_register_opts( + workflow_engine_opts, group="workflow_engine", ignore_errors=ignore_errors + ) def parse_args(args=None): register_opts() - cfg.CONF(args=args, version=VERSION_STRING, default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) diff --git a/st2common/st2common/constants/action.py b/st2common/st2common/constants/action.py index c28725f225..5587b0be91 100644 --- a/st2common/st2common/constants/action.py +++ b/st2common/st2common/constants/action.py @@ -14,61 +14,56 @@ # limitations under the License. __all__ = [ - 'ACTION_NAME', - 'ACTION_ID', - - 'LIBS_DIR', - - 'LIVEACTION_STATUS_REQUESTED', - 'LIVEACTION_STATUS_SCHEDULED', - 'LIVEACTION_STATUS_DELAYED', - 'LIVEACTION_STATUS_RUNNING', - 'LIVEACTION_STATUS_SUCCEEDED', - 'LIVEACTION_STATUS_FAILED', - 'LIVEACTION_STATUS_TIMED_OUT', - 'LIVEACTION_STATUS_CANCELING', - 'LIVEACTION_STATUS_CANCELED', - 'LIVEACTION_STATUS_PENDING', - 'LIVEACTION_STATUS_PAUSING', - 'LIVEACTION_STATUS_PAUSED', - 'LIVEACTION_STATUS_RESUMING', - - 'LIVEACTION_STATUSES', - 'LIVEACTION_RUNNABLE_STATES', - 'LIVEACTION_DELAYED_STATES', - 'LIVEACTION_CANCELABLE_STATES', - 'LIVEACTION_FAILED_STATES', - 'LIVEACTION_COMPLETED_STATES', - - 'ACTION_OUTPUT_RESULT_DELIMITER', - 'ACTION_CONTEXT_KV_PREFIX', - 'ACTION_PARAMETERS_KV_PREFIX', - 'ACTION_RESULTS_KV_PREFIX', - - 'WORKFLOW_RUNNER_TYPES' + "ACTION_NAME", + "ACTION_ID", + "LIBS_DIR", + "LIVEACTION_STATUS_REQUESTED", + "LIVEACTION_STATUS_SCHEDULED", + "LIVEACTION_STATUS_DELAYED", + "LIVEACTION_STATUS_RUNNING", + "LIVEACTION_STATUS_SUCCEEDED", + "LIVEACTION_STATUS_FAILED", + "LIVEACTION_STATUS_TIMED_OUT", + "LIVEACTION_STATUS_CANCELING", + "LIVEACTION_STATUS_CANCELED", + "LIVEACTION_STATUS_PENDING", + "LIVEACTION_STATUS_PAUSING", + "LIVEACTION_STATUS_PAUSED", + "LIVEACTION_STATUS_RESUMING", + "LIVEACTION_STATUSES", + "LIVEACTION_RUNNABLE_STATES", + "LIVEACTION_DELAYED_STATES", + "LIVEACTION_CANCELABLE_STATES", + "LIVEACTION_FAILED_STATES", + "LIVEACTION_COMPLETED_STATES", + "ACTION_OUTPUT_RESULT_DELIMITER", + "ACTION_CONTEXT_KV_PREFIX", + "ACTION_PARAMETERS_KV_PREFIX", + "ACTION_RESULTS_KV_PREFIX", + "WORKFLOW_RUNNER_TYPES", ] -ACTION_NAME = 'name' -ACTION_ID = 'id' -ACTION_PACK = 'pack' +ACTION_NAME = "name" +ACTION_ID = "id" +ACTION_PACK = "pack" -LIBS_DIR = 'lib' +LIBS_DIR = "lib" -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' -LIVEACTION_STATUS_PENDING = 'pending' -LIVEACTION_STATUS_PAUSING = 'pausing' -LIVEACTION_STATUS_PAUSED = 'paused' -LIVEACTION_STATUS_RESUMING = 'resuming' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" +LIVEACTION_STATUS_PENDING = "pending" +LIVEACTION_STATUS_PAUSING = "pausing" +LIVEACTION_STATUS_PAUSED = "paused" +LIVEACTION_STATUS_RESUMING = "resuming" LIVEACTION_STATUSES = [ LIVEACTION_STATUS_REQUESTED, @@ -84,25 +79,23 @@ LIVEACTION_STATUS_PENDING, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] -ACTION_OUTPUT_RESULT_DELIMITER = '%%%%%~=~=~=************=~=~=~%%%%' -ACTION_CONTEXT_KV_PREFIX = 'action_context' -ACTION_PARAMETERS_KV_PREFIX = 'action_parameters' -ACTION_RESULTS_KV_PREFIX = 'action_results' +ACTION_OUTPUT_RESULT_DELIMITER = "%%%%%~=~=~=************=~=~=~%%%%" +ACTION_CONTEXT_KV_PREFIX = "action_context" +ACTION_PARAMETERS_KV_PREFIX = "action_parameters" +ACTION_RESULTS_KV_PREFIX = "action_results" LIVEACTION_RUNNABLE_STATES = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] -LIVEACTION_DELAYED_STATES = [ - LIVEACTION_STATUS_DELAYED -] +LIVEACTION_DELAYED_STATES = [LIVEACTION_STATUS_DELAYED] LIVEACTION_CANCELABLE_STATES = [ LIVEACTION_STATUS_REQUESTED, @@ -111,7 +104,7 @@ LIVEACTION_STATUS_RUNNING, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] LIVEACTION_COMPLETED_STATES = [ @@ -119,29 +112,20 @@ LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] LIVEACTION_FAILED_STATES = [ LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] -LIVEACTION_PAUSE_STATES = [ - LIVEACTION_STATUS_PAUSING, - LIVEACTION_STATUS_PAUSED -] +LIVEACTION_PAUSE_STATES = [LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED] -LIVEACTION_CANCEL_STATES = [ - LIVEACTION_STATUS_CANCELING, - LIVEACTION_STATUS_CANCELED -] +LIVEACTION_CANCEL_STATES = [LIVEACTION_STATUS_CANCELING, LIVEACTION_STATUS_CANCELED] -WORKFLOW_RUNNER_TYPES = [ - 'action-chain', - 'orquesta' -] +WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"] # Linux's limit for param size _LINUX_PARAM_LIMIT = 131072 diff --git a/st2common/st2common/constants/api.py b/st2common/st2common/constants/api.py index c1df81fb0d..2690133314 100644 --- a/st2common/st2common/constants/api.py +++ b/st2common/st2common/constants/api.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'DEFAULT_API_VERSION' -] +__all__ = ["DEFAULT_API_VERSION"] -DEFAULT_API_VERSION = 'v1' +DEFAULT_API_VERSION = "v1" -REQUEST_ID_HEADER = 'X-Request-ID' +REQUEST_ID_HEADER = "X-Request-ID" diff --git a/st2common/st2common/constants/auth.py b/st2common/st2common/constants/auth.py index f0664739ce..7b4003c0ef 100644 --- a/st2common/st2common/constants/auth.py +++ b/st2common/st2common/constants/auth.py @@ -14,26 +14,22 @@ # limitations under the License. __all__ = [ - 'VALID_MODES', - 'DEFAULT_MODE', - 'DEFAULT_BACKEND', - - 'HEADER_ATTRIBUTE_NAME', - 'QUERY_PARAM_ATTRIBUTE_NAME' + "VALID_MODES", + "DEFAULT_MODE", + "DEFAULT_BACKEND", + "HEADER_ATTRIBUTE_NAME", + "QUERY_PARAM_ATTRIBUTE_NAME", ] -VALID_MODES = [ - 'proxy', - 'standalone' -] +VALID_MODES = ["proxy", "standalone"] -HEADER_ATTRIBUTE_NAME = 'X-Auth-Token' -QUERY_PARAM_ATTRIBUTE_NAME = 'x-auth-token' +HEADER_ATTRIBUTE_NAME = "X-Auth-Token" +QUERY_PARAM_ATTRIBUTE_NAME = "x-auth-token" -HEADER_API_KEY_ATTRIBUTE_NAME = 'St2-Api-Key' -QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = 'st2-api-key' +HEADER_API_KEY_ATTRIBUTE_NAME = "St2-Api-Key" +QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = "st2-api-key" -DEFAULT_MODE = 'standalone' +DEFAULT_MODE = "standalone" -DEFAULT_BACKEND = 'flat_file' -DEFAULT_SSO_BACKEND = 'noop' +DEFAULT_BACKEND = "flat_file" +DEFAULT_SSO_BACKEND = "noop" diff --git a/st2common/st2common/constants/error_messages.py b/st2common/st2common/constants/error_messages.py index 7aa56c4025..7c70377721 100644 --- a/st2common/st2common/constants/error_messages.py +++ b/st2common/st2common/constants/error_messages.py @@ -13,21 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'PACK_VIRTUALENV_DOESNT_EXIST', - 'PYTHON2_DEPRECATION' -] +__all__ = ["PACK_VIRTUALENV_DOESNT_EXIST", "PYTHON2_DEPRECATION"] -PACK_VIRTUALENV_DOESNT_EXIST = ''' +PACK_VIRTUALENV_DOESNT_EXIST = """ The virtual environment (%(virtualenv_path)s) for pack "%(pack)s" does not exist. Normally this is created when you install a pack using "st2 pack install". If you installed your pack by some other means, you can create a new virtual environment using the command: "st2 run packs.setup_virtualenv packs=%(pack)s" -''' +""" -PYTHON2_DEPRECATION = 'DEPRECATION WARNING. Support for python 2 will be removed in future ' \ - 'StackStorm releases. Please ensure that all packs used are python ' \ - '3 compatible. Your StackStorm installation may be upgraded from ' \ - 'python 2 to python 3 in future platform releases. It is recommended ' \ - 'to plan the manual migration to a python 3 native platform, e.g. ' \ - 'Ubuntu 18.04 LTS or CentOS/RHEL 8.' +PYTHON2_DEPRECATION = ( + "DEPRECATION WARNING. Support for python 2 will be removed in future " + "StackStorm releases. Please ensure that all packs used are python " + "3 compatible. Your StackStorm installation may be upgraded from " + "python 2 to python 3 in future platform releases. It is recommended " + "to plan the manual migration to a python 3 native platform, e.g. " + "Ubuntu 18.04 LTS or CentOS/RHEL 8." +) diff --git a/st2common/st2common/constants/exit_codes.py b/st2common/st2common/constants/exit_codes.py index 8fd1efd9a7..1b32e89e26 100644 --- a/st2common/st2common/constants/exit_codes.py +++ b/st2common/st2common/constants/exit_codes.py @@ -14,10 +14,10 @@ # limitations under the License. __all__ = [ - 'SUCCESS_EXIT_CODE', - 'FAILURE_EXIT_CODE', - 'SIGKILL_EXIT_CODE', - 'SIGTERM_EXIT_CODE' + "SUCCESS_EXIT_CODE", + "FAILURE_EXIT_CODE", + "SIGKILL_EXIT_CODE", + "SIGTERM_EXIT_CODE", ] SUCCESS_EXIT_CODE = 0 diff --git a/st2common/st2common/constants/garbage_collection.py b/st2common/st2common/constants/garbage_collection.py index dad3121896..ac8a2aac5f 100644 --- a/st2common/st2common/constants/garbage_collection.py +++ b/st2common/st2common/constants/garbage_collection.py @@ -14,10 +14,10 @@ # limitations under the License. __all__ = [ - 'DEFAULT_COLLECTION_INTERVAL', - 'DEFAULT_SLEEP_DELAY', - 'MINIMUM_TTL_DAYS', - 'MINIMUM_TTL_DAYS_EXECUTION_OUTPUT' + "DEFAULT_COLLECTION_INTERVAL", + "DEFAULT_SLEEP_DELAY", + "MINIMUM_TTL_DAYS", + "MINIMUM_TTL_DAYS_EXECUTION_OUTPUT", ] diff --git a/st2common/st2common/constants/keyvalue.py b/st2common/st2common/constants/keyvalue.py index 2897f1e32d..7a21eab8ec 100644 --- a/st2common/st2common/constants/keyvalue.py +++ b/st2common/st2common/constants/keyvalue.py @@ -14,46 +14,49 @@ # limitations under the License. __all__ = [ - 'ALLOWED_SCOPES', - 'SYSTEM_SCOPE', - 'FULL_SYSTEM_SCOPE', - 'SYSTEM_SCOPES', - 'USER_SCOPE', - 'FULL_USER_SCOPE', - 'USER_SCOPES', - 'USER_SEPARATOR', - - 'DATASTORE_SCOPE_SEPARATOR', - 'DATASTORE_KEY_SEPARATOR' + "ALLOWED_SCOPES", + "SYSTEM_SCOPE", + "FULL_SYSTEM_SCOPE", + "SYSTEM_SCOPES", + "USER_SCOPE", + "FULL_USER_SCOPE", + "USER_SCOPES", + "USER_SEPARATOR", + "DATASTORE_SCOPE_SEPARATOR", + "DATASTORE_KEY_SEPARATOR", ] -ALL_SCOPE = 'all' +ALL_SCOPE = "all" # Parent namespace for all items in key-value store -DATASTORE_PARENT_SCOPE = 'st2kv' -DATASTORE_SCOPE_SEPARATOR = '.' # To separate scope from datastore namespace. E.g. st2kv.system +DATASTORE_PARENT_SCOPE = "st2kv" +DATASTORE_SCOPE_SEPARATOR = ( + "." # To separate scope from datastore namespace. E.g. st2kv.system +) # Namespace to contain all system/global scoped variables in key-value store. -SYSTEM_SCOPE = 'system' -FULL_SYSTEM_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, SYSTEM_SCOPE) +SYSTEM_SCOPE = "system" +FULL_SYSTEM_SCOPE = "%s%s%s" % ( + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, + SYSTEM_SCOPE, +) SYSTEM_SCOPES = [SYSTEM_SCOPE] # Namespace to contain all user scoped variables in key-value store. -USER_SCOPE = 'user' -FULL_USER_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, USER_SCOPE) +USER_SCOPE = "user" +FULL_USER_SCOPE = "%s%s%s" % ( + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, + USER_SCOPE, +) USER_SCOPES = [USER_SCOPE] -USER_SEPARATOR = ':' +USER_SEPARATOR = ":" # Separator for keys in the datastore -DATASTORE_KEY_SEPARATOR = ':' - -ALLOWED_SCOPES = [ - SYSTEM_SCOPE, - USER_SCOPE, +DATASTORE_KEY_SEPARATOR = ":" - FULL_SYSTEM_SCOPE, - FULL_USER_SCOPE -] +ALLOWED_SCOPES = [SYSTEM_SCOPE, USER_SCOPE, FULL_SYSTEM_SCOPE, FULL_USER_SCOPE] diff --git a/st2common/st2common/constants/logging.py b/st2common/st2common/constants/logging.py index b62a59bd00..0985a03947 100644 --- a/st2common/st2common/constants/logging.py +++ b/st2common/st2common/constants/logging.py @@ -16,11 +16,9 @@ from __future__ import absolute_import import os -__all__ = [ - 'DEFAULT_LOGGING_CONF_PATH' -] +__all__ = ["DEFAULT_LOGGING_CONF_PATH"] BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, '../conf/base.logging.conf') +DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, "../conf/base.logging.conf") DEFAULT_LOGGING_CONF_PATH = os.path.abspath(DEFAULT_LOGGING_CONF_PATH) diff --git a/st2common/st2common/constants/meta.py b/st2common/st2common/constants/meta.py index ac4859b5e1..acd348a355 100644 --- a/st2common/st2common/constants/meta.py +++ b/st2common/st2common/constants/meta.py @@ -16,10 +16,7 @@ from __future__ import absolute_import import yaml -__all__ = [ - 'ALLOWED_EXTS', - 'PARSER_FUNCS' -] +__all__ = ["ALLOWED_EXTS", "PARSER_FUNCS"] -ALLOWED_EXTS = ['.yaml', '.yml'] -PARSER_FUNCS = {'.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".yaml", ".yml"] +PARSER_FUNCS = {".yml": yaml.safe_load, ".yaml": yaml.safe_load} diff --git a/st2common/st2common/constants/pack.py b/st2common/st2common/constants/pack.py index 91ae5a5e2c..f782a6920c 100644 --- a/st2common/st2common/constants/pack.py +++ b/st2common/st2common/constants/pack.py @@ -14,81 +14,74 @@ # limitations under the License. __all__ = [ - 'PACKS_PACK_NAME', - 'PACK_REF_WHITELIST_REGEX', - 'PACK_RESERVED_CHARACTERS', - 'PACK_VERSION_SEPARATOR', - 'PACK_VERSION_REGEX', - 'ST2_VERSION_REGEX', - 'SYSTEM_PACK_NAME', - 'PACKS_PACK_NAME', - 'LINUX_PACK_NAME', - 'SYSTEM_PACK_NAMES', - 'CHATOPS_PACK_NAME', - 'USER_PACK_NAME_BLACKLIST', - 'BASE_PACK_REQUIREMENTS', - 'MANIFEST_FILE_NAME', - 'CONFIG_SCHEMA_FILE_NAME' + "PACKS_PACK_NAME", + "PACK_REF_WHITELIST_REGEX", + "PACK_RESERVED_CHARACTERS", + "PACK_VERSION_SEPARATOR", + "PACK_VERSION_REGEX", + "ST2_VERSION_REGEX", + "SYSTEM_PACK_NAME", + "PACKS_PACK_NAME", + "LINUX_PACK_NAME", + "SYSTEM_PACK_NAMES", + "CHATOPS_PACK_NAME", + "USER_PACK_NAME_BLACKLIST", + "BASE_PACK_REQUIREMENTS", + "MANIFEST_FILE_NAME", + "CONFIG_SCHEMA_FILE_NAME", ] # Prefix for render context w/ config -PACK_CONFIG_CONTEXT_KV_PREFIX = 'config_context' +PACK_CONFIG_CONTEXT_KV_PREFIX = "config_context" # A list of allowed characters for the pack name -PACK_REF_WHITELIST_REGEX = r'^[a-z0-9_]+$' +PACK_REF_WHITELIST_REGEX = r"^[a-z0-9_]+$" # Check for a valid semver string -PACK_VERSION_REGEX = r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$' # noqa +PACK_VERSION_REGEX = r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$" # noqa # Special characters which can't be used in pack names -PACK_RESERVED_CHARACTERS = [ - '.' -] +PACK_RESERVED_CHARACTERS = ["."] # Version sperator when version is supplied in pack name # Example: libcloud@1.0.1 -PACK_VERSION_SEPARATOR = '=' +PACK_VERSION_SEPARATOR = "=" # Check for st2 version in engines -ST2_VERSION_REGEX = r'^((>?>|>=|=|<=|?>|>=|=|<=|=1.9.0,<2.0' -] +BASE_PACK_REQUIREMENTS = ["six>=1.9.0,<2.0"] # Name of the pack manifest file -MANIFEST_FILE_NAME = 'pack.yaml' +MANIFEST_FILE_NAME = "pack.yaml" # File name for the config schema file -CONFIG_SCHEMA_FILE_NAME = 'config.schema.yaml' +CONFIG_SCHEMA_FILE_NAME = "config.schema.yaml" diff --git a/st2common/st2common/constants/policy.py b/st2common/st2common/constants/policy.py index e36ce8fc12..7ce7093ed5 100644 --- a/st2common/st2common/constants/policy.py +++ b/st2common/st2common/constants/policy.py @@ -13,13 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'POLICY_TYPES_REQUIRING_LOCK' -] +__all__ = ["POLICY_TYPES_REQUIRING_LOCK"] # Concurrency policies require scheduler to acquire a distributed lock to prevent race # in scheduling when there are multiple scheduler instances. -POLICY_TYPES_REQUIRING_LOCK = [ - 'action.concurrency', - 'action.concurrency.attr' -] +POLICY_TYPES_REQUIRING_LOCK = ["action.concurrency", "action.concurrency.attr"] diff --git a/st2common/st2common/constants/rule_enforcement.py b/st2common/st2common/constants/rule_enforcement.py index fced450304..ceece2d6e1 100644 --- a/st2common/st2common/constants/rule_enforcement.py +++ b/st2common/st2common/constants/rule_enforcement.py @@ -14,16 +14,15 @@ # limitations under the License. __all__ = [ - 'RULE_ENFORCEMENT_STATUS_SUCCEEDED', - 'RULE_ENFORCEMENT_STATUS_FAILED', - - 'RULE_ENFORCEMENT_STATUSES' + "RULE_ENFORCEMENT_STATUS_SUCCEEDED", + "RULE_ENFORCEMENT_STATUS_FAILED", + "RULE_ENFORCEMENT_STATUSES", ] -RULE_ENFORCEMENT_STATUS_SUCCEEDED = 'succeeded' -RULE_ENFORCEMENT_STATUS_FAILED = 'failed' +RULE_ENFORCEMENT_STATUS_SUCCEEDED = "succeeded" +RULE_ENFORCEMENT_STATUS_FAILED = "failed" RULE_ENFORCEMENT_STATUSES = [ RULE_ENFORCEMENT_STATUS_SUCCEEDED, - RULE_ENFORCEMENT_STATUS_FAILED + RULE_ENFORCEMENT_STATUS_FAILED, ] diff --git a/st2common/st2common/constants/rules.py b/st2common/st2common/constants/rules.py index 393e94aebb..929e4b5e92 100644 --- a/st2common/st2common/constants/rules.py +++ b/st2common/st2common/constants/rules.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -TRIGGER_PAYLOAD_PREFIX = 'trigger' -TRIGGER_ITEM_PAYLOAD_PREFIX = 'item' +TRIGGER_PAYLOAD_PREFIX = "trigger" +TRIGGER_ITEM_PAYLOAD_PREFIX = "item" -RULE_TYPE_STANDARD = 'standard' -RULE_TYPE_BACKSTOP = 'backstop' +RULE_TYPE_STANDARD = "standard" +RULE_TYPE_BACKSTOP = "backstop" -MATCH_CRITERIA = r'({{)\s*(.*)\s*(}})' +MATCH_CRITERIA = r"({{)\s*(.*)\s*(}})" diff --git a/st2common/st2common/constants/runners.py b/st2common/st2common/constants/runners.py index fe78a6497f..52ec738384 100644 --- a/st2common/st2common/constants/runners.py +++ b/st2common/st2common/constants/runners.py @@ -17,36 +17,28 @@ from oslo_config import cfg __all__ = [ - 'RUNNER_NAME_WHITELIST', - - 'MANIFEST_FILE_NAME', - - 'LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT', - - 'REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT', - 'REMOTE_RUNNER_DEFAULT_REMOTE_DIR', - 'REMOTE_RUNNER_PRIVATE_KEY_HEADER', - - 'PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT', - 'PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE', - - 'WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT', - - 'COMMON_ACTION_ENV_VARIABLE_PREFIX', - 'COMMON_ACTION_ENV_VARIABLES', - - 'DEFAULT_SSH_PORT', - - 'RUNNERS_NAMESPACE' + "RUNNER_NAME_WHITELIST", + "MANIFEST_FILE_NAME", + "LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT", + "REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT", + "REMOTE_RUNNER_DEFAULT_REMOTE_DIR", + "REMOTE_RUNNER_PRIVATE_KEY_HEADER", + "PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT", + "PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE", + "WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT", + "COMMON_ACTION_ENV_VARIABLE_PREFIX", + "COMMON_ACTION_ENV_VARIABLES", + "DEFAULT_SSH_PORT", + "RUNNERS_NAMESPACE", ] DEFAULT_SSH_PORT = 22 # A list of allowed characters for the pack name -RUNNER_NAME_WHITELIST = r'^[A-Za-z0-9_-]+' +RUNNER_NAME_WHITELIST = r"^[A-Za-z0-9_-]+" # Manifest file name for runners -MANIFEST_FILE_NAME = 'runner.yaml' +MANIFEST_FILE_NAME = "runner.yaml" # Local runner LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT = 60 @@ -57,9 +49,9 @@ try: REMOTE_RUNNER_DEFAULT_REMOTE_DIR = cfg.CONF.ssh_runner.remote_dir except: - REMOTE_RUNNER_DEFAULT_REMOTE_DIR = '/tmp' + REMOTE_RUNNER_DEFAULT_REMOTE_DIR = "/tmp" -REMOTE_RUNNER_PRIVATE_KEY_HEADER = 'PRIVATE KEY-----'.lower() +REMOTE_RUNNER_PRIVATE_KEY_HEADER = "PRIVATE KEY-----".lower() # Python runner # Default timeout (in seconds) for actions executed by Python runner @@ -69,20 +61,20 @@ # action returns invalid status from the run() method PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE = 220 -PYTHON_RUNNER_DEFAULT_LOG_LEVEL = 'DEBUG' +PYTHON_RUNNER_DEFAULT_LOG_LEVEL = "DEBUG" # Windows runner WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT = 10 * 60 # Prefix for common st2 environment variables which are available to the actions -COMMON_ACTION_ENV_VARIABLE_PREFIX = 'ST2_ACTION_' +COMMON_ACTION_ENV_VARIABLE_PREFIX = "ST2_ACTION_" # Common st2 environment variables which are available to the actions COMMON_ACTION_ENV_VARIABLES = [ - 'ST2_ACTION_PACK_NAME', - 'ST2_ACTION_EXECUTION_ID', - 'ST2_ACTION_API_URL', - 'ST2_ACTION_AUTH_TOKEN' + "ST2_ACTION_PACK_NAME", + "ST2_ACTION_EXECUTION_ID", + "ST2_ACTION_API_URL", + "ST2_ACTION_AUTH_TOKEN", ] # Namespaces for dynamically loaded runner modules diff --git a/st2common/st2common/constants/scheduler.py b/st2common/st2common/constants/scheduler.py index d825d2aed0..fb97971a3c 100644 --- a/st2common/st2common/constants/scheduler.py +++ b/st2common/st2common/constants/scheduler.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'SCHEDULER_ENABLED_LOG_LINE', - 'SCHEDULER_DISABLED_LOG_LINE' -] +__all__ = ["SCHEDULER_ENABLED_LOG_LINE", "SCHEDULER_DISABLED_LOG_LINE"] # Integration tests look for these loglines to validate scheduler enable/disable -SCHEDULER_ENABLED_LOG_LINE = 'Scheduler is enabled.' -SCHEDULER_DISABLED_LOG_LINE = 'Scheduler is disabled.' +SCHEDULER_ENABLED_LOG_LINE = "Scheduler is enabled." +SCHEDULER_DISABLED_LOG_LINE = "Scheduler is disabled." diff --git a/st2common/st2common/constants/secrets.py b/st2common/st2common/constants/secrets.py index d3f9e53b9e..ef9a02d5ee 100644 --- a/st2common/st2common/constants/secrets.py +++ b/st2common/st2common/constants/secrets.py @@ -13,22 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'MASKED_ATTRIBUTES_BLACKLIST', - 'MASKED_ATTRIBUTE_VALUE' -] +__all__ = ["MASKED_ATTRIBUTES_BLACKLIST", "MASKED_ATTRIBUTE_VALUE"] # A blacklist of attributes which should be masked in the log messages by default. # Note: If an attribute is an object or a dict, we try to recursively process it and mask the # values. MASKED_ATTRIBUTES_BLACKLIST = [ - 'password', - 'auth_token', - 'token', - 'secret', - 'credentials', - 'st2_auth_token' + "password", + "auth_token", + "token", + "secret", + "credentials", + "st2_auth_token", ] # Value with which the masked attribute values are replaced -MASKED_ATTRIBUTE_VALUE = '********' +MASKED_ATTRIBUTE_VALUE = "********" diff --git a/st2common/st2common/constants/sensors.py b/st2common/st2common/constants/sensors.py index 3ba4f9487d..a2d7903d18 100644 --- a/st2common/st2common/constants/sensors.py +++ b/st2common/st2common/constants/sensors.py @@ -17,7 +17,7 @@ MINIMUM_POLL_INTERVAL = 4 # keys for PARTITION loaders -DEFAULT_PARTITION_LOADER = 'default' -KVSTORE_PARTITION_LOADER = 'kvstore' -FILE_PARTITION_LOADER = 'file' -HASH_PARTITION_LOADER = 'hash' +DEFAULT_PARTITION_LOADER = "default" +KVSTORE_PARTITION_LOADER = "kvstore" +FILE_PARTITION_LOADER = "file" +HASH_PARTITION_LOADER = "hash" diff --git a/st2common/st2common/constants/system.py b/st2common/st2common/constants/system.py index dcb8ee699c..9736527171 100644 --- a/st2common/st2common/constants/system.py +++ b/st2common/st2common/constants/system.py @@ -20,15 +20,14 @@ from st2common import __version__ __all__ = [ - 'VERSION_STRING', - 'DEFAULT_CONFIG_FILE_PATH', - - 'API_URL_ENV_VARIABLE_NAME', - 'AUTH_TOKEN_ENV_VARIABLE_NAME', + "VERSION_STRING", + "DEFAULT_CONFIG_FILE_PATH", + "API_URL_ENV_VARIABLE_NAME", + "AUTH_TOKEN_ENV_VARIABLE_NAME", ] -VERSION_STRING = 'StackStorm v%s' % (__version__) -DEFAULT_CONFIG_FILE_PATH = os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf') +VERSION_STRING = "StackStorm v%s" % (__version__) +DEFAULT_CONFIG_FILE_PATH = os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf") -API_URL_ENV_VARIABLE_NAME = 'ST2_API_URL' -AUTH_TOKEN_ENV_VARIABLE_NAME = 'ST2_AUTH_TOKEN' +API_URL_ENV_VARIABLE_NAME = "ST2_API_URL" +AUTH_TOKEN_ENV_VARIABLE_NAME = "ST2_AUTH_TOKEN" diff --git a/st2common/st2common/constants/timer.py b/st2common/st2common/constants/timer.py index 0f191a8027..9772743792 100644 --- a/st2common/st2common/constants/timer.py +++ b/st2common/st2common/constants/timer.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'TIMER_ENABLED_LOG_LINE', - 'TIMER_DISABLED_LOG_LINE' -] +__all__ = ["TIMER_ENABLED_LOG_LINE", "TIMER_DISABLED_LOG_LINE"] # Integration tests look for these loglines to validate timer enable/disable -TIMER_ENABLED_LOG_LINE = 'Timer is enabled.' -TIMER_DISABLED_LOG_LINE = 'Timer is disabled.' +TIMER_ENABLED_LOG_LINE = "Timer is enabled." +TIMER_DISABLED_LOG_LINE = "Timer is disabled." diff --git a/st2common/st2common/constants/trace.py b/st2common/st2common/constants/trace.py index d900912c60..f7e4242da1 100644 --- a/st2common/st2common/constants/trace.py +++ b/st2common/st2common/constants/trace.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['TRACE_CONTEXT', 'TRACE_ID'] +__all__ = ["TRACE_CONTEXT", "TRACE_ID"] -TRACE_CONTEXT = 'trace_context' -TRACE_ID = 'trace_tag' +TRACE_CONTEXT = "trace_context" +TRACE_ID = "trace_tag" diff --git a/st2common/st2common/constants/triggers.py b/st2common/st2common/constants/triggers.py index 4a0ccc8e4e..14ab861fd5 100644 --- a/st2common/st2common/constants/triggers.py +++ b/st2common/st2common/constants/triggers.py @@ -18,244 +18,200 @@ from st2common.models.system.common import ResourceReference __all__ = [ - 'WEBHOOKS_PARAMETERS_SCHEMA', - 'WEBHOOKS_PAYLOAD_SCHEMA', - 'INTERVAL_PARAMETERS_SCHEMA', - 'DATE_PARAMETERS_SCHEMA', - 'CRON_PARAMETERS_SCHEMA', - 'TIMER_PAYLOAD_SCHEMA', - - 'ACTION_SENSOR_TRIGGER', - 'NOTIFY_TRIGGER', - 'ACTION_FILE_WRITTEN_TRIGGER', - 'INQUIRY_TRIGGER', - - 'TIMER_TRIGGER_TYPES', - 'WEBHOOK_TRIGGER_TYPES', - 'WEBHOOK_TRIGGER_TYPE', - 'INTERNAL_TRIGGER_TYPES', - 'SYSTEM_TRIGGER_TYPES', - - 'INTERVAL_TIMER_TRIGGER_REF', - 'DATE_TIMER_TRIGGER_REF', - 'CRON_TIMER_TRIGGER_REF', - - 'TRIGGER_INSTANCE_STATUSES', - 'TRIGGER_INSTANCE_PENDING', - 'TRIGGER_INSTANCE_PROCESSING', - 'TRIGGER_INSTANCE_PROCESSED', - 'TRIGGER_INSTANCE_PROCESSING_FAILED' + "WEBHOOKS_PARAMETERS_SCHEMA", + "WEBHOOKS_PAYLOAD_SCHEMA", + "INTERVAL_PARAMETERS_SCHEMA", + "DATE_PARAMETERS_SCHEMA", + "CRON_PARAMETERS_SCHEMA", + "TIMER_PAYLOAD_SCHEMA", + "ACTION_SENSOR_TRIGGER", + "NOTIFY_TRIGGER", + "ACTION_FILE_WRITTEN_TRIGGER", + "INQUIRY_TRIGGER", + "TIMER_TRIGGER_TYPES", + "WEBHOOK_TRIGGER_TYPES", + "WEBHOOK_TRIGGER_TYPE", + "INTERNAL_TRIGGER_TYPES", + "SYSTEM_TRIGGER_TYPES", + "INTERVAL_TIMER_TRIGGER_REF", + "DATE_TIMER_TRIGGER_REF", + "CRON_TIMER_TRIGGER_REF", + "TRIGGER_INSTANCE_STATUSES", + "TRIGGER_INSTANCE_PENDING", + "TRIGGER_INSTANCE_PROCESSING", + "TRIGGER_INSTANCE_PROCESSED", + "TRIGGER_INSTANCE_PROCESSING_FAILED", ] # Action resource triggers ACTION_SENSOR_TRIGGER = { - 'name': 'st2.generic.actiontrigger', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating the completion of an action execution.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'execution_id': {}, - 'status': {}, - 'start_timestamp': {}, - 'action_name': {}, - 'action_ref': {}, - 'runner_ref': {}, - 'parameters': {}, - 'result': {} - } - } + "name": "st2.generic.actiontrigger", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating the completion of an action execution.", + "payload_schema": { + "type": "object", + "properties": { + "execution_id": {}, + "status": {}, + "start_timestamp": {}, + "action_name": {}, + "action_ref": {}, + "runner_ref": {}, + "parameters": {}, + "result": {}, + }, + }, } ACTION_FILE_WRITTEN_TRIGGER = { - 'name': 'st2.action.file_written', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating action file being written on disk.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'ref': {}, - 'file_path': {}, - 'host_info': {} - } - } + "name": "st2.action.file_written", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating action file being written on disk.", + "payload_schema": { + "type": "object", + "properties": {"ref": {}, "file_path": {}, "host_info": {}}, + }, } NOTIFY_TRIGGER = { - 'name': 'st2.generic.notifytrigger', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Notification trigger.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'execution_id': {}, - 'status': {}, - 'start_timestamp': {}, - 'end_timestamp': {}, - 'action_ref': {}, - 'runner_ref': {}, - 'channel': {}, - 'route': {}, - 'message': {}, - 'data': {} - } - } + "name": "st2.generic.notifytrigger", + "pack": SYSTEM_PACK_NAME, + "description": "Notification trigger.", + "payload_schema": { + "type": "object", + "properties": { + "execution_id": {}, + "status": {}, + "start_timestamp": {}, + "end_timestamp": {}, + "action_ref": {}, + "runner_ref": {}, + "channel": {}, + "route": {}, + "message": {}, + "data": {}, + }, + }, } INQUIRY_TRIGGER = { - 'name': 'st2.generic.inquiry', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating a new "inquiry" has entered "pending" status', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'description': 'ID of the new inquiry.', - 'required': True + "name": "st2.generic.inquiry", + "pack": SYSTEM_PACK_NAME, + "description": 'Trigger indicating a new "inquiry" has entered "pending" status', + "payload_schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "ID of the new inquiry.", + "required": True, + }, + "route": { + "type": "string", + "description": "An arbitrary value for allowing rules " + "to route to proper notification channel.", + "required": True, }, - 'route': { - 'type': 'string', - 'description': 'An arbitrary value for allowing rules ' - 'to route to proper notification channel.', - 'required': True - } }, - "additionalProperties": False - } + "additionalProperties": False, + }, } # Sensor spawn/exit triggers. SENSOR_SPAWN_TRIGGER = { - 'name': 'st2.sensor.process_spawn', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating sensor process is started up.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.sensor.process_spawn", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger indicating sensor process is started up.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } SENSOR_EXIT_TRIGGER = { - 'name': 'st2.sensor.process_exit', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating sensor process is stopped.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.sensor.process_exit", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger indicating sensor process is stopped.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } # KeyValuePair resource triggers KEY_VALUE_PAIR_CREATE_TRIGGER = { - 'name': 'st2.key_value_pair.create', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore item creation.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.create", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore item creation.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } KEY_VALUE_PAIR_UPDATE_TRIGGER = { - 'name': 'st2.key_value_pair.update', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore set action.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.update", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore set action.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER = { - 'name': 'st2.key_value_pair.value_change', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating a change of datastore item value.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'old_object': {}, - 'new_object': {} - } - } + "name": "st2.key_value_pair.value_change", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating a change of datastore item value.", + "payload_schema": { + "type": "object", + "properties": {"old_object": {}, "new_object": {}}, + }, } KEY_VALUE_PAIR_DELETE_TRIGGER = { - 'name': 'st2.key_value_pair.delete', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore item deletion.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.delete", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore item deletion.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } # Internal system triggers which are available for each resource INTERNAL_TRIGGER_TYPES = { - 'action': [ + "action": [ ACTION_SENSOR_TRIGGER, NOTIFY_TRIGGER, ACTION_FILE_WRITTEN_TRIGGER, - INQUIRY_TRIGGER - ], - 'sensor': [ - SENSOR_SPAWN_TRIGGER, - SENSOR_EXIT_TRIGGER + INQUIRY_TRIGGER, ], - 'key_value_pair': [ + "sensor": [SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER], + "key_value_pair": [ KEY_VALUE_PAIR_CREATE_TRIGGER, KEY_VALUE_PAIR_UPDATE_TRIGGER, KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER, - KEY_VALUE_PAIR_DELETE_TRIGGER - ] + KEY_VALUE_PAIR_DELETE_TRIGGER, + ], } WEBHOOKS_PARAMETERS_SCHEMA = { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'required': True - } - }, - 'additionalProperties': False + "type": "object", + "properties": {"url": {"type": "string", "required": True}}, + "additionalProperties": False, } WEBHOOKS_PAYLOAD_SCHEMA = { - 'type': 'object', - 'properties': { - 'headers': { - 'type': 'object' - }, - 'body': { - 'anyOf': [ - {'type': 'array'}, - {'type': 'object'}, + "type": "object", + "properties": { + "headers": {"type": "object"}, + "body": { + "anyOf": [ + {"type": "array"}, + {"type": "object"}, ] - } - } + }, + }, } WEBHOOK_TRIGGER_TYPES = { - ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.webhook'): { - 'name': 'st2.webhook', - 'pack': SYSTEM_PACK_NAME, - 'description': ('Trigger type for registering webhooks that can consume' - ' arbitrary payload.'), - 'parameters_schema': WEBHOOKS_PARAMETERS_SCHEMA, - 'payload_schema': WEBHOOKS_PAYLOAD_SCHEMA + ResourceReference.to_string_reference(SYSTEM_PACK_NAME, "st2.webhook"): { + "name": "st2.webhook", + "pack": SYSTEM_PACK_NAME, + "description": ( + "Trigger type for registering webhooks that can consume" + " arbitrary payload." + ), + "parameters_schema": WEBHOOKS_PARAMETERS_SCHEMA, + "payload_schema": WEBHOOKS_PAYLOAD_SCHEMA, } } WEBHOOK_TRIGGER_TYPE = list(WEBHOOK_TRIGGER_TYPES.keys())[0] @@ -265,107 +221,69 @@ INTERVAL_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, + "timezone": {"type": "string"}, "unit": { "enum": ["weeks", "days", "hours", "minutes", "seconds"], - "required": True + "required": True, }, - "delta": { - "type": "integer", - "required": True - - } + "delta": {"type": "integer", "required": True}, }, - "additionalProperties": False + "additionalProperties": False, } DATE_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, - "date": { - "type": "string", - "format": "date-time", - "required": True - } + "timezone": {"type": "string"}, + "date": {"type": "string", "format": "date-time", "required": True}, }, - "additionalProperties": False + "additionalProperties": False, } CRON_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, + "timezone": {"type": "string"}, "year": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], }, "month": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 12 + "maximum": 12, }, "day": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 31 + "maximum": 31, }, "week": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 53 + "maximum": 53, }, "day_of_week": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 6 + "maximum": 6, }, "hour": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 23 + "maximum": 23, }, "minute": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 59 + "maximum": 59, }, "second": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 59 - } + "maximum": 59, + }, }, - "additionalProperties": False + "additionalProperties": False, } TIMER_PAYLOAD_SCHEMA = { @@ -374,61 +292,62 @@ "executed_at": { "type": "string", "format": "date-time", - "default": "2014-07-30 05:04:24.578325" + "default": "2014-07-30 05:04:24.578325", }, - "schedule": { - "type": "object", - "default": { - "delta": 30, - "units": "seconds" - } - } - } + "schedule": {"type": "object", "default": {"delta": 30, "units": "seconds"}}, + }, } -INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, - 'st2.IntervalTimer') -DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.DateTimer') -CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.CronTimer') +INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.IntervalTimer" +) +DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.DateTimer" +) +CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.CronTimer" +) TIMER_TRIGGER_TYPES = { INTERVAL_TIMER_TRIGGER_REF: { - 'name': 'st2.IntervalTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers on specified intervals. e.g. every 30s, 1week etc.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': INTERVAL_PARAMETERS_SCHEMA + "name": "st2.IntervalTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers on specified intervals. e.g. every 30s, 1week etc.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": INTERVAL_PARAMETERS_SCHEMA, }, DATE_TIMER_TRIGGER_REF: { - 'name': 'st2.DateTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers exactly once when the current time matches the specified time. ' - 'e.g. timezone:UTC date:2014-12-31 23:59:59.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': DATE_PARAMETERS_SCHEMA + "name": "st2.DateTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers exactly once when the current time matches the specified time. " + "e.g. timezone:UTC date:2014-12-31 23:59:59.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": DATE_PARAMETERS_SCHEMA, }, CRON_TIMER_TRIGGER_REF: { - 'name': 'st2.CronTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers whenever current time matches the specified time constaints like ' - 'a UNIX cron scheduler.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': CRON_PARAMETERS_SCHEMA - } + "name": "st2.CronTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers whenever current time matches the specified time constaints like " + "a UNIX cron scheduler.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": CRON_PARAMETERS_SCHEMA, + }, } -SYSTEM_TRIGGER_TYPES = dict(list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items())) +SYSTEM_TRIGGER_TYPES = dict( + list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items()) +) # various status to record lifecycle of a TriggerInstance -TRIGGER_INSTANCE_PENDING = 'pending' -TRIGGER_INSTANCE_PROCESSING = 'processing' -TRIGGER_INSTANCE_PROCESSED = 'processed' -TRIGGER_INSTANCE_PROCESSING_FAILED = 'processing_failed' +TRIGGER_INSTANCE_PENDING = "pending" +TRIGGER_INSTANCE_PROCESSING = "processing" +TRIGGER_INSTANCE_PROCESSED = "processed" +TRIGGER_INSTANCE_PROCESSING_FAILED = "processing_failed" TRIGGER_INSTANCE_STATUSES = [ TRIGGER_INSTANCE_PENDING, TRIGGER_INSTANCE_PROCESSING, TRIGGER_INSTANCE_PROCESSED, - TRIGGER_INSTANCE_PROCESSING_FAILED + TRIGGER_INSTANCE_PROCESSING_FAILED, ] diff --git a/st2common/st2common/constants/types.py b/st2common/st2common/constants/types.py index 7873d5b665..01ec79605f 100644 --- a/st2common/st2common/constants/types.py +++ b/st2common/st2common/constants/types.py @@ -16,9 +16,7 @@ from __future__ import absolute_import from st2common.util.enum import Enum -__all__ = [ - 'ResourceType' -] +__all__ = ["ResourceType"] class ResourceType(Enum): @@ -27,37 +25,37 @@ class ResourceType(Enum): """ # System resources - RUNNER_TYPE = 'runner_type' + RUNNER_TYPE = "runner_type" # Pack resources - PACK = 'pack' - ACTION = 'action' - ACTION_ALIAS = 'action_alias' - SENSOR_TYPE = 'sensor_type' - TRIGGER_TYPE = 'trigger_type' - TRIGGER = 'trigger' - TRIGGER_INSTANCE = 'trigger_instance' - RULE = 'rule' - RULE_ENFORCEMENT = 'rule_enforcement' + PACK = "pack" + ACTION = "action" + ACTION_ALIAS = "action_alias" + SENSOR_TYPE = "sensor_type" + TRIGGER_TYPE = "trigger_type" + TRIGGER = "trigger" + TRIGGER_INSTANCE = "trigger_instance" + RULE = "rule" + RULE_ENFORCEMENT = "rule_enforcement" # Note: Policy type is a global resource and policy belong to a pack - POLICY_TYPE = 'policy_type' - POLICY = 'policy' + POLICY_TYPE = "policy_type" + POLICY = "policy" # Other resources - EXECUTION = 'execution' - EXECUTION_REQUEST = 'execution_request' - KEY_VALUE_PAIR = 'key_value_pair' + EXECUTION = "execution" + EXECUTION_REQUEST = "execution_request" + KEY_VALUE_PAIR = "key_value_pair" - WEBHOOK = 'webhook' - TIMER = 'timer' - API_KEY = 'api_key' - TRACE = 'trace' - TIMER = 'timer' + WEBHOOK = "webhook" + TIMER = "timer" + API_KEY = "api_key" + TRACE = "trace" + TIMER = "timer" # Special resource type for stream related stuff - STREAM = 'stream' + STREAM = "stream" - INQUIRY = 'inquiry' + INQUIRY = "inquiry" - UNKNOWN = 'unknown' + UNKNOWN = "unknown" diff --git a/st2common/st2common/content/bootstrap.py b/st2common/st2common/content/bootstrap.py index 1072d35053..d38296bd5a 100644 --- a/st2common/st2common/content/bootstrap.py +++ b/st2common/st2common/content/bootstrap.py @@ -38,46 +38,55 @@ from st2common.metrics.base import Timer from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] -LOG = logging.getLogger('st2common.content.bootstrap') +LOG = logging.getLogger("st2common.content.bootstrap") -cfg.CONF.register_cli_opt(cfg.BoolOpt('experimental', default=False)) +cfg.CONF.register_cli_opt(cfg.BoolOpt("experimental", default=False)) def register_opts(): content_opts = [ - cfg.BoolOpt('all', default=False, help='Register sensors, actions and rules.'), - cfg.BoolOpt('triggers', default=False, help='Register triggers.'), - cfg.BoolOpt('sensors', default=False, help='Register sensors.'), - cfg.BoolOpt('actions', default=False, help='Register actions.'), - cfg.BoolOpt('runners', default=False, help='Register runners.'), - cfg.BoolOpt('rules', default=False, help='Register rules.'), - cfg.BoolOpt('aliases', default=False, help='Register aliases.'), - cfg.BoolOpt('policies', default=False, help='Register policies.'), - cfg.BoolOpt('configs', default=False, help='Register and load pack configs.'), - - cfg.StrOpt('pack', default=None, help='Directory to the pack to register content from.'), - cfg.StrOpt('runner-dir', default=None, help='Directory to load runners from.'), - cfg.BoolOpt('setup-virtualenvs', default=False, help=('Setup Python virtual environments ' - 'all the Python runner actions.')), - + cfg.BoolOpt("all", default=False, help="Register sensors, actions and rules."), + cfg.BoolOpt("triggers", default=False, help="Register triggers."), + cfg.BoolOpt("sensors", default=False, help="Register sensors."), + cfg.BoolOpt("actions", default=False, help="Register actions."), + cfg.BoolOpt("runners", default=False, help="Register runners."), + cfg.BoolOpt("rules", default=False, help="Register rules."), + cfg.BoolOpt("aliases", default=False, help="Register aliases."), + cfg.BoolOpt("policies", default=False, help="Register policies."), + cfg.BoolOpt("configs", default=False, help="Register and load pack configs."), + cfg.StrOpt( + "pack", default=None, help="Directory to the pack to register content from." + ), + cfg.StrOpt("runner-dir", default=None, help="Directory to load runners from."), + cfg.BoolOpt( + "setup-virtualenvs", + default=False, + help=( + "Setup Python virtual environments " "all the Python runner actions." + ), + ), # General options # Note: This value should default to False since we want fail on failure behavior by # default. - cfg.BoolOpt('no-fail-on-failure', default=False, - help=('Don\'t exit with non-zero if some resource registration fails.')), + cfg.BoolOpt( + "no-fail-on-failure", + default=False, + help=("Don't exit with non-zero if some resource registration fails."), + ), # Note: Fail on failure is now a default behavior. This flag is only left here for backward # compatibility reasons, but it's not actually used. - cfg.BoolOpt('fail-on-failure', default=True, - help=('Exit with non-zero if some resource registration fails.')) + cfg.BoolOpt( + "fail-on-failure", + default=True, + help=("Exit with non-zero if some resource registration fails."), + ), ] try: - cfg.CONF.register_cli_opts(content_opts, group='register') + cfg.CONF.register_cli_opts(content_opts, group="register") except: - sys.stderr.write('Failed registering opts.\n') + sys.stderr.write("Failed registering opts.\n") register_opts() @@ -88,9 +97,9 @@ def setup_virtualenvs(): Setup Python virtual environments for all the registered or the provided pack. """ - LOG.info('=========================================================') - LOG.info('########### Setting up virtual environments #############') - LOG.info('=========================================================') + LOG.info("=========================================================") + LOG.info("########### Setting up virtual environments #############") + LOG.info("=========================================================") pack_dir = cfg.CONF.register.pack fail_on_failure = not cfg.CONF.register.no_fail_on_failure @@ -116,15 +125,19 @@ def setup_virtualenvs(): setup_pack_virtualenv(pack_name=pack_name, update=True, logger=LOG) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to setup virtualenv for pack "%s": %s', pack_name, e, - exc_info=exc_info) + LOG.warning( + 'Failed to setup virtualenv for pack "%s": %s', + pack_name, + e, + exc_info=exc_info, + ) if fail_on_failure: raise e else: setup_count += 1 - LOG.info('Setup virtualenv for %s pack(s).' % (setup_count)) + LOG.info("Setup virtualenv for %s pack(s)." % (setup_count)) def register_triggers(): @@ -134,22 +147,21 @@ def register_triggers(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering triggers #####################') - LOG.info('=========================================================') - with Timer(key='st2.register.triggers'): + LOG.info("=========================================================") + LOG.info("############## Registering triggers #####################") + LOG.info("=========================================================") + with Timer(key="st2.register.triggers"): registered_count = triggers_registrar.register_triggers( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info) + LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s triggers.' % (registered_count)) + LOG.info("Registered %s triggers." % (registered_count)) def register_sensors(): @@ -159,22 +171,21 @@ def register_sensors(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering sensors ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.sensors'): + LOG.info("=========================================================") + LOG.info("############## Registering sensors ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.sensors"): registered_count = sensors_registrar.register_sensors( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info) + LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s sensors.' % (registered_count)) + LOG.info("Registered %s sensors." % (registered_count)) def register_runners(): @@ -184,24 +195,23 @@ def register_runners(): # 1. Register runner types try: - LOG.info('=========================================================') - LOG.info('############## Registering runners ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.runners'): + LOG.info("=========================================================") + LOG.info("############## Registering runners ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.runners"): registered_count = runners_registrar.register_runners( - fail_on_failure=fail_on_failure, - experimental=False + fail_on_failure=fail_on_failure, experimental=False ) except Exception as error: exc_info = not fail_on_failure # TODO: Narrow exception window - LOG.warning('Failed to register runners: %s', error, exc_info=exc_info) + LOG.warning("Failed to register runners: %s", error, exc_info=exc_info) if fail_on_failure: raise error - LOG.info('Registered %s runners.', registered_count) + LOG.info("Registered %s runners.", registered_count) def register_actions(): @@ -213,22 +223,21 @@ def register_actions(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering actions ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.actions'): + LOG.info("=========================================================") + LOG.info("############## Registering actions ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.actions"): registered_count = actions_registrar.register_actions( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register actions: %s', e, exc_info=exc_info) + LOG.warning("Failed to register actions: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s actions.' % (registered_count)) + LOG.info("Registered %s actions." % (registered_count)) def register_rules(): @@ -239,28 +248,27 @@ def register_rules(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering rules ########################') - LOG.info('=========================================================') + LOG.info("=========================================================") + LOG.info("############## Registering rules ########################") + LOG.info("=========================================================") rule_types_registrar.register_rule_types() except Exception as e: - LOG.warning('Failed to register rule types: %s', e, exc_info=True) + LOG.warning("Failed to register rule types: %s", e, exc_info=True) return try: - with Timer(key='st2.register.rules'): + with Timer(key="st2.register.rules"): registered_count = rules_registrar.register_rules( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register rules: %s', e, exc_info=exc_info) + LOG.warning("Failed to register rules: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s rules.', registered_count) + LOG.info("Registered %s rules.", registered_count) def register_aliases(): @@ -270,21 +278,20 @@ def register_aliases(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering aliases ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.aliases'): + LOG.info("=========================================================") + LOG.info("############## Registering aliases ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.aliases"): registered_count = aliases_registrar.register_aliases( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: if fail_on_failure: raise e - LOG.warning('Failed to register aliases.', exc_info=True) + LOG.warning("Failed to register aliases.", exc_info=True) - LOG.info('Registered %s aliases.', registered_count) + LOG.info("Registered %s aliases.", registered_count) def register_policies(): @@ -295,31 +302,32 @@ def register_policies(): registered_type_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering policy types #################') - LOG.info('=========================================================') - with Timer(key='st2.register.policies'): + LOG.info("=========================================================") + LOG.info("############## Registering policy types #################") + LOG.info("=========================================================") + with Timer(key="st2.register.policies"): registered_type_count = policies_registrar.register_policy_types(st2common) except Exception: - LOG.warning('Failed to register policy types.', exc_info=True) + LOG.warning("Failed to register policy types.", exc_info=True) - LOG.info('Registered %s policy types.', registered_type_count) + LOG.info("Registered %s policy types.", registered_type_count) registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering policies #####################') - LOG.info('=========================================================') - registered_count = policies_registrar.register_policies(pack_dir=pack_dir, - fail_on_failure=fail_on_failure) + LOG.info("=========================================================") + LOG.info("############## Registering policies #####################") + LOG.info("=========================================================") + registered_count = policies_registrar.register_policies( + pack_dir=pack_dir, fail_on_failure=fail_on_failure + ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register policies: %s', e, exc_info=exc_info) + LOG.warning("Failed to register policies: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s policies.', registered_count) + LOG.info("Registered %s policies.", registered_count) def register_configs(): @@ -329,23 +337,23 @@ def register_configs(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering configs ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.configs'): + LOG.info("=========================================================") + LOG.info("############## Registering configs ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.configs"): registered_count = configs_registrar.register_configs( pack_dir=pack_dir, fail_on_failure=fail_on_failure, - validate_configs=True + validate_configs=True, ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register configs: %s', e, exc_info=exc_info) + LOG.warning("Failed to register configs: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s configs.' % (registered_count)) + LOG.info("Registered %s configs." % (registered_count)) def register_content(): @@ -395,8 +403,12 @@ def register_content(): def setup(argv): - common_setup(config=config, setup_db=True, register_mq_exchanges=True, - register_internal_trigger_types=True) + common_setup( + config=config, + setup_db=True, + register_mq_exchanges=True, + register_internal_trigger_types=True, + ) def teardown(): @@ -410,5 +422,5 @@ def main(argv): # This script registers actions and rules from content-packs. -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/st2common/st2common/content/loader.py b/st2common/st2common/content/loader.py index 0dfae4c0b6..420323fd76 100644 --- a/st2common/st2common/content/loader.py +++ b/st2common/st2common/content/loader.py @@ -28,10 +28,7 @@ if six.PY2: from io import open -__all__ = [ - 'ContentPackLoader', - 'MetaLoader' -] +__all__ = ["ContentPackLoader", "MetaLoader"] LOG = logging.getLogger(__name__) @@ -45,12 +42,12 @@ class ContentPackLoader(object): # content - they just return a path ALLOWED_CONTENT_TYPES = [ - 'triggers', - 'sensors', - 'actions', - 'rules', - 'aliases', - 'policies' + "triggers", + "sensors", + "actions", + "rules", + "aliases", + "policies", ] def get_packs(self, base_dirs): @@ -91,7 +88,7 @@ def get_content(self, base_dirs, content_type): assert isinstance(base_dirs, list) if content_type not in self.ALLOWED_CONTENT_TYPES: - raise ValueError('Unsupported content_type: %s' % (content_type)) + raise ValueError("Unsupported content_type: %s" % (content_type)) content = {} pack_to_dir_map = {} @@ -99,14 +96,18 @@ def get_content(self, base_dirs, content_type): if not os.path.isdir(base_dir): raise ValueError('Directory "%s" doesn\'t exist' % (base_dir)) - dir_content = self._get_content_from_dir(base_dir=base_dir, content_type=content_type) + dir_content = self._get_content_from_dir( + base_dir=base_dir, content_type=content_type + ) # Check for duplicate packs for pack_name, pack_content in six.iteritems(dir_content): if pack_name in content: pack_dir = pack_to_dir_map[pack_name] - LOG.warning('Pack "%s" already found in "%s", ignoring content from "%s"' % - (pack_name, pack_dir, base_dir)) + LOG.warning( + 'Pack "%s" already found in "%s", ignoring content from "%s"' + % (pack_name, pack_dir, base_dir) + ) else: content[pack_name] = pack_content pack_to_dir_map[pack_name] = base_dir @@ -126,13 +127,14 @@ def get_content_from_pack(self, pack_dir, content_type): :rtype: ``str`` """ if content_type not in self.ALLOWED_CONTENT_TYPES: - raise ValueError('Unsupported content_type: %s' % (content_type)) + raise ValueError("Unsupported content_type: %s" % (content_type)) if not os.path.isdir(pack_dir): raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir)) - content = self._get_content_from_pack_dir(pack_dir=pack_dir, - content_type=content_type) + content = self._get_content_from_pack_dir( + pack_dir=pack_dir, content_type=content_type + ) return content def _get_packs_from_dir(self, base_dir): @@ -154,8 +156,9 @@ def _get_content_from_dir(self, base_dir, content_type): # Ignore missing or non directories try: - pack_content = self._get_content_from_pack_dir(pack_dir=pack_dir, - content_type=content_type) + pack_content = self._get_content_from_pack_dir( + pack_dir=pack_dir, content_type=content_type + ) except ValueError: continue else: @@ -170,13 +173,13 @@ def _get_content_from_pack_dir(self, pack_dir, content_type): actions=self._get_actions, rules=self._get_rules, aliases=self._get_aliases, - policies=self._get_policies + policies=self._get_policies, ) get_func = content_types.get(content_type) if get_func is None: - raise ValueError('Invalid content_type: %s' % (content_type)) + raise ValueError("Invalid content_type: %s" % (content_type)) if not os.path.isdir(pack_dir): raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir)) @@ -185,22 +188,22 @@ def _get_content_from_pack_dir(self, pack_dir, content_type): return pack_content def _get_triggers(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='triggers') + return self._get_folder(pack_dir=pack_dir, content_type="triggers") def _get_sensors(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='sensors') + return self._get_folder(pack_dir=pack_dir, content_type="sensors") def _get_actions(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='actions') + return self._get_folder(pack_dir=pack_dir, content_type="actions") def _get_rules(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='rules') + return self._get_folder(pack_dir=pack_dir, content_type="rules") def _get_aliases(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='aliases') + return self._get_folder(pack_dir=pack_dir, content_type="aliases") def _get_policies(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='policies') + return self._get_folder(pack_dir=pack_dir, content_type="policies") def _get_folder(self, pack_dir, content_type): path = os.path.join(pack_dir, content_type) @@ -233,8 +236,10 @@ def load(self, file_path, expected_type=None): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) result = self._load(PARSER_FUNCS[file_ext], file_path) @@ -246,12 +251,12 @@ def load(self, file_path, expected_type=None): return result def _load(self, parser_func, file_path): - with open(file_path, 'r', encoding='utf-8') as fd: + with open(file_path, "r", encoding="utf-8") as fd: try: return parser_func(fd) except ValueError: - LOG.exception('Failed loading content from %s.', file_path) + LOG.exception("Failed loading content from %s.", file_path) raise except ParserError: - LOG.exception('Failed loading content from %s.', file_path) + LOG.exception("Failed loading content from %s.", file_path) raise diff --git a/st2common/st2common/content/utils.py b/st2common/st2common/content/utils.py index 3bd5e2b12c..ad9386acf6 100644 --- a/st2common/st2common/content/utils.py +++ b/st2common/st2common/content/utils.py @@ -24,22 +24,24 @@ from st2common.util.shell import quote_unix __all__ = [ - 'get_pack_group', - 'get_system_packs_base_path', - 'get_packs_base_paths', - 'get_pack_base_path', - 'get_pack_directory', - 'get_pack_file_abs_path', - 'get_pack_resource_file_abs_path', - 'get_relative_path_to_pack_file', - 'check_pack_directory_exists', - 'check_pack_content_directory_exists' + "get_pack_group", + "get_system_packs_base_path", + "get_packs_base_paths", + "get_pack_base_path", + "get_pack_directory", + "get_pack_file_abs_path", + "get_pack_resource_file_abs_path", + "get_relative_path_to_pack_file", + "check_pack_directory_exists", + "check_pack_content_directory_exists", ] INVALID_FILE_PATH_ERROR = """ Invalid file path: "%s". File path needs to be relative to the pack%sdirectory (%s). For example "my_%s.py". -""".strip().replace('\n', ' ') +""".strip().replace( + "\n", " " +) # Cache which stores pack name -> pack base path mappings PACK_NAME_TO_BASE_PATH_CACHE = {} @@ -70,10 +72,10 @@ def get_packs_base_paths(): :rtype: ``list`` """ system_packs_base_path = get_system_packs_base_path() - packs_base_paths = cfg.CONF.content.packs_base_paths or '' + packs_base_paths = cfg.CONF.content.packs_base_paths or "" # Remove trailing colon (if present) - if packs_base_paths.endswith(':'): + if packs_base_paths.endswith(":"): packs_base_paths = packs_base_paths[:-1] result = [] @@ -81,7 +83,7 @@ def get_packs_base_paths(): if system_packs_base_path: result.append(system_packs_base_path) - packs_base_paths = packs_base_paths.split(':') + packs_base_paths = packs_base_paths.split(":") result = result + packs_base_paths result = [path for path in result if path] @@ -223,22 +225,28 @@ def get_entry_point_abs_path(pack=None, entry_point=None, use_pack_cache=False): return None if os.path.isabs(entry_point): - pack_base_path = get_pack_base_path(pack_name=pack, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack, use_pack_cache=use_pack_cache + ) common_prefix = os.path.commonprefix([pack_base_path, entry_point]) if common_prefix != pack_base_path: - raise ValueError('Entry point file "%s" is located outside of the pack directory' % - (entry_point)) + raise ValueError( + 'Entry point file "%s" is located outside of the pack directory' + % (entry_point) + ) return entry_point - entry_point_abs_path = get_pack_resource_file_abs_path(pack_ref=pack, - resource_type='action', - file_path=entry_point) + entry_point_abs_path = get_pack_resource_file_abs_path( + pack_ref=pack, resource_type="action", file_path=entry_point + ) return entry_point_abs_path -def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cache=False): +def get_pack_file_abs_path( + pack_ref, file_path, resource_type=None, use_pack_cache=False +): """ Retrieve full absolute path to the pack file. @@ -258,36 +266,46 @@ def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cac :rtype: ``str`` """ - pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack_ref, use_pack_cache=use_pack_cache + ) if resource_type: - resource_type_plural = ' %ss ' % (resource_type) - resource_base_path = os.path.join(pack_base_path, '%ss/' % (resource_type)) + resource_type_plural = " %ss " % (resource_type) + resource_base_path = os.path.join(pack_base_path, "%ss/" % (resource_type)) else: - resource_type_plural = ' ' + resource_type_plural = " " resource_base_path = pack_base_path path_components = [] path_components.append(pack_base_path) # Normalize the path to prevent directory traversal - normalized_file_path = os.path.normpath('/' + file_path).lstrip('/') + normalized_file_path = os.path.normpath("/" + file_path).lstrip("/") if normalized_file_path != file_path: - msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path, - resource_type or 'action') + msg = INVALID_FILE_PATH_ERROR % ( + file_path, + resource_type_plural, + resource_base_path, + resource_type or "action", + ) raise ValueError(msg) path_components.append(normalized_file_path) - result = os.path.join(*path_components) # pylint: disable=E1120 + result = os.path.join(*path_components) # pylint: disable=E1120 assert normalized_file_path in result # Final safety check for common prefix to avoid traversal attack common_prefix = os.path.commonprefix([pack_base_path, result]) if common_prefix != pack_base_path: - msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path, - resource_type or 'action') + msg = INVALID_FILE_PATH_ERROR % ( + file_path, + resource_type_plural, + resource_base_path, + resource_type or "action", + ) raise ValueError(msg) return result @@ -313,19 +331,20 @@ def get_pack_resource_file_abs_path(pack_ref, resource_type, file_path): :rtype: ``str`` """ path_components = [] - if resource_type == 'action': - path_components.append('actions/') - elif resource_type == 'sensor': - path_components.append('sensors/') - elif resource_type == 'rule': - path_components.append('rules/') + if resource_type == "action": + path_components.append("actions/") + elif resource_type == "sensor": + path_components.append("sensors/") + elif resource_type == "rule": + path_components.append("rules/") else: - raise ValueError('Invalid resource type: %s' % (resource_type)) + raise ValueError("Invalid resource type: %s" % (resource_type)) path_components.append(file_path) file_path = os.path.join(*path_components) # pylint: disable=E1120 - result = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path, - resource_type=resource_type) + result = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path, resource_type=resource_type + ) return result @@ -341,7 +360,9 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False): :rtype: ``str`` """ - pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack_ref, use_pack_cache=use_pack_cache + ) if not os.path.isabs(file_path): return file_path @@ -350,8 +371,10 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False): common_prefix = os.path.commonprefix([pack_base_path, file_path]) if common_prefix != pack_base_path: - raise ValueError('file_path (%s) is not located inside the pack directory (%s)' % - (file_path, pack_base_path)) + raise ValueError( + "file_path (%s) is not located inside the pack directory (%s)" + % (file_path, pack_base_path) + ) relative_path = os.path.relpath(file_path, common_prefix) return relative_path @@ -381,15 +404,15 @@ def get_aliases_base_paths(): :rtype: ``list`` """ - aliases_base_paths = cfg.CONF.content.aliases_base_paths or '' + aliases_base_paths = cfg.CONF.content.aliases_base_paths or "" # Remove trailing colon (if present) - if aliases_base_paths.endswith(':'): + if aliases_base_paths.endswith(":"): aliases_base_paths = aliases_base_paths[:-1] result = [] - aliases_base_paths = aliases_base_paths.split(':') + aliases_base_paths = aliases_base_paths.split(":") result = aliases_base_paths result = [path for path in result if path] diff --git a/st2common/st2common/content/validators.py b/st2common/st2common/content/validators.py index bba9c446e3..8b1ab822c0 100644 --- a/st2common/st2common/content/validators.py +++ b/st2common/st2common/content/validators.py @@ -19,20 +19,16 @@ from st2common.constants.pack import USER_PACK_NAME_BLACKLIST -__all__ = [ - 'RequirementsValidator', - 'validate_pack_name' -] +__all__ = ["RequirementsValidator", "validate_pack_name"] class RequirementsValidator(object): - @staticmethod def validate(requirements_file): if not os.path.exists(requirements_file): - raise Exception('Requirements file %s not found.' % requirements_file) + raise Exception("Requirements file %s not found." % requirements_file) missing = [] - with open(requirements_file, 'r') as f: + with open(requirements_file, "r") as f: for line in f: rqmnt = line.strip() try: @@ -54,10 +50,9 @@ def validate_pack_name(name): :rtype: ``str`` """ if not name: - raise ValueError('Content pack name cannot be empty') + raise ValueError("Content pack name cannot be empty") if name.lower() in USER_PACK_NAME_BLACKLIST: - raise ValueError('Name "%s" is blacklisted and can\'t be used' % - (name.lower())) + raise ValueError('Name "%s" is blacklisted and can\'t be used' % (name.lower())) return name diff --git a/st2common/st2common/database_setup.py b/st2common/st2common/database_setup.py index 2678ecbf2e..2e2e7d2a17 100644 --- a/st2common/st2common/database_setup.py +++ b/st2common/st2common/database_setup.py @@ -23,29 +23,27 @@ from st2common.models import db from st2common.persistence import db_init -__all__ = [ - 'db_config', - 'db_setup', - 'db_teardown' -] +__all__ = ["db_config", "db_setup", "db_teardown"] def db_config(): - username = getattr(cfg.CONF.database, 'username', None) - password = getattr(cfg.CONF.database, 'password', None) - - return {'db_name': cfg.CONF.database.db_name, - 'db_host': cfg.CONF.database.host, - 'db_port': cfg.CONF.database.port, - 'username': username, - 'password': password, - 'ssl': cfg.CONF.database.ssl, - 'ssl_keyfile': cfg.CONF.database.ssl_keyfile, - 'ssl_certfile': cfg.CONF.database.ssl_certfile, - 'ssl_cert_reqs': cfg.CONF.database.ssl_cert_reqs, - 'ssl_ca_certs': cfg.CONF.database.ssl_ca_certs, - 'authentication_mechanism': cfg.CONF.database.authentication_mechanism, - 'ssl_match_hostname': cfg.CONF.database.ssl_match_hostname} + username = getattr(cfg.CONF.database, "username", None) + password = getattr(cfg.CONF.database, "password", None) + + return { + "db_name": cfg.CONF.database.db_name, + "db_host": cfg.CONF.database.host, + "db_port": cfg.CONF.database.port, + "username": username, + "password": password, + "ssl": cfg.CONF.database.ssl, + "ssl_keyfile": cfg.CONF.database.ssl_keyfile, + "ssl_certfile": cfg.CONF.database.ssl_certfile, + "ssl_cert_reqs": cfg.CONF.database.ssl_cert_reqs, + "ssl_ca_certs": cfg.CONF.database.ssl_ca_certs, + "authentication_mechanism": cfg.CONF.database.authentication_mechanism, + "ssl_match_hostname": cfg.CONF.database.ssl_match_hostname, + } def db_setup(ensure_indexes=True): @@ -53,7 +51,7 @@ def db_setup(ensure_indexes=True): Creates the database and indexes (optional). """ db_cfg = db_config() - db_cfg['ensure_indexes'] = ensure_indexes + db_cfg["ensure_indexes"] = ensure_indexes connection = db_init.db_setup_with_retry(**db_cfg) return connection diff --git a/st2common/st2common/exceptions/__init__.py b/st2common/st2common/exceptions/__init__.py index ec4e9430e9..065d3ff0fe 100644 --- a/st2common/st2common/exceptions/__init__.py +++ b/st2common/st2common/exceptions/__init__.py @@ -16,24 +16,26 @@ class StackStormBaseException(Exception): """ - The root of the exception class hierarchy for all - StackStorm server exceptions. + The root of the exception class hierarchy for all + StackStorm server exceptions. - For exceptions raised by plug-ins, see StackStormPluginException - class. + For exceptions raised by plug-ins, see StackStormPluginException + class. """ + pass class StackStormPluginException(StackStormBaseException): """ - The root of the exception class hierarchy for all - exceptions that are defined as part of a StackStorm - plug-in API. - - It is recommended that each API define a root exception - class for the API. This root exception class for the - API should inherit from the StackStormPluginException - class. + The root of the exception class hierarchy for all + exceptions that are defined as part of a StackStorm + plug-in API. + + It is recommended that each API define a root exception + class for the API. This root exception class for the + API should inherit from the StackStormPluginException + class. """ + pass diff --git a/st2common/st2common/exceptions/action.py b/st2common/st2common/exceptions/action.py index f7ed430266..f4bba2ee75 100644 --- a/st2common/st2common/exceptions/action.py +++ b/st2common/st2common/exceptions/action.py @@ -17,9 +17,9 @@ from st2common.exceptions import StackStormBaseException __all__ = [ - 'ParameterRenderingFailedException', - 'InvalidActionReferencedException', - 'InvalidActionParameterException' + "ParameterRenderingFailedException", + "InvalidActionReferencedException", + "InvalidActionParameterException", ] diff --git a/st2common/st2common/exceptions/actionalias.py b/st2common/st2common/exceptions/actionalias.py index 1c01cd5736..3172a72dc6 100644 --- a/st2common/st2common/exceptions/actionalias.py +++ b/st2common/st2common/exceptions/actionalias.py @@ -16,9 +16,7 @@ from __future__ import absolute_import from st2common.exceptions import StackStormBaseException -__all__ = [ - 'ActionAliasAmbiguityException' -] +__all__ = ["ActionAliasAmbiguityException"] class ActionAliasAmbiguityException(ValueError, StackStormBaseException): diff --git a/st2common/st2common/exceptions/api.py b/st2common/st2common/exceptions/api.py index f5aee1c1c0..054eb1bcf1 100644 --- a/st2common/st2common/exceptions/api.py +++ b/st2common/st2common/exceptions/api.py @@ -16,8 +16,7 @@ from __future__ import absolute_import from st2common.exceptions import StackStormBaseException -__all__ = [ -] +__all__ = [] class InternalServerErrorException(StackStormBaseException): diff --git a/st2common/st2common/exceptions/auth.py b/st2common/st2common/exceptions/auth.py index 429d597abd..5eab1915f5 100644 --- a/st2common/st2common/exceptions/auth.py +++ b/st2common/st2common/exceptions/auth.py @@ -18,19 +18,19 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'TokenNotProvidedError', - 'TokenNotFoundError', - 'TokenExpiredError', - 'TTLTooLargeException', - 'ApiKeyNotProvidedError', - 'ApiKeyNotFoundError', - 'MultipleAuthSourcesError', - 'NoAuthSourceProvidedError', - 'NoNicknameOriginProvidedError', - 'UserNotFoundError', - 'AmbiguousUserError', - 'NotServiceUserError', - 'SSOVerificationError' + "TokenNotProvidedError", + "TokenNotFoundError", + "TokenExpiredError", + "TTLTooLargeException", + "ApiKeyNotProvidedError", + "ApiKeyNotFoundError", + "MultipleAuthSourcesError", + "NoAuthSourceProvidedError", + "NoNicknameOriginProvidedError", + "UserNotFoundError", + "AmbiguousUserError", + "NotServiceUserError", + "SSOVerificationError", ] diff --git a/st2common/st2common/exceptions/connection.py b/st2common/st2common/exceptions/connection.py index 8cb9681b41..806d6e1046 100644 --- a/st2common/st2common/exceptions/connection.py +++ b/st2common/st2common/exceptions/connection.py @@ -16,14 +16,17 @@ class UnknownHostException(Exception): """Raised when a host is unknown (dns failure)""" + pass class ConnectionErrorException(Exception): """Raised on error connecting (connection refused/timed out)""" + pass class AuthenticationException(Exception): """Raised on authentication error (user/password/ssh key error)""" + pass diff --git a/st2common/st2common/exceptions/db.py b/st2common/st2common/exceptions/db.py index fcd607e964..776d927e0f 100644 --- a/st2common/st2common/exceptions/db.py +++ b/st2common/st2common/exceptions/db.py @@ -29,6 +29,7 @@ class StackStormDBObjectConflictError(StackStormBaseException): """ Exception that captures a DB object conflict error. """ + def __init__(self, message, conflict_id, model_object): super(StackStormDBObjectConflictError, self).__init__(message) self.conflict_id = conflict_id @@ -36,7 +37,9 @@ def __init__(self, message, conflict_id, model_object): class StackStormDBObjectWriteConflictError(StackStormBaseException): - def __init__(self, instance): - msg = 'Conflict saving DB object with id "%s" and rev "%s".' % (instance.id, instance.rev) + msg = 'Conflict saving DB object with id "%s" and rev "%s".' % ( + instance.id, + instance.rev, + ) super(StackStormDBObjectWriteConflictError, self).__init__(msg) diff --git a/st2common/st2common/exceptions/inquiry.py b/st2common/st2common/exceptions/inquiry.py index 0636d0f985..b5c3f30646 100644 --- a/st2common/st2common/exceptions/inquiry.py +++ b/st2common/st2common/exceptions/inquiry.py @@ -23,32 +23,33 @@ class InvalidInquiryInstance(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Action execution "%s" is not an inquiry.' % inquiry_id) + Exception.__init__( + self, 'Action execution "%s" is not an inquiry.' % inquiry_id + ) class InquiryTimedOut(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id) + Exception.__init__( + self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id + ) class InquiryAlreadyResponded(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Inquiry "%s" has already been responded to.' % inquiry_id) + Exception.__init__( + self, 'Inquiry "%s" has already been responded to.' % inquiry_id + ) class InquiryResponseUnauthorized(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id, user): msg = 'User "%s" does not have permission to respond to inquiry "%s".' Exception.__init__(self, msg % (user, inquiry_id)) class InvalidInquiryResponse(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id, error): msg = 'Response for inquiry "%s" did not pass schema validation. %s' Exception.__init__(self, msg % (inquiry_id, error)) diff --git a/st2common/st2common/exceptions/keyvalue.py b/st2common/st2common/exceptions/keyvalue.py index 6ef2702fe8..7fccb8b819 100644 --- a/st2common/st2common/exceptions/keyvalue.py +++ b/st2common/st2common/exceptions/keyvalue.py @@ -18,9 +18,9 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'CryptoKeyNotSetupException', - 'DataStoreKeyNotFoundError', - 'InvalidScopeException' + "CryptoKeyNotSetupException", + "DataStoreKeyNotFoundError", + "InvalidScopeException", ] diff --git a/st2common/st2common/exceptions/rbac.py b/st2common/st2common/exceptions/rbac.py index 308110c267..957b0fe5be 100644 --- a/st2common/st2common/exceptions/rbac.py +++ b/st2common/st2common/exceptions/rbac.py @@ -18,10 +18,10 @@ from st2common.rbac.types import GLOBAL_PERMISSION_TYPES __all__ = [ - 'AccessDeniedError', - 'ResourceTypeAccessDeniedError', - 'ResourceAccessDeniedError', - 'ResourceAccessDeniedPermissionIsolationError' + "AccessDeniedError", + "ResourceTypeAccessDeniedError", + "ResourceAccessDeniedError", + "ResourceAccessDeniedPermissionIsolationError", ] @@ -45,9 +45,13 @@ class ResourceTypeAccessDeniedError(AccessDeniedError): def __init__(self, user_db, permission_type): self.permission_type = permission_type - message = ('User "%s" doesn\'t have required permission "%s"' % (user_db.name, - permission_type)) - super(ResourceTypeAccessDeniedError, self).__init__(message=message, user_db=user_db) + message = 'User "%s" doesn\'t have required permission "%s"' % ( + user_db.name, + permission_type, + ) + super(ResourceTypeAccessDeniedError, self).__init__( + message=message, user_db=user_db + ) class ResourceAccessDeniedError(AccessDeniedError): @@ -59,15 +63,25 @@ def __init__(self, user_db, resource_api_or_db, permission_type): self.resource_api_db = resource_api_or_db self.permission_type = permission_type - resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown' + resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown" if resource_api_or_db and permission_type not in GLOBAL_PERMISSION_TYPES: - message = ('User "%s" doesn\'t have required permission "%s" on resource "%s"' % - (user_db.name, permission_type, resource_uid)) + message = ( + 'User "%s" doesn\'t have required permission "%s" on resource "%s"' + % ( + user_db.name, + permission_type, + resource_uid, + ) + ) else: - message = ('User "%s" doesn\'t have required permission "%s"' % - (user_db.name, permission_type)) - super(ResourceAccessDeniedError, self).__init__(message=message, user_db=user_db) + message = 'User "%s" doesn\'t have required permission "%s"' % ( + user_db.name, + permission_type, + ) + super(ResourceAccessDeniedError, self).__init__( + message=message, user_db=user_db + ) class ResourceAccessDeniedPermissionIsolationError(AccessDeniedError): @@ -80,9 +94,12 @@ def __init__(self, user_db, resource_api_or_db, permission_type): self.resource_api_db = resource_api_or_db self.permission_type = permission_type - resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown' + resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown" - message = ('User "%s" doesn\'t have access to resource "%s" due to resource permission ' - 'isolation.' % (user_db.name, resource_uid)) - super(ResourceAccessDeniedPermissionIsolationError, self).__init__(message=message, - user_db=user_db) + message = ( + 'User "%s" doesn\'t have access to resource "%s" due to resource permission ' + "isolation." % (user_db.name, resource_uid) + ) + super(ResourceAccessDeniedPermissionIsolationError, self).__init__( + message=message, user_db=user_db + ) diff --git a/st2common/st2common/exceptions/ssh.py b/st2common/st2common/exceptions/ssh.py index f720e54b8a..7a4e1ee516 100644 --- a/st2common/st2common/exceptions/ssh.py +++ b/st2common/st2common/exceptions/ssh.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'InvalidCredentialsException' -] +__all__ = ["InvalidCredentialsException"] class InvalidCredentialsException(Exception): diff --git a/st2common/st2common/exceptions/workflow.py b/st2common/st2common/exceptions/workflow.py index dd787417c2..2a346819be 100644 --- a/st2common/st2common/exceptions/workflow.py +++ b/st2common/st2common/exceptions/workflow.py @@ -27,28 +27,25 @@ def retry_on_connection_errors(exc): - LOG.warning('Determining if exception %s should be retried.', type(exc)) + LOG.warning("Determining if exception %s should be retried.", type(exc)) - retrying = ( - isinstance(exc, tooz.coordination.ToozConnectionError) or - isinstance(exc, mongoengine.connection.MongoEngineConnectionError) + retrying = isinstance(exc, tooz.coordination.ToozConnectionError) or isinstance( + exc, mongoengine.connection.MongoEngineConnectionError ) if retrying: - LOG.warning('Retrying operation due to connection error: %s', type(exc)) + LOG.warning("Retrying operation due to connection error: %s", type(exc)) return retrying def retry_on_transient_db_errors(exc): - LOG.warning('Determining if exception %s should be retried.', type(exc)) + LOG.warning("Determining if exception %s should be retried.", type(exc)) - retrying = ( - isinstance(exc, db_exc.StackStormDBObjectWriteConflictError) - ) + retrying = isinstance(exc, db_exc.StackStormDBObjectWriteConflictError) if retrying: - LOG.warning('Retrying operation due to transient database error: %s', type(exc)) + LOG.warning("Retrying operation due to transient database error: %s", type(exc)) return retrying @@ -62,38 +59,37 @@ class WorkflowExecutionException(st2_exc.StackStormBaseException): class WorkflowExecutionNotFoundException(st2_exc.StackStormBaseException): - def __init__(self, ac_ex_id): Exception.__init__( self, - 'Unable to identify any workflow execution that is ' - 'associated to action execution "%s".' % ac_ex_id + "Unable to identify any workflow execution that is " + 'associated to action execution "%s".' % ac_ex_id, ) class AmbiguousWorkflowExecutionException(st2_exc.StackStormBaseException): - def __init__(self, ac_ex_id): Exception.__init__( self, - 'More than one workflow execution is associated ' - 'to action execution "%s".' % ac_ex_id + "More than one workflow execution is associated " + 'to action execution "%s".' % ac_ex_id, ) class WorkflowExecutionIsCompletedException(st2_exc.StackStormBaseException): - def __init__(self, wf_ex_id): - Exception.__init__(self, 'Workflow execution "%s" is already completed.' % wf_ex_id) + Exception.__init__( + self, 'Workflow execution "%s" is already completed.' % wf_ex_id + ) class WorkflowExecutionIsRunningException(st2_exc.StackStormBaseException): - def __init__(self, wf_ex_id): - Exception.__init__(self, 'Workflow execution "%s" is already active.' % wf_ex_id) + Exception.__init__( + self, 'Workflow execution "%s" is already active.' % wf_ex_id + ) class WorkflowExecutionRerunException(st2_exc.StackStormBaseException): - def __init__(self, msg): Exception.__init__(self, msg) diff --git a/st2common/st2common/expressions/functions/data.py b/st2common/st2common/expressions/functions/data.py index d3783e652e..b240cb7238 100644 --- a/st2common/st2common/expressions/functions/data.py +++ b/st2common/st2common/expressions/functions/data.py @@ -24,13 +24,13 @@ __all__ = [ - 'from_json_string', - 'from_yaml_string', - 'json_escape', - 'jsonpath_query', - 'to_complex', - 'to_json_string', - 'to_yaml_string', + "from_json_string", + "from_yaml_string", + "json_escape", + "jsonpath_query", + "to_complex", + "to_json_string", + "to_yaml_string", ] @@ -42,19 +42,19 @@ def from_yaml_string(value): return yaml.safe_load(six.text_type(value)) -def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')): +def to_json_string(value, indent=None, sort_keys=False, separators=(",", ": ")): value = db_util.mongodb_to_python_types(value) options = {} if indent is not None: - options['indent'] = indent + options["indent"] = indent if sort_keys is not None: - options['sort_keys'] = sort_keys + options["sort_keys"] = sort_keys if separators is not None: - options['separators'] = separators + options["separators"] = separators return json.dumps(value, **options) @@ -62,19 +62,19 @@ def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')): def to_yaml_string(value, indent=None, allow_unicode=True): value = db_util.mongodb_to_python_types(value) - options = {'default_flow_style': False} + options = {"default_flow_style": False} if indent is not None: - options['indent'] = indent + options["indent"] = indent if allow_unicode is not None: - options['allow_unicode'] = allow_unicode + options["allow_unicode"] = allow_unicode return yaml.safe_dump(value, **options) def json_escape(value): - """ Adds escape sequences to problematic characters in the string + """Adds escape sequences to problematic characters in the string This filter simply passes the value to json.dumps as a convenient way of escaping characters in it However, before returning, we want to strip the double @@ -110,7 +110,7 @@ def to_complex(value): # Magic string to which None type is serialized when using use_none filter -NONE_MAGIC_VALUE = '%*****__%NONE%__*****%' +NONE_MAGIC_VALUE = "%*****__%NONE%__*****%" def use_none(value): diff --git a/st2common/st2common/expressions/functions/datastore.py b/st2common/st2common/expressions/functions/datastore.py index a8e903c377..bd0e5fbb09 100644 --- a/st2common/st2common/expressions/functions/datastore.py +++ b/st2common/st2common/expressions/functions/datastore.py @@ -22,9 +22,7 @@ from st2common.util.crypto import read_crypto_key from st2common.util.crypto import symmetric_decrypt -__all__ = [ - 'decrypt_kv' -] +__all__ = ["decrypt_kv"] def decrypt_kv(value): @@ -41,11 +39,13 @@ def decrypt_kv(value): # NOTE: If value is None this indicate key value item doesn't exist and we hrow a more # user-friendly error - if is_kv_item and value == '': + if is_kv_item and value == "": # Build original key name key_name = original_value.get_key_name() - raise ValueError('Referenced datastore item "%s" doesn\'t exist or it contains an empty ' - 'string' % (key_name)) + raise ValueError( + 'Referenced datastore item "%s" doesn\'t exist or it contains an empty ' + "string" % (key_name) + ) crypto_key_path = cfg.CONF.keyvalue.encryption_key_path crypto_key = read_crypto_key(key_path=crypto_key_path) diff --git a/st2common/st2common/expressions/functions/path.py b/st2common/st2common/expressions/functions/path.py index 6081be895c..d21f301aa1 100644 --- a/st2common/st2common/expressions/functions/path.py +++ b/st2common/st2common/expressions/functions/path.py @@ -16,10 +16,7 @@ from __future__ import absolute_import import os -__all__ = [ - 'basename', - 'dirname' -] +__all__ = ["basename", "dirname"] def basename(path): diff --git a/st2common/st2common/expressions/functions/regex.py b/st2common/st2common/expressions/functions/regex.py index 4db7fe0f65..4b17b7372f 100644 --- a/st2common/st2common/expressions/functions/regex.py +++ b/st2common/st2common/expressions/functions/regex.py @@ -17,12 +17,7 @@ import re import six -__all__ = [ - 'regex_match', - 'regex_replace', - 'regex_search', - 'regex_substring' -] +__all__ = ["regex_match", "regex_replace", "regex_search", "regex_substring"] def _get_regex_flags(ignorecase=False): diff --git a/st2common/st2common/expressions/functions/time.py b/st2common/st2common/expressions/functions/time.py index 543fc80938..d25b8acecc 100644 --- a/st2common/st2common/expressions/functions/time.py +++ b/st2common/st2common/expressions/functions/time.py @@ -19,14 +19,12 @@ import datetime -__all__ = [ - 'to_human_time_from_seconds' -] +__all__ = ["to_human_time_from_seconds"] if six.PY3: long_int = int else: - long_int = long # noqa # pylint: disable=E0602 + long_int = long # noqa # pylint: disable=E0602 def to_human_time_from_seconds(seconds): @@ -39,8 +37,11 @@ def to_human_time_from_seconds(seconds): :rtype: ``str`` """ - assert (isinstance(seconds, int) or isinstance(seconds, int) or - isinstance(seconds, float)) + assert ( + isinstance(seconds, int) + or isinstance(seconds, int) + or isinstance(seconds, float) + ) return _get_human_time(seconds) @@ -59,10 +60,10 @@ def _get_human_time(seconds): return None if seconds == 0: - return '0s' + return "0s" if seconds < 1: - return '%s\u03BCs' % seconds # Microseconds + return "%s\u03BCs" % seconds # Microseconds if isinstance(seconds, float): seconds = long_int(round(seconds)) # Let's lose microseconds. @@ -81,17 +82,17 @@ def _get_human_time(seconds): first_non_zero_pos = next((i for i, x in enumerate(time_parts) if x), None) if first_non_zero_pos is None: - return '0s' + return "0s" else: time_parts = time_parts[first_non_zero_pos:] if len(time_parts) == 1: - return '%ss' % tuple(time_parts) + return "%ss" % tuple(time_parts) elif len(time_parts) == 2: - return '%sm%ss' % tuple(time_parts) + return "%sm%ss" % tuple(time_parts) elif len(time_parts) == 3: - return '%sh%sm%ss' % tuple(time_parts) + return "%sh%sm%ss" % tuple(time_parts) elif len(time_parts) == 4: - return '%sd%sh%sm%ss' % tuple(time_parts) + return "%sd%sh%sm%ss" % tuple(time_parts) elif len(time_parts) == 5: - return '%sy%sd%sh%sm%ss' % tuple(time_parts) + return "%sy%sd%sh%sm%ss" % tuple(time_parts) diff --git a/st2common/st2common/expressions/functions/version.py b/st2common/st2common/expressions/functions/version.py index 2dc8d353f1..825d5965e3 100644 --- a/st2common/st2common/expressions/functions/version.py +++ b/st2common/st2common/expressions/functions/version.py @@ -17,13 +17,13 @@ import semver __all__ = [ - 'version_compare', - 'version_more_than', - 'version_less_than', - 'version_equal', - 'version_match', - 'version_bump_major', - 'version_bump_minor' + "version_compare", + "version_more_than", + "version_less_than", + "version_equal", + "version_match", + "version_bump_major", + "version_bump_minor", ] diff --git a/st2common/st2common/fields.py b/st2common/st2common/fields.py index 7217365874..b968e2fdb7 100644 --- a/st2common/st2common/fields.py +++ b/st2common/st2common/fields.py @@ -21,9 +21,7 @@ from st2common.util import date as date_utils -__all__ = [ - 'ComplexDateTimeField' -] +__all__ = ["ComplexDateTimeField"] SECOND_TO_MICROSECONDS = 1000000 @@ -60,7 +58,7 @@ def _microseconds_since_epoch_to_datetime(self, data): :type data: ``int`` """ result = datetime.datetime.utcfromtimestamp(data // SECOND_TO_MICROSECONDS) - microseconds_reminder = (data % SECOND_TO_MICROSECONDS) + microseconds_reminder = data % SECOND_TO_MICROSECONDS result = result.replace(microsecond=microseconds_reminder) result = date_utils.add_utc_tz(result) return result @@ -77,11 +75,13 @@ def _datetime_to_microseconds_since_epoch(self, value): # Verify that the value which is passed in contains UTC timezone # information. if not value.tzinfo or (value.tzinfo.utcoffset(value) != datetime.timedelta(0)): - raise ValueError('Value passed to this function needs to be in UTC timezone') + raise ValueError( + "Value passed to this function needs to be in UTC timezone" + ) seconds = calendar.timegm(value.timetuple()) microseconds_reminder = value.time().microsecond - result = (int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder) + result = int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder return result def __get__(self, instance, owner): @@ -99,8 +99,7 @@ def __set__(self, instance, value): def validate(self, value): value = self.to_python(value) if not isinstance(value, datetime.datetime): - self.error('Only datetime objects may used in a ' - 'ComplexDateTimeField') + self.error("Only datetime objects may used in a " "ComplexDateTimeField") def to_python(self, value): original_value = value diff --git a/st2common/st2common/garbage_collection/executions.py b/st2common/st2common/garbage_collection/executions.py index ba924e76f2..ae0f3296f4 100644 --- a/st2common/st2common/garbage_collection/executions.py +++ b/st2common/st2common/garbage_collection/executions.py @@ -32,15 +32,14 @@ from st2common.services import action as action_service from st2common.services import workflows as workflow_service -__all__ = [ - 'purge_executions', - 'purge_execution_output_objects' -] +__all__ = ["purge_executions", "purge_execution_output_objects"] -DONE_STATES = [action_constants.LIVEACTION_STATUS_SUCCEEDED, - action_constants.LIVEACTION_STATUS_FAILED, - action_constants.LIVEACTION_STATUS_TIMED_OUT, - action_constants.LIVEACTION_STATUS_CANCELED] +DONE_STATES = [ + action_constants.LIVEACTION_STATUS_SUCCEEDED, + action_constants.LIVEACTION_STATUS_FAILED, + action_constants.LIVEACTION_STATUS_TIMED_OUT, + action_constants.LIVEACTION_STATUS_CANCELED, +] def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False): @@ -57,90 +56,118 @@ def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False) :type purge_incomplete: ``bool`` """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging executions older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging executions older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) filters = {} if purge_incomplete: - filters['start_timestamp__lt'] = timestamp + filters["start_timestamp__lt"] = timestamp else: - filters['end_timestamp__lt'] = timestamp - filters['start_timestamp__lt'] = timestamp - filters['status'] = {'$in': DONE_STATES} + filters["end_timestamp__lt"] = timestamp + filters["start_timestamp__lt"] = timestamp + filters["status"] = {"$in": DONE_STATES} exec_filters = copy.copy(filters) if action_ref: - exec_filters['action__ref'] = action_ref + exec_filters["action__ref"] = action_ref liveaction_filters = copy.deepcopy(filters) if action_ref: - liveaction_filters['action'] = action_ref + liveaction_filters["action"] = action_ref to_delete_execution_dbs = [] # 1. Delete ActionExecutionDB objects try: # Note: We call list() on the query set object because it's lazyily evaluated otherwise - to_delete_execution_dbs = list(ActionExecution.query(only_fields=['id'], - no_dereference=True, - **exec_filters)) + to_delete_execution_dbs = list( + ActionExecution.query( + only_fields=["id"], no_dereference=True, **exec_filters + ) + ) deleted_count = ActionExecution.delete_by_query(**exec_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution instances: %s' - 'Please contact support.' % (exec_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution instances: %s" + "Please contact support." + % ( + exec_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution models failed for query with filters: %s.', - exec_filters) + logger.exception( + "Deletion of execution models failed for query with filters: %s.", + exec_filters, + ) else: - logger.info('Deleted %s action execution objects' % (deleted_count)) + logger.info("Deleted %s action execution objects" % (deleted_count)) # 2. Delete LiveActionDB objects try: deleted_count = LiveAction.delete_by_query(**liveaction_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete liveaction instances: %s' - 'Please contact support.' % (liveaction_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete liveaction instances: %s" + "Please contact support." + % ( + liveaction_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of liveaction models failed for query with filters: %s.', - liveaction_filters) + logger.exception( + "Deletion of liveaction models failed for query with filters: %s.", + liveaction_filters, + ) else: - logger.info('Deleted %s liveaction objects' % (deleted_count)) + logger.info("Deleted %s liveaction objects" % (deleted_count)) # 3. Delete ActionExecutionOutputDB objects - to_delete_exection_ids = [str(execution_db.id) for execution_db in to_delete_execution_dbs] + to_delete_exection_ids = [ + str(execution_db.id) for execution_db in to_delete_execution_dbs + ] output_dbs_filters = {} - output_dbs_filters['execution_id'] = {'$in': to_delete_exection_ids} + output_dbs_filters["execution_id"] = {"$in": to_delete_exection_ids} try: deleted_count = ActionExecutionOutput.delete_by_query(**output_dbs_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution output instances: %s' - 'Please contact support.' % (output_dbs_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution output instances: %s" + "Please contact support." % (output_dbs_filters, six.text_type(e)) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution output models failed for query with filters: %s.', - output_dbs_filters) + logger.exception( + "Deletion of execution output models failed for query with filters: %s.", + output_dbs_filters, + ) else: - logger.info('Deleted %s execution output objects' % (deleted_count)) + logger.info("Deleted %s execution output objects" % (deleted_count)) - zombie_execution_instances = len(ActionExecution.query(only_fields=['id'], - no_dereference=True, - **exec_filters)) - zombie_liveaction_instances = len(LiveAction.query(only_fields=['id'], - no_dereference=True, - **liveaction_filters)) + zombie_execution_instances = len( + ActionExecution.query(only_fields=["id"], no_dereference=True, **exec_filters) + ) + zombie_liveaction_instances = len( + LiveAction.query(only_fields=["id"], no_dereference=True, **liveaction_filters) + ) if (zombie_execution_instances > 0) or (zombie_liveaction_instances > 0): - logger.error('Zombie execution instances left: %d.', zombie_execution_instances) - logger.error('Zombie liveaction instances left: %s.', zombie_liveaction_instances) + logger.error("Zombie execution instances left: %d.", zombie_execution_instances) + logger.error( + "Zombie liveaction instances left: %s.", zombie_liveaction_instances + ) # Print stats - logger.info('All execution models older than timestamp %s were deleted.', timestamp) + logger.info("All execution models older than timestamp %s were deleted.", timestamp) def purge_execution_output_objects(logger, timestamp, action_ref=None): @@ -154,28 +181,34 @@ def purge_execution_output_objects(logger, timestamp, action_ref=None): :type action_ref: ``str`` """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging action execution output objects older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging action execution output objects older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) filters = {} - filters['timestamp__lt'] = timestamp + filters["timestamp__lt"] = timestamp if action_ref: - filters['action_ref'] = action_ref + filters["action_ref"] = action_ref try: deleted_count = ActionExecutionOutput.delete_by_query(**filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution output instances: %s' - 'Please contact support.' % (filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution output instances: %s" + "Please contact support." % (filters, six.text_type(e)) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution output models failed for query with filters: %s.', - filters) + logger.exception( + "Deletion of execution output models failed for query with filters: %s.", + filters, + ) else: - logger.info('Deleted %s execution output objects' % (deleted_count)) + logger.info("Deleted %s execution output objects" % (deleted_count)) def purge_orphaned_workflow_executions(logger): @@ -190,5 +223,5 @@ def purge_orphaned_workflow_executions(logger): # as a result of the original failure, the garbage collection routine here cancels # the workflow execution so it cannot be rerun from failed task(s). for ac_ex_db in workflow_service.identify_orphaned_workflows(): - lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction['id']) + lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction["id"]) action_service.request_cancellation(lv_ac_db, None) diff --git a/st2common/st2common/garbage_collection/inquiries.py b/st2common/st2common/garbage_collection/inquiries.py index 724033853f..ad95126b21 100644 --- a/st2common/st2common/garbage_collection/inquiries.py +++ b/st2common/st2common/garbage_collection/inquiries.py @@ -27,7 +27,7 @@ from st2common.util.date import get_datetime_utc_now __all__ = [ - 'purge_inquiries', + "purge_inquiries", ] @@ -44,7 +44,10 @@ def purge_inquiries(logger): """ # Get all existing Inquiries - filters = {'runner__name': 'inquirer', 'status': action_constants.LIVEACTION_STATUS_PENDING} + filters = { + "runner__name": "inquirer", + "status": action_constants.LIVEACTION_STATUS_PENDING, + } inquiries = list(ActionExecution.query(**filters)) gc_count = 0 @@ -52,7 +55,7 @@ def purge_inquiries(logger): # Inspect each Inquiry, and determine if TTL is expired for inquiry in inquiries: - ttl = int(inquiry.result.get('ttl')) + ttl = int(inquiry.result.get("ttl")) if ttl <= 0: logger.debug("Inquiry %s has a TTL of %s. Skipping." % (inquiry.id, ttl)) continue @@ -61,17 +64,22 @@ def purge_inquiries(logger): (get_datetime_utc_now() - inquiry.start_timestamp).total_seconds() / 60 ) - logger.debug("Inquiry %s has a TTL of %s and was started %s minute(s) ago" % ( - inquiry.id, ttl, min_since_creation)) + logger.debug( + "Inquiry %s has a TTL of %s and was started %s minute(s) ago" + % (inquiry.id, ttl, min_since_creation) + ) if min_since_creation > ttl: gc_count += 1 - logger.info("TTL expired for Inquiry %s. Marking as timed out." % inquiry.id) + logger.info( + "TTL expired for Inquiry %s. Marking as timed out." % inquiry.id + ) liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_TIMED_OUT, result=inquiry.result, - liveaction_id=inquiry.liveaction.get('id')) + liveaction_id=inquiry.liveaction.get("id"), + ) executions.update_execution(liveaction_db) # Call Inquiry runner's post_run to trigger callback to workflow @@ -82,8 +90,7 @@ def purge_inquiries(logger): # Request that root workflow resumes root_liveaction = action_service.get_root_liveaction(liveaction_db) action_service.request_resume( - root_liveaction, - UserDB(cfg.CONF.system_user.user) + root_liveaction, UserDB(cfg.CONF.system_user.user) ) logger.info('Marked %s ttl-expired Inquiries as "timed out".' % (gc_count)) diff --git a/st2common/st2common/garbage_collection/trigger_instances.py b/st2common/st2common/garbage_collection/trigger_instances.py index 47996614dd..0fbabb5e72 100644 --- a/st2common/st2common/garbage_collection/trigger_instances.py +++ b/st2common/st2common/garbage_collection/trigger_instances.py @@ -25,9 +25,7 @@ from st2common.persistence.trigger import TriggerInstance from st2common.util import isotime -__all__ = [ - 'purge_trigger_instances' -] +__all__ = ["purge_trigger_instances"] def purge_trigger_instances(logger, timestamp): @@ -36,23 +34,35 @@ def purge_trigger_instances(logger, timestamp): :type timestamp: ``datetime.datetime """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging trigger instances older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging trigger instances older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) - query_filters = {'occurrence_time__lt': isotime.parse(timestamp)} + query_filters = {"occurrence_time__lt": isotime.parse(timestamp)} try: deleted_count = TriggerInstance.delete_by_query(**query_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete trigger instances: %s' - 'Please contact support.' % (query_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete trigger instances: %s" + "Please contact support." + % ( + query_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deleting instances using query_filters %s failed.', query_filters) + logger.exception( + "Deleting instances using query_filters %s failed.", query_filters + ) else: - logger.info('Deleted %s trigger instance objects' % (deleted_count)) + logger.info("Deleted %s trigger instance objects" % (deleted_count)) # Print stats - logger.info('All trigger instance models older than timestamp %s were deleted.', timestamp) + logger.info( + "All trigger instance models older than timestamp %s were deleted.", timestamp + ) diff --git a/st2common/st2common/log.py b/st2common/st2common/log.py index 5335af5f53..fbf6205bb9 100644 --- a/st2common/st2common/log.py +++ b/st2common/st2common/log.py @@ -35,34 +35,30 @@ from st2common.util.misc import get_normalized_file_path __all__ = [ - 'getLogger', - 'setup', - - 'FormatNamedFileHandler', - 'ConfigurableSyslogHandler', - - 'LoggingStream', - - 'ignore_lib2to3_log_messages', - 'ignore_statsd_log_messages' + "getLogger", + "setup", + "FormatNamedFileHandler", + "ConfigurableSyslogHandler", + "LoggingStream", + "ignore_lib2to3_log_messages", + "ignore_statsd_log_messages", ] # NOTE: We set AUDIT to the highest log level which means AUDIT log messages will always be # included (e.g. also if log level is set to INFO). To avoid that, we need to explicitly filter # out AUDIT log level in service setup code. logging.AUDIT = logging.CRITICAL + 10 -logging.addLevelName(logging.AUDIT, 'AUDIT') +logging.addLevelName(logging.AUDIT, "AUDIT") LOGGER_KEYS = [ - 'debug', - 'info', - 'warning', - 'error', - 'critical', - 'exception', - 'log', - - 'audit' + "debug", + "info", + "warning", + "error", + "critical", + "exception", + "log", + "audit", ] # Note: This attribute is used by "find_caller" so it can correctly exclude this file when looking @@ -89,10 +85,10 @@ def find_caller(stack_info=False, stacklevel=1): on what runtine we're working in. """ if six.PY2: - rv = '(unknown file)', 0, '(unknown function)' + rv = "(unknown file)", 0, "(unknown function)" else: # python 3, has extra tuple element at the end for stack information - rv = '(unknown file)', 0, '(unknown function)', None + rv = "(unknown file)", 0, "(unknown function)", None try: f = logging.currentframe() @@ -107,7 +103,7 @@ def find_caller(stack_info=False, stacklevel=1): if not f: f = orig_f - while hasattr(f, 'f_code'): + while hasattr(f, "f_code"): co = f.f_code filename = os.path.normcase(co.co_filename) if filename in (_srcfile, logging._srcfile): # This line is modified. @@ -121,10 +117,10 @@ def find_caller(stack_info=False, stacklevel=1): sinfo = None if stack_info: sio = io.StringIO() - sio.write('Stack (most recent call last):\n') + sio.write("Stack (most recent call last):\n") traceback.print_stack(f, file=sio) sinfo = sio.getvalue() - if sinfo[-1] == '\n': + if sinfo[-1] == "\n": sinfo = sinfo[:-1] sio.close() rv = (filename, f.f_lineno, co.co_name, sinfo) @@ -139,8 +135,8 @@ def decorate_log_method(func): @wraps(func) def func_wrapper(*args, **kwargs): # Prefix extra keys with underscore - if 'extra' in kwargs: - kwargs['extra'] = prefix_dict_keys(dictionary=kwargs['extra'], prefix='_') + if "extra" in kwargs: + kwargs["extra"] = prefix_dict_keys(dictionary=kwargs["extra"], prefix="_") try: return func(*args, **kwargs) @@ -150,10 +146,11 @@ def func_wrapper(*args, **kwargs): # See: # - https://docs.python.org/release/2.7.3/library/logging.html#logging.Logger.exception # - https://docs.python.org/release/2.7.7/library/logging.html#logging.Logger.exception - if 'got an unexpected keyword argument \'extra\'' in six.text_type(e): - kwargs.pop('extra', None) + if "got an unexpected keyword argument 'extra'" in six.text_type(e): + kwargs.pop("extra", None) return func(*args, **kwargs) raise e + return func_wrapper @@ -179,11 +176,11 @@ def decorate_logger_methods(logger): def getLogger(name): # make sure that prefix isn't appended multiple times to preserve logging name hierarchy - prefix = 'st2.' + prefix = "st2." if name.startswith(prefix): logger = logging.getLogger(name) else: - logger_name = '{}{}'.format(prefix, name) + logger_name = "{}{}".format(prefix, name) logger = logging.getLogger(logger_name) logger = decorate_logger_methods(logger=logger) @@ -191,7 +188,6 @@ def getLogger(name): class LoggingStream(object): - def __init__(self, name, level=logging.ERROR): self._logger = getLogger(name) self._level = level @@ -219,11 +215,16 @@ def _add_exclusion_filters(handlers, excludes=None): def _redirect_stderr(): # It is ok to redirect stderr as none of the st2 handlers write to stderr. - sys.stderr = LoggingStream('STDERR') + sys.stderr = LoggingStream("STDERR") -def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_loggers=False, - st2_conf_path=None): +def setup( + config_file, + redirect_stderr=True, + excludes=None, + disable_existing_loggers=False, + st2_conf_path=None, +): """ Configure logging from file. @@ -232,16 +233,18 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log absolute path relative to st2.conf. :type st2_conf_path: ``str`` """ - if st2_conf_path and config_file[:2] == './' and not os.path.isfile(config_file): + if st2_conf_path and config_file[:2] == "./" and not os.path.isfile(config_file): # Logging config path is relative to st2.conf, resolve it to full absolute path directory = os.path.dirname(st2_conf_path) config_file_name = os.path.basename(config_file) config_file = os.path.join(directory, config_file_name) try: - logging.config.fileConfig(config_file, - defaults=None, - disable_existing_loggers=disable_existing_loggers) + logging.config.fileConfig( + config_file, + defaults=None, + disable_existing_loggers=disable_existing_loggers, + ) handlers = logging.getLoggerClass().manager.root.handlers _add_exclusion_filters(handlers=handlers, excludes=excludes) if redirect_stderr: @@ -251,13 +254,13 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log tb_msg = traceback.format_exc() msg = str(exc) - msg += '\n\n' + tb_msg + msg += "\n\n" + tb_msg # revert stderr redirection since there is no logger in place. sys.stderr = sys.__stderr__ # No logger yet therefore write to stderr - sys.stderr.write('ERROR: %s' % (msg)) + sys.stderr.write("ERROR: %s" % (msg)) raise exc_cls(six.text_type(msg)) @@ -271,10 +274,10 @@ def ignore_lib2to3_log_messages(): class MockLoggingModule(object): def getLogger(self, *args, **kwargs): - return logging.getLogger('lib2to3') + return logging.getLogger("lib2to3") lib2to3.pgen2.driver.logging = MockLoggingModule() - logging.getLogger('lib2to3').setLevel(logging.ERROR) + logging.getLogger("lib2to3").setLevel(logging.ERROR) def ignore_statsd_log_messages(): @@ -288,8 +291,8 @@ def ignore_statsd_log_messages(): class MockLoggingModule(object): def getLogger(self, *args, **kwargs): - return logging.getLogger('statsd') + return logging.getLogger("statsd") statsd.connection.logging = MockLoggingModule() statsd.client.logging = MockLoggingModule() - logging.getLogger('statsd').setLevel(logging.ERROR) + logging.getLogger("statsd").setLevel(logging.ERROR) diff --git a/st2common/st2common/logging/filters.py b/st2common/st2common/logging/filters.py index d997589a0e..1fef164028 100644 --- a/st2common/st2common/logging/filters.py +++ b/st2common/st2common/logging/filters.py @@ -17,9 +17,9 @@ import logging __all__ = [ - 'LoggerNameExclusionFilter', - 'LoggerFunctionNameExclusionFilter', - 'LogLevelFilter', + "LoggerNameExclusionFilter", + "LoggerFunctionNameExclusionFilter", + "LogLevelFilter", ] @@ -36,8 +36,11 @@ def filter(self, record): if len(self._exclusions) < 1: return True - module_decomposition = record.name.split('.') - exclude = len(module_decomposition) > 0 and module_decomposition[0] in self._exclusions + module_decomposition = record.name.split(".") + exclude = ( + len(module_decomposition) > 0 + and module_decomposition[0] in self._exclusions + ) return not exclude @@ -54,7 +57,7 @@ def filter(self, record): if len(self._exclusions) < 1: return True - function_name = getattr(record, 'funcName', None) + function_name = getattr(record, "funcName", None) if function_name in self._exclusions: return False diff --git a/st2common/st2common/logging/formatters.py b/st2common/st2common/logging/formatters.py index d20b240a5a..7c30e780a9 100644 --- a/st2common/st2common/logging/formatters.py +++ b/st2common/st2common/logging/formatters.py @@ -28,8 +28,8 @@ from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE __all__ = [ - 'ConsoleLogFormatter', - 'GelfLogFormatter', + "ConsoleLogFormatter", + "GelfLogFormatter", ] SIMPLE_TYPES = (int, float) + six.string_types @@ -37,16 +37,16 @@ # GELF logger specific constants HOSTNAME = socket.gethostname() -GELF_SPEC_VERSION = '1.1' +GELF_SPEC_VERSION = "1.1" COMMON_ATTRIBUTE_NAMES = [ - 'name', - 'process', - 'processName', - 'module', - 'filename', - 'funcName', - 'lineno' + "name", + "process", + "processName", + "module", + "filename", + "funcName", + "lineno", ] @@ -60,9 +60,9 @@ def serialize_object(obj): :rtype: ``str`` """ # Try to serialize the object - if getattr(obj, 'to_dict', None): + if getattr(obj, "to_dict", None): value = obj.to_dict() - elif getattr(obj, 'to_serializable_dict', None): + elif getattr(obj, "to_serializable_dict", None): value = obj.to_serializable_dict(mask_secrets=True) else: value = repr(obj) @@ -77,7 +77,9 @@ def process_attribute_value(key, value): if not cfg.CONF.log.mask_secrets: return value - blacklisted_attribute_names = MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist + blacklisted_attribute_names = ( + MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist + ) # NOTE: This can be expensive when processing large dicts or objects if isinstance(value, SIMPLE_TYPES): @@ -121,11 +123,16 @@ class BaseExtraLogFormatter(logging.Formatter): dictionary need to be prefixed with a slash ('_'). """ - PREFIX = '_' # Prefix for user provided attributes in the extra dict + PREFIX = "_" # Prefix for user provided attributes in the extra dict def _get_extra_attributes(self, record): - attributes = dict([(k, v) for k, v in six.iteritems(record.__dict__) - if k.startswith(self.PREFIX)]) + attributes = dict( + [ + (k, v) + for k, v in six.iteritems(record.__dict__) + if k.startswith(self.PREFIX) + ] + ) return attributes def _get_common_extra_attributes(self, record): @@ -182,17 +189,17 @@ def format(self, record): msg = super(ConsoleLogFormatter, self).format(record) if attributes: - msg = '%s (%s)' % (msg, attributes) + msg = "%s (%s)" % (msg, attributes) return msg def _dict_to_str(self, attributes): result = [] for key, value in six.iteritems(attributes): - item = '%s=%s' % (key[1:], repr(value)) + item = "%s=%s" % (key[1:], repr(value)) result.append(item) - result = ','.join(result) + result = ",".join(result) return result @@ -245,30 +252,32 @@ def format(self, record): exc_info = record.exc_info time_now_float = record.created time_now_sec = int(time_now_float) - level = self.PYTHON_TO_GELF_LEVEL_MAP.get(record.levelno, self.DEFAULT_LOG_LEVEL) + level = self.PYTHON_TO_GELF_LEVEL_MAP.get( + record.levelno, self.DEFAULT_LOG_LEVEL + ) common_attributes = self._get_common_extra_attributes(record=record) full_msg = super(GelfLogFormatter, self).format(record) data = { - 'version': GELF_SPEC_VERSION, - 'host': HOSTNAME, - 'short_message': msg, - 'full_message': full_msg, - 'timestamp': time_now_sec, - 'timestamp_f': time_now_float, - 'level': level + "version": GELF_SPEC_VERSION, + "host": HOSTNAME, + "short_message": msg, + "full_message": full_msg, + "timestamp": time_now_sec, + "timestamp_f": time_now_float, + "level": level, } if exc_info: # Include exception information exc_type, exc_value, exc_tb = exc_info - tb_str = ''.join(traceback.format_tb(exc_tb)) - data['_exception'] = six.text_type(exc_value) - data['_traceback'] = tb_str + tb_str = "".join(traceback.format_tb(exc_tb)) + data["_exception"] = six.text_type(exc_value) + data["_traceback"] = tb_str # Include common Python log record attributes - data['_python'] = common_attributes + data["_python"] = common_attributes # Include user extra attributes data.update(attributes) diff --git a/st2common/st2common/logging/handlers.py b/st2common/st2common/logging/handlers.py index ade4dfbb04..963ac197a0 100644 --- a/st2common/st2common/logging/handlers.py +++ b/st2common/st2common/logging/handlers.py @@ -24,26 +24,29 @@ from st2common.util import date as date_utils __all__ = [ - 'FormatNamedFileHandler', - 'ConfigurableSyslogHandler', + "FormatNamedFileHandler", + "ConfigurableSyslogHandler", ] class FormatNamedFileHandler(logging.handlers.RotatingFileHandler): - def __init__(self, filename, mode='a', maxBytes=0, backupCount=0, encoding=None, delay=False): + def __init__( + self, filename, mode="a", maxBytes=0, backupCount=0, encoding=None, delay=False + ): # We add aditional values to the context which can be used in the log filename timestamp = int(time.time()) - isotime_str = str(date_utils.get_datetime_utc_now()).replace(' ', '_') + isotime_str = str(date_utils.get_datetime_utc_now()).replace(" ", "_") pid = os.getpid() - format_values = { - 'timestamp': timestamp, - 'ts': isotime_str, - 'pid': pid - } + format_values = {"timestamp": timestamp, "ts": isotime_str, "pid": pid} filename = filename.format(**format_values) - super(FormatNamedFileHandler, self).__init__(filename, mode=mode, maxBytes=maxBytes, - backupCount=backupCount, encoding=encoding, - delay=delay) + super(FormatNamedFileHandler, self).__init__( + filename, + mode=mode, + maxBytes=maxBytes, + backupCount=backupCount, + encoding=encoding, + delay=delay, + ) class ConfigurableSyslogHandler(logging.handlers.SysLogHandler): @@ -55,12 +58,12 @@ def __init__(self, address=None, facility=None, socktype=None): if not socktype: protocol = cfg.CONF.syslog.protocol.lower() - if protocol == 'udp': + if protocol == "udp": socktype = socket.SOCK_DGRAM - elif protocol == 'tcp': + elif protocol == "tcp": socktype = socket.SOCK_STREAM else: - raise ValueError('Unsupported protocol: %s' % (protocol)) + raise ValueError("Unsupported protocol: %s" % (protocol)) if socktype: super(ConfigurableSyslogHandler, self).__init__(address, facility, socktype) diff --git a/st2common/st2common/logging/misc.py b/st2common/st2common/logging/misc.py index 36f8b17986..de7f673431 100644 --- a/st2common/st2common/logging/misc.py +++ b/st2common/st2common/logging/misc.py @@ -23,32 +23,26 @@ from st2common.logging.filters import LoggerFunctionNameExclusionFilter __all__ = [ - 'reopen_log_files', - - 'set_log_level_for_all_handlers', - 'set_log_level_for_all_loggers', - - 'add_global_filters_for_all_loggers' + "reopen_log_files", + "set_log_level_for_all_handlers", + "set_log_level_for_all_loggers", + "add_global_filters_for_all_loggers", ] LOG = logging.getLogger(__name__) # Because some loggers are just waste of attention span -SPECIAL_LOGGERS = { - 'swagger_spec_validator.ref_validators': logging.INFO -} +SPECIAL_LOGGERS = {"swagger_spec_validator.ref_validators": logging.INFO} # Log messages for function names which are very spammy and we want to filter out when DEBUG log # level is enabled IGNORED_FUNCTION_NAMES = [ # Used by pyamqp, logs every heartbit tick every 2 ms by default - 'heartbeat_tick' + "heartbeat_tick" ] # List of global filters which apply to all the loggers -GLOBAL_FILTERS = [ - LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES) -] +GLOBAL_FILTERS = [LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES)] def reopen_log_files(handlers): @@ -65,8 +59,10 @@ def reopen_log_files(handlers): if not isinstance(handler, logging.FileHandler): continue - LOG.info('Re-opening log file "%s" with mode "%s"\n' % - (handler.baseFilename, handler.mode)) + LOG.info( + 'Re-opening log file "%s" with mode "%s"\n' + % (handler.baseFilename, handler.mode) + ) try: handler.acquire() @@ -76,10 +72,10 @@ def reopen_log_files(handlers): try: handler.release() except RuntimeError as e: - if 'cannot release' in six.text_type(e): + if "cannot release" in six.text_type(e): # Release failed which most likely indicates that acquire failed # and lock was never acquired - LOG.warn('Failed to release lock', exc_info=True) + LOG.warn("Failed to release lock", exc_info=True) else: raise e @@ -112,7 +108,9 @@ def set_log_level_for_all_loggers(level=logging.DEBUG): logger = add_filters_for_logger(logger=logger, filters=GLOBAL_FILTERS) if logger.name in SPECIAL_LOGGERS: - set_log_level_for_all_handlers(logger=logger, level=SPECIAL_LOGGERS.get(logger.name)) + set_log_level_for_all_handlers( + logger=logger, level=SPECIAL_LOGGERS.get(logger.name) + ) else: set_log_level_for_all_handlers(logger=logger, level=level) @@ -152,7 +150,7 @@ def add_filters_for_logger(logger, filters): if not isinstance(logger, logging.Logger): return logger - if not hasattr(logger, 'addFilter'): + if not hasattr(logger, "addFilter"): return logger for logger_filter in filters: @@ -170,7 +168,7 @@ def get_logger_name_for_module(module, exclude_module_name=False): module_file = module.__file__ base_dir = os.path.dirname(os.path.abspath(module_file)) module_name = os.path.basename(module_file) - module_name = module_name.replace('.pyc', '').replace('.py', '') + module_name = module_name.replace(".pyc", "").replace(".py", "") split = base_dir.split(os.path.sep) split = [component for component in split if component] @@ -178,15 +176,15 @@ def get_logger_name_for_module(module, exclude_module_name=False): # Find first component which starts with st2 and use that as a starting point start_index = 0 for index, component in enumerate(reversed(split)): - if component.startswith('st2'): - start_index = ((len(split) - 1) - index) + if component.startswith("st2"): + start_index = (len(split) - 1) - index break split = split[start_index:] if exclude_module_name: - name = '.'.join(split) + name = ".".join(split) else: - name = '.'.join(split) + '.' + module_name + name = ".".join(split) + "." + module_name return name diff --git a/st2common/st2common/metrics/base.py b/st2common/st2common/metrics/base.py index 18801c901d..215780b86f 100644 --- a/st2common/st2common/metrics/base.py +++ b/st2common/st2common/metrics/base.py @@ -28,23 +28,22 @@ from st2common.exceptions.plugins import PluginLoadError __all__ = [ - 'BaseMetricsDriver', - - 'Timer', - 'Counter', - 'CounterWithTimer', - - 'metrics_initialize', - 'get_driver' + "BaseMetricsDriver", + "Timer", + "Counter", + "CounterWithTimer", + "metrics_initialize", + "get_driver", ] -if not hasattr(cfg.CONF, 'metrics'): +if not hasattr(cfg.CONF, "metrics"): from st2common.config import register_opts + register_opts() LOG = logging.getLogger(__name__) -PLUGIN_NAMESPACE = 'st2common.metrics.driver' +PLUGIN_NAMESPACE = "st2common.metrics.driver" # Stores reference to the metrics driver class instance. # NOTE: This value is populated lazily on the first get_driver() function call @@ -97,6 +96,7 @@ class Timer(object): """ Timer context manager for easily sending timer statistics. """ + def __init__(self, key, include_parameter=False): check_key(key) @@ -136,8 +136,9 @@ def __call__(self, func): def wrapper(*args, **kw): with self as metrics_timer: if self._include_parameter: - kw['metrics_timer'] = metrics_timer + kw["metrics_timer"] = metrics_timer return func(*args, **kw) + return wrapper @@ -145,6 +146,7 @@ class Counter(object): """ Counter context manager for easily sending counter statistics. """ + def __init__(self, key): check_key(key) self.key = key @@ -162,6 +164,7 @@ def __call__(self, func): def wrapper(*args, **kw): with self: return func(*args, **kw) + return wrapper @@ -209,8 +212,9 @@ def __call__(self, func): def wrapper(*args, **kw): with self as counter_with_timer: if self._include_parameter: - kw['metrics_counter_with_timer'] = counter_with_timer + kw["metrics_counter_with_timer"] = counter_with_timer return func(*args, **kw) + return wrapper @@ -223,7 +227,9 @@ def metrics_initialize(): try: METRICS = get_plugin_instance(PLUGIN_NAMESPACE, cfg.CONF.metrics.driver) except (NoMatches, MultipleMatches, NoSuchOptError) as error: - raise PluginLoadError('Error loading metrics driver. Check configuration: %s' % error) + raise PluginLoadError( + "Error loading metrics driver. Check configuration: %s" % error + ) return METRICS diff --git a/st2common/st2common/metrics/drivers/echo_driver.py b/st2common/st2common/metrics/drivers/echo_driver.py index 40b2ed3947..7cb115aab6 100644 --- a/st2common/st2common/metrics/drivers/echo_driver.py +++ b/st2common/st2common/metrics/drivers/echo_driver.py @@ -16,9 +16,7 @@ from st2common import log as logging from st2common.metrics.base import BaseMetricsDriver -__all__ = [ - 'EchoDriver' -] +__all__ = ["EchoDriver"] LOG = logging.getLogger(__name__) @@ -29,19 +27,19 @@ class EchoDriver(BaseMetricsDriver): """ def time(self, key, time): - LOG.debug('[metrics] time(key=%s, time=%s)' % (key, time)) + LOG.debug("[metrics] time(key=%s, time=%s)" % (key, time)) def inc_counter(self, key, amount=1): - LOG.debug('[metrics] counter.incr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] counter.incr(%s, %s)" % (key, amount)) def dec_counter(self, key, amount=1): - LOG.debug('[metrics] counter.decr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] counter.decr(%s, %s)" % (key, amount)) def set_gauge(self, key, value): - LOG.debug('[metrics] set_gauge(%s, %s)' % (key, value)) + LOG.debug("[metrics] set_gauge(%s, %s)" % (key, value)) def inc_gauge(self, key, amount=1): - LOG.debug('[metrics] gauge.incr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] gauge.incr(%s, %s)" % (key, amount)) def dec_gauge(self, key, amount=1): - LOG.debug('[metrics] gauge.decr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] gauge.decr(%s, %s)" % (key, amount)) diff --git a/st2common/st2common/metrics/drivers/noop_driver.py b/st2common/st2common/metrics/drivers/noop_driver.py index 6f816f2a69..658ee10a40 100644 --- a/st2common/st2common/metrics/drivers/noop_driver.py +++ b/st2common/st2common/metrics/drivers/noop_driver.py @@ -15,9 +15,7 @@ from st2common.metrics.base import BaseMetricsDriver -__all__ = [ - 'NoopDriver' -] +__all__ = ["NoopDriver"] class NoopDriver(BaseMetricsDriver): diff --git a/st2common/st2common/metrics/drivers/statsd_driver.py b/st2common/st2common/metrics/drivers/statsd_driver.py index c334837e9b..efbefde601 100644 --- a/st2common/st2common/metrics/drivers/statsd_driver.py +++ b/st2common/st2common/metrics/drivers/statsd_driver.py @@ -30,15 +30,9 @@ LOG = logging.getLogger(__name__) # Which exceptions thrown by statsd library should be considered as non-fatal -NON_FATAL_EXC_CLASSES = [ - socket.error, - IOError, - OSError -] +NON_FATAL_EXC_CLASSES = [socket.error, IOError, OSError] -__all__ = [ - 'StatsdDriver' -] +__all__ = ["StatsdDriver"] class StatsdDriver(BaseMetricsDriver): @@ -55,11 +49,15 @@ class StatsdDriver(BaseMetricsDriver): """ def __init__(self): - statsd.Connection.set_defaults(host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port, - sample_rate=cfg.CONF.metrics.sample_rate) - - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + statsd.Connection.set_defaults( + host=cfg.CONF.metrics.host, + port=cfg.CONF.metrics.port, + sample_rate=cfg.CONF.metrics.sample_rate, + ) + + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def time(self, key, time): """ Timer metric @@ -68,11 +66,12 @@ def time(self, key, time): assert isinstance(time, Number) key = get_full_key_name(key) - timer = statsd.Timer('') + timer = statsd.Timer("") timer.send(key, time) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def inc_counter(self, key, amount=1): """ Increment counter @@ -84,8 +83,9 @@ def inc_counter(self, key, amount=1): counter = statsd.Counter(key) counter.increment(delta=amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def dec_counter(self, key, amount=1): """ Decrement metric @@ -97,8 +97,9 @@ def dec_counter(self, key, amount=1): counter = statsd.Counter(key) counter.decrement(delta=amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def set_gauge(self, key, value): """ Set gauge value. @@ -110,8 +111,9 @@ def set_gauge(self, key, value): gauge = statsd.Gauge(key) gauge.send(None, value) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def inc_gauge(self, key, amount=1): """ Increment gauge value. @@ -123,8 +125,9 @@ def inc_gauge(self, key, amount=1): gauge = statsd.Gauge(key) gauge.increment(None, amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def dec_gauge(self, key, amount=1): """ Decrement gauge value. diff --git a/st2common/st2common/metrics/utils.py b/st2common/st2common/metrics/utils.py index f741743cd2..710aceff15 100644 --- a/st2common/st2common/metrics/utils.py +++ b/st2common/st2common/metrics/utils.py @@ -16,10 +16,7 @@ import six from oslo_config import cfg -__all__ = [ - 'get_full_key_name', - 'check_key' -] +__all__ = ["get_full_key_name", "check_key"] def get_full_key_name(key): @@ -27,14 +24,14 @@ def get_full_key_name(key): Return full metric key name, taking into account optional prefix which can be specified in the config. """ - parts = ['st2'] + parts = ["st2"] if cfg.CONF.metrics.prefix: parts.append(cfg.CONF.metrics.prefix) parts.append(key) - return '.'.join(parts) + return ".".join(parts) def check_key(key): diff --git a/st2common/st2common/middleware/cors.py b/st2common/st2common/middleware/cors.py index 1388e65e63..eaeac86f07 100644 --- a/st2common/st2common/middleware/cors.py +++ b/st2common/st2common/middleware/cors.py @@ -42,18 +42,18 @@ def __call__(self, environ, start_response): def custom_start_response(status, headers, exc_info=None): headers = ResponseHeaders(headers) - origin = request.headers.get('Origin') + origin = request.headers.get("Origin") origins = OrderedSet(cfg.CONF.api.allow_origin) # Build a list of the default allowed origins public_api_url = cfg.CONF.auth.api_url # Default gulp development server WebUI URL - origins.add('http://127.0.0.1:3000') + origins.add("http://127.0.0.1:3000") # By default WebUI simple http server listens on 8080 - origins.add('http://localhost:8080') - origins.add('http://127.0.0.1:8080') + origins.add("http://localhost:8080") + origins.add("http://127.0.0.1:8080") if public_api_url: # Public API URL @@ -62,7 +62,7 @@ def custom_start_response(status, headers, exc_info=None): origins = list(origins) if origin: - if '*' in origins: + if "*" in origins: origin_allowed = origin else: # See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header @@ -70,21 +70,32 @@ def custom_start_response(status, headers, exc_info=None): else: origin_allowed = list(origins)[0] - methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS'] - request_headers_allowed = ['Content-Type', 'Authorization', HEADER_ATTRIBUTE_NAME, - HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER] - response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count', - REQUEST_ID_HEADER] - - headers['Access-Control-Allow-Origin'] = origin_allowed - headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed) - headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed) - headers['Access-Control-Allow-Credentials'] = 'true' - headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed) + methods_allowed = ["GET", "POST", "PUT", "DELETE", "OPTIONS"] + request_headers_allowed = [ + "Content-Type", + "Authorization", + HEADER_ATTRIBUTE_NAME, + HEADER_API_KEY_ATTRIBUTE_NAME, + REQUEST_ID_HEADER, + ] + response_headers_allowed = [ + "Content-Type", + "X-Limit", + "X-Total-Count", + REQUEST_ID_HEADER, + ] + + headers["Access-Control-Allow-Origin"] = origin_allowed + headers["Access-Control-Allow-Methods"] = ",".join(methods_allowed) + headers["Access-Control-Allow-Headers"] = ",".join(request_headers_allowed) + headers["Access-Control-Allow-Credentials"] = "true" + headers["Access-Control-Expose-Headers"] = ",".join( + response_headers_allowed + ) return start_response(status, headers._items, exc_info) - if request.method == 'OPTIONS': + if request.method == "OPTIONS": return Response()(environ, custom_start_response) else: return self.app(environ, custom_start_response) diff --git a/st2common/st2common/middleware/error_handling.py b/st2common/st2common/middleware/error_handling.py index 478cf3691b..d7ae59cde5 100644 --- a/st2common/st2common/middleware/error_handling.py +++ b/st2common/st2common/middleware/error_handling.py @@ -50,13 +50,13 @@ def __call__(self, environ, start_response): except NotFoundException: raise exc.HTTPNotFound() except Exception as e: - status = getattr(e, 'code', exc.HTTPInternalServerError.code) + status = getattr(e, "code", exc.HTTPInternalServerError.code) - if hasattr(e, 'detail') and not getattr(e, 'comment'): - setattr(e, 'comment', getattr(e, 'detail')) + if hasattr(e, "detail") and not getattr(e, "comment"): + setattr(e, "comment", getattr(e, "detail")) - if hasattr(e, 'body') and isinstance(getattr(e, 'body', None), dict): - body = getattr(e, 'body', None) + if hasattr(e, "body") and isinstance(getattr(e, "body", None), dict): + body = getattr(e, "body", None) else: body = {} @@ -69,40 +69,40 @@ def __call__(self, environ, start_response): elif isinstance(e, db_exceptions.StackStormDBObjectConflictError): status_code = exc.HTTPConflict.code message = six.text_type(e) - body['conflict-id'] = getattr(e, 'conflict_id', None) + body["conflict-id"] = getattr(e, "conflict_id", None) elif isinstance(e, rbac_exceptions.AccessDeniedError): status_code = exc.HTTPForbidden.code message = six.text_type(e) elif isinstance(e, (ValueValidationException, ValueError, ValidationError)): status_code = exc.HTTPBadRequest.code - message = getattr(e, 'message', six.text_type(e)) + message = getattr(e, "message", six.text_type(e)) else: status_code = exc.HTTPInternalServerError.code - message = 'Internal Server Error' + message = "Internal Server Error" # Log the error is_internal_server_error = status_code == exc.HTTPInternalServerError.code - error_msg = getattr(e, 'comment', six.text_type(e)) + error_msg = getattr(e, "comment", six.text_type(e)) extra = { - 'exception_class': e.__class__.__name__, - 'exception_message': six.text_type(e), - 'exception_data': e.__dict__ + "exception_class": e.__class__.__name__, + "exception_message": six.text_type(e), + "exception_data": e.__dict__, } if is_internal_server_error: - LOG.exception('API call failed: %s', error_msg, extra=extra) + LOG.exception("API call failed: %s", error_msg, extra=extra) else: - LOG.debug('API call failed: %s', error_msg, extra=extra) + LOG.debug("API call failed: %s", error_msg, extra=extra) if is_debugging_enabled(): LOG.debug(traceback.format_exc()) - body['faultstring'] = message + body["faultstring"] = message response_body = json_encode(body) headers = { - 'Content-Type': 'application/json', - 'Content-Length': str(len(response_body)) + "Content-Type": "application/json", + "Content-Length": str(len(response_body)), } resp = Response(response_body, status=status_code, headers=headers) diff --git a/st2common/st2common/middleware/instrumentation.py b/st2common/st2common/middleware/instrumentation.py index 8ff7445f75..e5d01d2223 100644 --- a/st2common/st2common/middleware/instrumentation.py +++ b/st2common/st2common/middleware/instrumentation.py @@ -21,10 +21,7 @@ from st2common.util.date import get_datetime_utc_now from st2common.router import NotFoundException -__all__ = [ - 'RequestInstrumentationMiddleware', - 'ResponseInstrumentationMiddleware' -] +__all__ = ["RequestInstrumentationMiddleware", "ResponseInstrumentationMiddleware"] LOG = logging.getLogger(__name__) @@ -54,10 +51,11 @@ def __call__(self, environ, start_response): # NOTE: We don't track per request and response metrics for /v1/executions/ and some # other endpoints because this would result in a lot of unique metrics which is an # anti-pattern and causes unnecessary load on the metrics server. - submit_metrics = endpoint.get('x-submit-metrics', True) - operation_id = endpoint.get('operationId', None) - is_get_one_endpoint = bool(operation_id) and (operation_id.endswith('.get') or - operation_id.endswith('.get_one')) + submit_metrics = endpoint.get("x-submit-metrics", True) + operation_id = endpoint.get("operationId", None) + is_get_one_endpoint = bool(operation_id) and ( + operation_id.endswith(".get") or operation_id.endswith(".get_one") + ) if is_get_one_endpoint: # NOTE: We don't submit metrics for any get one API endpoint since this would result @@ -65,22 +63,22 @@ def __call__(self, environ, start_response): submit_metrics = False if not submit_metrics: - LOG.debug('Not submitting request metrics for path: %s' % (request.path)) + LOG.debug("Not submitting request metrics for path: %s" % (request.path)) return self.app(environ, start_response) metrics_driver = get_driver() - key = '%s.request.total' % (self._service_name) + key = "%s.request.total" % (self._service_name) metrics_driver.inc_counter(key) - key = '%s.request.method.%s' % (self._service_name, request.method) + key = "%s.request.method.%s" % (self._service_name, request.method) metrics_driver.inc_counter(key) - path = request.path.replace('/', '_') - key = '%s.request.path.%s' % (self._service_name, path) + path = request.path.replace("/", "_") + key = "%s.request.path.%s" % (self._service_name, path) metrics_driver.inc_counter(key) - if self._service_name == 'stream': + if self._service_name == "stream": # For stream service, we also record current number of open connections. # Due to the way stream service works, we need to utilize eventlet posthook to # correctly set the counter when the connection is closed / full response is returned. @@ -88,34 +86,34 @@ def __call__(self, environ, start_response): # hooks for details # Increase request counter - key = '%s.request' % (self._service_name) + key = "%s.request" % (self._service_name) metrics_driver.inc_counter(key) # Increase "total number of connections" gauge - metrics_driver.inc_gauge('stream.connections', 1) + metrics_driver.inc_gauge("stream.connections", 1) start_time = get_datetime_utc_now() def update_metrics_hook(env): # Hook which is called at the very end after all the response has been sent and # connection closed - time_delta = (get_datetime_utc_now() - start_time) + time_delta = get_datetime_utc_now() - start_time duration = time_delta.total_seconds() # Send total request time metrics_driver.time(key, duration) # Decrease "current number of connections" gauge - metrics_driver.dec_gauge('stream.connections', 1) + metrics_driver.dec_gauge("stream.connections", 1) # NOTE: Some tests mock environ and there 'eventlet.posthooks' key is not available - if 'eventlet.posthooks' in environ: - environ['eventlet.posthooks'].append((update_metrics_hook, (), {})) + if "eventlet.posthooks" in environ: + environ["eventlet.posthooks"].append((update_metrics_hook, (), {})) return self.app(environ, start_response) else: # Track and time current number of processing requests - key = '%s.request' % (self._service_name) + key = "%s.request" % (self._service_name) with CounterWithTimer(key=key): return self.app(environ, start_response) @@ -138,11 +136,12 @@ def __init__(self, app, router, service_name): def __call__(self, environ, start_response): # Track and time current number of processing requests def custom_start_response(status, headers, exc_info=None): - status_code = int(status.split(' ')[0]) + status_code = int(status.split(" ")[0]) metrics_driver = get_driver() - metrics_driver.inc_counter('%s.response.status.%s' % (self._service_name, - status_code)) + metrics_driver.inc_counter( + "%s.response.status.%s" % (self._service_name, status_code) + ) return start_response(status, headers, exc_info) diff --git a/st2common/st2common/middleware/logging.py b/st2common/st2common/middleware/logging.py index d41622ff29..a044e2c59b 100644 --- a/st2common/st2common/middleware/logging.py +++ b/st2common/st2common/middleware/logging.py @@ -33,7 +33,7 @@ SECRET_QUERY_PARAMS = [ QUERY_PARAM_ATTRIBUTE_NAME, - QUERY_PARAM_API_KEY_ATTRIBUTE_NAME + QUERY_PARAM_API_KEY_ATTRIBUTE_NAME, ] + MASKED_ATTRIBUTES_BLACKLIST try: @@ -68,21 +68,23 @@ def __call__(self, environ, start_response): # Log the incoming request values = { - 'method': request.method, - 'path': request.path, - 'remote_addr': request.remote_addr, - 'query': query_params, - 'request_id': request.headers.get(REQUEST_ID_HEADER, None) + "method": request.method, + "path": request.path, + "remote_addr": request.remote_addr, + "query": query_params, + "request_id": request.headers.get(REQUEST_ID_HEADER, None), } - LOG.info('%(request_id)s - %(method)s %(path)s with query=%(query)s' % - values, extra=values) + LOG.info( + "%(request_id)s - %(method)s %(path)s with query=%(query)s" % values, + extra=values, + ) def custom_start_response(status, headers, exc_info=None): - status_code.append(int(status.split(' ')[0])) + status_code.append(int(status.split(" ")[0])) for name, value in headers: - if name.lower() == 'content-length': + if name.lower() == "content-length": content_length.append(int(value)) break @@ -95,7 +97,7 @@ def custom_start_response(status, headers, exc_info=None): except NotFoundException: endpoint = {} - log_result = endpoint.get('x-log-result', True) + log_result = endpoint.get("x-log-result", True) if isinstance(retval, (types.GeneratorType, itertools.chain)): # Note: We don't log the result when return value is a generator, because this would @@ -105,22 +107,28 @@ def custom_start_response(status, headers, exc_info=None): # Log the response values = { - 'method': request.method, - 'path': request.path, - 'remote_addr': request.remote_addr, - 'status': status_code[0], - 'runtime': float("{0:.3f}".format((clock() - start_time) * 10**3)), - 'content_length': content_length[0] if content_length else len(b''.join(retval)), - 'request_id': request.headers.get(REQUEST_ID_HEADER, None) + "method": request.method, + "path": request.path, + "remote_addr": request.remote_addr, + "status": status_code[0], + "runtime": float("{0:.3f}".format((clock() - start_time) * 10 ** 3)), + "content_length": content_length[0] + if content_length + else len(b"".join(retval)), + "request_id": request.headers.get(REQUEST_ID_HEADER, None), } - log_msg = '%(request_id)s - %(status)s %(content_length)s %(runtime)sms' % (values) + log_msg = "%(request_id)s - %(status)s %(content_length)s %(runtime)sms" % ( + values + ) LOG.info(log_msg, extra=values) if log_result: - values['result'] = retval[0] - log_msg = ('%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s' % - (values)) + values["result"] = retval[0] + log_msg = ( + "%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s" + % (values) + ) LOG.debug(log_msg, extra=values) return retval diff --git a/st2common/st2common/middleware/streaming.py b/st2common/st2common/middleware/streaming.py index 8f48dedbcf..eb09084b30 100644 --- a/st2common/st2common/middleware/streaming.py +++ b/st2common/st2common/middleware/streaming.py @@ -16,9 +16,7 @@ from __future__ import absolute_import import fnmatch -__all__ = [ - 'StreamingMiddleware' -] +__all__ = ["StreamingMiddleware"] class StreamingMiddleware(object): @@ -32,7 +30,7 @@ def __call__(self, environ, start_response): # middleware is not important since it acts as pass-through. matches = False - req_path = environ.get('PATH_INFO', None) + req_path = environ.get("PATH_INFO", None) if not self._path_whitelist: matches = True @@ -43,6 +41,6 @@ def __call__(self, environ, start_response): break if matches: - environ['eventlet.minimum_write_chunk_size'] = 0 + environ["eventlet.minimum_write_chunk_size"] = 0 return self.app(environ, start_response) diff --git a/st2common/st2common/models/api/action.py b/st2common/st2common/models/api/action.py index 70eaeddad9..1924f54460 100644 --- a/st2common/st2common/models/api/action.py +++ b/st2common/st2common/models/api/action.py @@ -23,7 +23,10 @@ from st2common.models.api.base import BaseAPI from st2common.models.api.base import APIUIDMixin from st2common.models.api.tag import TagsHelper -from st2common.models.api.notification import (NotificationSubSchemaAPI, NotificationsHelper) +from st2common.models.api.notification import ( + NotificationSubSchemaAPI, + NotificationsHelper, +) from st2common.models.db.action import ActionDB from st2common.models.db.actionalias import ActionAliasDB from st2common.models.db.executionstate import ActionExecutionStateDB @@ -34,17 +37,16 @@ __all__ = [ - 'ActionAPI', - 'ActionCreateAPI', - 'LiveActionAPI', - 'LiveActionCreateAPI', - 'RunnerTypeAPI', - - 'AliasExecutionAPI', - 'AliasMatchAndExecuteInputAPI', - 'ActionAliasAPI', - 'ActionAliasMatchAPI', - 'ActionAliasHelpAPI' + "ActionAPI", + "ActionCreateAPI", + "LiveActionAPI", + "LiveActionCreateAPI", + "RunnerTypeAPI", + "AliasExecutionAPI", + "AliasMatchAndExecuteInputAPI", + "ActionAliasAPI", + "ActionAliasMatchAPI", + "ActionAliasHelpAPI", ] @@ -56,6 +58,7 @@ class RunnerTypeAPI(BaseAPI): The representation of an RunnerType in the system. An RunnerType has a one-to-one mapping to a particular ActionRunner implementation. """ + model = RunnerTypeDB schema = { "title": "Runner", @@ -65,42 +68,40 @@ class RunnerTypeAPI(BaseAPI): "id": { "description": "The unique identifier for the action runner.", "type": "string", - "default": None - }, - "uid": { - "type": "string" + "default": None, }, + "uid": {"type": "string"}, "name": { "description": "The name of the action runner.", "type": "string", - "required": True + "required": True, }, "description": { "description": "The description of the action runner.", - "type": "string" + "type": "string", }, "enabled": { "description": "Enable or disable the action runner.", "type": "boolean", - "default": True + "default": True, }, "runner_package": { "description": "The python package that implements the " - "action runner for this type.", + "action runner for this type.", "type": "string", - "required": False + "required": False, }, "runner_module": { "description": "The python module that implements the " - "action runner for this type.", + "action runner for this type.", "type": "string", - "required": True + "required": True, }, "query_module": { "description": "The python module that implements the " - "results tracker (querier) for the runner.", + "results tracker (querier) for the runner.", "type": "string", - "required": False + "required": False, }, "runner_parameters": { "description": "Input parameters for the action runner.", @@ -108,24 +109,22 @@ class RunnerTypeAPI(BaseAPI): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False + "additionalProperties": False, }, "output_key": { "description": "Default key to expect results to be published to.", "type": "string", - "required": False + "required": False, }, "output_schema": { "description": "Schema for the runner's output.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_output_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()}, + "additionalProperties": False, + "default": {}, }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): @@ -138,25 +137,34 @@ def __init__(self, **kw): # modified one for key, value in kw.items(): setattr(self, key, value) - if not hasattr(self, 'runner_parameters'): - setattr(self, 'runner_parameters', dict()) + if not hasattr(self, "runner_parameters"): + setattr(self, "runner_parameters", dict()) @classmethod def to_model(cls, runner_type): name = runner_type.name description = runner_type.description - enabled = getattr(runner_type, 'enabled', True) - runner_package = getattr(runner_type, 'runner_package', runner_type.runner_module) + enabled = getattr(runner_type, "enabled", True) + runner_package = getattr( + runner_type, "runner_package", runner_type.runner_module + ) runner_module = str(runner_type.runner_module) - runner_parameters = getattr(runner_type, 'runner_parameters', dict()) - output_key = getattr(runner_type, 'output_key', None) - output_schema = getattr(runner_type, 'output_schema', dict()) - query_module = getattr(runner_type, 'query_module', None) - - model = cls.model(name=name, description=description, enabled=enabled, - runner_package=runner_package, runner_module=runner_module, - runner_parameters=runner_parameters, output_schema=output_schema, - query_module=query_module, output_key=output_key) + runner_parameters = getattr(runner_type, "runner_parameters", dict()) + output_key = getattr(runner_type, "output_key", None) + output_schema = getattr(runner_type, "output_schema", dict()) + query_module = getattr(runner_type, "query_module", None) + + model = cls.model( + name=name, + description=description, + enabled=enabled, + runner_package=runner_package, + runner_module=runner_module, + runner_parameters=runner_parameters, + output_schema=output_schema, + query_module=query_module, + output_key=output_key, + ) return model @@ -174,44 +182,42 @@ class ActionAPI(BaseAPI, APIUIDMixin): "properties": { "id": { "description": "The unique identifier for the action.", - "type": "string" + "type": "string", }, "ref": { "description": "System computed user friendly reference for the action. \ Provided value will be overridden by computed value.", - "type": "string" - }, - "uid": { - "type": "string" + "type": "string", }, + "uid": {"type": "string"}, "name": { "description": "The name of the action.", "type": "string", - "required": True + "required": True, }, "description": { "description": "The description of the action.", - "type": "string" + "type": "string", }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True + "default": True, }, "runner_type": { "description": "The type of runner that executes the action.", "type": "string", - "required": True + "required": True, }, "entry_point": { "description": "The entry point for the action.", "type": "string", - "default": "" + "default": "", }, "pack": { "description": "The content pack this action belongs to.", "type": "string", - "default": DEFAULT_PACK_NAME + "default": DEFAULT_PACK_NAME, }, "parameters": { "description": "Input parameters for the action.", @@ -219,22 +225,20 @@ class ActionAPI(BaseAPI, APIUIDMixin): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False, - "default": {} + "additionalProperties": False, + "default": {}, }, "output_schema": { "description": "Schema for the action's output.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_output_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()}, + "additionalProperties": False, + "default": {}, }, "tags": { "description": "User associated metadata assigned to this object.", "type": "array", - "items": {"type": "object"} + "items": {"type": "object"}, }, "notify": { "description": "Notification settings for action.", @@ -242,52 +246,52 @@ class ActionAPI(BaseAPI, APIUIDMixin): "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False + "additionalProperties": False, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): for key, value in kw.items(): setattr(self, key, value) - if not hasattr(self, 'parameters'): - setattr(self, 'parameters', dict()) - if not hasattr(self, 'entry_point'): - setattr(self, 'entry_point', '') + if not hasattr(self, "parameters"): + setattr(self, "parameters", dict()) + if not hasattr(self, "entry_point"): + setattr(self, "entry_point", "") @classmethod def from_model(cls, model, mask_secrets=False): action = cls._from_model(model) - action['runner_type'] = action.get('runner_type', {}).get('name', None) - action['tags'] = TagsHelper.from_model(model.tags) + action["runner_type"] = action.get("runner_type", {}).get("name", None) + action["tags"] = TagsHelper.from_model(model.tags) - if getattr(model, 'notify', None): - action['notify'] = NotificationsHelper.from_model(model.notify) + if getattr(model, "notify", None): + action["notify"] = NotificationsHelper.from_model(model.notify) return cls(**action) @classmethod def to_model(cls, action): - name = getattr(action, 'name', None) - description = getattr(action, 'description', None) - enabled = bool(getattr(action, 'enabled', True)) + name = getattr(action, "name", None) + description = getattr(action, "description", None) + enabled = bool(getattr(action, "enabled", True)) entry_point = str(action.entry_point) pack = str(action.pack) - runner_type = {'name': str(action.runner_type)} - parameters = getattr(action, 'parameters', dict()) - output_schema = getattr(action, 'output_schema', dict()) - tags = TagsHelper.to_model(getattr(action, 'tags', [])) + runner_type = {"name": str(action.runner_type)} + parameters = getattr(action, "parameters", dict()) + output_schema = getattr(action, "output_schema", dict()) + tags = TagsHelper.to_model(getattr(action, "tags", [])) ref = ResourceReference.to_string_reference(pack=pack, name=name) - if getattr(action, 'notify', None): + if getattr(action, "notify", None): notify = NotificationsHelper.to_model(action.notify) else: # We use embedded document model for ``notify`` in action model. If notify is @@ -296,12 +300,22 @@ def to_model(cls, action): # to use an empty document. notify = NotificationsHelper.to_model({}) - metadata_file = getattr(action, 'metadata_file', None) - - model = cls.model(name=name, description=description, enabled=enabled, - entry_point=entry_point, pack=pack, runner_type=runner_type, - tags=tags, parameters=parameters, output_schema=output_schema, - notify=notify, ref=ref, metadata_file=metadata_file) + metadata_file = getattr(action, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + enabled=enabled, + entry_point=entry_point, + pack=pack, + runner_type=runner_type, + tags=tags, + parameters=parameters, + output_schema=output_schema, + notify=notify, + ref=ref, + metadata_file=metadata_file, + ) return model @@ -310,28 +324,31 @@ class ActionCreateAPI(ActionAPI, APIUIDMixin): """ API model for create action operation. """ + schema = copy.deepcopy(ActionAPI.schema) - schema['properties']['data_files'] = { - 'description': 'Optional action script and data files which are written to the filesystem.', - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'file_path': { - 'type': 'string', - 'description': ('Path to the file relative to the pack actions directory ' - '(e.g. my_action.py)'), - 'required': True + schema["properties"]["data_files"] = { + "description": "Optional action script and data files which are written to the filesystem.", + "type": "array", + "items": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Path to the file relative to the pack actions directory " + "(e.g. my_action.py)" + ), + "required": True, }, - 'content': { - 'type': 'string', - 'description': 'Raw file content.', - 'required': True + "content": { + "type": "string", + "description": "Raw file content.", + "required": True, }, }, - 'additionalProperties': False + "additionalProperties": False, }, - 'default': [] + "default": [], } @@ -339,8 +356,9 @@ class ActionUpdateAPI(ActionAPI, APIUIDMixin): """ API model for update action operation. """ + schema = copy.deepcopy(ActionCreateAPI.schema) - del schema['properties']['pack']['default'] + del schema["properties"]["pack"]["default"] class LiveActionAPI(BaseAPI): @@ -356,27 +374,27 @@ class LiveActionAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the action execution.", - "type": "string" + "type": "string", }, "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "start_timestamp": { "description": "The start time when the action is executed.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "end_timestamp": { "description": "The timestamp when the action has finished.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "action": { "description": "Reference to the action to be executed.", "type": "string", - "required": True + "required": True, }, "parameters": { "description": "Input parameters for the action.", @@ -390,58 +408,56 @@ class LiveActionAPI(BaseAPI): {"type": "number"}, {"type": "object"}, {"type": "string"}, - {"type": "null"} + {"type": "null"}, ] } }, - 'additionalProperties': False + "additionalProperties": False, }, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] - }, - "context": { - "type": "object" - }, - "callback": { - "type": "object" - }, - "runner_info": { - "type": "object" - }, + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] + }, + "context": {"type": "object"}, + "callback": {"type": "object"}, + "runner_info": {"type": "object"}, "notify": { "description": "Notification settings for liveaction.", "type": "object", "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False + "additionalProperties": False, }, "delay": { - "description": ("How long (in milliseconds) to delay the execution before" - "scheduling."), + "description": ( + "How long (in milliseconds) to delay the execution before" + "scheduling." + ), "type": "integer", - } + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) if model.start_timestamp: - doc['start_timestamp'] = isotime.format(model.start_timestamp, offset=False) + doc["start_timestamp"] = isotime.format(model.start_timestamp, offset=False) if model.end_timestamp: - doc['end_timestamp'] = isotime.format(model.end_timestamp, offset=False) + doc["end_timestamp"] = isotime.format(model.end_timestamp, offset=False) - if getattr(model, 'notify', None): - doc['notify'] = NotificationsHelper.from_model(model.notify) + if getattr(model, "notify", None): + doc["notify"] = NotificationsHelper.from_model(model.notify) return cls(**doc) @@ -449,32 +465,40 @@ def from_model(cls, model, mask_secrets=False): def to_model(cls, live_action): action = live_action.action - if getattr(live_action, 'start_timestamp', None): + if getattr(live_action, "start_timestamp", None): start_timestamp = isotime.parse(live_action.start_timestamp) else: start_timestamp = None - if getattr(live_action, 'end_timestamp', None): + if getattr(live_action, "end_timestamp", None): end_timestamp = isotime.parse(live_action.end_timestamp) else: end_timestamp = None - status = getattr(live_action, 'status', None) - parameters = getattr(live_action, 'parameters', dict()) - context = getattr(live_action, 'context', dict()) - callback = getattr(live_action, 'callback', dict()) - result = getattr(live_action, 'result', None) - delay = getattr(live_action, 'delay', None) + status = getattr(live_action, "status", None) + parameters = getattr(live_action, "parameters", dict()) + context = getattr(live_action, "context", dict()) + callback = getattr(live_action, "callback", dict()) + result = getattr(live_action, "result", None) + delay = getattr(live_action, "delay", None) - if getattr(live_action, 'notify', None): + if getattr(live_action, "notify", None): notify = NotificationsHelper.to_model(live_action.notify) else: notify = None - model = cls.model(action=action, - start_timestamp=start_timestamp, end_timestamp=end_timestamp, - status=status, parameters=parameters, context=context, - callback=callback, result=result, notify=notify, delay=delay) + model = cls.model( + action=action, + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + status=status, + parameters=parameters, + context=context, + callback=callback, + result=result, + notify=notify, + delay=delay, + ) return model @@ -483,11 +507,12 @@ class LiveActionCreateAPI(LiveActionAPI): """ API model for action execution create (run action) operations. """ + schema = copy.deepcopy(LiveActionAPI.schema) - schema['properties']['user'] = { - 'description': 'User context under which action should run (admins only)', - 'type': 'string', - 'default': None + schema["properties"]["user"] = { + "description": "User context under which action should run (admins only)", + "type": "string", + "default": None, } @@ -496,6 +521,7 @@ class ActionExecutionStateAPI(BaseAPI): System entity that represents state of an action in the system. This is used only in tests for now. """ + model = ActionExecutionStateDB schema = { "title": "ActionExecutionState", @@ -504,25 +530,25 @@ class ActionExecutionStateAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the action execution state.", - "type": "string" + "type": "string", }, "execution_id": { "type": "string", "description": "ID of the action execution.", - "required": True + "required": True, }, "query_context": { "type": "object", "description": "query context to be used by querier.", - "required": True + "required": True, }, "query_module": { "type": "string", "description": "Name of the query module.", - "required": True - } + "required": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -531,8 +557,11 @@ def to_model(cls, state): query_module = state.query_module query_context = state.query_context - model = cls.model(execution_id=execution_id, query_module=query_module, - query_context=query_context) + model = cls.model( + execution_id=execution_id, + query_module=query_module, + query_context=query_context, + ) return model @@ -540,6 +569,7 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): """ Alias for an action in the system. """ + model = ActionAliasDB schema = { "title": "ActionAlias", @@ -548,42 +578,40 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "properties": { "id": { "description": "The unique identifier for the action alias.", - "type": "string" + "type": "string", }, "ref": { "description": ( "System computed user friendly reference for the alias. " "Provided value will be overridden by computed value." ), - "type": "string" - }, - "uid": { - "type": "string" + "type": "string", }, + "uid": {"type": "string"}, "name": { "type": "string", "description": "Name of the action alias.", - "required": True + "required": True, }, "pack": { "description": "The content pack this actionalias belongs to.", "type": "string", - "required": True + "required": True, }, "description": { "type": "string", "description": "Description of the action alias.", - "default": None + "default": None, }, "enabled": { "description": "Flag indicating of action alias is enabled.", "type": "boolean", - "default": True + "default": True, }, "action_ref": { "type": "string", "description": "Reference to the aliased action.", - "required": True + "required": True, }, "formats": { "type": "array", @@ -596,13 +624,13 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "display": {"type": "string"}, "representation": { "type": "array", - "items": {"type": "string"} - } - } - } + "items": {"type": "string"}, + }, + }, + }, ] }, - "description": "Possible parameter format." + "description": "Possible parameter format.", }, "ack": { "type": "object", @@ -610,56 +638,65 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "enabled": {"type": "boolean"}, "format": {"type": "string"}, "extra": {"type": "object"}, - "append_url": {"type": "boolean"} + "append_url": {"type": "boolean"}, }, - "description": "Acknowledgement message format." + "description": "Acknowledgement message format.", }, "result": { "type": "object", "properties": { "enabled": {"type": "boolean"}, "format": {"type": "string"}, - "extra": {"type": "object"} + "extra": {"type": "object"}, }, - "description": "Execution message format." + "description": "Execution message format.", }, "extra": { "type": "object", - "description": "Extra parameters, usually adapter-specific." + "description": "Extra parameters, usually adapter-specific.", }, "immutable_parameters": { "type": "object", - "description": "Parameters to be passed to the action on every execution." + "description": "Parameters to be passed to the action on every execution.", }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def to_model(cls, alias): name = alias.name - description = getattr(alias, 'description', None) + description = getattr(alias, "description", None) pack = alias.pack ref = ResourceReference.to_string_reference(pack=pack, name=name) - enabled = getattr(alias, 'enabled', True) + enabled = getattr(alias, "enabled", True) action_ref = alias.action_ref formats = alias.formats - ack = getattr(alias, 'ack', None) - result = getattr(alias, 'result', None) - extra = getattr(alias, 'extra', None) - immutable_parameters = getattr(alias, 'immutable_parameters', None) - metadata_file = getattr(alias, 'metadata_file', None) - - model = cls.model(name=name, description=description, pack=pack, ref=ref, - enabled=enabled, action_ref=action_ref, formats=formats, - ack=ack, result=result, extra=extra, - immutable_parameters=immutable_parameters, - metadata_file=metadata_file) + ack = getattr(alias, "ack", None) + result = getattr(alias, "result", None) + extra = getattr(alias, "extra", None) + immutable_parameters = getattr(alias, "immutable_parameters", None) + metadata_file = getattr(alias, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + pack=pack, + ref=ref, + enabled=enabled, + action_ref=action_ref, + formats=formats, + ack=ack, + result=result, + extra=extra, + immutable_parameters=immutable_parameters, + metadata_file=metadata_file, + ) return model @@ -667,6 +704,7 @@ class AliasExecutionAPI(BaseAPI): """ Alias for an action in the system. """ + model = None schema = { "title": "AliasExecution", @@ -676,48 +714,48 @@ class AliasExecutionAPI(BaseAPI): "name": { "type": "string", "description": "Name of the action alias which matched.", - "required": True + "required": True, }, "format": { "type": "string", "description": "Format string which matched.", - "required": True + "required": True, }, "command": { "type": "string", "description": "Command used in chat.", - "required": True + "required": True, }, "user": { "type": "string", "description": "User that requested the execution.", - "default": "channel" # TODO: This value doesnt get set + "default": "channel", # TODO: This value doesnt get set }, "source_channel": { "type": "string", "description": "Channel from which the execution was requested. This is not the " - "channel as defined by the notification system.", - "required": True + "channel as defined by the notification system.", + "required": True, }, "source_context": { "type": "object", "description": "ALL data included with the message (also called the message " - "envelope). This is currently only used by the Microsoft Teams " - "adapter.", - "required": False + "envelope). This is currently only used by the Microsoft Teams " + "adapter.", + "required": False, }, "notification_channel": { "type": "string", "description": "StackStorm notification channel to use to respond.", - "required": False + "required": False, }, "notification_route": { "type": "string", "description": "StackStorm notification route to use to respond.", - "required": False - } + "required": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -734,6 +772,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): """ API object used for alias execution "match and execute" API endpoint request payload. """ + model = None schema = { "title": "ActionAliasMatchAndExecuteInputAPI", @@ -743,7 +782,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): "command": { "type": "string", "description": "Command used in chat.", - "required": True + "required": True, }, "user": { "type": "string", @@ -753,22 +792,22 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): "type": "string", "description": "Channel from which the execution was requested. This is not the \ channel as defined by the notification system.", - "required": True + "required": True, }, "notification_channel": { "type": "string", "description": "StackStorm notification channel to use to respond.", "required": False, - "default": None + "default": None, }, "notification_route": { "type": "string", "description": "StackStorm notification route to use to respond.", "required": False, - "default": None - } + "default": None, + }, }, - "additionalProperties": False + "additionalProperties": False, } @@ -776,6 +815,7 @@ class ActionAliasMatchAPI(BaseAPI): """ API model used for alias match API endpoint. """ + model = None schema = { @@ -786,10 +826,10 @@ class ActionAliasMatchAPI(BaseAPI): "command": { "type": "string", "description": "Command string to try to match the aliases against.", - "required": True + "required": True, } }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -805,6 +845,7 @@ class ActionAliasHelpAPI(BaseAPI): """ API model used to display action-alias help API endpoint. """ + model = None schema = { @@ -816,28 +857,28 @@ class ActionAliasHelpAPI(BaseAPI): "type": "string", "description": "Find help strings containing keyword.", "required": False, - "default": "" + "default": "", }, "pack": { "type": "string", "description": "List help strings for a specific pack.", "required": False, - "default": "" + "default": "", }, "offset": { "type": "integer", "description": "List help strings from the offset position.", "required": False, - "default": 0 + "default": 0, }, "limit": { "type": "integer", "description": "Limit the number of help strings returned.", "required": False, - "default": 0 - } + "default": 0, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod diff --git a/st2common/st2common/models/api/actionrunner.py b/st2common/st2common/models/api/actionrunner.py index d2a2029e32..7b580e1c9b 100644 --- a/st2common/st2common/models/api/actionrunner.py +++ b/st2common/st2common/models/api/actionrunner.py @@ -17,7 +17,7 @@ from st2common import log as logging from st2common.models.api.base import BaseAPI -__all__ = ['ActionRunnerAPI'] +__all__ = ["ActionRunnerAPI"] LOG = logging.getLogger(__name__) @@ -29,12 +29,9 @@ class ActionRunnerAPI(BaseAPI): Attribute: ... """ + schema = { - 'type': 'object', - 'parameters': { - 'id': { - 'type': 'string' - } - }, - 'additionalProperties': False + "type": "object", + "parameters": {"id": {"type": "string"}}, + "additionalProperties": False, } diff --git a/st2common/st2common/models/api/auth.py b/st2common/st2common/models/api/auth.py index 8e5ed34e34..10672e99ec 100644 --- a/st2common/st2common/models/api/auth.py +++ b/st2common/st2common/models/api/auth.py @@ -36,13 +36,8 @@ class UserAPI(BaseAPI): schema = { "title": "User", "type": "object", - "properties": { - "name": { - "type": "string", - "required": True - } - }, - "additionalProperties": False + "properties": {"name": {"type": "string", "required": True}}, + "additionalProperties": False, } @classmethod @@ -58,34 +53,25 @@ class TokenAPI(BaseAPI): "title": "Token", "type": "object", "properties": { - "id": { - "type": "string" - }, - "user": { - "type": ["string", "null"] - }, - "token": { - "type": ["string", "null"] - }, - "ttl": { - "type": "integer", - "minimum": 1 - }, + "id": {"type": "string"}, + "user": {"type": ["string", "null"]}, + "token": {"type": ["string", "null"]}, + "ttl": {"type": "integer", "minimum": 1}, "expiry": { "type": ["string", "null"], - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, - "metadata": { - "type": ["object", "null"] - } + "metadata": {"type": ["object", "null"]}, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) - doc['expiry'] = isotime.format(model.expiry, offset=False) if model.expiry else None + doc["expiry"] = ( + isotime.format(model.expiry, offset=False) if model.expiry else None + ) return cls(**doc) @classmethod @@ -104,52 +90,44 @@ class ApiKeyAPI(BaseAPI, APIUIDMixin): "title": "ApiKey", "type": "object", "properties": { - "id": { - "type": "string" - }, - "uid": { - "type": "string" - }, - "user": { - "type": ["string", "null"], - "default": "" - }, - "key_hash": { - "type": ["string", "null"] - }, - "metadata": { - "type": ["object", "null"] - }, - 'created_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "id": {"type": "string"}, + "uid": {"type": "string"}, + "user": {"type": ["string", "null"], "default": ""}, + "key_hash": {"type": ["string", "null"]}, + "metadata": {"type": ["object", "null"]}, + "created_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True - } + "default": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) - doc['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \ - else None + doc["created_at"] = ( + isotime.format(model.created_at, offset=False) if model.created_at else None + ) return cls(**doc) @classmethod def to_model(cls, instance): # If PrimaryKey ID is provided, - we want to work with existing ST2 API key - id = getattr(instance, 'id', None) + id = getattr(instance, "id", None) user = str(instance.user) if instance.user else None - key_hash = getattr(instance, 'key_hash', None) - metadata = getattr(instance, 'metadata', {}) - enabled = bool(getattr(instance, 'enabled', True)) - model = cls.model(id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled) + key_hash = getattr(instance, "key_hash", None) + metadata = getattr(instance, "metadata", {}) + enabled = bool(getattr(instance, "enabled", True)) + model = cls.model( + id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled + ) return model @@ -158,45 +136,35 @@ class ApiKeyCreateResponseAPI(BaseAPI): "title": "APIKeyCreateResponse", "type": "object", "properties": { - "id": { - "type": "string" - }, - "uid": { - "type": "string" - }, - "user": { - "type": ["string", "null"], - "default": "" - }, - "key": { - "type": ["string", "null"] - }, - "metadata": { - "type": ["object", "null"] - }, - 'created_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "id": {"type": "string"}, + "uid": {"type": "string"}, + "user": {"type": ["string", "null"], "default": ""}, + "key": {"type": ["string", "null"]}, + "metadata": {"type": ["object", "null"]}, + "created_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True - } + "default": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model=model, mask_secrets=mask_secrets) attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} - attrs['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \ - else None + attrs["created_at"] = ( + isotime.format(model.created_at, offset=False) if model.created_at else None + ) # key_hash is ignored. - attrs.pop('key_hash', None) + attrs.pop("key_hash", None) # key is unknown so the calling code will have to update after conversion. - attrs['key'] = None + attrs["key"] = None return cls(**attrs) diff --git a/st2common/st2common/models/api/base.py b/st2common/st2common/models/api/base.py index 3669291e9c..6c052a43e3 100644 --- a/st2common/st2common/models/api/base.py +++ b/st2common/st2common/models/api/base.py @@ -23,10 +23,7 @@ from st2common.util import mongoescape as util_mongodb from st2common import log as logging -__all__ = [ - 'BaseAPI', - 'APIUIDMixin' -] +__all__ = ["BaseAPI", "APIUIDMixin"] LOG = logging.getLogger(__name__) @@ -43,13 +40,13 @@ def __init__(self, **kw): def __repr__(self): name = type(self).__name__ - attrs = ', '.join("'%s': %r" % item for item in six.iteritems(vars(self))) + attrs = ", ".join("'%s': %r" % item for item in six.iteritems(vars(self))) # The format here is so that eval can be applied. return "%s(**{%s})" % (name, attrs) def __str__(self): name = type(self).__name__ - attrs = ', '.join("%s=%r" % item for item in six.iteritems(vars(self))) + attrs = ", ".join("%s=%r" % item for item in six.iteritems(vars(self))) return "%s[%s]" % (name, attrs) @@ -66,12 +63,16 @@ def validate(self): """ from st2common.util import schema as util_schema - schema = getattr(self, 'schema', {}) + schema = getattr(self, "schema", {}) attributes = vars(self) - cleaned = util_schema.validate(instance=attributes, schema=schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=attributes, + schema=schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) # Note: We use type() instead of self.__class__ since self.__class__ confuses pylint return type(self)(**cleaned) @@ -80,8 +81,8 @@ def validate(self): def _from_model(cls, model, mask_secrets=False): doc = model.to_mongo() - if '_id' in doc: - doc['id'] = str(doc.pop('_id')) + if "_id" in doc: + doc["id"] = str(doc.pop("_id")) doc = util_mongodb.unescape_chars(doc) @@ -117,7 +118,7 @@ def to_model(cls, doc): class APIUIDMixin(object): - """" + """ " Mixin class for retrieving UID for API objects. """ @@ -142,9 +143,11 @@ def has_valid_uid(self): def cast_argument_value(value_type, value): if value_type == bool: + def cast_func(value): value = str(value) - return value.lower() in ['1', 'true'] + return value.lower() in ["1", "true"] + else: cast_func = value_type diff --git a/st2common/st2common/models/api/execution.py b/st2common/st2common/models/api/execution.py index 447a8679eb..87a5ff52c0 100644 --- a/st2common/st2common/models/api/execution.py +++ b/st2common/st2common/models/api/execution.py @@ -28,10 +28,7 @@ from st2common.models.api.action import RunnerTypeAPI, ActionAPI, LiveActionAPI from st2common import log as logging -__all__ = [ - 'ActionExecutionAPI', - 'ActionExecutionOutputAPI' -] +__all__ = ["ActionExecutionAPI", "ActionExecutionOutputAPI"] LOG = logging.getLogger(__name__) @@ -48,47 +45,44 @@ class ActionExecutionAPI(BaseAPI): model = ActionExecutionDB - SKIP = ['start_timestamp', 'end_timestamp'] + SKIP = ["start_timestamp", "end_timestamp"] schema = { "title": "ActionExecution", "description": "Record of the execution of an action.", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, + "id": {"type": "string", "required": True}, "trigger": TriggerAPI.schema, "trigger_type": TriggerTypeAPI.schema, "trigger_instance": TriggerInstanceAPI.schema, "rule": RuleAPI.schema, - "action": REQUIRED_ATTR_SCHEMAS['action'], - "runner": REQUIRED_ATTR_SCHEMAS['runner'], - "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'], + "action": REQUIRED_ATTR_SCHEMAS["action"], + "runner": REQUIRED_ATTR_SCHEMAS["runner"], + "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"], "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "start_timestamp": { "description": "The start time when the action is executed.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "end_timestamp": { "description": "The timestamp when the action has finished.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "elapsed_seconds": { "description": "Time duration in seconds taken for completion of this execution.", "type": "number", - "required": False + "required": False, }, "web_url": { "description": "History URL for this execution if you want to view in UI.", "type": "string", - "required": False + "required": False, }, "parameters": { "description": "Input parameters for the action.", @@ -101,28 +95,28 @@ class ActionExecutionAPI(BaseAPI): {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } }, - 'additionalProperties': False - }, - "context": { - "type": "object" + "additionalProperties": False, }, + "context": {"type": "object"}, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] }, "parent": {"type": "string"}, "children": { "type": "array", "items": {"type": "string"}, - "uniqueItems": True + "uniqueItems": True, }, "log": { "description": "Contains information about execution state transitions.", @@ -132,22 +126,21 @@ class ActionExecutionAPI(BaseAPI): "properties": { "timestamp": { "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, - "status": { - "type": "string", - "enum": LIVEACTION_STATUSES - } - } - } + "status": {"type": "string", "enum": LIVEACTION_STATUSES}, + }, + }, }, "delay": { - "description": ("How long (in milliseconds) to delay the execution before" - "scheduling."), + "description": ( + "How long (in milliseconds) to delay the execution before" + "scheduling." + ), "type": "integer", - } + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -155,16 +148,16 @@ def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) start_timestamp = model.start_timestamp start_timestamp_iso = isotime.format(start_timestamp, offset=False) - doc['start_timestamp'] = start_timestamp_iso + doc["start_timestamp"] = start_timestamp_iso end_timestamp = model.end_timestamp if end_timestamp: end_timestamp_iso = isotime.format(end_timestamp, offset=False) - doc['end_timestamp'] = end_timestamp_iso - doc['elapsed_seconds'] = (end_timestamp - start_timestamp).total_seconds() + doc["end_timestamp"] = end_timestamp_iso + doc["elapsed_seconds"] = (end_timestamp - start_timestamp).total_seconds() - for entry in doc.get('log', []): - entry['timestamp'] = isotime.format(entry['timestamp'], offset=False) + for entry in doc.get("log", []): + entry["timestamp"] = isotime.format(entry["timestamp"], offset=False) attrs = {attr: value for attr, value in six.iteritems(doc) if value} return cls(**attrs) @@ -172,11 +165,11 @@ def from_model(cls, model, mask_secrets=False): @classmethod def to_model(cls, instance): values = {} - for attr, meta in six.iteritems(cls.schema.get('properties', dict())): + for attr, meta in six.iteritems(cls.schema.get("properties", dict())): if not getattr(instance, attr, None): continue - default = copy.deepcopy(meta.get('default', None)) + default = copy.deepcopy(meta.get("default", None)) value = getattr(instance, attr, default) # pylint: disable=no-member @@ -188,8 +181,8 @@ def to_model(cls, instance): if attr not in ActionExecutionAPI.SKIP: values[attr] = value - values['start_timestamp'] = isotime.parse(instance.start_timestamp) - values['end_timestamp'] = isotime.parse(instance.end_timestamp) + values["start_timestamp"] = isotime.parse(instance.start_timestamp) + values["end_timestamp"] = isotime.parse(instance.end_timestamp) model = cls.model(**values) return model @@ -198,41 +191,24 @@ def to_model(cls, instance): class ActionExecutionOutputAPI(BaseAPI): model = ActionExecutionOutputDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' - }, - 'execution_id': { - 'type': 'string' - }, - 'action_ref': { - 'type': 'string' - }, - 'runner_ref': { - 'type': 'string' - }, - 'timestamp': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX - }, - 'output_type': { - 'type': 'string' - }, - 'data': { - 'type': 'string' - }, - 'delay': { - 'type': 'integer' - } + "type": "object", + "properties": { + "id": {"type": "string"}, + "execution_id": {"type": "string"}, + "action_ref": {"type": "string"}, + "runner_ref": {"type": "string"}, + "timestamp": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX}, + "output_type": {"type": "string"}, + "data": {"type": "string"}, + "delay": {"type": "integer"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=True): doc = cls._from_model(model, mask_secrets=mask_secrets) - doc['timestamp'] = isotime.format(model.timestamp, offset=False) + doc["timestamp"] = isotime.format(model.timestamp, offset=False) attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} return cls(**attrs) diff --git a/st2common/st2common/models/api/inquiry.py b/st2common/st2common/models/api/inquiry.py index e3194df28c..a45327aaa7 100644 --- a/st2common/st2common/models/api/inquiry.py +++ b/st2common/st2common/models/api/inquiry.py @@ -54,30 +54,11 @@ class InquiryAPI(BaseAPI): "description": "Record of an Inquiry", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, - "route": { - "type": "string", - "default": "", - "required": True - }, - "ttl": { - "type": "integer", - "default": 1440, - "required": True - }, - "users": { - "type": "array", - "default": [], - "required": True - }, - "roles": { - "type": "array", - "default": [], - "required": True - }, + "id": {"type": "string", "required": True}, + "route": {"type": "string", "default": "", "required": True}, + "ttl": {"type": "integer", "default": 1440, "required": True}, + "users": {"type": "array", "default": [], "required": True}, + "roles": {"type": "array", "default": [], "required": True}, "schema": { "type": "object", "default": { @@ -87,30 +68,32 @@ class InquiryAPI(BaseAPI): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, }, - "required": True + "required": True, }, - "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'], - "runner": REQUIRED_ATTR_SCHEMAS['runner'], + "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"], + "runner": REQUIRED_ATTR_SCHEMAS["runner"], "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "parent": {"type": "string"}, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] - } + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -118,23 +101,22 @@ def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) newdoc = { - 'id': doc['id'], - 'runner': doc.get('runner', None), - 'status': doc.get('status', None), - 'liveaction': doc.get('liveaction', None), - 'parent': doc.get('parent', None), - 'result': doc.get('result', None) + "id": doc["id"], + "runner": doc.get("runner", None), + "status": doc.get("status", None), + "liveaction": doc.get("liveaction", None), + "parent": doc.get("parent", None), + "result": doc.get("result", None), } - for field in ['route', 'ttl', 'users', 'roles', 'schema']: - newdoc[field] = doc['result'].get(field, None) + for field in ["route", "ttl", "users", "roles", "schema"]: + newdoc[field] = doc["result"].get(field, None) return cls(**newdoc) class InquiryResponseAPI(BaseAPI): - """A more pruned Inquiry model, containing only the fields needed for an API response - """ + """A more pruned Inquiry model, containing only the fields needed for an API response""" model = ActionExecutionDB schema = { @@ -142,30 +124,11 @@ class InquiryResponseAPI(BaseAPI): "description": "Record of an Inquiry", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, - "route": { - "type": "string", - "default": "", - "required": True - }, - "ttl": { - "type": "integer", - "default": 1440, - "required": True - }, - "users": { - "type": "array", - "default": [], - "required": True - }, - "roles": { - "type": "array", - "default": [], - "required": True - }, + "id": {"type": "string", "required": True}, + "route": {"type": "string", "default": "", "required": True}, + "ttl": {"type": "integer", "default": 1440, "required": True}, + "users": {"type": "array", "default": [], "required": True}, + "roles": {"type": "array", "default": [], "required": True}, "schema": { "type": "object", "default": { @@ -175,14 +138,14 @@ class InquiryResponseAPI(BaseAPI): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, }, - "required": True - } + "required": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -201,9 +164,7 @@ def from_model(cls, model, mask_secrets=False, skip_db=False): else: doc = model - newdoc = { - "id": doc["id"] - } + newdoc = {"id": doc["id"]} for field in ["route", "ttl", "users", "roles", "schema"]: newdoc[field] = doc["result"].get(field) @@ -211,16 +172,16 @@ def from_model(cls, model, mask_secrets=False, skip_db=False): @classmethod def from_inquiry_api(cls, inquiry_api, mask_secrets=False): - """ Allows translation of InquiryAPI directly to InquiryResponseAPI + """Allows translation of InquiryAPI directly to InquiryResponseAPI This bypasses the DB modeling, since there's no DB model for Inquiries yet. """ return cls( - id=getattr(inquiry_api, 'id', None), - route=getattr(inquiry_api, 'route', None), - ttl=getattr(inquiry_api, 'ttl', None), - users=getattr(inquiry_api, 'users', None), - roles=getattr(inquiry_api, 'roles', None), - schema=getattr(inquiry_api, 'schema', None) + id=getattr(inquiry_api, "id", None), + route=getattr(inquiry_api, "route", None), + ttl=getattr(inquiry_api, "ttl", None), + users=getattr(inquiry_api, "users", None), + roles=getattr(inquiry_api, "roles", None), + schema=getattr(inquiry_api, "schema", None), ) diff --git a/st2common/st2common/models/api/keyvalue.py b/st2common/st2common/models/api/keyvalue.py index 8365350ef7..a19cfcc33e 100644 --- a/st2common/st2common/models/api/keyvalue.py +++ b/st2common/st2common/models/api/keyvalue.py @@ -21,9 +21,16 @@ from oslo_config import cfg import six -from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, ALLOWED_SCOPES +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + ALLOWED_SCOPES, +) from st2common.constants.keyvalue import SYSTEM_SCOPE, USER_SCOPE -from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException +from st2common.exceptions.keyvalue import ( + CryptoKeyNotSetupException, + InvalidScopeException, +) from st2common.log import logging from st2common.util import isotime from st2common.util import date as date_utils @@ -32,10 +39,7 @@ from st2common.models.system.keyvalue import UserKeyReference from st2common.models.db.keyvalue import KeyValuePairDB -__all__ = [ - 'KeyValuePairAPI', - 'KeyValuePairSetAPI' -] +__all__ = ["KeyValuePairAPI", "KeyValuePairSetAPI"] LOG = logging.getLogger(__name__) @@ -44,50 +48,29 @@ class KeyValuePairAPI(BaseAPI): crypto_setup = False model = KeyValuePairDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' + "type": "object", + "properties": { + "id": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string"}, + "description": {"type": "string"}, + "value": {"type": "string", "required": True}, + "secret": {"type": "boolean", "required": False, "default": False}, + "encrypted": {"type": "boolean", "required": False, "default": False}, + "scope": { + "type": "string", + "required": False, + "default": FULL_SYSTEM_SCOPE, }, - "uid": { - "type": "string" - }, - 'name': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'value': { - 'type': 'string', - 'required': True - }, - 'secret': { - 'type': 'boolean', - 'required': False, - 'default': False - }, - 'encrypted': { - 'type': 'boolean', - 'required': False, - 'default': False - }, - 'scope': { - 'type': 'string', - 'required': False, - 'default': FULL_SYSTEM_SCOPE - }, - 'expire_timestamp': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "expire_timestamp": { + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, # Note: Those values are only used for input # TODO: Improve - 'ttl': { - 'type': 'integer' - } + "ttl": {"type": "integer"}, }, - 'additionalProperties': False + "additionalProperties": False, } @staticmethod @@ -96,19 +79,25 @@ def _setup_crypto(): # Crypto already set up return - LOG.info('Checking if encryption is enabled for key-value store.') + LOG.info("Checking if encryption is enabled for key-value store.") KeyValuePairAPI.is_encryption_enabled = cfg.CONF.keyvalue.enable_encryption - LOG.debug('Encryption enabled? : %s', KeyValuePairAPI.is_encryption_enabled) + LOG.debug("Encryption enabled? : %s", KeyValuePairAPI.is_encryption_enabled) if KeyValuePairAPI.is_encryption_enabled: KeyValuePairAPI.crypto_key_path = cfg.CONF.keyvalue.encryption_key_path - LOG.info('Encryption enabled. Looking for key in path %s', - KeyValuePairAPI.crypto_key_path) + LOG.info( + "Encryption enabled. Looking for key in path %s", + KeyValuePairAPI.crypto_key_path, + ) if not os.path.exists(KeyValuePairAPI.crypto_key_path): - msg = ('Encryption key file does not exist in path %s.' % - KeyValuePairAPI.crypto_key_path) + msg = ( + "Encryption key file does not exist in path %s." + % KeyValuePairAPI.crypto_key_path + ) LOG.exception(msg) - LOG.info('All API requests will now send out BAD_REQUEST ' + - 'if you ask to store secrets in key value store.') + LOG.info( + "All API requests will now send out BAD_REQUEST " + + "if you ask to store secrets in key value store." + ) KeyValuePairAPI.crypto_key = None else: KeyValuePairAPI.crypto_key = read_crypto_key( @@ -123,28 +112,30 @@ def from_model(cls, model, mask_secrets=True): doc = cls._from_model(model, mask_secrets=mask_secrets) - if getattr(model, 'expire_timestamp', None) and model.expire_timestamp: - doc['expire_timestamp'] = isotime.format(model.expire_timestamp, offset=False) + if getattr(model, "expire_timestamp", None) and model.expire_timestamp: + doc["expire_timestamp"] = isotime.format( + model.expire_timestamp, offset=False + ) encrypted = False - secret = getattr(model, 'secret', False) + secret = getattr(model, "secret", False) if secret: encrypted = True if not mask_secrets and secret: - doc['value'] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value) + doc["value"] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value) encrypted = False - scope = getattr(model, 'scope', SYSTEM_SCOPE) + scope = getattr(model, "scope", SYSTEM_SCOPE) if scope: - doc['scope'] = scope + doc["scope"] = scope - key = doc.get('name', None) + key = doc.get("name", None) if (scope == USER_SCOPE or scope == FULL_USER_SCOPE) and key: - doc['user'] = UserKeyReference.get_user(key) - doc['name'] = UserKeyReference.get_name(key) + doc["user"] = UserKeyReference.get_user(key) + doc["name"] = UserKeyReference.get_name(key) - doc['encrypted'] = encrypted + doc["encrypted"] = encrypted attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} return cls(**attrs) @@ -153,21 +144,22 @@ def to_model(cls, kvp): if not KeyValuePairAPI.crypto_setup: KeyValuePairAPI._setup_crypto() - kvp_id = getattr(kvp, 'id', None) - name = getattr(kvp, 'name', None) - description = getattr(kvp, 'description', None) + kvp_id = getattr(kvp, "id", None) + name = getattr(kvp, "name", None) + description = getattr(kvp, "description", None) value = kvp.value original_value = value secret = False - if getattr(kvp, 'ttl', None): - expire_timestamp = (date_utils.get_datetime_utc_now() + - datetime.timedelta(seconds=kvp.ttl)) + if getattr(kvp, "ttl", None): + expire_timestamp = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=kvp.ttl + ) else: expire_timestamp = None - encrypted = getattr(kvp, 'encrypted', False) - secret = getattr(kvp, 'secret', False) + encrypted = getattr(kvp, "encrypted", False) + secret = getattr(kvp, "secret", False) # If user transmitted the value in an pre-encrypted format, we perform the decryption here # to ensure data integrity. Besides that, we store data as-is. @@ -182,9 +174,11 @@ def to_model(cls, kvp): try: symmetric_decrypt(KeyValuePairAPI.crypto_key, value) except Exception: - msg = ('Failed to verify the integrity of the provided value for key "%s". Ensure ' - 'that the value is encrypted with the correct key and not corrupted.' % - (name)) + msg = ( + 'Failed to verify the integrity of the provided value for key "%s". Ensure ' + "that the value is encrypted with the correct key and not corrupted." + % (name) + ) raise ValueError(msg) # Additional safety check to ensure that the value hasn't been decrypted @@ -194,30 +188,39 @@ def to_model(cls, kvp): value = symmetric_encrypt(KeyValuePairAPI.crypto_key, value) - scope = getattr(kvp, 'scope', FULL_SYSTEM_SCOPE) + scope = getattr(kvp, "scope", FULL_SYSTEM_SCOPE) if scope not in ALLOWED_SCOPES: - raise InvalidScopeException('Invalid scope "%s"! Allowed scopes are %s.' % ( - scope, ALLOWED_SCOPES) + raise InvalidScopeException( + 'Invalid scope "%s"! Allowed scopes are %s.' % (scope, ALLOWED_SCOPES) ) # NOTE: For security reasons, encrypted always implies secret=True. See comment # above for explanation. if encrypted and not secret: - raise ValueError('encrypted option can only be used in combination with secret ' - 'option') + raise ValueError( + "encrypted option can only be used in combination with secret " "option" + ) - model = cls.model(id=kvp_id, name=name, description=description, value=value, - secret=secret, scope=scope, - expire_timestamp=expire_timestamp) + model = cls.model( + id=kvp_id, + name=name, + description=description, + value=value, + secret=secret, + scope=scope, + expire_timestamp=expire_timestamp, + ) return model @classmethod def _verif_key_is_set_up(cls, name): if not KeyValuePairAPI.crypto_key: - msg = ('Crypto key not found in %s. Unable to encrypt / decrypt value for key %s.' % - (KeyValuePairAPI.crypto_key_path, name)) + msg = "Crypto key not found in %s. Unable to encrypt / decrypt value for key %s." % ( + KeyValuePairAPI.crypto_key_path, + name, + ) raise CryptoKeyNotSetupException(msg) @@ -227,13 +230,12 @@ class KeyValuePairSetAPI(KeyValuePairAPI): """ schema = copy.deepcopy(KeyValuePairAPI.schema) - schema['properties']['ttl'] = { - 'description': 'Items TTL', - 'type': 'integer' - } - schema['properties']['user'] = { - 'description': ('User to which the value should be scoped to. Only applicable to ' - 'scope == user'), - 'type': 'string', - 'default': None + schema["properties"]["ttl"] = {"description": "Items TTL", "type": "integer"} + schema["properties"]["user"] = { + "description": ( + "User to which the value should be scoped to. Only applicable to " + "scope == user" + ), + "type": "string", + "default": None, } diff --git a/st2common/st2common/models/api/notification.py b/st2common/st2common/models/api/notification.py index fef0545f26..9d80ddbf7f 100644 --- a/st2common/st2common/models/api/notification.py +++ b/st2common/st2common/models/api/notification.py @@ -19,57 +19,60 @@ NotificationSubSchemaAPI = { "type": "object", "properties": { - "message": { - "type": "string", - "description": "Message to use for notification" - }, + "message": {"type": "string", "description": "Message to use for notification"}, "data": { "type": "object", - "description": "Data to be sent as part of notification" + "description": "Data to be sent as part of notification", }, "routes": { "type": "array", - "description": "Channels to post notifications to." + "description": "Channels to post notifications to.", }, "channels": { # Deprecated. Only here for backward compatibility. "type": "array", - "description": "Channels to post notifications to." + "description": "Channels to post notifications to.", }, }, - "additionalProperties": False + "additionalProperties": False, } class NotificationsHelper(object): - @staticmethod def to_model(notify_api_object): - if notify_api_object.get('on-success', None): - on_success = NotificationsHelper._to_model_sub_schema(notify_api_object['on-success']) + if notify_api_object.get("on-success", None): + on_success = NotificationsHelper._to_model_sub_schema( + notify_api_object["on-success"] + ) else: on_success = None - if notify_api_object.get('on-complete', None): + if notify_api_object.get("on-complete", None): on_complete = NotificationsHelper._to_model_sub_schema( - notify_api_object['on-complete']) + notify_api_object["on-complete"] + ) else: on_complete = None - if notify_api_object.get('on-failure', None): - on_failure = NotificationsHelper._to_model_sub_schema(notify_api_object['on-failure']) + if notify_api_object.get("on-failure", None): + on_failure = NotificationsHelper._to_model_sub_schema( + notify_api_object["on-failure"] + ) else: on_failure = None - model = NotificationSchema(on_success=on_success, on_failure=on_failure, - on_complete=on_complete) + model = NotificationSchema( + on_success=on_success, on_failure=on_failure, on_complete=on_complete + ) return model @staticmethod def _to_model_sub_schema(notification_settings_json): - message = notification_settings_json.get('message', None) - data = notification_settings_json.get('data', {}) - routes = (notification_settings_json.get('routes', None) or - notification_settings_json.get('channels', [])) + message = notification_settings_json.get("message", None) + data = notification_settings_json.get("data", {}) + routes = notification_settings_json.get( + "routes", None + ) or notification_settings_json.get("channels", []) model = NotificationSubSchema(message=message, data=data, routes=routes) return model @@ -77,15 +80,18 @@ def _to_model_sub_schema(notification_settings_json): @staticmethod def from_model(notify_model): notify = {} - if getattr(notify_model, 'on_complete', None): - notify['on-complete'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_complete) - if getattr(notify_model, 'on_success', None): - notify['on-success'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_success) - if getattr(notify_model, 'on_failure', None): - notify['on-failure'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_failure) + if getattr(notify_model, "on_complete", None): + notify["on-complete"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_complete + ) + if getattr(notify_model, "on_success", None): + notify["on-success"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_success + ) + if getattr(notify_model, "on_failure", None): + notify["on-failure"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_failure + ) return notify @@ -93,13 +99,14 @@ def from_model(notify_model): def _from_model_sub_schema(notify_sub_schema_model): notify_sub_schema = {} - if getattr(notify_sub_schema_model, 'message', None): - notify_sub_schema['message'] = notify_sub_schema_model.message - if getattr(notify_sub_schema_model, 'data', None): - notify_sub_schema['data'] = notify_sub_schema_model.data - routes = (getattr(notify_sub_schema_model, 'routes') or - getattr(notify_sub_schema_model, 'channels')) + if getattr(notify_sub_schema_model, "message", None): + notify_sub_schema["message"] = notify_sub_schema_model.message + if getattr(notify_sub_schema_model, "data", None): + notify_sub_schema["data"] = notify_sub_schema_model.data + routes = getattr(notify_sub_schema_model, "routes") or getattr( + notify_sub_schema_model, "channels" + ) if routes: - notify_sub_schema['routes'] = routes + notify_sub_schema["routes"] = routes return notify_sub_schema diff --git a/st2common/st2common/models/api/pack.py b/st2common/st2common/models/api/pack.py index 02c6d00f63..6de2893427 100644 --- a/st2common/st2common/models/api/pack.py +++ b/st2common/st2common/models/api/pack.py @@ -37,16 +37,14 @@ from st2common.util.pack import validate_config_against_schema __all__ = [ - 'PackAPI', - 'ConfigSchemaAPI', - 'ConfigAPI', - - 'ConfigItemSetAPI', - - 'PackInstallRequestAPI', - 'PackRegisterRequestAPI', - 'PackSearchRequestAPI', - 'PackAsyncAPI' + "PackAPI", + "ConfigSchemaAPI", + "ConfigAPI", + "ConfigItemSetAPI", + "PackInstallRequestAPI", + "PackRegisterRequestAPI", + "PackSearchRequestAPI", + "PackAsyncAPI", ] LOG = logging.getLogger(__name__) @@ -55,124 +53,117 @@ class PackAPI(BaseAPI): model = PackDB schema = { - 'type': 'object', - 'description': 'Content pack schema.', - 'properties': { - 'id': { - 'type': 'string', - 'description': 'Unique identifier for the pack.', - 'default': None + "type": "object", + "description": "Content pack schema.", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the pack.", + "default": None, }, - 'name': { - 'type': 'string', - 'description': 'Display name of the pack. If the name only contains lowercase' - 'letters, digits and underscores, the "ref" field is not required.', - 'required': True + "name": { + "type": "string", + "description": "Display name of the pack. If the name only contains lowercase" + 'letters, digits and underscores, the "ref" field is not required.', + "required": True, }, - 'ref': { - 'type': 'string', - 'description': 'Reference for the pack, used as an internal id.', - 'default': None, - 'pattern': PACK_REF_WHITELIST_REGEX + "ref": { + "type": "string", + "description": "Reference for the pack, used as an internal id.", + "default": None, + "pattern": PACK_REF_WHITELIST_REGEX, }, - 'uid': { - 'type': 'string' + "uid": {"type": "string"}, + "description": { + "type": "string", + "description": "Brief description of the pack and the service it integrates with.", + "required": True, }, - 'description': { - 'type': 'string', - 'description': 'Brief description of the pack and the service it integrates with.', - 'required': True + "keywords": { + "type": "array", + "description": "Keywords describing the pack.", + "items": {"type": "string"}, + "default": [], }, - 'keywords': { - 'type': 'array', - 'description': 'Keywords describing the pack.', - 'items': {'type': 'string'}, - 'default': [] + "version": { + "type": "string", + "description": "Pack version. Must follow the semver format " + '(for instance, "0.1.0").', + "pattern": PACK_VERSION_REGEX, + "required": True, }, - 'version': { - 'type': 'string', - 'description': 'Pack version. Must follow the semver format ' - '(for instance, "0.1.0").', - 'pattern': PACK_VERSION_REGEX, - 'required': True + "stackstorm_version": { + "type": "string", + "description": 'Required StackStorm version. Examples: ">1.6.0", ' + '">=1.8.0, <2.2.0"', + "pattern": ST2_VERSION_REGEX, }, - 'stackstorm_version': { - 'type': 'string', - 'description': 'Required StackStorm version. Examples: ">1.6.0", ' - '">=1.8.0, <2.2.0"', - 'pattern': ST2_VERSION_REGEX, + "python_versions": { + "type": "array", + "description": ( + "Major Python versions supported by this pack. E.g. " + '"2" for Python 2.7.x and "3" for Python 3.6.x' + ), + "items": {"type": "string", "enum": ["2", "3"]}, + "minItems": 1, + "maxItems": 2, + "uniqueItems": True, + "additionalItems": True, }, - 'python_versions': { - 'type': 'array', - 'description': ('Major Python versions supported by this pack. E.g. ' - '"2" for Python 2.7.x and "3" for Python 3.6.x'), - 'items': { - 'type': 'string', - 'enum': [ - '2', - '3' - ] - }, - 'minItems': 1, - 'maxItems': 2, - 'uniqueItems': True, - 'additionalItems': True + "author": { + "type": "string", + "description": "Pack author or authors.", + "required": True, }, - 'author': { - 'type': 'string', - 'description': 'Pack author or authors.', - 'required': True + "email": { + "type": "string", + "description": "E-mail of the pack author.", + "format": "email", }, - 'email': { - 'type': 'string', - 'description': 'E-mail of the pack author.', - 'format': 'email' + "contributors": { + "type": "array", + "items": {"type": "string", "maxLength": 100}, + "description": ( + "A list of people who have contributed to the pack. Format is: " + "Name e.g. Tomaz Muraus ." + ), }, - 'contributors': { - 'type': 'array', - 'items': { - 'type': 'string', - 'maxLength': 100 - }, - 'description': ('A list of people who have contributed to the pack. Format is: ' - 'Name e.g. Tomaz Muraus .') + "files": { + "type": "array", + "description": "A list of files inside the pack.", + "items": {"type": "string"}, + "default": [], }, - 'files': { - 'type': 'array', - 'description': 'A list of files inside the pack.', - 'items': {'type': 'string'}, - 'default': [] + "dependencies": { + "type": "array", + "description": "A list of other StackStorm packs this pack depends upon. " + 'The same format as in "st2 pack install" is used: ' + '"[=]".', + "items": {"type": "string"}, + "default": [], }, - 'dependencies': { - 'type': 'array', - 'description': 'A list of other StackStorm packs this pack depends upon. ' - 'The same format as in "st2 pack install" is used: ' - '"[=]".', - 'items': {'type': 'string'}, - 'default': [] + "system": { + "type": "object", + "description": "Specification for the system components and packages " + "required for the pack.", + "default": {}, }, - 'system': { - 'type': 'object', - 'description': 'Specification for the system components and packages ' - 'required for the pack.', - 'default': {} + "path": { + "type": "string", + "description": "Location of the pack on disk in st2 system.", + "required": False, }, - 'path': { - 'type': 'string', - 'description': 'Location of the pack on disk in st2 system.', - 'required': False - } }, # NOTE: We add this here explicitly so we can gracefuly add new attributs to pack.yaml # without breaking existing installations - 'additionalProperties': True + "additionalProperties": True, } def __init__(self, **values): # Note: If some version values are not explicitly surrounded by quotes they are recognized # as numbers so we cast them to string - if values.get('version', None): - values['version'] = str(values['version']) + if values.get("version", None): + values["version"] = str(values["version"]) super(PackAPI, self).__init__(**values) @@ -186,17 +177,21 @@ def validate(self): # Invalid version if "Failed validating 'pattern' in schema['properties']['version']" in msg: - new_msg = ('Pack version "%s" doesn\'t follow a valid semver format. Valid ' - 'versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc.' % - (self.version)) - new_msg += '\n\n' + msg + new_msg = ( + 'Pack version "%s" doesn\'t follow a valid semver format. Valid ' + "versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc." + % (self.version) + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) # Invalid ref / name if "Failed validating 'pattern' in schema['properties']['ref']" in msg: - new_msg = ('Pack ref / name can only contain valid word characters (a-z, 0-9 and ' - '_), dashes are not allowed.') - new_msg += '\n\n' + msg + new_msg = ( + "Pack ref / name can only contain valid word characters (a-z, 0-9 and " + "_), dashes are not allowed." + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) raise e @@ -206,24 +201,35 @@ def to_model(cls, pack): ref = pack.ref name = pack.name description = pack.description - keywords = getattr(pack, 'keywords', []) + keywords = getattr(pack, "keywords", []) version = str(pack.version) - stackstorm_version = getattr(pack, 'stackstorm_version', None) - python_versions = getattr(pack, 'python_versions', []) + stackstorm_version = getattr(pack, "stackstorm_version", None) + python_versions = getattr(pack, "python_versions", []) author = pack.author email = pack.email - contributors = getattr(pack, 'contributors', []) - files = getattr(pack, 'files', []) - pack_dir = getattr(pack, 'path', None) - dependencies = getattr(pack, 'dependencies', []) - system = getattr(pack, 'system', {}) - - model = cls.model(ref=ref, name=name, description=description, keywords=keywords, - version=version, author=author, email=email, contributors=contributors, - files=files, dependencies=dependencies, system=system, - stackstorm_version=stackstorm_version, path=pack_dir, - python_versions=python_versions) + contributors = getattr(pack, "contributors", []) + files = getattr(pack, "files", []) + pack_dir = getattr(pack, "path", None) + dependencies = getattr(pack, "dependencies", []) + system = getattr(pack, "system", {}) + + model = cls.model( + ref=ref, + name=name, + description=description, + keywords=keywords, + version=version, + author=author, + email=email, + contributors=contributors, + files=files, + dependencies=dependencies, + system=system, + stackstorm_version=stackstorm_version, + path=pack_dir, + python_versions=python_versions, + ) return model @@ -236,11 +242,11 @@ class ConfigSchemaAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the config schema.", - "type": "string" + "type": "string", }, "pack": { "description": "The content pack this config schema belongs to.", - "type": "string" + "type": "string", }, "attributes": { "description": "Config schema attributes.", @@ -248,11 +254,11 @@ class ConfigSchemaAPI(BaseAPI): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False, - "default": {} - } + "additionalProperties": False, + "default": {}, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -273,19 +279,19 @@ class ConfigAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the config.", - "type": "string" + "type": "string", }, "pack": { "description": "The content pack this config belongs to.", - "type": "string" + "type": "string", }, "values": { "description": "Config values.", "type": "object", - "default": {} - } + "default": {}, + }, }, - "additionalProperties": False + "additionalProperties": False, } def validate(self, validate_against_schema=False): @@ -310,13 +316,15 @@ def _validate_config_values_against_schema(self): instance = self.values or {} schema = config_schema_db.attributes or {} - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % (self.pack)) + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % (self.pack)) - cleaned = validate_config_against_schema(config_schema=schema, - config_object=instance, - config_path=config_path, - pack_name=self.pack) + cleaned = validate_config_against_schema( + config_schema=schema, + config_object=instance, + config_path=config_path, + pack_name=self.pack, + ) return cleaned @@ -330,15 +338,14 @@ def to_model(cls, config): class ConfigUpdateRequestAPI(BaseAPI): - schema = { - "type": "object" - } + schema = {"type": "object"} class ConfigItemSetAPI(BaseAPI): """ API class used with the config set API endpoint. """ + model = None schema = { "title": "", @@ -348,30 +355,27 @@ class ConfigItemSetAPI(BaseAPI): "name": { "description": "Config item name (key)", "type": "string", - "required": True + "required": True, }, "value": { "description": "Config item value.", "type": ["string", "number", "boolean", "array", "object"], - "required": True + "required": True, }, "scope": { "description": "Config item scope (system / user)", "type": "string", "default": SYSTEM_SCOPE, - "enum": [ - SYSTEM_SCOPE, - USER_SCOPE - ] + "enum": [SYSTEM_SCOPE, USER_SCOPE], }, "user": { "description": "User for user-scoped items (only available to admins).", "type": "string", "required": False, - "default": None - } + "default": None, + }, }, - "additionalProperties": False + "additionalProperties": False, } @@ -379,15 +383,13 @@ class PackInstallRequestAPI(BaseAPI): schema = { "type": "object", "properties": { - "packs": { - "type": "array" - }, + "packs": {"type": "array"}, "force": { "type": "boolean", "description": "Force pack installation", - "default": False - } - } + "default": False, + }, + }, } @@ -395,24 +397,14 @@ class PackRegisterRequestAPI(BaseAPI): schema = { "type": "object", "properties": { - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "packs": { - "type": "array", - "items": { - "type": "string" - } - }, + "types": {"type": "array", "items": {"type": "string"}}, + "packs": {"type": "array", "items": {"type": "string"}}, "fail_on_failure": { "type": "boolean", "description": "True to fail on failure", - "default": True - } - } + "default": True, + }, + }, } @@ -438,18 +430,13 @@ class PackSearchRequestAPI(BaseAPI): }, "additionalProperties": False, }, - ] + ], } class PackAsyncAPI(BaseAPI): schema = { "type": "object", - "properties": { - "execution_id": { - "type": "string", - "required": True - } - }, - "additionalProperties": False + "properties": {"execution_id": {"type": "string", "required": True}}, + "additionalProperties": False, } diff --git a/st2common/st2common/models/api/policy.py b/st2common/st2common/models/api/policy.py index a46dad9eda..211560d453 100644 --- a/st2common/st2common/models/api/policy.py +++ b/st2common/st2common/models/api/policy.py @@ -22,7 +22,7 @@ from st2common.util import schema as util_schema -__all__ = ['PolicyTypeAPI'] +__all__ = ["PolicyTypeAPI"] LOG = logging.getLogger(__name__) @@ -33,55 +33,34 @@ class PolicyTypeAPI(BaseAPI, APIUIDMixin): "title": "Policy Type", "type": "object", "properties": { - "id": { - "type": "string", - "default": None - }, - 'uid': { - 'type': 'string' - }, - "name": { - "type": "string", - "required": True - }, - "resource_type": { - "enum": ["action"], - "required": True - }, - "ref": { - "type": "string" - }, - "description": { - "type": "string" - }, - "enabled": { - "type": "boolean", - "default": True - }, - "module": { - "type": "string", - "required": True - }, + "id": {"type": "string", "default": None}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "resource_type": {"enum": ["action"], "required": True}, + "ref": {"type": "string"}, + "description": {"type": "string"}, + "enabled": {"type": "boolean", "default": True}, + "module": {"type": "string", "required": True}, "parameters": { "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_draft_schema() - }, - 'additionalProperties': False - } + "patternProperties": {r"^\w+$": util_schema.get_draft_schema()}, + "additionalProperties": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def to_model(cls, instance): - return cls.model(name=str(instance.name), - description=getattr(instance, 'description', None), - resource_type=str(instance.resource_type), - ref=getattr(instance, 'ref', None), - enabled=getattr(instance, 'enabled', None), - module=str(instance.module), - parameters=getattr(instance, 'parameters', dict())) + return cls.model( + name=str(instance.name), + description=getattr(instance, "description", None), + resource_type=str(instance.resource_type), + ref=getattr(instance, "ref", None), + enabled=getattr(instance, "enabled", None), + module=str(instance.module), + parameters=getattr(instance, "parameters", dict()), + ) class PolicyAPI(BaseAPI, APIUIDMixin): @@ -90,38 +69,15 @@ class PolicyAPI(BaseAPI, APIUIDMixin): "title": "Policy", "type": "object", "properties": { - "id": { - "type": "string", - "default": None - }, - 'uid': { - 'type': 'string' - }, - "name": { - "type": "string", - "required": True - }, - "pack": { - "type": "string" - }, - "ref": { - "type": "string" - }, - "description": { - "type": "string" - }, - "enabled": { - "type": "boolean", - "default": True - }, - "resource_ref": { - "type": "string", - "required": True - }, - "policy_type": { - "type": "string", - "required": True - }, + "id": {"type": "string", "default": None}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "ref": {"type": "string"}, + "description": {"type": "string"}, + "enabled": {"type": "boolean", "default": True}, + "resource_ref": {"type": "string", "required": True}, + "policy_type": {"type": "string", "required": True}, "parameters": { "type": "object", "patternProperties": { @@ -132,20 +88,19 @@ class PolicyAPI(BaseAPI, APIUIDMixin): {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } }, - 'additionalProperties': False - + "additionalProperties": False, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } def validate(self): @@ -156,15 +111,19 @@ def validate(self): # pylint: disable=no-member policy_type_db = PolicyType.get_by_ref(cleaned.policy_type) if not policy_type_db: - raise ValueError('Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type)) + raise ValueError( + 'Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type) + ) parameters_schema = policy_type_db.parameters - parameters = getattr(cleaned, 'parameters', {}) + parameters = getattr(cleaned, "parameters", {}) schema = util_schema.get_schema_for_resource_parameters( - parameters_schema=parameters_schema) + parameters_schema=parameters_schema + ) validator = util_schema.get_validator() - cleaned_parameters = util_schema.validate(parameters, schema, validator, use_default=True, - allow_default_none=True) + cleaned_parameters = util_schema.validate( + parameters, schema, validator, use_default=True, allow_default_none=True + ) cleaned.parameters = cleaned_parameters @@ -172,13 +131,15 @@ def validate(self): @classmethod def to_model(cls, instance): - return cls.model(id=getattr(instance, 'id', None), - name=str(instance.name), - description=getattr(instance, 'description', None), - pack=str(instance.pack), - ref=getattr(instance, 'ref', None), - enabled=getattr(instance, 'enabled', None), - resource_ref=str(instance.resource_ref), - policy_type=str(instance.policy_type), - parameters=getattr(instance, 'parameters', dict()), - metadata_file=getattr(instance, 'metadata_file', None)) + return cls.model( + id=getattr(instance, "id", None), + name=str(instance.name), + description=getattr(instance, "description", None), + pack=str(instance.pack), + ref=getattr(instance, "ref", None), + enabled=getattr(instance, "enabled", None), + resource_ref=str(instance.resource_ref), + policy_type=str(instance.policy_type), + parameters=getattr(instance, "parameters", dict()), + metadata_file=getattr(instance, "metadata_file", None), + ) diff --git a/st2common/st2common/models/api/rbac.py b/st2common/st2common/models/api/rbac.py index 556793b7a6..bd269ce3d6 100644 --- a/st2common/st2common/models/api/rbac.py +++ b/st2common/st2common/models/api/rbac.py @@ -25,67 +25,55 @@ from st2common.util.uid import parse_uid __all__ = [ - 'RoleAPI', - 'UserRoleAssignmentAPI', - - 'RoleDefinitionFileFormatAPI', - 'UserRoleAssignmentFileFormatAPI', - - 'AuthGroupToRoleMapAssignmentFileFormatAPI' + "RoleAPI", + "UserRoleAssignmentAPI", + "RoleDefinitionFileFormatAPI", + "UserRoleAssignmentFileFormatAPI", + "AuthGroupToRoleMapAssignmentFileFormatAPI", ] class RoleAPI(BaseAPI): model = RoleDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string' - }, - 'permission_grant_ids': { - 'type': 'array', - 'items': { - 'type': 'string' - } - }, - 'permission_grant_objects': { - 'type': 'array', - 'items': { - 'type': 'object' - } - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "name": {"type": "string", "required": True}, + "description": {"type": "string"}, + "permission_grant_ids": {"type": "array", "items": {"type": "string"}}, + "permission_grant_objects": {"type": "array", "items": {"type": "object"}}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod - def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects=True): + def from_model( + cls, model, mask_secrets=False, retrieve_permission_grant_objects=True + ): role = cls._from_model(model, mask_secrets=mask_secrets) # Convert ObjectIDs to strings - role['permission_grant_ids'] = [str(permission_grant) for permission_grant in - model.permission_grants] + role["permission_grant_ids"] = [ + str(permission_grant) for permission_grant in model.permission_grants + ] # Retrieve and include corresponding permission grant objects if retrieve_permission_grant_objects: from st2common.persistence.rbac import PermissionGrant - permission_grant_dbs = PermissionGrant.query(id__in=role['permission_grants']) + + permission_grant_dbs = PermissionGrant.query( + id__in=role["permission_grants"] + ) permission_grant_apis = [] for permission_grant_db in permission_grant_dbs: - permission_grant_api = PermissionGrantAPI.from_model(permission_grant_db) + permission_grant_api = PermissionGrantAPI.from_model( + permission_grant_db + ) permission_grant_apis.append(permission_grant_api) - role['permission_grant_objects'] = permission_grant_apis + role["permission_grant_objects"] = permission_grant_apis return cls(**role) @@ -93,56 +81,30 @@ def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects class UserRoleAssignmentAPI(BaseAPI): model = UserRoleAssignmentDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'user': { - 'type': 'string', - 'required': True - }, - 'role': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string' - }, - 'is_remote': { - 'type': 'boolean' - }, - 'source': { - 'type': 'string' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "user": {"type": "string", "required": True}, + "role": {"type": "string", "required": True}, + "description": {"type": "string"}, + "is_remote": {"type": "boolean"}, + "source": {"type": "string"}, }, - 'additionalProperties': False + "additionalProperties": False, } class PermissionGrantAPI(BaseAPI): model = PermissionGrantDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'resource_uid': { - 'type': 'string', - 'required': True - }, - 'resource_type': { - 'type': 'string', - 'required': True - }, - 'permission_types': { - 'type': 'array' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "resource_uid": {"type": "string", "required": True}, + "resource_type": {"type": "string", "required": True}, + "permission_types": {"type": "array"}, }, - 'additionalProperties': False + "additionalProperties": False, } @@ -152,53 +114,55 @@ class RoleDefinitionFileFormatAPI(BaseAPI): """ schema = { - 'type': 'object', - 'properties': { - 'name': { - 'type': 'string', - 'description': 'Role name', - 'required': True, - 'default': None + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Role name", + "required": True, + "default": None, }, - 'description': { - 'type': 'string', - 'description': 'Role description', - 'required': False + "description": { + "type": "string", + "description": "Role description", + "required": False, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this role is enabled. Note: Disabled roles ' - 'are simply ignored when loading definitions from disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this role is enabled. Note: Disabled roles " + "are simply ignored when loading definitions from disk." + ), + "default": True, }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': None + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": None, }, - 'permission_types': { - 'type': 'array', - 'description': 'A list of permission types to grant', - 'uniqueItems': True, - 'items': { - 'type': 'string', + "permission_types": { + "type": "array", + "description": "A list of permission types to grant", + "uniqueItems": True, + "items": { + "type": "string", # Note: We permission aditional validation for based on the # resource type in other place - 'enum': PermissionType.get_valid_values() + "enum": PermissionType.get_valid_values(), }, - 'default': [] - } - } - } - } + "default": [], + }, + }, + }, + }, }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self): @@ -208,31 +172,43 @@ def validate(self): # Custom validation # Validate that only the correct permission types are used - permission_grants = getattr(self, 'permission_grants', []) + permission_grants = getattr(self, "permission_grants", []) for permission_grant in permission_grants: - resource_uid = permission_grant.get('resource_uid', None) - permission_types = permission_grant.get('permission_types', []) + resource_uid = permission_grant.get("resource_uid", None) + permission_types = permission_grant.get("permission_types", []) if resource_uid: # Permission types which apply to a resource resource_type, _ = parse_uid(uid=resource_uid) - valid_permission_types = PermissionType.get_valid_permissions_for_resource_type( - resource_type=resource_type) + valid_permission_types = ( + PermissionType.get_valid_permissions_for_resource_type( + resource_type=resource_type + ) + ) for permission_type in permission_types: if permission_type not in valid_permission_types: - message = ('Invalid permission type "%s" for resource type "%s"' % - (permission_type, resource_type)) + message = ( + 'Invalid permission type "%s" for resource type "%s"' + % ( + permission_type, + resource_type, + ) + ) raise ValueError(message) else: # Right now we only support single permission type (list) which is global and # doesn't apply to a resource for permission_type in permission_types: if permission_type not in GLOBAL_PERMISSION_TYPES: - valid_global_permission_types = ', '.join(GLOBAL_PERMISSION_TYPES) - message = ('Invalid permission type "%s". Valid global permission types ' - 'which can be used without a resource id are: %s' % - (permission_type, valid_global_permission_types)) + valid_global_permission_types = ", ".join( + GLOBAL_PERMISSION_TYPES + ) + message = ( + 'Invalid permission type "%s". Valid global permission types ' + "which can be used without a resource id are: %s" + % (permission_type, valid_global_permission_types) + ) raise ValueError(message) return cleaned @@ -252,52 +228,53 @@ def validate(self, validate_role_exists=False): if validate_role_exists: # Validate that the referenced roles exist in the db rbac_service = get_rbac_backend().get_service_class() - rbac_service.validate_roles_exists(role_names=self.roles) # pylint: disable=no-member + rbac_service.validate_roles_exists( + role_names=self.roles + ) # pylint: disable=no-member return cleaned class UserRoleAssignmentFileFormatAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'username': { - 'type': 'string', - 'description': 'Username', - 'required': True, - 'default': None + "type": "object", + "properties": { + "username": { + "type": "string", + "description": "Username", + "required": True, + "default": None, }, - 'description': { - 'type': 'string', - 'description': 'Assignment description', - 'required': False, - 'default': None + "description": { + "type": "string", + "description": "Assignment description", + "required": False, + "default": None, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this assignment is enabled. Note: Disabled ' - 'assignments are simply ignored when loading definitions from ' - ' disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this assignment is enabled. Note: Disabled " + "assignments are simply ignored when loading definitions from " + " disk." + ), + "default": True, }, - 'roles': { - 'type': 'array', - 'description': 'Roles assigned to this user', - 'uniqueItems': True, - 'items': { - 'type': 'string' - }, - 'required': True + "roles": { + "type": "array", + "description": "Roles assigned to this user", + "uniqueItems": True, + "items": {"type": "string"}, + "required": True, + }, + "file_path": { + "type": "string", + "description": "Path of the file of where this assignment comes from.", + "default": None, + "required": False, }, - 'file_path': { - 'type': 'string', - 'description': 'Path of the file of where this assignment comes from.', - 'default': None, - 'required': False - } - }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self, validate_role_exists=False): @@ -307,44 +284,46 @@ def validate(self, validate_role_exists=False): class AuthGroupToRoleMapAssignmentFileFormatAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'group': { - 'type': 'string', - 'description': 'Name of the group as returned by auth backend.', - 'required': True + "type": "object", + "properties": { + "group": { + "type": "string", + "description": "Name of the group as returned by auth backend.", + "required": True, }, - 'description': { - 'type': 'string', - 'description': 'Mapping description', - 'required': False, - 'default': None + "description": { + "type": "string", + "description": "Mapping description", + "required": False, + "default": None, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this mapping is enabled. Note: Disabled ' - 'assignments are simply ignored when loading definitions from ' - ' disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this mapping is enabled. Note: Disabled " + "assignments are simply ignored when loading definitions from " + " disk." + ), + "default": True, }, - 'roles': { - 'type': 'array', - 'description': ('StackStorm roles which are assigned to each user which belongs ' - 'to that group.'), - 'uniqueItems': True, - 'items': { - 'type': 'string' - }, - 'required': True + "roles": { + "type": "array", + "description": ( + "StackStorm roles which are assigned to each user which belongs " + "to that group." + ), + "uniqueItems": True, + "items": {"type": "string"}, + "required": True, + }, + "file_path": { + "type": "string", + "description": "Path of the file of where this assignment comes from.", + "default": None, + "required": False, }, - 'file_path': { - 'type': 'string', - 'description': 'Path of the file of where this assignment comes from.', - 'default': None, - 'required': False - } }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self, validate_role_exists=False): diff --git a/st2common/st2common/models/api/rule.py b/st2common/st2common/models/api/rule.py index 716eeec207..8919a2ffc9 100644 --- a/st2common/st2common/models/api/rule.py +++ b/st2common/st2common/models/api/rule.py @@ -20,7 +20,12 @@ from st2common.models.api.base import BaseAPI from st2common.models.api.base import APIUIDMixin from st2common.models.api.tag import TagsHelper -from st2common.models.db.rule import RuleDB, RuleTypeDB, RuleTypeSpecDB, ActionExecutionSpecDB +from st2common.models.db.rule import ( + RuleDB, + RuleTypeDB, + RuleTypeSpecDB, + ActionExecutionSpecDB, +) from st2common.models.system.common import ResourceReference from st2common.persistence.trigger import Trigger import st2common.services.triggers as TriggerService @@ -30,61 +35,52 @@ class RuleTypeSpec(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'ref': { - 'type': 'string', - 'required': True - }, - 'parameters': { - 'type': 'object' - } + "type": "object", + "properties": { + "ref": {"type": "string", "required": True}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False + "additionalProperties": False, } class RuleTypeAPI(BaseAPI): model = RuleTypeDB schema = { - 'title': 'RuleType', - 'description': 'A specific type of rule.', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the rule type.', - 'type': 'string', - 'default': None - }, - 'name': { - 'description': 'The name for the rule type.', - 'type': 'string', - 'required': True + "title": "RuleType", + "description": "A specific type of rule.", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the rule type.", + "type": "string", + "default": None, }, - 'description': { - 'description': 'The description of the rule type.', - 'type': 'string' + "name": { + "description": "The name for the rule type.", + "type": "string", + "required": True, }, - 'enabled': { - 'type': 'boolean', - 'default': True + "description": { + "description": "The description of the rule type.", + "type": "string", }, - 'parameters': { - 'type': 'object' - } + "enabled": {"type": "boolean", "default": True}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, rule_type): - name = getattr(rule_type, 'name', None) - description = getattr(rule_type, 'description', None) - enabled = getattr(rule_type, 'enabled', False) - parameters = getattr(rule_type, 'parameters', {}) + name = getattr(rule_type, "name", None) + description = getattr(rule_type, "description", None) + enabled = getattr(rule_type, "enabled", False) + parameters = getattr(rule_type, "parameters", {}) - return cls.model(name=name, description=description, enabled=enabled, - parameters=parameters) + return cls.model( + name=name, description=description, enabled=enabled, parameters=parameters + ) class RuleAPI(BaseAPI, APIUIDMixin): @@ -113,100 +109,60 @@ class RuleAPI(BaseAPI, APIUIDMixin): status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + model = RuleDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, "ref": { "description": ( "System computed user friendly reference for the rule. " "Provided value will be overridden by computed value." ), - "type": "string" - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string', - 'default': DEFAULT_PACK_NAME - }, - 'description': { - 'type': 'string' + "type": "string", }, - 'type': RuleTypeSpec.schema, - 'trigger': { - 'type': 'object', - 'required': True, - 'properties': { - 'type': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string', - 'require': False - }, - 'parameters': { - 'type': 'object', - 'default': {} - }, - 'ref': { - 'type': 'string', - 'required': False - } + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string", "default": DEFAULT_PACK_NAME}, + "description": {"type": "string"}, + "type": RuleTypeSpec.schema, + "trigger": { + "type": "object", + "required": True, + "properties": { + "type": {"type": "string", "required": True}, + "description": {"type": "string", "require": False}, + "parameters": {"type": "object", "default": {}}, + "ref": {"type": "string", "required": False}, }, - 'additionalProperties': True - }, - 'criteria': { - 'type': 'object', - 'default': {} - }, - 'action': { - 'type': 'object', - 'required': True, - 'properties': { - 'ref': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string', - 'require': False - }, - 'parameters': { - 'type': 'object' - } + "additionalProperties": True, + }, + "criteria": {"type": "object", "default": {}}, + "action": { + "type": "object", + "required": True, + "properties": { + "ref": {"type": "string", "required": True}, + "description": {"type": "string", "require": False}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False - }, - 'enabled': { - 'type': 'boolean', - 'default': False - }, - 'context': { - 'type': 'object' + "additionalProperties": False, }, + "enabled": {"type": "boolean", "default": False}, + "context": {"type": "object"}, "tags": { "description": "User associated metadata assigned to this object.", "type": "array", - "items": {"type": "object"} + "items": {"type": "object"}, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod @@ -215,58 +171,62 @@ def from_model(cls, model, mask_secrets=False, ignore_missing_trigger=False): trigger_db = reference.get_model_by_resource_ref(Trigger, model.trigger) if not ignore_missing_trigger and not trigger_db: - raise ValueError('Missing TriggerDB object for rule %s' % (rule['id'])) + raise ValueError("Missing TriggerDB object for rule %s" % (rule["id"])) if trigger_db: - rule['trigger'] = { - 'type': trigger_db.type, - 'parameters': trigger_db.parameters, - 'ref': model.trigger + rule["trigger"] = { + "type": trigger_db.type, + "parameters": trigger_db.parameters, + "ref": model.trigger, } - rule['tags'] = TagsHelper.from_model(model.tags) + rule["tags"] = TagsHelper.from_model(model.tags) return cls(**rule) @classmethod def to_model(cls, rule): kwargs = {} - kwargs['name'] = getattr(rule, 'name', None) - kwargs['description'] = getattr(rule, 'description', None) + kwargs["name"] = getattr(rule, "name", None) + kwargs["description"] = getattr(rule, "description", None) # Validate trigger parameters # Note: This must happen before we create a trigger, otherwise create trigger could fail # with a cryptic error - trigger = getattr(rule, 'trigger', {}) - trigger_type_ref = trigger.get('type', None) - parameters = trigger.get('parameters', {}) + trigger = getattr(rule, "trigger", {}) + trigger_type_ref = trigger.get("type", None) + parameters = trigger.get("parameters", {}) - validator.validate_trigger_parameters(trigger_type_ref=trigger_type_ref, - parameters=parameters) + validator.validate_trigger_parameters( + trigger_type_ref=trigger_type_ref, parameters=parameters + ) # Create a trigger for the provided rule trigger_db = TriggerService.create_trigger_db_from_rule(rule) - kwargs['trigger'] = reference.get_str_resource_ref_from_model(trigger_db) + kwargs["trigger"] = reference.get_str_resource_ref_from_model(trigger_db) - kwargs['pack'] = getattr(rule, 'pack', DEFAULT_PACK_NAME) - kwargs['ref'] = ResourceReference.to_string_reference(pack=kwargs['pack'], - name=kwargs['name']) + kwargs["pack"] = getattr(rule, "pack", DEFAULT_PACK_NAME) + kwargs["ref"] = ResourceReference.to_string_reference( + pack=kwargs["pack"], name=kwargs["name"] + ) # Validate criteria - kwargs['criteria'] = dict(getattr(rule, 'criteria', {})) - validator.validate_criteria(kwargs['criteria']) + kwargs["criteria"] = dict(getattr(rule, "criteria", {})) + validator.validate_criteria(kwargs["criteria"]) - kwargs['action'] = ActionExecutionSpecDB(ref=rule.action['ref'], - parameters=rule.action.get('parameters', {})) + kwargs["action"] = ActionExecutionSpecDB( + ref=rule.action["ref"], parameters=rule.action.get("parameters", {}) + ) - rule_type = dict(getattr(rule, 'type', {})) + rule_type = dict(getattr(rule, "type", {})) if rule_type: - kwargs['type'] = RuleTypeSpecDB(ref=rule_type['ref'], - parameters=rule_type.get('parameters', {})) + kwargs["type"] = RuleTypeSpecDB( + ref=rule_type["ref"], parameters=rule_type.get("parameters", {}) + ) - kwargs['enabled'] = getattr(rule, 'enabled', False) - kwargs['context'] = getattr(rule, 'context', dict()) - kwargs['tags'] = TagsHelper.to_model(getattr(rule, 'tags', [])) - kwargs['metadata_file'] = getattr(rule, 'metadata_file', None) + kwargs["enabled"] = getattr(rule, "enabled", False) + kwargs["context"] = getattr(rule, "context", dict()) + kwargs["tags"] = TagsHelper.to_model(getattr(rule, "tags", [])) + kwargs["metadata_file"] = getattr(rule, "metadata_file", None) model = cls.model(**kwargs) return model @@ -277,13 +237,5 @@ class RuleViewAPI(RuleAPI): # Always deep-copy to avoid breaking the original. schema = copy.deepcopy(RuleAPI.schema) # Update the schema to include the description properties - schema['properties']['action'].update({ - 'description': { - 'type': 'string' - } - }) - schema['properties']['trigger'].update({ - 'description': { - 'type': 'string' - } - }) + schema["properties"]["action"].update({"description": {"type": "string"}}) + schema["properties"]["trigger"].update({"description": {"type": "string"}}) diff --git a/st2common/st2common/models/api/rule_enforcement.py b/st2common/st2common/models/api/rule_enforcement.py index c950b59bfe..d7aa1bc873 100644 --- a/st2common/st2common/models/api/rule_enforcement.py +++ b/st2common/st2common/models/api/rule_enforcement.py @@ -28,95 +28,98 @@ from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUSES from st2common.util import isotime -__all__ = [ - 'RuleEnforcementAPI', - 'RuleEnforcementViewAPI', - - 'RuleReferenceSpecDB' -] +__all__ = ["RuleEnforcementAPI", "RuleEnforcementViewAPI", "RuleReferenceSpecDB"] class RuleReferenceSpec(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'ref': { - 'type': 'string', - 'required': True, + "type": "object", + "properties": { + "ref": { + "type": "string", + "required": True, }, - 'uid': { - 'type': 'string', - 'required': True, + "uid": { + "type": "string", + "required": True, }, - 'id': { - 'type': 'string', - 'required': False, + "id": { + "type": "string", + "required": False, }, }, - 'additionalProperties': False + "additionalProperties": False, } class RuleEnforcementAPI(BaseAPI): model = RuleEnforcementDB schema = { - 'title': 'RuleEnforcement', - 'description': 'A specific instance of rule enforcement.', - 'type': 'object', - 'properties': { - 'trigger_instance_id': { - 'description': 'The unique identifier for the trigger instance ' + - 'that flipped the rule.', - 'type': 'string', - 'required': True + "title": "RuleEnforcement", + "description": "A specific instance of rule enforcement.", + "type": "object", + "properties": { + "trigger_instance_id": { + "description": "The unique identifier for the trigger instance " + + "that flipped the rule.", + "type": "string", + "required": True, }, - 'execution_id': { - 'description': 'ID of the action execution that was invoked as a response.', - 'type': 'string' + "execution_id": { + "description": "ID of the action execution that was invoked as a response.", + "type": "string", }, - 'failure_reason': { - 'description': 'Reason for failure to execute the action specified in the rule.', - 'type': 'string' + "failure_reason": { + "description": "Reason for failure to execute the action specified in the rule.", + "type": "string", }, - 'rule': RuleReferenceSpec.schema, - 'enforced_at': { - 'description': 'Timestamp when rule enforcement happened.', - 'type': 'string', - 'required': True + "rule": RuleReferenceSpec.schema, + "enforced_at": { + "description": "Timestamp when rule enforcement happened.", + "type": "string", + "required": True, }, "status": { "description": "Rule enforcement status.", "type": "string", - "enum": RULE_ENFORCEMENT_STATUSES + "enum": RULE_ENFORCEMENT_STATUSES, }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, rule_enforcement): - trigger_instance_id = getattr(rule_enforcement, 'trigger_instance_id', None) - execution_id = getattr(rule_enforcement, 'execution_id', None) - enforced_at = getattr(rule_enforcement, 'enforced_at', None) - failure_reason = getattr(rule_enforcement, 'failure_reason', None) - status = getattr(rule_enforcement, 'status', RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - rule_ref_model = dict(getattr(rule_enforcement, 'rule', {})) - rule = RuleReferenceSpecDB(ref=rule_ref_model['ref'], id=rule_ref_model['id'], - uid=rule_ref_model['uid']) + trigger_instance_id = getattr(rule_enforcement, "trigger_instance_id", None) + execution_id = getattr(rule_enforcement, "execution_id", None) + enforced_at = getattr(rule_enforcement, "enforced_at", None) + failure_reason = getattr(rule_enforcement, "failure_reason", None) + status = getattr(rule_enforcement, "status", RULE_ENFORCEMENT_STATUS_SUCCEEDED) + + rule_ref_model = dict(getattr(rule_enforcement, "rule", {})) + rule = RuleReferenceSpecDB( + ref=rule_ref_model["ref"], + id=rule_ref_model["id"], + uid=rule_ref_model["uid"], + ) if enforced_at: enforced_at = isotime.parse(enforced_at) - return cls.model(trigger_instance_id=trigger_instance_id, execution_id=execution_id, - failure_reason=failure_reason, enforced_at=enforced_at, rule=rule, - status=status) + return cls.model( + trigger_instance_id=trigger_instance_id, + execution_id=execution_id, + failure_reason=failure_reason, + enforced_at=enforced_at, + rule=rule, + status=status, + ) @classmethod def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) enforced_at = isotime.format(model.enforced_at, offset=False) - doc['enforced_at'] = enforced_at + doc["enforced_at"] = enforced_at attrs = {attr: value for attr, value in six.iteritems(doc) if value} return cls(**attrs) @@ -126,7 +129,7 @@ class RuleEnforcementViewAPI(RuleEnforcementAPI): schema = copy.deepcopy(RuleEnforcementAPI.schema) # Update the schema to include additional execution properties - schema['properties']['execution'] = copy.deepcopy(ActionExecutionAPI.schema) + schema["properties"]["execution"] = copy.deepcopy(ActionExecutionAPI.schema) # Update the schema to include additional trigger instance properties - schema['properties']['trigger_instance'] = copy.deepcopy(TriggerInstanceAPI.schema) + schema["properties"]["trigger_instance"] = copy.deepcopy(TriggerInstanceAPI.schema) diff --git a/st2common/st2common/models/api/sensor.py b/st2common/st2common/models/api/sensor.py index af9c687611..a2ba978adf 100644 --- a/st2common/st2common/models/api/sensor.py +++ b/st2common/st2common/models/api/sensor.py @@ -22,53 +22,34 @@ class SensorTypeAPI(BaseAPI): model = SensorTypeDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'artifact_uri': { - 'type': 'string', - }, - 'entry_point': { - 'type': 'string', - }, - 'enabled': { - 'description': 'Enable or disable the sensor.', - 'type': 'boolean', - 'default': True + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "description": {"type": "string"}, + "artifact_uri": { + "type": "string", }, - 'trigger_types': { - 'type': 'array', - 'default': [] + "entry_point": { + "type": "string", }, - 'poll_interval': { - 'type': 'number' + "enabled": { + "description": "Enable or disable the sensor.", + "type": "boolean", + "default": True, }, + "trigger_types": {"type": "array", "default": []}, + "poll_interval": {"type": "number"}, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod diff --git a/st2common/st2common/models/api/tag.py b/st2common/st2common/models/api/tag.py index 78d92568e0..0ed763a51f 100644 --- a/st2common/st2common/models/api/tag.py +++ b/st2common/st2common/models/api/tag.py @@ -16,19 +16,19 @@ from __future__ import absolute_import from st2common.models.db.stormbase import TagField -__all__ = [ - 'TagsHelper' -] +__all__ = ["TagsHelper"] class TagsHelper(object): - @staticmethod def to_model(tags): tags = tags or [] - return [TagField(name=tag.get('name', ''), value=tag.get('value', '')) for tag in tags] + return [ + TagField(name=tag.get("name", ""), value=tag.get("value", "")) + for tag in tags + ] @staticmethod def from_model(tags): tags = tags or [] - return [{'name': tag.name, 'value': tag.value} for tag in tags] + return [{"name": tag.name, "value": tag.value} for tag in tags] diff --git a/st2common/st2common/models/api/trace.py b/st2common/st2common/models/api/trace.py index f09faf6d36..4ce8ec8942 100644 --- a/st2common/st2common/models/api/trace.py +++ b/st2common/st2common/models/api/trace.py @@ -21,141 +21,148 @@ TraceComponentAPISchema = { - 'type': 'object', - 'properties': { - 'object_id': { - 'type': 'string', - 'description': 'Id of the component', - 'required': True + "type": "object", + "properties": { + "object_id": { + "type": "string", + "description": "Id of the component", + "required": True, }, - 'ref': { - 'type': 'string', - 'description': 'ref of the component', - 'required': False + "ref": { + "type": "string", + "description": "ref of the component", + "required": False, }, - 'updated_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "updated_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, - 'caused_by': { - 'type': 'object', - 'description': 'Component that is the cause or the predecesor.', - 'properties': { - 'id': { - 'description': 'Id of the causal component.', - 'type': 'string' + "caused_by": { + "type": "object", + "description": "Component that is the cause or the predecesor.", + "properties": { + "id": {"description": "Id of the causal component.", "type": "string"}, + "type": { + "description": "Type of the causal component.", + "type": "string", }, - 'type': { - 'description': 'Type of the causal component.', - 'type': 'string' - } - } - } + }, + }, }, - 'additionalProperties': False + "additionalProperties": False, } class TraceAPI(BaseAPI, APIUIDMixin): model = TraceDB schema = { - 'title': 'Trace', - 'desciption': 'Trace is a collection of all TriggerInstances, Rules and ActionExecutions \ + "title": "Trace", + "desciption": "Trace is a collection of all TriggerInstances, Rules and ActionExecutions \ that represent an activity which begins with the introduction of a \ TriggerInstance or request of an ActionExecution and ends with the \ - completion of an ActionExecution.', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for a Trace.', - 'type': 'string', - 'default': None + completion of an ActionExecution.", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for a Trace.", + "type": "string", + "default": None, }, - 'trace_tag': { - 'description': 'User assigned identifier for each Trace.', - 'type': 'string', - 'required': True + "trace_tag": { + "description": "User assigned identifier for each Trace.", + "type": "string", + "required": True, }, - 'action_executions': { - 'description': 'All ActionExecutions belonging to a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "action_executions": { + "description": "All ActionExecutions belonging to a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'rules': { - 'description': 'All rules that applied as part of a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "rules": { + "description": "All rules that applied as part of a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'trigger_instances': { - 'description': 'All TriggerInstances fired during a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "trigger_instances": { + "description": "All TriggerInstances fired during a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'start_timestamp': { - 'description': 'Timestamp when the Trace is started.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "start_timestamp": { + "description": "Timestamp when the Trace is started.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_component_model(cls, component): values = { - 'object_id': component['object_id'], - 'ref': component['ref'], - 'caused_by': component.get('caused_by', {}) + "object_id": component["object_id"], + "ref": component["ref"], + "caused_by": component.get("caused_by", {}), } - updated_at = component.get('updated_at', None) + updated_at = component.get("updated_at", None) if updated_at: - values['updated_at'] = isotime.parse(updated_at) + values["updated_at"] = isotime.parse(updated_at) return TraceComponentDB(**values) @classmethod def to_model(cls, instance): - values = { - 'trace_tag': instance.trace_tag - } - action_executions = getattr(instance, 'action_executions', []) - action_executions = [TraceAPI.to_component_model(component=action_execution) - for action_execution in action_executions] - values['action_executions'] = action_executions - - rules = getattr(instance, 'rules', []) + values = {"trace_tag": instance.trace_tag} + action_executions = getattr(instance, "action_executions", []) + action_executions = [ + TraceAPI.to_component_model(component=action_execution) + for action_execution in action_executions + ] + values["action_executions"] = action_executions + + rules = getattr(instance, "rules", []) rules = [TraceAPI.to_component_model(component=rule) for rule in rules] - values['rules'] = rules + values["rules"] = rules - trigger_instances = getattr(instance, 'trigger_instances', []) - trigger_instances = [TraceAPI.to_component_model(component=trigger_instance) - for trigger_instance in trigger_instances] - values['trigger_instances'] = trigger_instances + trigger_instances = getattr(instance, "trigger_instances", []) + trigger_instances = [ + TraceAPI.to_component_model(component=trigger_instance) + for trigger_instance in trigger_instances + ] + values["trigger_instances"] = trigger_instances - start_timestamp = getattr(instance, 'start_timestamp', None) + start_timestamp = getattr(instance, "start_timestamp", None) if start_timestamp: - values['start_timestamp'] = isotime.parse(start_timestamp) + values["start_timestamp"] = isotime.parse(start_timestamp) return cls.model(**values) @classmethod def from_component_model(cls, component_model): - return {'object_id': component_model.object_id, - 'ref': component_model.ref, - 'updated_at': isotime.format(component_model.updated_at, offset=False), - 'caused_by': component_model.caused_by} + return { + "object_id": component_model.object_id, + "ref": component_model.ref, + "updated_at": isotime.format(component_model.updated_at, offset=False), + "caused_by": component_model.caused_by, + } @classmethod def from_model(cls, model, mask_secrets=False): instance = cls._from_model(model, mask_secrets=mask_secrets) - instance['start_timestamp'] = isotime.format(model.start_timestamp, offset=False) + instance["start_timestamp"] = isotime.format( + model.start_timestamp, offset=False + ) if model.action_executions: - instance['action_executions'] = [cls.from_component_model(action_execution) - for action_execution in model.action_executions] + instance["action_executions"] = [ + cls.from_component_model(action_execution) + for action_execution in model.action_executions + ] if model.rules: - instance['rules'] = [cls.from_component_model(rule) for rule in model.rules] + instance["rules"] = [cls.from_component_model(rule) for rule in model.rules] if model.trigger_instances: - instance['trigger_instances'] = [cls.from_component_model(trigger_instance) - for trigger_instance in model.trigger_instances] + instance["trigger_instances"] = [ + cls.from_component_model(trigger_instance) + for trigger_instance in model.trigger_instances + ] return cls(**instance) @@ -173,12 +180,13 @@ class TraceContext(object): Optional property. :type trace_tag: ``str`` """ + def __init__(self, id_=None, trace_tag=None): self.id_ = id_ self.trace_tag = trace_tag def __str__(self): - return '{id_: %s, trace_tag: %s}' % (self.id_, self.trace_tag) + return "{id_: %s, trace_tag: %s}" % (self.id_, self.trace_tag) def __json__(self): return vars(self) diff --git a/st2common/st2common/models/api/trigger.py b/st2common/st2common/models/api/trigger.py index af88027fe0..cdb2cd9ddd 100644 --- a/st2common/st2common/models/api/trigger.py +++ b/st2common/st2common/models/api/trigger.py @@ -23,140 +23,113 @@ from st2common.models.db.trigger import TriggerTypeDB, TriggerDB, TriggerInstanceDB from st2common.models.system.common import ResourceReference -DATE_FORMAT = '%Y-%m-%d %H:%M:%S.%f' +DATE_FORMAT = "%Y-%m-%d %H:%M:%S.%f" class TriggerTypeAPI(BaseAPI): model = TriggerTypeDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'payload_schema': { - 'type': 'object', - 'default': {} - }, - 'parameters_schema': { - 'type': 'object', - 'default': {} - }, - 'tags': { - 'description': 'User associated metadata assigned to this object.', - 'type': 'array', - 'items': {'type': 'object'} + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "description": {"type": "string"}, + "payload_schema": {"type": "object", "default": {}}, + "parameters_schema": {"type": "object", "default": {}}, + "tags": { + "description": "User associated metadata assigned to this object.", + "type": "array", + "items": {"type": "object"}, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, trigger_type): - name = getattr(trigger_type, 'name', None) - description = getattr(trigger_type, 'description', None) - pack = getattr(trigger_type, 'pack', None) - payload_schema = getattr(trigger_type, 'payload_schema', {}) - parameters_schema = getattr(trigger_type, 'parameters_schema', {}) - tags = TagsHelper.to_model(getattr(trigger_type, 'tags', [])) - metadata_file = getattr(trigger_type, 'metadata_file', None) - - model = cls.model(name=name, description=description, pack=pack, - payload_schema=payload_schema, parameters_schema=parameters_schema, - tags=tags, metadata_file=metadata_file) + name = getattr(trigger_type, "name", None) + description = getattr(trigger_type, "description", None) + pack = getattr(trigger_type, "pack", None) + payload_schema = getattr(trigger_type, "payload_schema", {}) + parameters_schema = getattr(trigger_type, "parameters_schema", {}) + tags = TagsHelper.to_model(getattr(trigger_type, "tags", [])) + metadata_file = getattr(trigger_type, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + pack=pack, + payload_schema=payload_schema, + parameters_schema=parameters_schema, + tags=tags, + metadata_file=metadata_file, + ) return model @classmethod def from_model(cls, model, mask_secrets=False): triggertype = cls._from_model(model, mask_secrets=mask_secrets) - triggertype['tags'] = TagsHelper.from_model(model.tags) + triggertype["tags"] = TagsHelper.from_model(model.tags) return cls(**triggertype) class TriggerAPI(BaseAPI): model = TriggerDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string' - }, - 'pack': { - 'type': 'string' - }, - 'type': { - 'type': 'string', - 'required': True - }, - 'parameters': { - 'type': 'object' - }, - 'description': { - 'type': 'string' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string"}, + "pack": {"type": "string"}, + "type": {"type": "string", "required": True}, + "parameters": {"type": "object"}, + "description": {"type": "string"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): trigger = cls._from_model(model, mask_secrets=mask_secrets) # Hide ref count from API. - trigger.pop('ref_count', None) + trigger.pop("ref_count", None) return cls(**trigger) @classmethod def to_model(cls, trigger): - name = getattr(trigger, 'name', None) - description = getattr(trigger, 'description', None) - pack = getattr(trigger, 'pack', None) - _type = getattr(trigger, 'type', None) - parameters = getattr(trigger, 'parameters', {}) + name = getattr(trigger, "name", None) + description = getattr(trigger, "description", None) + pack = getattr(trigger, "pack", None) + _type = getattr(trigger, "type", None) + parameters = getattr(trigger, "parameters", {}) if _type and not parameters: trigger_type_ref = ResourceReference.from_string_reference(_type) name = trigger_type_ref.name - if hasattr(trigger, 'name') and trigger.name: + if hasattr(trigger, "name") and trigger.name: name = trigger.name else: # assign a name if none is provided. name = str(uuid.uuid4()) - model = cls.model(name=name, description=description, pack=pack, type=_type, - parameters=parameters) + model = cls.model( + name=name, + description=description, + pack=pack, + type=_type, + parameters=parameters, + ) return model def to_dict(self): @@ -167,38 +140,29 @@ def to_dict(self): class TriggerInstanceAPI(BaseAPI): model = TriggerInstanceDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' - }, - 'occurrence_time': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX - }, - 'payload': { - 'type': 'object' - }, - 'trigger': { - 'type': 'string', - 'default': None, - 'required': True + "type": "object", + "properties": { + "id": {"type": "string"}, + "occurrence_time": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX}, + "payload": {"type": "object"}, + "trigger": {"type": "string", "default": None, "required": True}, + "status": { + "type": "string", + "default": None, + "enum": TRIGGER_INSTANCE_STATUSES, }, - 'status': { - 'type': 'string', - 'default': None, - 'enum': TRIGGER_INSTANCE_STATUSES - } }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): instance = cls._from_model(model, mask_secrets=mask_secrets) - if instance.get('occurrence_time', None): - instance['occurrence_time'] = isotime.format(instance['occurrence_time'], offset=False) + if instance.get("occurrence_time", None): + instance["occurrence_time"] = isotime.format( + instance["occurrence_time"], offset=False + ) return cls(**instance) @@ -209,6 +173,10 @@ def to_model(cls, instance): occurrence_time = isotime.parse(instance.occurrence_time) status = instance.status - model = cls.model(trigger=trigger, payload=payload, occurrence_time=occurrence_time, - status=status) + model = cls.model( + trigger=trigger, + payload=payload, + occurrence_time=occurrence_time, + status=status, + ) return model diff --git a/st2common/st2common/models/api/webhook.py b/st2common/st2common/models/api/webhook.py index 9d1a37ed1d..eb7b04a29b 100644 --- a/st2common/st2common/models/api/webhook.py +++ b/st2common/st2common/models/api/webhook.py @@ -15,20 +15,15 @@ from st2common.models.api.base import BaseAPI -__all___ = [ - 'WebhookBodyAPI' -] +__all___ = ["WebhookBodyAPI"] class WebhookBodyAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { + "type": "object", + "properties": { # Holds actual webhook body - 'data': { - 'type': ['object', 'array'], - 'required': True - } + "data": {"type": ["object", "array"], "required": True} }, - 'additionalProperties': False + "additionalProperties": False, } diff --git a/st2common/st2common/models/base.py b/st2common/st2common/models/base.py index 342daf7028..35d5c884a7 100644 --- a/st2common/st2common/models/base.py +++ b/st2common/st2common/models/base.py @@ -17,9 +17,7 @@ Common model related classes. """ -__all__ = [ - 'DictSerializableClassMixin' -] +__all__ = ["DictSerializableClassMixin"] class DictSerializableClassMixin(object): diff --git a/st2common/st2common/models/db/__init__.py b/st2common/st2common/models/db/__init__.py index 4fd51b4f61..ee7261facd 100644 --- a/st2common/st2common/models/db/__init__.py +++ b/st2common/st2common/models/db/__init__.py @@ -40,32 +40,30 @@ LOG = logging.getLogger(__name__) MODEL_MODULE_NAMES = [ - 'st2common.models.db.auth', - 'st2common.models.db.action', - 'st2common.models.db.actionalias', - 'st2common.models.db.keyvalue', - 'st2common.models.db.execution', - 'st2common.models.db.executionstate', - 'st2common.models.db.execution_queue', - 'st2common.models.db.liveaction', - 'st2common.models.db.notification', - 'st2common.models.db.pack', - 'st2common.models.db.policy', - 'st2common.models.db.rbac', - 'st2common.models.db.rule', - 'st2common.models.db.rule_enforcement', - 'st2common.models.db.runner', - 'st2common.models.db.sensor', - 'st2common.models.db.trace', - 'st2common.models.db.trigger', - 'st2common.models.db.webhook', - 'st2common.models.db.workflow' + "st2common.models.db.auth", + "st2common.models.db.action", + "st2common.models.db.actionalias", + "st2common.models.db.keyvalue", + "st2common.models.db.execution", + "st2common.models.db.executionstate", + "st2common.models.db.execution_queue", + "st2common.models.db.liveaction", + "st2common.models.db.notification", + "st2common.models.db.pack", + "st2common.models.db.policy", + "st2common.models.db.rbac", + "st2common.models.db.rule", + "st2common.models.db.rule_enforcement", + "st2common.models.db.runner", + "st2common.models.db.sensor", + "st2common.models.db.trace", + "st2common.models.db.trigger", + "st2common.models.db.webhook", + "st2common.models.db.workflow", ] # A list of model names for which we don't perform extra index cleanup -INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [ - 'PermissionGrantDB' -] +INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = ["PermissionGrantDB"] # Reference to DB model classes used for db_ensure_indexes # NOTE: This variable is populated lazily inside get_model_classes() @@ -86,55 +84,78 @@ def get_model_classes(): result = [] for module_name in MODEL_MODULE_NAMES: module = importlib.import_module(module_name) - model_classes = getattr(module, 'MODELS', []) + model_classes = getattr(module, "MODELS", []) result.extend(model_classes) MODEL_CLASSES = result return MODEL_CLASSES -def _db_connect(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True): - - if '://' in db_host: +def _db_connect( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + if "://" in db_host: # Hostname is provided as a URI string. Make sure we don't log the password in case one is # included as part of the URI string. uri_dict = uri_parser.parse_uri(db_host) - username_string = uri_dict.get('username', username) or username + username_string = uri_dict.get("username", username) or username - if uri_dict.get('username', None) and username: + if uri_dict.get("username", None) and username: # Username argument has precedence over connection string username username_string = username hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict) - if len(uri_dict['nodelist']) > 1: - host_string = '%s (replica set)' % (hostnames) + if len(uri_dict["nodelist"]) > 1: + host_string = "%s (replica set)" % (hostnames) else: host_string = hostnames else: - host_string = '%s:%s' % (db_host, db_port) + host_string = "%s:%s" % (db_host, db_port) username_string = username - LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string, - str(username_string))) - - ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + LOG.info( + 'Connecting to database "%s" @ "%s" as user "%s".' + % (db_name, host_string, str(username_string)) + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) # NOTE: We intentionally set "serverSelectionTimeoutMS" to 3 seconds. By default it's set to # 30 seconds, which means it will block up to 30 seconds and fail if there are any SSL related # or other errors connection_timeout = cfg.CONF.database.connection_timeout - connection = mongoengine.connection.connect(db_name, host=db_host, - port=db_port, tz_aware=True, - username=username, password=password, - connectTimeoutMS=connection_timeout, - serverSelectionTimeoutMS=connection_timeout, - **ssl_kwargs) + connection = mongoengine.connection.connect( + db_name, + host=db_host, + port=db_port, + tz_aware=True, + username=username, + password=password, + connectTimeoutMS=connection_timeout, + serverSelectionTimeoutMS=connection_timeout, + **ssl_kwargs, + ) # NOTE: Since pymongo 3.0, connect() method is lazy and not blocking (always returns success) # so we need to issue a command / query to check if connection has been @@ -142,32 +163,55 @@ def _db_connect(db_name, db_host, db_port, username=None, password=None, # See http://api.mongodb.com/python/current/api/pymongo/mongo_client.html for details try: # The ismaster command is cheap and does not require auth - connection.admin.command('ismaster') + connection.admin.command("ismaster") except (ConnectionFailure, ServerSelectionTimeoutError) as e: # NOTE: ServerSelectionTimeoutError can also be thrown if SSLHandShake fails in the server # Sadly the client doesn't include more information about the error so in such scenarios # user needs to check MongoDB server log - LOG.error('Failed to connect to database "%s" @ "%s" as user "%s": %s' % - (db_name, host_string, str(username_string), six.text_type(e))) + LOG.error( + 'Failed to connect to database "%s" @ "%s" as user "%s": %s' + % (db_name, host_string, str(username_string), six.text_type(e)) + ) raise e - LOG.info('Successfully connected to database "%s" @ "%s" as user "%s".' % ( - db_name, host_string, str(username_string))) + LOG.info( + 'Successfully connected to database "%s" @ "%s" as user "%s".' + % (db_name, host_string, str(username_string)) + ) return connection -def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True, - ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): - - connection = _db_connect(db_name, db_host, db_port, username=username, - password=password, ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) +def db_setup( + db_name, + db_host, + db_port, + username=None, + password=None, + ensure_indexes=True, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + connection = _db_connect( + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) # Create all the indexes upfront to prevent race-conditions caused by # lazy index creation @@ -192,7 +236,7 @@ def db_ensure_indexes(model_classes=None): ensured for all the models. :type model_classes: ``list`` """ - LOG.debug('Ensuring database indexes...') + LOG.debug("Ensuring database indexes...") if not model_classes: model_classes = get_model_classes() @@ -210,34 +254,44 @@ def db_ensure_indexes(model_classes=None): # Note: This condition would only be encountered when upgrading existing StackStorm # installation from MongoDB 3.2 to 3.4. msg = six.text_type(e) - if 'already exists with different options' in msg and 'uid_1' in msg: + if "already exists with different options" in msg and "uid_1" in msg: drop_obsolete_types_indexes(model_class=model_class) else: raise e except Exception as e: tb_msg = traceback.format_exc() - msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, six.text_type(e)) - msg += '\n\n' + tb_msg + msg = 'Failed to ensure indexes for model "%s": %s' % ( + class_name, + six.text_type(e), + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST: - LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name)) + LOG.debug( + 'Skipping index cleanup for blacklisted model "%s"...' % (class_name) + ) continue removed_count = cleanup_extra_indexes(model_class=model_class) if removed_count: - LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name)) + LOG.debug( + 'Removed "%s" extra indexes for model "%s"' + % (removed_count, class_name) + ) - LOG.debug('Indexes are ensured for models: %s' % - ', '.join(sorted((model_class.__name__ for model_class in model_classes)))) + LOG.debug( + "Indexes are ensured for models: %s" + % ", ".join(sorted((model_class.__name__ for model_class in model_classes))) + ) def cleanup_extra_indexes(model_class): """ Finds any extra indexes and removes those from mongodb. """ - extra_indexes = model_class.compare_indexes().get('extra', None) + extra_indexes = model_class.compare_indexes().get("extra", None) if not extra_indexes: return 0 @@ -248,10 +302,14 @@ def cleanup_extra_indexes(model_class): for extra_index in extra_indexes: try: c.drop_index(extra_index) - LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__) + LOG.debug( + "Dropped index %s for model %s.", extra_index, model_class.__name__ + ) removed_count += 1 except OperationFailure: - LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True) + LOG.warning( + "Attempt to cleanup index %s failed.", extra_index, exc_info=True + ) return removed_count @@ -266,14 +324,19 @@ def drop_obsolete_types_indexes(model_class): LOG.debug('Dropping obsolete types index for model "%s"' % (class_name)) collection = model_class._get_collection() - collection.update({}, {'$unset': {'_types': 1}}, multi=True) + collection.update({}, {"$unset": {"_types": 1}}, multi=True) info = collection.index_information() - indexes_to_drop = [key for key, value in six.iteritems(info) - if '_types' in dict(value['key']) or 'types' in value] + indexes_to_drop = [ + key + for key, value in six.iteritems(info) + if "_types" in dict(value["key"]) or "types" in value + ] - LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name, - str(indexes_to_drop))) + LOG.debug( + 'Will drop obsolete types indexes for model "%s": %s' + % (class_name, str(indexes_to_drop)) + ) for index in indexes_to_drop: collection.drop_index(index) @@ -286,57 +349,87 @@ def db_teardown(): mongoengine.connection.disconnect() -def db_cleanup(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): - - connection = _db_connect(db_name, db_host, db_port, username=username, - password=password, ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) - - LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".', - db_name, db_host, db_port, str(username)) +def db_cleanup( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + connection = _db_connect( + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) + + LOG.info( + 'Dropping database "%s" @ "%s:%s" as user "%s".', + db_name, + db_host, + db_port, + str(username), + ) connection.drop_database(db_name) return connection -def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True): +def _get_ssl_kwargs( + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): # NOTE: In pymongo 3.9.0 some of the ssl related arguments have been renamed - # https://api.mongodb.com/python/current/changelog.html#changes-in-version-3-9-0 # Old names still work, but we should eventually update to new argument names. ssl_kwargs = { - 'ssl': ssl, + "ssl": ssl, } if ssl_keyfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_keyfile'] = ssl_keyfile + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_keyfile"] = ssl_keyfile if ssl_certfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_certfile'] = ssl_certfile + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_certfile"] = ssl_certfile if ssl_cert_reqs: - if ssl_cert_reqs == 'none': + if ssl_cert_reqs == "none": ssl_cert_reqs = ssl_lib.CERT_NONE - elif ssl_cert_reqs == 'optional': + elif ssl_cert_reqs == "optional": ssl_cert_reqs = ssl_lib.CERT_OPTIONAL - elif ssl_cert_reqs == 'required': + elif ssl_cert_reqs == "required": ssl_cert_reqs = ssl_lib.CERT_REQUIRED - ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs + ssl_kwargs["ssl_cert_reqs"] = ssl_cert_reqs if ssl_ca_certs: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_ca_certs"] = ssl_ca_certs if authentication_mechanism: - ssl_kwargs['ssl'] = True - ssl_kwargs['authentication_mechanism'] = authentication_mechanism - if ssl_kwargs.get('ssl', False): + ssl_kwargs["ssl"] = True + ssl_kwargs["authentication_mechanism"] = authentication_mechanism + if ssl_kwargs.get("ssl", False): # pass in ssl_match_hostname only if ssl is True. The right default value # for ssl_match_hostname in almost all cases is True. - ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname + ssl_kwargs["ssl_match_hostname"] = ssl_match_hostname return ssl_kwargs @@ -362,9 +455,9 @@ def get_by_pack(self, value): return self.get(pack=value, raise_exception=True) def get(self, *args, **kwargs): - exclude_fields = kwargs.pop('exclude_fields', None) - raise_exception = kwargs.pop('raise_exception', False) - only_fields = kwargs.pop('only_fields', None) + exclude_fields = kwargs.pop("exclude_fields", None) + raise_exception = kwargs.pop("raise_exception", False) + only_fields = kwargs.pop("only_fields", None) args = self._process_arg_filters(args) @@ -377,14 +470,17 @@ def get(self, *args, **kwargs): try: instances = instances.only(*only_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: - msg = ('Invalid or unsupported include attribute specified: %s' % six.text_type(e)) + msg = ( + "Invalid or unsupported include attribute specified: %s" + % six.text_type(e) + ) raise ValueError(msg) instance = instances[0] if instances else None log_query_and_profile_data_for_queryset(queryset=instances) if not instance and raise_exception: - msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs) + msg = "Unable to find the %s instance. %s" % (self.model.__name__, kwargs) raise db_exc.StackStormDBObjectNotFoundError(msg) return instance @@ -404,12 +500,12 @@ def count(self, *args, **kwargs): # **filters): def query(self, *args, **filters): # Python 2: Pop keyword parameters that aren't actually filters off of the kwargs - offset = filters.pop('offset', 0) - limit = filters.pop('limit', None) - order_by = filters.pop('order_by', None) - exclude_fields = filters.pop('exclude_fields', None) - only_fields = filters.pop('only_fields', None) - no_dereference = filters.pop('no_dereference', None) + offset = filters.pop("offset", 0) + limit = filters.pop("limit", None) + order_by = filters.pop("order_by", None) + exclude_fields = filters.pop("exclude_fields", None) + only_fields = filters.pop("only_fields", None) + no_dereference = filters.pop("no_dereference", None) order_by = order_by or [] exclude_fields = exclude_fields or [] @@ -419,7 +515,9 @@ def query(self, *args, **filters): # Process the filters # Note: Both of those functions manipulate "filters" variable so the order in which they # are called matters - filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by) + filters, order_by = self._process_datetime_range_filters( + filters=filters, order_by=order_by + ) filters = self._process_null_filters(filters=filters) result = self.model.objects(*args, **filters) @@ -429,7 +527,7 @@ def query(self, *args, **filters): result = result.exclude(*exclude_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: field = get_field_name_from_mongoengine_error(e) - msg = ('Invalid or unsupported exclude attribute specified: %s' % field) + msg = "Invalid or unsupported exclude attribute specified: %s" % field raise ValueError(msg) if only_fields: @@ -437,7 +535,7 @@ def query(self, *args, **filters): result = result.only(*only_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: field = get_field_name_from_mongoengine_error(e) - msg = ('Invalid or unsupported include attribute specified: %s' % field) + msg = "Invalid or unsupported include attribute specified: %s" % field raise ValueError(msg) if no_dereference: @@ -450,7 +548,7 @@ def query(self, *args, **filters): return result def distinct(self, *args, **kwargs): - field = kwargs.pop('field') + field = kwargs.pop("field") result = self.model.objects(**kwargs).distinct(field) log_query_and_profile_data_for_queryset(queryset=result) return result @@ -513,8 +611,10 @@ def _process_arg_filters(self, args): # Create a new QCombination object with the same operation and fixed filters _args += (visitor.QCombination(arg.operation, children),) else: - raise TypeError("Unknown argument type '%s' of argument '%s'" - % (type(arg), repr(arg))) + raise TypeError( + "Unknown argument type '%s' of argument '%s'" + % (type(arg), repr(arg)) + ) return _args @@ -526,35 +626,38 @@ def _process_null_filters(self, filters): for key, value in six.iteritems(filters): if value is None: null_filters[key] = value - elif isinstance(value, (str, six.text_type)) and value.lower() == 'null': + elif isinstance(value, (str, six.text_type)) and value.lower() == "null": null_filters[key] = value else: continue for key in null_filters.keys(): - result['%s__exists' % (key)] = False + result["%s__exists" % (key)] = False del result[key] return result def _process_datetime_range_filters(self, filters, order_by=None): - ranges = {k: v for k, v in six.iteritems(filters) - if type(v) in [str, six.text_type] and '..' in v} + ranges = { + k: v + for k, v in six.iteritems(filters) + if type(v) in [str, six.text_type] and ".." in v + } order_by_list = copy.deepcopy(order_by) if order_by else [] for k, v in six.iteritems(ranges): - values = v.split('..') + values = v.split("..") dt1 = isotime.parse(values[0]) dt2 = isotime.parse(values[1]) - k__gte = '%s__gte' % k - k__lte = '%s__lte' % k + k__gte = "%s__gte" % k + k__lte = "%s__lte" % k if dt1 < dt2: query = {k__gte: dt1, k__lte: dt2} - sort_key, reverse_sort_key = k, '-' + k + sort_key, reverse_sort_key = k, "-" + k else: query = {k__gte: dt2, k__lte: dt1} - sort_key, reverse_sort_key = '-' + k, k + sort_key, reverse_sort_key = "-" + k, k del filters[k] filters.update(query) @@ -569,7 +672,6 @@ def _process_datetime_range_filters(self, filters, order_by=None): class ChangeRevisionMongoDBAccess(MongoDBAccess): - def insert(self, instance): instance = self.model.objects.insert(instance) @@ -585,11 +687,11 @@ def update(self, instance, **kwargs): return self.save(instance) def save(self, instance, validate=True): - if not hasattr(instance, 'id') or not instance.id: + if not hasattr(instance, "id") or not instance.id: return self.insert(instance) else: try: - save_condition = {'id': instance.id, 'rev': instance.rev} + save_condition = {"id": instance.id, "rev": instance.rev} instance.rev = instance.rev + 1 instance.save(save_condition=save_condition, validate=validate) except mongoengine.SaveConditionError: @@ -601,8 +703,8 @@ def save(self, instance, validate=True): def get_host_names_for_uri_dict(uri_dict): hosts = [] - for host, port in uri_dict['nodelist']: - hosts.append('%s:%s' % (host, port)) + for host, port in uri_dict["nodelist"]: + hosts.append("%s:%s" % (host, port)) - hosts = ','.join(hosts) + hosts = ",".join(hosts) return hosts diff --git a/st2common/st2common/models/db/action.py b/st2common/st2common/models/db/action.py index 1c28b207c2..52a1ed0374 100644 --- a/st2common/st2common/models/db/action.py +++ b/st2common/st2common/models/db/action.py @@ -29,22 +29,26 @@ from st2common.constants.types import ResourceType __all__ = [ - 'RunnerTypeDB', - 'ActionDB', - 'LiveActionDB', - 'ActionExecutionDB', - 'ActionExecutionStateDB', - 'ActionAliasDB' + "RunnerTypeDB", + "ActionDB", + "LiveActionDB", + "ActionExecutionDB", + "ActionExecutionStateDB", + "ActionAliasDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." -class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin): +class ActionDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ The system entity that represents a Stack Action/Automation in the system. @@ -56,38 +60,46 @@ class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin, """ RESOURCE_TYPE = ResourceType.ACTION - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the action is enabled.') - entry_point = me.StringField( required=True, - help_text='The entry point to the action.') + default=True, + help_text="A flag indicating whether the action is enabled.", + ) + entry_point = me.StringField( + required=True, help_text="The entry point to the action." + ) pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) runner_type = me.DictField( - required=True, default={}, - help_text='The action runner to use for executing the action.') + required=True, + default={}, + help_text="The action runner to use for executing the action.", + ) parameters = stormbase.EscapedDynamicField( - help_text='The specification for parameters for the action.') + help_text="The specification for parameters for the action." + ) output_schema = stormbase.EscapedDynamicField( - help_text='The schema for output of the action.') + help_text="The schema for output of the action." + ) notify = me.EmbeddedDocumentField(NotificationSchema) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['pack']}, - {'fields': ['ref']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["pack"]}, + {"fields": ["ref"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -102,11 +114,17 @@ def is_workflow(self): :rtype: ``bool`` """ # pylint: disable=unsubscriptable-object - return self.runner_type['name'] in WORKFLOW_RUNNER_TYPES + return self.runner_type["name"] in WORKFLOW_RUNNER_TYPES # specialized access objects action_access = MongoDBAccess(ActionDB) -MODELS = [ActionDB, ActionExecutionDB, ActionExecutionStateDB, ActionAliasDB, - LiveActionDB, RunnerTypeDB] +MODELS = [ + ActionDB, + ActionExecutionDB, + ActionExecutionStateDB, + ActionAliasDB, + LiveActionDB, + RunnerTypeDB, +] diff --git a/st2common/st2common/models/db/actionalias.py b/st2common/st2common/models/db/actionalias.py index a696ff08b4..765630d8a4 100644 --- a/st2common/st2common/models/db/actionalias.py +++ b/st2common/st2common/models/db/actionalias.py @@ -21,18 +21,19 @@ from st2common.models.db import stormbase from st2common.constants.types import ResourceType -__all__ = [ - 'ActionAliasDB' -] +__all__ = ["ActionAliasDB"] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." -class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class ActionAliasDB( + stormbase.StormFoundationDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ Database entity that represent an Alias for an action. @@ -46,42 +47,48 @@ class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMi """ RESOURCE_TYPE = ResourceType.ACTION_ALIAS - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=True, - help_text='Name of the content pack.', - unique_with='name') + required=True, help_text="Name of the content pack.", unique_with="name" + ) enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the action alias is enabled.') - action_ref = me.StringField( required=True, - help_text='Reference of the Action map this alias.') + default=True, + help_text="A flag indicating whether the action alias is enabled.", + ) + action_ref = me.StringField( + required=True, help_text="Reference of the Action map this alias." + ) formats = me.ListField( - help_text='Possible parameter formats that an alias supports.') + help_text="Possible parameter formats that an alias supports." + ) ack = me.DictField( - help_text='Parameters pertaining to the acknowledgement message.' + help_text="Parameters pertaining to the acknowledgement message." ) result = me.DictField( - help_text='Parameters pertaining to the execution result message.' + help_text="Parameters pertaining to the execution result message." ) extra = me.DictField( - help_text='Additional parameters (usually adapter-specific) not covered in the schema.' + help_text="Additional parameters (usually adapter-specific) not covered in the schema." ) immutable_parameters = me.DictField( - help_text='Parameters to be passed to the action on every execution.') + help_text="Parameters to be passed to the action on every execution." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['enabled']}, - {'fields': ['formats']}, - ] + (stormbase.ContentPackResourceMixin().get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["enabled"]}, + {"fields": ["formats"]}, + ] + + ( + stormbase.ContentPackResourceMixin().get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -97,10 +104,12 @@ def get_format_strings(self): """ result = [] - formats = getattr(self, 'formats', []) + formats = getattr(self, "formats", []) for format_string in formats: - if isinstance(format_string, dict) and format_string.get('representation', None): - result.extend(format_string['representation']) + if isinstance(format_string, dict) and format_string.get( + "representation", None + ): + result.extend(format_string["representation"]) else: result.append(format_string) diff --git a/st2common/st2common/models/db/auth.py b/st2common/st2common/models/db/auth.py index 7ef30ee017..2531ecb11a 100644 --- a/st2common/st2common/models/db/auth.py +++ b/st2common/st2common/models/db/auth.py @@ -25,11 +25,7 @@ from st2common.rbac.backends import get_rbac_backend from st2common.util import date as date_utils -__all__ = [ - 'UserDB', - 'TokenDB', - 'ApiKeyDB' -] +__all__ = ["UserDB", "TokenDB", "ApiKeyDB"] class UserDB(stormbase.StormFoundationDB): @@ -42,10 +38,12 @@ class UserDB(stormbase.StormFoundationDB): is_service: True if this is a service account. nicknames: Nickname + origin pairs for ChatOps auth. """ + name = me.StringField(required=True, unique=True) is_service = me.BooleanField(required=True, default=False) - nicknames = me.DictField(required=False, - help_text='"Nickname + origin" pairs for ChatOps auth') + nicknames = me.DictField( + required=False, help_text='"Nickname + origin" pairs for ChatOps auth' + ) def get_roles(self, include_remote=True): """ @@ -57,7 +55,9 @@ def get_roles(self, include_remote=True): :rtype: ``list`` of :class:`RoleDB` """ rbac_service = get_rbac_backend().get_service_class() - result = rbac_service.get_roles_for_user(user_db=self, include_remote=include_remote) + result = rbac_service.get_roles_for_user( + user_db=self, include_remote=include_remote + ) return result def get_permission_assignments(self): @@ -75,11 +75,13 @@ class TokenDB(stormbase.StormFoundationDB): expiry: Date when this token expires. service: True if this is a service (system) token. """ + user = me.StringField(required=True) token = me.StringField(required=True, unique=True) expiry = me.DateTimeField(required=True) - metadata = me.DictField(required=False, - help_text='Arbitrary metadata associated with this token') + metadata = me.DictField( + required=False, help_text="Arbitrary metadata associated with this token" + ) service = me.BooleanField(required=True, default=False) @@ -91,23 +93,24 @@ class ApiKeyDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.API_KEY - UID_FIELDS = ['key_hash'] + UID_FIELDS = ["key_hash"] user = me.StringField(required=True) key_hash = me.StringField(required=True, unique=True) - metadata = me.DictField(required=False, - help_text='Arbitrary metadata associated with this token') - created_at = ComplexDateTimeField(default=date_utils.get_datetime_utc_now, - help_text='The creation time of this ApiKey.') - enabled = me.BooleanField(required=True, default=True, - help_text='A flag indicating whether the ApiKey is enabled.') - - meta = { - 'indexes': [ - {'fields': ['user']}, - {'fields': ['key_hash']} - ] - } + metadata = me.DictField( + required=False, help_text="Arbitrary metadata associated with this token" + ) + created_at = ComplexDateTimeField( + default=date_utils.get_datetime_utc_now, + help_text="The creation time of this ApiKey.", + ) + enabled = me.BooleanField( + required=True, + default=True, + help_text="A flag indicating whether the ApiKey is enabled.", + ) + + meta = {"indexes": [{"fields": ["user"]}, {"fields": ["key_hash"]}]} def __init__(self, *args, **values): super(ApiKeyDB, self).__init__(*args, **values) @@ -119,8 +122,8 @@ def mask_secrets(self, value): # In theory the key_hash is safe to return as it is one way. On the other # hand given that this is actually a secret no real point in letting the hash # escape. Since uid contains key_hash masking that as well. - result['key_hash'] = MASKED_ATTRIBUTE_VALUE - result['uid'] = MASKED_ATTRIBUTE_VALUE + result["key_hash"] = MASKED_ATTRIBUTE_VALUE + result["uid"] = MASKED_ATTRIBUTE_VALUE return result diff --git a/st2common/st2common/models/db/execution.py b/st2common/st2common/models/db/execution.py index 3e8f3c7742..a44e5072d6 100644 --- a/st2common/st2common/models/db/execution.py +++ b/st2common/st2common/models/db/execution.py @@ -27,10 +27,7 @@ from st2common.util.secrets import mask_secret_parameters from st2common.constants.types import ResourceType -__all__ = [ - 'ActionExecutionDB', - 'ActionExecutionOutputDB' -] +__all__ = ["ActionExecutionDB", "ActionExecutionOutputDB"] LOG = logging.getLogger(__name__) @@ -38,7 +35,7 @@ class ActionExecutionDB(stormbase.StormFoundationDB): RESOURCE_TYPE = ResourceType.EXECUTION - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] trigger = stormbase.EscapedDictField() trigger_type = stormbase.EscapedDictField() @@ -52,22 +49,25 @@ class ActionExecutionDB(stormbase.StormFoundationDB): workflow_execution = me.StringField() task_execution = me.StringField() status = me.StringField( - required=True, - help_text='The current status of the liveaction.') + required=True, help_text="The current status of the liveaction." + ) start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) end_timestamp = ComplexDateTimeField( - help_text='The timestamp when the liveaction has finished.') + help_text="The timestamp when the liveaction has finished." + ) parameters = stormbase.EscapedDynamicField( default={}, - help_text='The key-value pairs passed as to the action runner & action.') + help_text="The key-value pairs passed as to the action runner & action.", + ) result = stormbase.EscapedDynamicField( - default={}, - help_text='Action defined result.') + default={}, help_text="Action defined result." + ) context = me.DictField( - default={}, - help_text='Contextual information on the action execution.') + default={}, help_text="Contextual information on the action execution." + ) parent = me.StringField() children = me.ListField(field=me.StringField()) log = me.ListField(field=me.DictField()) @@ -76,49 +76,51 @@ class ActionExecutionDB(stormbase.StormFoundationDB): web_url = me.StringField(required=False) meta = { - 'indexes': [ - {'fields': ['rule.ref']}, - {'fields': ['action.ref']}, - {'fields': ['liveaction.id']}, - {'fields': ['start_timestamp']}, - {'fields': ['end_timestamp']}, - {'fields': ['status']}, - {'fields': ['parent']}, - {'fields': ['rule.name']}, - {'fields': ['runner.name']}, - {'fields': ['trigger.name']}, - {'fields': ['trigger_type.name']}, - {'fields': ['trigger_instance.id']}, - {'fields': ['context.user']}, - {'fields': ['-start_timestamp', 'action.ref', 'status']}, - {'fields': ['workflow_execution']}, - {'fields': ['task_execution']} + "indexes": [ + {"fields": ["rule.ref"]}, + {"fields": ["action.ref"]}, + {"fields": ["liveaction.id"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["end_timestamp"]}, + {"fields": ["status"]}, + {"fields": ["parent"]}, + {"fields": ["rule.name"]}, + {"fields": ["runner.name"]}, + {"fields": ["trigger.name"]}, + {"fields": ["trigger_type.name"]}, + {"fields": ["trigger_instance.id"]}, + {"fields": ["context.user"]}, + {"fields": ["-start_timestamp", "action.ref", "status"]}, + {"fields": ["workflow_execution"]}, + {"fields": ["task_execution"]}, ] } def get_uid(self): # TODO Construct id from non id field: uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=no-member - return ':'.join(uid) + return ":".join(uid) def mask_secrets(self, value): result = copy.deepcopy(value) - liveaction = result['liveaction'] + liveaction = result["liveaction"] parameters = {} # pylint: disable=no-member - parameters.update(value.get('action', {}).get('parameters', {})) - parameters.update(value.get('runner', {}).get('runner_parameters', {})) + parameters.update(value.get("action", {}).get("parameters", {})) + parameters.update(value.get("runner", {}).get("runner_parameters", {})) secret_parameters = get_secret_parameters(parameters=parameters) - result['parameters'] = mask_secret_parameters(parameters=result.get('parameters', {}), - secret_parameters=secret_parameters) + result["parameters"] = mask_secret_parameters( + parameters=result.get("parameters", {}), secret_parameters=secret_parameters + ) - if 'parameters' in liveaction: - liveaction['parameters'] = mask_secret_parameters(parameters=liveaction['parameters'], - secret_parameters=secret_parameters) + if "parameters" in liveaction: + liveaction["parameters"] = mask_secret_parameters( + parameters=liveaction["parameters"], secret_parameters=secret_parameters + ) - if liveaction.get('action', '') == 'st2.inquiry.respond': + if liveaction.get("action", "") == "st2.inquiry.respond": # Special case to mask parameters for `st2.inquiry.respond` action # In this case, this execution is just a plain python action, not # an inquiry, so we don't natively have a handle on the response @@ -130,22 +132,24 @@ def mask_secrets(self, value): # it's just a placeholder to tell mask_secret_parameters() # that this parameter is indeed a secret parameter and to # mask it. - result['parameters']['response'] = mask_secret_parameters( - parameters=liveaction['parameters']['response'], - secret_parameters={p: 'string' for p in liveaction['parameters']['response']} + result["parameters"]["response"] = mask_secret_parameters( + parameters=liveaction["parameters"]["response"], + secret_parameters={ + p: "string" for p in liveaction["parameters"]["response"] + }, ) # TODO(mierdin): This logic should be moved to the dedicated Inquiry # data model once it exists. - if self.runner.get('name') == "inquirer": + if self.runner.get("name") == "inquirer": - schema = result['result'].get('schema', {}) - response = result['result'].get('response', {}) + schema = result["result"].get("schema", {}) + response = result["result"].get("response", {}) # We can only mask response secrets if response and schema exist and are # not empty if response and schema: - result['result']['response'] = mask_inquiry_response(response, schema) + result["result"]["response"] = mask_inquiry_response(response, schema) return result def get_masked_parameters(self): @@ -155,7 +159,7 @@ def get_masked_parameters(self): :rtype: ``dict`` """ serializable_dict = self.to_serializable_dict(mask_secrets=True) - return serializable_dict['parameters'] + return serializable_dict["parameters"] class ActionExecutionOutputDB(stormbase.StormFoundationDB): @@ -174,22 +178,25 @@ class ActionExecutionOutputDB(stormbase.StormFoundationDB): data: Actual output data. This could either be line, chunk or similar, depending on the runner. """ + execution_id = me.StringField(required=True) action_ref = me.StringField(required=True) runner_ref = me.StringField(required=True) - timestamp = ComplexDateTimeField(required=True, default=date_utils.get_datetime_utc_now) - output_type = me.StringField(required=True, default='output') + timestamp = ComplexDateTimeField( + required=True, default=date_utils.get_datetime_utc_now + ) + output_type = me.StringField(required=True, default="output") delay = me.IntField() data = me.StringField() meta = { - 'indexes': [ - {'fields': ['execution_id']}, - {'fields': ['action_ref']}, - {'fields': ['runner_ref']}, - {'fields': ['timestamp']}, - {'fields': ['output_type']} + "indexes": [ + {"fields": ["execution_id"]}, + {"fields": ["action_ref"]}, + {"fields": ["runner_ref"]}, + {"fields": ["timestamp"]}, + {"fields": ["output_type"]}, ] } diff --git a/st2common/st2common/models/db/execution_queue.py b/st2common/st2common/models/db/execution_queue.py index 31dcebbd1a..8db0993363 100644 --- a/st2common/st2common/models/db/execution_queue.py +++ b/st2common/st2common/models/db/execution_queue.py @@ -25,15 +25,16 @@ from st2common.constants.types import ResourceType __all__ = [ - 'ActionExecutionSchedulingQueueItemDB', + "ActionExecutionSchedulingQueueItemDB", ] LOG = logging.getLogger(__name__) -class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB, - stormbase.ChangeRevisionFieldMixin): +class ActionExecutionSchedulingQueueItemDB( + stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin +): """ A model which represents a request for execution to be scheduled. @@ -42,36 +43,45 @@ class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB, """ RESOURCE_TYPE = ResourceType.EXECUTION_REQUEST - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] - liveaction_id = me.StringField(required=True, - help_text='Foreign key to the LiveActionDB which is to be scheduled') + liveaction_id = me.StringField( + required=True, + help_text="Foreign key to the LiveActionDB which is to be scheduled", + ) action_execution_id = me.StringField( - help_text='Foreign key to the ActionExecutionDB which is to be scheduled') + help_text="Foreign key to the ActionExecutionDB which is to be scheduled" + ) original_start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created and originally be scheduled to ' - 'run.') + help_text="The timestamp when the liveaction was created and originally be scheduled to " + "run.", + ) scheduled_start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when liveaction is scheduled to run.') + help_text="The timestamp when liveaction is scheduled to run.", + ) delay = me.IntField() - handling = me.BooleanField(default=False, - help_text='Flag indicating if this item is currently being handled / ' - 'processed by a scheduler service') + handling = me.BooleanField( + default=False, + help_text="Flag indicating if this item is currently being handled / " + "processed by a scheduler service", + ) meta = { - 'indexes': [ + "indexes": [ # NOTE: We limit index names to 65 characters total for compatibility with AWS # DocumentDB. # See https://github.com/StackStorm/st2/pull/4690 for details. - {'fields': ['action_execution_id'], 'name': 'ac_exc_id'}, - {'fields': ['liveaction_id'], 'name': 'lv_ac_id'}, - {'fields': ['original_start_timestamp'], 'name': 'orig_s_ts'}, - {'fields': ['scheduled_start_timestamp'], 'name': 'schd_s_ts'}, + {"fields": ["action_execution_id"], "name": "ac_exc_id"}, + {"fields": ["liveaction_id"], "name": "lv_ac_id"}, + {"fields": ["original_start_timestamp"], "name": "orig_s_ts"}, + {"fields": ["scheduled_start_timestamp"], "name": "schd_s_ts"}, ] } MODELS = [ActionExecutionSchedulingQueueItemDB] -EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess(ActionExecutionSchedulingQueueItemDB) +EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess( + ActionExecutionSchedulingQueueItemDB +) diff --git a/st2common/st2common/models/db/executionstate.py b/st2common/st2common/models/db/executionstate.py index db949b6658..94b883038d 100644 --- a/st2common/st2common/models/db/executionstate.py +++ b/st2common/st2common/models/db/executionstate.py @@ -21,33 +21,32 @@ from st2common.models.db import stormbase __all__ = [ - 'ActionExecutionStateDB', + "ActionExecutionStateDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class ActionExecutionStateDB(stormbase.StormFoundationDB): """ - Database entity that represents the state of Action execution. + Database entity that represents the state of Action execution. """ + execution_id = me.ObjectIdField( - required=True, - unique=True, - help_text='liveaction ID.') + required=True, unique=True, help_text="liveaction ID." + ) query_module = me.StringField( - required=True, - help_text='Reference to the runner model.') + required=True, help_text="Reference to the runner model." + ) query_context = me.DictField( required=True, - help_text='Context about the action execution that is needed for results query.') + help_text="Context about the action execution that is needed for results query.", + ) - meta = { - 'indexes': ['query_module'] - } + meta = {"indexes": ["query_module"]} # specialized access objects diff --git a/st2common/st2common/models/db/keyvalue.py b/st2common/st2common/models/db/keyvalue.py index debe58ebbb..ea7fda3b9d 100644 --- a/st2common/st2common/models/db/keyvalue.py +++ b/st2common/st2common/models/db/keyvalue.py @@ -21,9 +21,7 @@ from st2common.models.db import MongoDBAccess from st2common.models.db import stormbase -__all__ = [ - 'KeyValuePairDB' -] +__all__ = ["KeyValuePairDB"] class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -34,22 +32,20 @@ class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.KEY_VALUE_PAIR - UID_FIELDS = ['scope', 'name'] + UID_FIELDS = ["scope", "name"] - scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with='name') + scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with="name") name = me.StringField(required=True) value = me.StringField() secret = me.BooleanField(default=False) expire_timestamp = me.DateTimeField() meta = { - 'indexes': [ - {'fields': ['name']}, - { - 'fields': ['expire_timestamp'], - 'expireAfterSeconds': 0 - } - ] + stormbase.UIDFieldMixin.get_indexes() + "indexes": [ + {"fields": ["name"]}, + {"fields": ["expire_timestamp"], "expireAfterSeconds": 0}, + ] + + stormbase.UIDFieldMixin.get_indexes() } def __init__(self, *args, **values): diff --git a/st2common/st2common/models/db/liveaction.py b/st2common/st2common/models/db/liveaction.py index 6bc5fd77fa..29f5a13bfc 100644 --- a/st2common/st2common/models/db/liveaction.py +++ b/st2common/st2common/models/db/liveaction.py @@ -28,12 +28,12 @@ from st2common.util.secrets import mask_secret_parameters __all__ = [ - 'LiveActionDB', + "LiveActionDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class LiveActionDB(stormbase.StormFoundationDB): @@ -41,50 +41,56 @@ class LiveActionDB(stormbase.StormFoundationDB): task_execution = me.StringField() # TODO: Can status be an enum at the Mongo layer? status = me.StringField( - required=True, - help_text='The current status of the liveaction.') + required=True, help_text="The current status of the liveaction." + ) start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) end_timestamp = ComplexDateTimeField( - help_text='The timestamp when the liveaction has finished.') + help_text="The timestamp when the liveaction has finished." + ) action = me.StringField( - required=True, - help_text='Reference to the action that has to be executed.') + required=True, help_text="Reference to the action that has to be executed." + ) action_is_workflow = me.BooleanField( default=False, - help_text='A flag indicating whether the referenced action is a workflow.') + help_text="A flag indicating whether the referenced action is a workflow.", + ) parameters = stormbase.EscapedDynamicField( default={}, - help_text='The key-value pairs passed as to the action runner & execution.') + help_text="The key-value pairs passed as to the action runner & execution.", + ) result = stormbase.EscapedDynamicField( - default={}, - help_text='Action defined result.') + default={}, help_text="Action defined result." + ) context = me.DictField( - default={}, - help_text='Contextual information on the action execution.') + default={}, help_text="Contextual information on the action execution." + ) callback = me.DictField( default={}, - help_text='Callback information for the on completion of action execution.') + help_text="Callback information for the on completion of action execution.", + ) runner_info = me.DictField( default={}, - help_text='Information about the runner which executed this live action (hostname, pid).') + help_text="Information about the runner which executed this live action (hostname, pid).", + ) notify = me.EmbeddedDocumentField(NotificationSchema) delay = me.IntField( min_value=0, - help_text='How long (in milliseconds) to delay the execution before scheduling.' + help_text="How long (in milliseconds) to delay the execution before scheduling.", ) meta = { - 'indexes': [ - {'fields': ['-start_timestamp', 'action']}, - {'fields': ['start_timestamp']}, - {'fields': ['end_timestamp']}, - {'fields': ['action']}, - {'fields': ['status']}, - {'fields': ['context.trigger_instance.id']}, - {'fields': ['workflow_execution']}, - {'fields': ['task_execution']} + "indexes": [ + {"fields": ["-start_timestamp", "action"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["end_timestamp"]}, + {"fields": ["action"]}, + {"fields": ["status"]}, + {"fields": ["context.trigger_instance.id"]}, + {"fields": ["workflow_execution"]}, + {"fields": ["task_execution"]}, ] } @@ -92,7 +98,7 @@ def mask_secrets(self, value): from st2common.util import action_db result = copy.deepcopy(value) - execution_parameters = value['parameters'] + execution_parameters = value["parameters"] # TODO: This results into two DB looks, we should cache action and runner type object # for each liveaction... @@ -104,8 +110,9 @@ def mask_secrets(self, value): parameters = action_db.get_action_parameters_specs(action_ref=self.action) secret_parameters = get_secret_parameters(parameters=parameters) - result['parameters'] = mask_secret_parameters(parameters=execution_parameters, - secret_parameters=secret_parameters) + result["parameters"] = mask_secret_parameters( + parameters=execution_parameters, secret_parameters=secret_parameters + ) return result def get_masked_parameters(self): @@ -115,7 +122,7 @@ def get_masked_parameters(self): :rtype: ``dict`` """ serializable_dict = self.to_serializable_dict(mask_secrets=True) - return serializable_dict['parameters'] + return serializable_dict["parameters"] # specialized access objects diff --git a/st2common/st2common/models/db/marker.py b/st2common/st2common/models/db/marker.py index 1bddf3f604..7a053e5490 100644 --- a/st2common/st2common/models/db/marker.py +++ b/st2common/st2common/models/db/marker.py @@ -20,10 +20,7 @@ from st2common.models.db import stormbase from st2common.util import date as date_utils -__all__ = [ - 'MarkerDB', - 'DumperMarkerDB' -] +__all__ = ["MarkerDB", "DumperMarkerDB"] class MarkerDB(stormbase.StormFoundationDB): @@ -37,20 +34,21 @@ class MarkerDB(stormbase.StormFoundationDB): :param updated_at: Timestamp when marker was updated. :type updated_at: ``datetime.datetime`` """ + marker = me.StringField(required=True) updated_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) - meta = { - 'abstract': True - } + meta = {"abstract": True} class DumperMarkerDB(MarkerDB): """ Marker model used by Dumper (in exporter). """ + pass diff --git a/st2common/st2common/models/db/notification.py b/st2common/st2common/models/db/notification.py index e311f46b75..8ef793887b 100644 --- a/st2common/st2common/models/db/notification.py +++ b/st2common/st2common/models/db/notification.py @@ -21,43 +21,47 @@ class NotificationSubSchema(me.EmbeddedDocument): """ - Schema for notification settings to be specified for action success/failure. + Schema for notification settings to be specified for action success/failure. """ + message = me.StringField() data = stormbase.EscapedDynamicField( - default={}, - help_text='Payload to be sent as part of notification.') + default={}, help_text="Payload to be sent as part of notification." + ) routes = me.ListField( - default=['notify.default'], - help_text='Routes to post notifications to.') - channels = me.ListField( # Deprecated. Only here for backward compatibility reasons. - default=['notify.default'], - help_text='Routes to post notifications to.') + default=["notify.default"], help_text="Routes to post notifications to." + ) + channels = ( + me.ListField( # Deprecated. Only here for backward compatibility reasons. + default=["notify.default"], help_text="Routes to post notifications to." + ) + ) def __str__(self): result = [] - result.append('NotificationSubSchema@') + result.append("NotificationSubSchema@") result.append(str(id(self))) result.append('(message="%s", ' % str(self.message)) result.append('data="%s", ' % str(self.data)) result.append('routes="%s", ' % str(self.routes)) result.append('[**deprecated**]channels="%s")' % str(self.channels)) - return ''.join(result) + return "".join(result) class NotificationSchema(me.EmbeddedDocument): """ - Schema for notification settings to be specified for actions. + Schema for notification settings to be specified for actions. """ + on_success = me.EmbeddedDocumentField(NotificationSubSchema) on_failure = me.EmbeddedDocumentField(NotificationSubSchema) on_complete = me.EmbeddedDocumentField(NotificationSubSchema) def __str__(self): result = [] - result.append('NotifySchema@') + result.append("NotifySchema@") result.append(str(id(self))) result.append('(on_complete="%s", ' % str(self.on_complete)) result.append('on_success="%s", ' % str(self.on_success)) result.append('on_failure="%s")' % str(self.on_failure)) - return ''.join(result) + return "".join(result) diff --git a/st2common/st2common/models/db/pack.py b/st2common/st2common/models/db/pack.py index cf16910987..c92b009624 100644 --- a/st2common/st2common/models/db/pack.py +++ b/st2common/st2common/models/db/pack.py @@ -25,21 +25,16 @@ from st2common.util.secrets import get_secret_parameters from st2common.util.secrets import mask_secret_parameters -__all__ = [ - 'PackDB', - 'ConfigSchemaDB', - 'ConfigDB' -] +__all__ = ["PackDB", "ConfigSchemaDB", "ConfigDB"] -class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, - me.DynamicDocument): +class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, me.DynamicDocument): """ System entity which represents a pack. """ RESOURCE_TYPE = ResourceType.PACK - UID_FIELDS = ['ref'] + UID_FIELDS = ["ref"] ref = me.StringField(required=True, unique=True) name = me.StringField(required=True, unique=True) @@ -56,9 +51,7 @@ class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, dependencies = me.ListField(field=me.StringField()) system = me.DictField() - meta = { - 'indexes': stormbase.UIDFieldMixin.get_indexes() - } + meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()} def __init__(self, *args, **values): super(PackDB, self).__init__(*args, **values) @@ -73,22 +66,24 @@ class ConfigSchemaDB(stormbase.StormFoundationDB): pack = me.StringField( required=True, unique=True, - help_text='Name of the content pack this schema belongs to.') + help_text="Name of the content pack this schema belongs to.", + ) attributes = stormbase.EscapedDynamicField( - help_text='The specification for config schema attributes.') + help_text="The specification for config schema attributes." + ) class ConfigDB(stormbase.StormFoundationDB): """ System entity representing pack config. """ + pack = me.StringField( required=True, unique=True, - help_text='Name of the content pack this config belongs to.') - values = stormbase.EscapedDynamicField( - help_text='Config values.', - default={}) + help_text="Name of the content pack this config belongs to.", + ) + values = stormbase.EscapedDynamicField(help_text="Config values.", default={}) def mask_secrets(self, value): """ @@ -101,11 +96,12 @@ def mask_secrets(self, value): """ result = copy.deepcopy(value) - config_schema = config_schema_access.get_by_pack(result['pack']) + config_schema = config_schema_access.get_by_pack(result["pack"]) secret_parameters = get_secret_parameters(parameters=config_schema.attributes) - result['values'] = mask_secret_parameters(parameters=result['values'], - secret_parameters=secret_parameters) + result["values"] = mask_secret_parameters( + parameters=result["values"], secret_parameters=secret_parameters + ) return result diff --git a/st2common/st2common/models/db/policy.py b/st2common/st2common/models/db/policy.py index 69f709093c..8b9fcafef0 100644 --- a/st2common/st2common/models/db/policy.py +++ b/st2common/st2common/models/db/policy.py @@ -23,9 +23,7 @@ from st2common.constants.types import ResourceType -__all__ = ['PolicyTypeReference', - 'PolicyTypeDB', - 'PolicyDB'] +__all__ = ["PolicyTypeReference", "PolicyTypeDB", "PolicyDB"] LOG = logging.getLogger(__name__) @@ -34,7 +32,8 @@ class PolicyTypeReference(object): """ Class used for referring to policy types which belong to a resource type. """ - separator = '.' + + separator = "." def __init__(self, resource_type=None, name=None): self.resource_type = self.validate_resource_type(resource_type) @@ -54,14 +53,15 @@ def is_reference(cls, ref): @classmethod def from_string_reference(cls, ref): - return cls(resource_type=cls.get_resource_type(ref), - name=cls.get_name(ref)) + return cls(resource_type=cls.get_resource_type(ref), name=cls.get_name(ref)) @classmethod def to_string_reference(cls, resource_type=None, name=None): if not resource_type or not name: - raise ValueError('Both resource_type and name are required for building ref. ' - 'resource_type=%s, name=%s' % (resource_type, name)) + raise ValueError( + "Both resource_type and name are required for building ref. " + "resource_type=%s, name=%s" % (resource_type, name) + ) resource_type = cls.validate_resource_type(resource_type) return cls.separator.join([resource_type, name]) @@ -69,7 +69,7 @@ def to_string_reference(cls, resource_type=None, name=None): @classmethod def validate_resource_type(cls, resource_type): if not resource_type: - raise ValueError('Resource type should not be empty.') + raise ValueError("Resource type should not be empty.") if cls.separator in resource_type: raise ValueError('Resource type should not contain "%s".' % cls.separator) @@ -80,7 +80,7 @@ def validate_resource_type(cls, resource_type): def get_resource_type(cls, ref): try: if not cls.is_reference(ref): - raise ValueError('%s is not a valid reference.' % ref) + raise ValueError("%s is not a valid reference." % ref) return ref.split(cls.separator, 1)[0] except (ValueError, IndexError, AttributeError): @@ -90,15 +90,19 @@ def get_resource_type(cls, ref): def get_name(cls, ref): try: if not cls.is_reference(ref): - raise ValueError('%s is not a valid reference.' % ref) + raise ValueError("%s is not a valid reference." % ref) return ref.split(cls.separator, 1)[1] except (ValueError, IndexError, AttributeError): raise common_models.InvalidReferenceError(ref=ref) def __repr__(self): - return ('<%s resource_type=%s,name=%s,ref=%s>' % - (self.__class__.__name__, self.resource_type, self.name, self.ref)) + return "<%s resource_type=%s,name=%s,ref=%s>" % ( + self.__class__.__name__, + self.resource_type, + self.name, + self.ref, + ) class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -114,29 +118,35 @@ class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): module: The python module that implements the policy for this type. parameters: The specification for parameters for the policy type. """ + RESOURCE_TYPE = ResourceType.POLICY_TYPE - UID_FIELDS = ['resource_type', 'name'] + UID_FIELDS = ["resource_type", "name"] ref = me.StringField(required=True) resource_type = me.StringField( required=True, - unique_with='name', - help_text='The type of resource that this policy type can be applied to.') + unique_with="name", + help_text="The type of resource that this policy type can be applied to.", + ) enabled = me.BooleanField( required=True, default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + help_text="A flag indicating whether the runner for this type is enabled.", + ) module = me.StringField( required=True, - help_text='The python module that implements the policy for this type.') + help_text="The python module that implements the policy for this type.", + ) parameters = me.DictField( - help_text='The specification for parameters for the policy type.') + help_text="The specification for parameters for the policy type." + ) def __init__(self, *args, **kwargs): super(PolicyTypeDB, self).__init__(*args, **kwargs) self.uid = self.get_uid() - self.ref = PolicyTypeReference.to_string_reference(resource_type=self.resource_type, - name=self.name) + self.ref = PolicyTypeReference.to_string_reference( + resource_type=self.resource_type, name=self.name + ) def get_reference(self): """ @@ -147,8 +157,11 @@ def get_reference(self): return PolicyTypeReference(resource_type=self.resource_type, name=self.name) -class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class PolicyDB( + stormbase.StormFoundationDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ The representation for a policy in the system. @@ -158,43 +171,47 @@ class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, policy_type: The type of policy. parameters: The specification of input parameters for the policy. """ + RESOURCE_TYPE = ResourceType.POLICY - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) pack = me.StringField( required=False, default=pack_constants.DEFAULT_PACK_NAME, - unique_with='name', - help_text='Name of the content pack.') + unique_with="name", + help_text="Name of the content pack.", + ) description = me.StringField() enabled = me.BooleanField( required=True, default=True, - help_text='A flag indicating whether this policy is enabled in the system.') + help_text="A flag indicating whether this policy is enabled in the system.", + ) resource_ref = me.StringField( - required=True, - help_text='The resource that this policy is applied to.') + required=True, help_text="The resource that this policy is applied to." + ) policy_type = me.StringField( - required=True, - unique_with='resource_ref', - help_text='The type of policy.') + required=True, unique_with="resource_ref", help_text="The type of policy." + ) parameters = me.DictField( - help_text='The specification of input parameters for the policy.') + help_text="The specification of input parameters for the policy." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['resource_ref']}, + "indexes": [ + {"fields": ["name"]}, + {"fields": ["resource_ref"]}, ] } def __init__(self, *args, **kwargs): super(PolicyDB, self).__init__(*args, **kwargs) self.uid = self.get_uid() - self.ref = common_models.ResourceReference.to_string_reference(pack=self.pack, - name=self.name) + self.ref = common_models.ResourceReference.to_string_reference( + pack=self.pack, name=self.name + ) MODELS = [PolicyTypeDB, PolicyDB] diff --git a/st2common/st2common/models/db/rbac.py b/st2common/st2common/models/db/rbac.py index 68b41ea314..bb82ba88cb 100644 --- a/st2common/st2common/models/db/rbac.py +++ b/st2common/st2common/models/db/rbac.py @@ -21,14 +21,13 @@ __all__ = [ - 'RoleDB', - 'UserRoleAssignmentDB', - 'PermissionGrantDB', - 'GroupToRoleMappingDB', - - 'role_access', - 'user_role_assignment_access', - 'permission_grant_access' + "RoleDB", + "UserRoleAssignmentDB", + "PermissionGrantDB", + "GroupToRoleMappingDB", + "role_access", + "user_role_assignment_access", + "permission_grant_access", ] @@ -43,15 +42,16 @@ class RoleDB(stormbase.StormFoundationDB): permission_grants: A list of IDs to the permission grant which apply to this role. """ + name = me.StringField(required=True, unique=True) description = me.StringField() system = me.BooleanField(default=False) permission_grants = me.ListField(field=me.StringField()) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['system']}, + "indexes": [ + {"fields": ["name"]}, + {"fields": ["system"]}, ] } @@ -67,9 +67,10 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB): and "API" for API assignments. description: Optional assigment description. """ + user = me.StringField(required=True) - role = me.StringField(required=True, unique_with=['user', 'source']) - source = me.StringField(required=True, unique_with=['user', 'role']) + role = me.StringField(required=True, unique_with=["user", "source"]) + source = me.StringField(required=True, unique_with=["user", "role"]) description = me.StringField() # True if this is assigned created on authentication based on the remote groups provided by # the auth backends. @@ -78,12 +79,12 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB): is_remote = me.BooleanField(default=False) meta = { - 'indexes': [ - {'fields': ['user']}, - {'fields': ['role']}, - {'fields': ['source']}, - {'fields': ['is_remote']}, - {'fields': ['user', 'role']}, + "indexes": [ + {"fields": ["user"]}, + {"fields": ["role"]}, + {"fields": ["source"]}, + {"fields": ["is_remote"]}, + {"fields": ["user", "role"]}, ] } @@ -98,13 +99,14 @@ class PermissionGrantDB(stormbase.StormFoundationDB): convenience and to allow for more efficient queries. permission_types: A list of permission type granted to that resources. """ + resource_uid = me.StringField(required=False) resource_type = me.StringField(required=False) permission_types = me.ListField(field=me.StringField()) meta = { - 'indexes': [ - {'fields': ['resource_uid']}, + "indexes": [ + {"fields": ["resource_uid"]}, ] } @@ -120,12 +122,16 @@ class GroupToRoleMappingDB(stormbase.StormFoundationDB): and "API" for API assignments. description: Optional description for this mapping. """ + group = me.StringField(required=True, unique=True) roles = me.ListField(field=me.StringField()) source = me.StringField() description = me.StringField() - enabled = me.BooleanField(required=True, default=True, - help_text='A flag indicating whether the mapping is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="A flag indicating whether the mapping is enabled.", + ) # Specialized access objects diff --git a/st2common/st2common/models/db/reactor.py b/st2common/st2common/models/db/reactor.py index dc9f08b58e..8b8032654b 100644 --- a/st2common/st2common/models/db/reactor.py +++ b/st2common/st2common/models/db/reactor.py @@ -14,18 +14,17 @@ # limitations under the License. from __future__ import absolute_import -from st2common.models.db.rule import (ActionExecutionSpecDB, RuleDB) +from st2common.models.db.rule import ActionExecutionSpecDB, RuleDB from st2common.models.db.sensor import SensorTypeDB -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB __all__ = [ - 'ActionExecutionSpecDB', - 'RuleDB', - 'SensorTypeDB', - 'TriggerTypeDB', - 'TriggerDB', - 'TriggerInstanceDB' + "ActionExecutionSpecDB", + "RuleDB", + "SensorTypeDB", + "TriggerTypeDB", + "TriggerDB", + "TriggerInstanceDB", ] -MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB, - TriggerTypeDB] +MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB, TriggerTypeDB] diff --git a/st2common/st2common/models/db/rule.py b/st2common/st2common/models/db/rule.py index f056734f8c..f4f26ec669 100644 --- a/st2common/st2common/models/db/rule.py +++ b/st2common/st2common/models/db/rule.py @@ -28,25 +28,24 @@ class RuleTypeDB(stormbase.StormBaseDB): enabled = me.BooleanField( default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + help_text="A flag indicating whether the runner for this type is enabled.", + ) parameters = me.DictField( - help_text='The specification for parameters for the action.', - default={}) + help_text="The specification for parameters for the action.", default={} + ) class RuleTypeSpecDB(me.EmbeddedDocument): - ref = me.StringField(unique=False, - help_text='Type of rule.', - default='standard') + ref = me.StringField(unique=False, help_text="Type of rule.", default="standard") parameters = me.DictField(default={}) def __str__(self): result = [] - result.append('RuleTypeSpecDB@') + result.append("RuleTypeSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) class ActionExecutionSpecDB(me.EmbeddedDocument): @@ -55,15 +54,19 @@ class ActionExecutionSpecDB(me.EmbeddedDocument): def __str__(self): result = [] - result.append('ActionExecutionSpecDB@') + result.append("ActionExecutionSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) -class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin): +class RuleDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """Specifies the action to invoke on the occurrence of a Trigger. It also includes the transformation to perform to match the impedance between the payload of a TriggerInstance and input of a action. @@ -74,36 +77,39 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + RESOURCE_TYPE = ResourceType.RULE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) type = me.EmbeddedDocumentField(RuleTypeSpecDB, default=RuleTypeSpecDB()) trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - context = me.DictField( - default={}, - help_text='Contextual info on the rule' + context = me.DictField(default={}, help_text="Contextual info on the rule") + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", ) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') meta = { - 'indexes': [ - {'fields': ['enabled']}, - {'fields': ['action.ref']}, - {'fields': ['trigger']}, - {'fields': ['context.user']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["enabled"]}, + {"fields": ["action.ref"]}, + {"fields": ["trigger"]}, + {"fields": ["context.user"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def mask_secrets(self, value): @@ -120,7 +126,7 @@ def mask_secrets(self, value): """ result = copy.deepcopy(value) - action_ref = result.get('action', {}).get('ref', None) + action_ref = result.get("action", {}).get("ref", None) if not action_ref: return result @@ -131,9 +137,10 @@ def mask_secrets(self, value): return result secret_parameters = get_secret_parameters(parameters=action_db.parameters) - result['action']['parameters'] = mask_secret_parameters( - parameters=result['action']['parameters'], - secret_parameters=secret_parameters) + result["action"]["parameters"] = mask_secret_parameters( + parameters=result["action"]["parameters"], + secret_parameters=secret_parameters, + ) return result @@ -147,8 +154,9 @@ def _get_referenced_action_model(self, action_ref): :rtype: ``ActionDB`` """ # NOTE: We need to retrieve pack and name since that's needed for the PK - action_dbs = Action.query(only_fields=['pack', 'ref', 'name', 'parameters'], - ref=action_ref, limit=1) + action_dbs = Action.query( + only_fields=["pack", "ref", "name", "parameters"], ref=action_ref, limit=1 + ) if action_dbs: return action_dbs[0] diff --git a/st2common/st2common/models/db/rule_enforcement.py b/st2common/st2common/models/db/rule_enforcement.py index 80ea1f14fe..62d2a21faf 100644 --- a/st2common/st2common/models/db/rule_enforcement.py +++ b/st2common/st2common/models/db/rule_enforcement.py @@ -24,34 +24,27 @@ from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_SUCCEEDED from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_FAILED -__all__ = [ - 'RuleReferenceSpecDB', - 'RuleEnforcementDB' -] +__all__ = ["RuleReferenceSpecDB", "RuleEnforcementDB"] class RuleReferenceSpecDB(me.EmbeddedDocument): - ref = me.StringField(unique=False, - help_text='Reference to rule.', - required=True) - id = me.StringField(required=False, - help_text='Rule ID.') - uid = me.StringField(required=True, - help_text='Rule UID.') + ref = me.StringField(unique=False, help_text="Reference to rule.", required=True) + id = me.StringField(required=False, help_text="Rule ID.") + uid = me.StringField(required=True, help_text="Rule UID.") def __str__(self): result = [] - result.append('RuleReferenceSpecDB@') + result.append("RuleReferenceSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('id="%s", ' % self.id) result.append('uid="%s")' % self.uid) - return ''.join(result) + return "".join(result) class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin): - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] trigger_instance_id = me.StringField(required=True) execution_id = me.StringField(required=False) @@ -59,31 +52,34 @@ class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin): rule = me.EmbeddedDocumentField(RuleReferenceSpecDB, required=True) enforced_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the rule enforcement happened.') + help_text="The timestamp when the rule enforcement happened.", + ) status = me.StringField( required=True, default=RULE_ENFORCEMENT_STATUS_SUCCEEDED, - help_text='Rule enforcement status.') + help_text="Rule enforcement status.", + ) meta = { - 'indexes': [ - {'fields': ['trigger_instance_id']}, - {'fields': ['execution_id']}, - {'fields': ['rule.id']}, - {'fields': ['rule.ref']}, - {'fields': ['enforced_at']}, - {'fields': ['-enforced_at']}, - {'fields': ['-enforced_at', 'rule.ref']}, - {'fields': ['status']}, - ] + stormbase.TagsMixin.get_indexes() + "indexes": [ + {"fields": ["trigger_instance_id"]}, + {"fields": ["execution_id"]}, + {"fields": ["rule.id"]}, + {"fields": ["rule.ref"]}, + {"fields": ["enforced_at"]}, + {"fields": ["-enforced_at"]}, + {"fields": ["-enforced_at", "rule.ref"]}, + {"fields": ["status"]}, + ] + + stormbase.TagsMixin.get_indexes() } def __init__(self, *args, **values): super(RuleEnforcementDB, self).__init__(*args, **values) # Set status to succeeded for old / existing RuleEnforcementDB which predate status field - status = getattr(self, 'status', None) - failure_reason = getattr(self, 'failure_reason', None) + status = getattr(self, "status", None) + failure_reason = getattr(self, "failure_reason", None) if status in [None, RULE_ENFORCEMENT_STATUS_SUCCEEDED] and failure_reason: self.status = RULE_ENFORCEMENT_STATUS_FAILED @@ -92,8 +88,8 @@ def __init__(self, *args, **values): # with a consistent get_uid interface. def get_uid(self): # TODO Construct uid from non id field: - uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101 - return ':'.join(uid) + uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101 + return ":".join(uid) rule_enforcement_access = MongoDBAccess(RuleEnforcementDB) diff --git a/st2common/st2common/models/db/runner.py b/st2common/st2common/models/db/runner.py index c2f290f5b4..9097d35be6 100644 --- a/st2common/st2common/models/db/runner.py +++ b/st2common/st2common/models/db/runner.py @@ -22,13 +22,13 @@ from st2common.constants.types import ResourceType __all__ = [ - 'RunnerTypeDB', + "RunnerTypeDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -46,31 +46,37 @@ class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.RUNNER_TYPE - UID_FIELDS = ['name'] + UID_FIELDS = ["name"] enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + required=True, + default=True, + help_text="A flag indicating whether the runner for this type is enabled.", + ) runner_package = me.StringField( required=False, - help_text=('The python package that implements the action runner for this type. If' - 'not provided it assumes package name equals module name.')) + help_text=( + "The python package that implements the action runner for this type. If" + "not provided it assumes package name equals module name." + ), + ) runner_module = me.StringField( required=True, - help_text='The python module that implements the action runner for this type.') + help_text="The python module that implements the action runner for this type.", + ) runner_parameters = me.DictField( - help_text='The specification for parameters for the action runner.') + help_text="The specification for parameters for the action runner." + ) output_key = me.StringField( - help_text='Default key to expect results to be published to.') - output_schema = me.DictField( - help_text='The schema for runner output.') + help_text="Default key to expect results to be published to." + ) + output_schema = me.DictField(help_text="The schema for runner output.") query_module = me.StringField( required=False, - help_text='The python module that implements the query module for this runner.') + help_text="The python module that implements the query module for this runner.", + ) - meta = { - 'indexes': stormbase.UIDFieldMixin.get_indexes() - } + meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()} def __init__(self, *args, **values): super(RunnerTypeDB, self).__init__(*args, **values) diff --git a/st2common/st2common/models/db/sensor.py b/st2common/st2common/models/db/sensor.py index 6517fb3a75..31437ad321 100644 --- a/st2common/st2common/models/db/sensor.py +++ b/st2common/st2common/models/db/sensor.py @@ -20,13 +20,12 @@ from st2common.models.db import stormbase from st2common.constants.types import ResourceType -__all__ = [ - 'SensorTypeDB' -] +__all__ = ["SensorTypeDB"] -class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class SensorTypeDB( + stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin +): """ Description of a specific type of a sensor (think of it as a sensor template). @@ -40,25 +39,29 @@ class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, """ RESOURCE_TYPE = ResourceType.SENSOR_TYPE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") artifact_uri = me.StringField() entry_point = me.StringField() trigger_types = me.ListField(field=me.StringField()) poll_interval = me.IntField() - enabled = me.BooleanField(default=True, - help_text=u'Flag indicating whether the sensor is enabled.') + enabled = me.BooleanField( + default=True, help_text="Flag indicating whether the sensor is enabled." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['enabled']}, - {'fields': ['trigger_types']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["enabled"]}, + {"fields": ["trigger_types"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): diff --git a/st2common/st2common/models/db/stormbase.py b/st2common/st2common/models/db/stormbase.py index bf312c6e4f..50f79dde78 100644 --- a/st2common/st2common/models/db/stormbase.py +++ b/st2common/st2common/models/db/stormbase.py @@ -29,17 +29,15 @@ from st2common.constants.types import ResourceType __all__ = [ - 'StormFoundationDB', - 'StormBaseDB', - - 'EscapedDictField', - 'EscapedDynamicField', - 'TagField', - - 'RefFieldMixin', - 'UIDFieldMixin', - 'TagsMixin', - 'ContentPackResourceMixin' + "StormFoundationDB", + "StormBaseDB", + "EscapedDictField", + "EscapedDynamicField", + "TagField", + "RefFieldMixin", + "UIDFieldMixin", + "TagsMixin", + "ContentPackResourceMixin", ] JSON_UNFRIENDLY_TYPES = (datetime.datetime, bson.ObjectId) @@ -62,17 +60,19 @@ class StormFoundationDB(me.Document, DictSerializableClassMixin): # don't do that # see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes - meta = { - 'abstract': True - } + meta = {"abstract": True} def __str__(self): attrs = list() - for k in sorted(self._fields.keys()): # pylint: disable=E1101 + for k in sorted(self._fields.keys()): # pylint: disable=E1101 v = getattr(self, k) - v = '"%s"' % str(v) if type(v) in [str, six.text_type, datetime.datetime] else str(v) - attrs.append('%s=%s' % (k, v)) - return '%s(%s)' % (self.__class__.__name__, ', '.join(attrs)) + v = ( + '"%s"' % str(v) + if type(v) in [str, six.text_type, datetime.datetime] + else str(v) + ) + attrs.append("%s=%s" % (k, v)) + return "%s(%s)" % (self.__class__.__name__, ", ".join(attrs)) def get_resource_type(self): return self.RESOURCE_TYPE @@ -98,7 +98,7 @@ def to_serializable_dict(self, mask_secrets=False): :rtype: ``dict`` """ serializable_dict = {} - for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101 + for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101 v = getattr(self, k) if isinstance(v, JSON_UNFRIENDLY_TYPES): v = str(v) @@ -120,17 +120,15 @@ class StormBaseDB(StormFoundationDB): description = me.StringField() # see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes - meta = { - 'abstract': True - } + meta = {"abstract": True} class EscapedDictField(me.DictField): - def to_mongo(self, value, use_db_field=True, fields=None): value = mongoescape.escape_chars(value) - return super(EscapedDictField, self).to_mongo(value=value, use_db_field=use_db_field, - fields=fields) + return super(EscapedDictField, self).to_mongo( + value=value, use_db_field=use_db_field, fields=fields + ) def to_python(self, value): value = super(EscapedDictField, self).to_python(value) @@ -138,18 +136,18 @@ def to_python(self, value): def validate(self, value): if not isinstance(value, dict): - self.error('Only dictionaries may be used in a DictField') + self.error("Only dictionaries may be used in a DictField") if me.fields.key_not_string(value): self.error("Invalid dictionary key - documents must have only string keys") me.base.ComplexBaseField.validate(self, value) class EscapedDynamicField(me.DynamicField): - def to_mongo(self, value, use_db_field=True, fields=None): value = mongoescape.escape_chars(value) - return super(EscapedDynamicField, self).to_mongo(value=value, use_db_field=use_db_field, - fields=fields) + return super(EscapedDynamicField, self).to_mongo( + value=value, use_db_field=use_db_field, fields=fields + ) def to_python(self, value): value = super(EscapedDynamicField, self).to_python(value) @@ -161,6 +159,7 @@ class TagField(me.EmbeddedDocument): To be attached to a db model object for the purpose of providing supplemental information. """ + name = me.StringField(max_length=1024) value = me.StringField(max_length=1024) @@ -169,11 +168,12 @@ class TagsMixin(object): """ Mixin to include tags on an object. """ + tags = me.ListField(field=me.EmbeddedDocumentField(TagField)) @classmethod def get_indexes(cls): - return ['tags.name', 'tags.value'] + return ["tags.name", "tags.value"] class RefFieldMixin(object): @@ -192,7 +192,7 @@ class UIDFieldMixin(object): the system. """ - UID_SEPARATOR = ':' # TODO: Move to constants + UID_SEPARATOR = ":" # TODO: Move to constants RESOURCE_TYPE = abc.abstractproperty UID_FIELDS = abc.abstractproperty @@ -205,13 +205,7 @@ def get_indexes(cls): # models in the database before ensure_indexes() is called. # This field gets populated in the constructor which means it will be lazily assigned next # time the model is saved (e.g. once register-content is ran). - indexes = [ - { - 'fields': ['uid'], - 'unique': True, - 'sparse': True - } - ] + indexes = [{"fields": ["uid"], "unique": True, "sparse": True}] return indexes def get_uid(self): @@ -224,7 +218,7 @@ def get_uid(self): parts.append(self.RESOURCE_TYPE) for field in self.UID_FIELDS: - value = getattr(self, field, None) or '' + value = getattr(self, field, None) or "" parts.append(value) uid = self.UID_SEPARATOR.join(parts) @@ -257,8 +251,11 @@ class ContentPackResourceMixin(object): metadata_file = me.StringField( required=False, - help_text=('Path to the metadata file (file on disk which contains resource definition) ' - 'relative to the pack directory.')) + help_text=( + "Path to the metadata file (file on disk which contains resource definition) " + "relative to the pack directory." + ), + ) def get_pack_uid(self): """ @@ -276,7 +273,7 @@ def get_reference(self): :rtype: :class:`ResourceReference` """ - if getattr(self, 'ref', None): + if getattr(self, "ref", None): ref = ResourceReference.from_string_reference(ref=self.ref) else: ref = ResourceReference(pack=self.pack, name=self.name) @@ -287,7 +284,7 @@ def get_reference(self): def get_indexes(cls): return [ { - 'fields': ['metadata_file'], + "fields": ["metadata_file"], } ] @@ -298,9 +295,4 @@ class ChangeRevisionFieldMixin(object): @classmethod def get_indexes(cls): - return [ - { - 'fields': ['id', 'rev'], - 'unique': True - } - ] + return [{"fields": ["id", "rev"], "unique": True}] diff --git a/st2common/st2common/models/db/timer.py b/st2common/st2common/models/db/timer.py index 98bb7952e1..652d6a056a 100644 --- a/st2common/st2common/models/db/timer.py +++ b/st2common/st2common/models/db/timer.py @@ -30,10 +30,10 @@ class TimerDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.TIMER - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") type = me.StringField() parameters = me.DictField() diff --git a/st2common/st2common/models/db/trace.py b/st2common/st2common/models/db/trace.py index 00b7010d91..fe358e90c9 100644 --- a/st2common/st2common/models/db/trace.py +++ b/st2common/st2common/models/db/trace.py @@ -25,25 +25,24 @@ from st2common.models.db import MongoDBAccess -__all__ = [ - 'TraceDB', - 'TraceComponentDB' -] +__all__ = ["TraceDB", "TraceComponentDB"] class TraceComponentDB(me.EmbeddedDocument): - """ - """ + """""" + object_id = me.StringField() - ref = me.StringField(default='') + ref = me.StringField(default="") updated_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the TraceComponent was included.') - caused_by = me.DictField(help_text='Causal component.') + help_text="The timestamp when the TraceComponent was included.", + ) + caused_by = me.DictField(help_text="Causal component.") def __str__(self): - return 'TraceComponentDB@(object_id:{}, updated_at:{})'.format( - self.object_id, self.updated_at) + return "TraceComponentDB@(object_id:{}, updated_at:{})".format( + self.object_id, self.updated_at + ) class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): @@ -66,28 +65,37 @@ class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): RESOURCE_TYPE = ResourceType.TRACE - trace_tag = me.StringField(required=True, - help_text='A user specified reference to the trace.') - trigger_instances = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated TriggerInstances.') - rules = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated Rules.') - action_executions = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated ActionExecutions.') - start_timestamp = ComplexDateTimeField(default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the Trace was created.') + trace_tag = me.StringField( + required=True, help_text="A user specified reference to the trace." + ) + trigger_instances = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated TriggerInstances.", + ) + rules = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated Rules.", + ) + action_executions = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated ActionExecutions.", + ) + start_timestamp = ComplexDateTimeField( + default=date_utils.get_datetime_utc_now, + help_text="The timestamp when the Trace was created.", + ) meta = { - 'indexes': [ - {'fields': ['trace_tag']}, - {'fields': ['start_timestamp']}, - {'fields': ['action_executions.object_id']}, - {'fields': ['trigger_instances.object_id']}, - {'fields': ['rules.object_id']}, - {'fields': ['-start_timestamp', 'trace_tag']}, + "indexes": [ + {"fields": ["trace_tag"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["action_executions.object_id"]}, + {"fields": ["trigger_instances.object_id"]}, + {"fields": ["rules.object_id"]}, + {"fields": ["-start_timestamp", "trace_tag"]}, ] } diff --git a/st2common/st2common/models/db/trigger.py b/st2common/st2common/models/db/trigger.py index 0546c3b739..9b749c5241 100644 --- a/st2common/st2common/models/db/trigger.py +++ b/st2common/st2common/models/db/trigger.py @@ -24,16 +24,18 @@ from st2common.constants.types import ResourceType __all__ = [ - 'TriggerTypeDB', - 'TriggerDB', - 'TriggerInstanceDB', + "TriggerTypeDB", + "TriggerDB", + "TriggerInstanceDB", ] -class TriggerTypeDB(stormbase.StormBaseDB, - stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin, - stormbase.TagsMixin): +class TriggerTypeDB( + stormbase.StormBaseDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, + stormbase.TagsMixin, +): """Description of a specific kind/type of a trigger. The (pack, name) tuple is expected uniquely identify a trigger in the namespace of all triggers provided by a specific trigger_source. @@ -45,18 +47,20 @@ class TriggerTypeDB(stormbase.StormBaseDB, """ RESOURCE_TYPE = ResourceType.TRIGGER_TYPE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] ref = me.StringField(required=False) name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") payload_schema = me.DictField() parameters_schema = me.DictField(default={}) meta = { - 'indexes': (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -66,8 +70,9 @@ def __init__(self, *args, **values): self.uid = self.get_uid() -class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class TriggerDB( + stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin +): """ Attribute: name - Trigger name. @@ -77,21 +82,22 @@ class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, """ RESOURCE_TYPE = ResourceType.TRIGGER - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] ref = me.StringField(required=False) name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") type = me.StringField() parameters = me.DictField() ref_count = me.IntField(default=0) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['type']}, - {'fields': ['parameters']}, - ] + stormbase.UIDFieldMixin.get_indexes() + "indexes": [ + {"fields": ["name"]}, + {"fields": ["type"]}, + {"fields": ["parameters"]}, + ] + + stormbase.UIDFieldMixin.get_indexes() } def __init__(self, *args, **values): @@ -106,7 +112,7 @@ def get_uid(self): # Note: We sort the resulting JSON object so that the same dictionary always results # in the same hash - parameters = getattr(self, 'parameters', {}) + parameters = getattr(self, "parameters", {}) parameters = json.dumps(parameters, sort_keys=True) parameters = hashlib.md5(parameters.encode()).hexdigest() @@ -126,19 +132,20 @@ class TriggerInstanceDB(stormbase.StormFoundationDB): payload (dict): payload specific to the occurrence. occurrence_time (datetime): time of occurrence of the trigger. """ + trigger = me.StringField() payload = stormbase.EscapedDictField() occurrence_time = me.DateTimeField() status = me.StringField( - required=True, - help_text='Processing status of TriggerInstance.') + required=True, help_text="Processing status of TriggerInstance." + ) meta = { - 'indexes': [ - {'fields': ['occurrence_time']}, - {'fields': ['trigger']}, - {'fields': ['-occurrence_time', 'trigger']}, - {'fields': ['status']} + "indexes": [ + {"fields": ["occurrence_time"]}, + {"fields": ["trigger"]}, + {"fields": ["-occurrence_time", "trigger"]}, + {"fields": ["status"]}, ] } diff --git a/st2common/st2common/models/db/webhook.py b/st2common/st2common/models/db/webhook.py index 0ef2906b90..b608f6c355 100644 --- a/st2common/st2common/models/db/webhook.py +++ b/st2common/st2common/models/db/webhook.py @@ -29,7 +29,7 @@ class WebhookDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.WEBHOOK - UID_FIELDS = ['name'] + UID_FIELDS = ["name"] name = me.StringField(required=True) @@ -40,7 +40,7 @@ def __init__(self, *args, **values): def _normalize_name(self, name): # Remove trailing slash if present - if name.endswith('/'): + if name.endswith("/"): name = name[:-1] return name diff --git a/st2common/st2common/models/db/workflow.py b/st2common/st2common/models/db/workflow.py index dc73c1c55c..fd5cdb111e 100644 --- a/st2common/st2common/models/db/workflow.py +++ b/st2common/st2common/models/db/workflow.py @@ -24,16 +24,15 @@ from st2common.util import date as date_utils -__all__ = [ - 'WorkflowExecutionDB', - 'TaskExecutionDB' -] +__all__ = ["WorkflowExecutionDB", "TaskExecutionDB"] LOG = logging.getLogger(__name__) -class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin): +class WorkflowExecutionDB( + stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin +): RESOURCE_TYPE = types.ResourceType.EXECUTION action_execution = me.StringField(required=True) @@ -46,14 +45,12 @@ class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionF status = me.StringField(required=True) output = stormbase.EscapedDictField() errors = stormbase.EscapedDynamicField() - start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now) + start_timestamp = db_field_types.ComplexDateTimeField( + default=date_utils.get_datetime_utc_now + ) end_timestamp = db_field_types.ComplexDateTimeField() - meta = { - 'indexes': [ - {'fields': ['action_execution']} - ] - } + meta = {"indexes": [{"fields": ["action_execution"]}]} class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin): @@ -71,21 +68,20 @@ class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionField context = stormbase.EscapedDictField() status = me.StringField(required=True) result = stormbase.EscapedDictField() - start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now) + start_timestamp = db_field_types.ComplexDateTimeField( + default=date_utils.get_datetime_utc_now + ) end_timestamp = db_field_types.ComplexDateTimeField() meta = { - 'indexes': [ - {'fields': ['workflow_execution']}, - {'fields': ['task_id']}, - {'fields': ['task_id', 'task_route']}, - {'fields': ['workflow_execution', 'task_id']}, - {'fields': ['workflow_execution', 'task_id', 'task_route']} + "indexes": [ + {"fields": ["workflow_execution"]}, + {"fields": ["task_id"]}, + {"fields": ["task_id", "task_route"]}, + {"fields": ["workflow_execution", "task_id"]}, + {"fields": ["workflow_execution", "task_id", "task_route"]}, ] } -MODELS = [ - WorkflowExecutionDB, - TaskExecutionDB -] +MODELS = [WorkflowExecutionDB, TaskExecutionDB] diff --git a/st2common/st2common/models/system/action.py b/st2common/st2common/models/system/action.py index b5efe124f5..2afcbf649b 100644 --- a/st2common/st2common/models/system/action.py +++ b/st2common/st2common/models/system/action.py @@ -35,11 +35,11 @@ from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE __all__ = [ - 'ShellCommandAction', - 'ShellScriptAction', - 'RemoteAction', - 'RemoteScriptAction', - 'ResolvedActionParameters' + "ShellCommandAction", + "ShellScriptAction", + "RemoteAction", + "RemoteScriptAction", + "ResolvedActionParameters", ] LOG = logging.getLogger(__name__) @@ -48,21 +48,31 @@ # Flags which are passed to every sudo invocation SUDO_COMMON_OPTIONS = [ - '-E' # we want to preserve the environment of the user which ran sudo -] + "-E" +] # we want to preserve the environment of the user which ran sudo # Flags which are only passed to sudo when not running as current user and when # -u flag is used SUDO_DIFFERENT_USER_OPTIONS = [ - '-H' # we want $HOME to reflect the home directory of the requested / target user + "-H" # we want $HOME to reflect the home directory of the requested / target user ] class ShellCommandAction(object): - EXPORT_CMD = 'export' - - def __init__(self, name, action_exec_id, command, user, env_vars=None, sudo=False, - timeout=None, cwd=None, sudo_password=None): + EXPORT_CMD = "export" + + def __init__( + self, + name, + action_exec_id, + command, + user, + env_vars=None, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): self.name = name self.action_exec_id = action_exec_id self.command = command @@ -77,15 +87,15 @@ def get_full_command_string(self): # Note: We pass -E to sudo because we want to preserve user provided environment variables if self.sudo: command = quote_unix(self.command) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: if self.user and self.user != LOGGED_USER_USERNAME: # Need to use sudo to run as a different (requested) user user = quote_unix(self.user) - sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user)) + sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user)) command = quote_unix(self.command) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: command = self.command @@ -103,7 +113,10 @@ def get_sanitized_full_command_string(self): if self.sudo_password: # Mask sudo password - command_string = 'echo -e \'%s\n\' | %s' % (MASKED_ATTRIBUTE_VALUE, command_string) + command_string = "echo -e '%s\n' | %s" % ( + MASKED_ATTRIBUTE_VALUE, + command_string, + ) return command_string @@ -124,7 +137,7 @@ def _get_common_sudo_arguments(self): if self.sudo_password: # Note: We use subprocess.Popen in local runner so we provide password via subprocess # stdin (using echo -e won't work when using subprocess.Popen) - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS @@ -139,7 +152,7 @@ def _get_user_sudo_arguments(self, user): """ flags = self._get_common_sudo_arguments() flags += SUDO_DIFFERENT_USER_OPTIONS - flags += ['-u', user] + flags += ["-u", user] return flags @@ -150,21 +163,21 @@ def _get_env_vars_export_string(self): # If sudo_password is provided, explicitly disable bash history to make sure password # is not logged, because password is provided via command line if self.sudo and self.sudo_password: - env_vars['HISTFILE'] = '/dev/null' - env_vars['HISTSIZE'] = '0' + env_vars["HISTFILE"] = "/dev/null" + env_vars["HISTSIZE"] = "0" # Sort the dict to guarantee consistent order env_vars = collections.OrderedDict(sorted(env_vars.items())) # Environment variables could contain spaces and open us to shell # injection attacks. Always quote the key and the value. - exports = ' '.join( - '%s=%s' % (quote_unix(k), quote_unix(v)) + exports = " ".join( + "%s=%s" % (quote_unix(k), quote_unix(v)) for k, v in six.iteritems(env_vars) ) - shell_env_str = '%s %s' % (ShellCommandAction.EXPORT_CMD, exports) + shell_env_str = "%s %s" % (ShellCommandAction.EXPORT_CMD, exports) else: - shell_env_str = '' + shell_env_str = "" return shell_env_str @@ -180,8 +193,8 @@ def _get_command_string(self, cmd, args): assert isinstance(args, (list, tuple)) args = [quote_unix(arg) for arg in args] - args = ' '.join(args) - result = '%s %s' % (cmd, args) + args = " ".join(args) + result = "%s %s" % (cmd, args) return result def _get_error_result(self): @@ -195,24 +208,42 @@ def _get_error_result(self): _, exc_value, exc_traceback = sys.exc_info() exc_value = str(exc_value) - exc_traceback = ''.join(traceback.format_tb(exc_traceback)) + exc_traceback = "".join(traceback.format_tb(exc_traceback)) result = {} - result['failed'] = True - result['succeeded'] = False - result['error'] = exc_value - result['traceback'] = exc_traceback + result["failed"] = True + result["succeeded"] = False + result["error"] = exc_value + result["traceback"] = exc_traceback return result class ShellScriptAction(ShellCommandAction): - def __init__(self, name, action_exec_id, script_local_path_abs, named_args=None, - positional_args=None, env_vars=None, user=None, sudo=False, timeout=None, - cwd=None, sudo_password=None): - super(ShellScriptAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=None, user=user, env_vars=env_vars, - sudo=sudo, timeout=timeout, - cwd=cwd, sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + script_local_path_abs, + named_args=None, + positional_args=None, + env_vars=None, + user=None, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): + super(ShellScriptAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=None, + user=user, + env_vars=env_vars, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.script_local_path_abs = script_local_path_abs self.named_args = named_args self.positional_args = positional_args @@ -221,33 +252,38 @@ def get_full_command_string(self): return self._format_command() def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) if self.sudo: if script_arguments: - command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments)) + command = quote_unix( + "%s %s" % (self.script_local_path_abs, script_arguments) + ) else: command = quote_unix(self.script_local_path_abs) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: if self.user and self.user != LOGGED_USER_USERNAME: # Need to use sudo to run as a different user user = quote_unix(self.user) if script_arguments: - command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments)) + command = quote_unix( + "%s %s" % (self.script_local_path_abs, script_arguments) + ) else: command = quote_unix(self.script_local_path_abs) - sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user)) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user)) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: script_path = quote_unix(self.script_local_path_abs) if script_arguments: - command = '%s %s' % (script_path, script_arguments) + command = "%s %s" % (script_path, script_arguments) else: command = script_path return command @@ -270,8 +306,10 @@ def _get_script_arguments(self, named_args=None, positional_args=None): # add all named_args in the format name=value (e.g. --name=value) if named_args is not None: for (arg, value) in six.iteritems(named_args): - if value is None or (isinstance(value, (str, six.text_type)) and len(value) < 1): - LOG.debug('Ignoring arg %s as its value is %s.', arg, value) + if value is None or ( + isinstance(value, (str, six.text_type)) and len(value) < 1 + ): + LOG.debug("Ignoring arg %s as its value is %s.", arg, value) continue if isinstance(value, bool): @@ -279,24 +317,45 @@ def _get_script_arguments(self, named_args=None, positional_args=None): command_parts.append(arg) else: values = (quote_unix(arg), quote_unix(six.text_type(value))) - command_parts.append(six.text_type('%s=%s' % values)) + command_parts.append(six.text_type("%s=%s" % values)) # add the positional args if positional_args: quoted_pos_args = [quote_unix(pos_arg) for pos_arg in positional_args] - pos_args_string = ' '.join(quoted_pos_args) + pos_args_string = " ".join(quoted_pos_args) command_parts.append(pos_args_string) - return ' '.join(command_parts) + return " ".join(command_parts) class SSHCommandAction(ShellCommandAction): - def __init__(self, name, action_exec_id, command, env_vars, user, password=None, pkey=None, - hosts=None, parallel=True, sudo=False, timeout=None, cwd=None, passphrase=None, - sudo_password=None): - super(SSHCommandAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=command, env_vars=env_vars, user=user, - sudo=sudo, timeout=timeout, cwd=cwd, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + command, + env_vars, + user, + password=None, + pkey=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + passphrase=None, + sudo_password=None, + ): + super(SSHCommandAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=command, + env_vars=env_vars, + user=user, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.hosts = hosts self.parallel = parallel self.pkey = pkey @@ -329,25 +388,51 @@ def get_command(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('command: %s' % self.command) - str_rep.append('user: %s' % self.user) - str_rep.append('sudo: %s' % str(self.sudo)) - str_rep.append('parallel: %s' % str(self.parallel)) - str_rep.append('hosts: %s)' % str(self.hosts)) - return ', '.join(str_rep) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("command: %s" % self.command) + str_rep.append("user: %s" % self.user) + str_rep.append("sudo: %s" % str(self.sudo)) + str_rep.append("parallel: %s" % str(self.parallel)) + str_rep.append("hosts: %s)" % str(self.hosts)) + return ", ".join(str_rep) class RemoteAction(SSHCommandAction): - def __init__(self, name, action_exec_id, command, env_vars=None, on_behalf_user=None, - user=None, password=None, private_key=None, hosts=None, parallel=True, sudo=False, - timeout=None, cwd=None, passphrase=None, sudo_password=None): - super(RemoteAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=command, env_vars=env_vars, user=user, - hosts=hosts, parallel=parallel, sudo=sudo, - timeout=timeout, cwd=cwd, passphrase=passphrase, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + command, + env_vars=None, + on_behalf_user=None, + user=None, + password=None, + private_key=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + passphrase=None, + sudo_password=None, + ): + super(RemoteAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=command, + env_vars=env_vars, + user=user, + hosts=hosts, + parallel=parallel, + sudo=sudo, + timeout=timeout, + cwd=cwd, + passphrase=passphrase, + sudo_password=sudo_password, + ) self.password = password self.private_key = private_key self.passphrase = passphrase @@ -359,34 +444,61 @@ def get_on_behalf_user(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('command: %s' % self.command) - str_rep.append('user: %s' % self.user) - str_rep.append('on_behalf_user: %s' % self.on_behalf_user) - str_rep.append('sudo: %s' % str(self.sudo)) - str_rep.append('parallel: %s' % str(self.parallel)) - str_rep.append('hosts: %s)' % str(self.hosts)) - str_rep.append('timeout: %s)' % str(self.timeout)) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("command: %s" % self.command) + str_rep.append("user: %s" % self.user) + str_rep.append("on_behalf_user: %s" % self.on_behalf_user) + str_rep.append("sudo: %s" % str(self.sudo)) + str_rep.append("parallel: %s" % str(self.parallel)) + str_rep.append("hosts: %s)" % str(self.hosts)) + str_rep.append("timeout: %s)" % str(self.timeout)) - return ', '.join(str_rep) + return ", ".join(str_rep) class RemoteScriptAction(ShellScriptAction): - def __init__(self, name, action_exec_id, script_local_path_abs, script_local_libs_path_abs, - named_args=None, positional_args=None, env_vars=None, on_behalf_user=None, - user=None, password=None, private_key=None, remote_dir=None, hosts=None, - parallel=True, sudo=False, timeout=None, cwd=None, sudo_password=None): - super(RemoteScriptAction, self).__init__(name=name, action_exec_id=action_exec_id, - script_local_path_abs=script_local_path_abs, - user=user, - named_args=named_args, - positional_args=positional_args, env_vars=env_vars, - sudo=sudo, timeout=timeout, cwd=cwd, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + script_local_path_abs, + script_local_libs_path_abs, + named_args=None, + positional_args=None, + env_vars=None, + on_behalf_user=None, + user=None, + password=None, + private_key=None, + remote_dir=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): + super(RemoteScriptAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + script_local_path_abs=script_local_path_abs, + user=user, + named_args=named_args, + positional_args=positional_args, + env_vars=env_vars, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.script_local_libs_path_abs = script_local_libs_path_abs - self.script_local_dir, self.script_name = os.path.split(self.script_local_path_abs) - self.remote_dir = remote_dir if remote_dir is not None else '/tmp' + self.script_local_dir, self.script_name = os.path.split( + self.script_local_path_abs + ) + self.remote_dir = remote_dir if remote_dir is not None else "/tmp" self.remote_libs_path_abs = os.path.join(self.remote_dir, ACTION_LIBS_DIR) self.on_behalf_user = on_behalf_user self.password = password @@ -395,7 +507,7 @@ def __init__(self, name, action_exec_id, script_local_path_abs, script_local_lib self.hosts = hosts self.parallel = parallel self.command = self._format_command() - LOG.debug('RemoteScriptAction: command to run on remote box: %s', self.command) + LOG.debug("RemoteScriptAction: command to run on remote box: %s", self.command) def get_remote_script_abs_path(self): return self.remote_script @@ -413,11 +525,12 @@ def get_remote_base_dir(self): return self.remote_dir def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) if script_arguments: - command = '%s %s' % (self.remote_script, script_arguments) + command = "%s %s" % (self.remote_script, script_arguments) else: command = self.remote_script @@ -425,21 +538,23 @@ def _format_command(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('local_script: %s' % self.script_local_path_abs) - str_rep.append('local_libs: %s' % self.script_local_libs_path_abs) - str_rep.append('remote_dir: %s' % self.remote_dir) - str_rep.append('remote_libs: %s' % self.remote_libs_path_abs) - str_rep.append('named_args: %s' % self.named_args) - str_rep.append('positional_args: %s' % self.positional_args) - str_rep.append('user: %s' % self.user) - str_rep.append('on_behalf_user: %s' % self.on_behalf_user) - str_rep.append('sudo: %s' % self.sudo) - str_rep.append('parallel: %s' % self.parallel) - str_rep.append('hosts: %s)' % self.hosts) - - return ', '.join(str_rep) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("local_script: %s" % self.script_local_path_abs) + str_rep.append("local_libs: %s" % self.script_local_libs_path_abs) + str_rep.append("remote_dir: %s" % self.remote_dir) + str_rep.append("remote_libs: %s" % self.remote_libs_path_abs) + str_rep.append("named_args: %s" % self.named_args) + str_rep.append("positional_args: %s" % self.positional_args) + str_rep.append("user: %s" % self.user) + str_rep.append("on_behalf_user: %s" % self.on_behalf_user) + str_rep.append("sudo: %s" % self.sudo) + str_rep.append("parallel: %s" % self.parallel) + str_rep.append("hosts: %s)" % self.hosts) + + return ", ".join(str_rep) class ResolvedActionParameters(DictSerializableClassMixin): @@ -447,7 +562,9 @@ class ResolvedActionParameters(DictSerializableClassMixin): Class which contains resolved runner and action parameters for a particular action. """ - def __init__(self, action_db, runner_type_db, runner_parameters=None, action_parameters=None): + def __init__( + self, action_db, runner_type_db, runner_parameters=None, action_parameters=None + ): self._action_db = action_db self._runner_type_db = runner_type_db self._runner_parameters = runner_parameters @@ -456,28 +573,34 @@ def __init__(self, action_db, runner_type_db, runner_parameters=None, action_par def mask_secrets(self, value): result = copy.deepcopy(value) - runner_parameters = result['runner_parameters'] - action_parameters = result['action_parameters'] + runner_parameters = result["runner_parameters"] + action_parameters = result["action_parameters"] runner_parameters_specs = self._runner_type_db.runner_parameters action_parameters_sepcs = self._action_db.parameters - secret_runner_parameters = get_secret_parameters(parameters=runner_parameters_specs) - secret_action_parameters = get_secret_parameters(parameters=action_parameters_sepcs) - - runner_parameters = mask_secret_parameters(parameters=runner_parameters, - secret_parameters=secret_runner_parameters) - action_parameters = mask_secret_parameters(parameters=action_parameters, - secret_parameters=secret_action_parameters) - result['runner_parameters'] = runner_parameters - result['action_parameters'] = action_parameters + secret_runner_parameters = get_secret_parameters( + parameters=runner_parameters_specs + ) + secret_action_parameters = get_secret_parameters( + parameters=action_parameters_sepcs + ) + + runner_parameters = mask_secret_parameters( + parameters=runner_parameters, secret_parameters=secret_runner_parameters + ) + action_parameters = mask_secret_parameters( + parameters=action_parameters, secret_parameters=secret_action_parameters + ) + result["runner_parameters"] = runner_parameters + result["action_parameters"] = action_parameters return result def to_serializable_dict(self, mask_secrets=False): result = {} - result['runner_parameters'] = self._runner_parameters - result['action_parameters'] = self._action_parameters + result["runner_parameters"] = self._runner_parameters + result["action_parameters"] = self._action_parameters if mask_secrets and cfg.CONF.log.mask_secrets: result = self.mask_secrets(value=result) diff --git a/st2common/st2common/models/system/actionchain.py b/st2common/st2common/models/system/actionchain.py index 2c5ce24c3d..24a84cc6b6 100644 --- a/st2common/st2common/models/system/actionchain.py +++ b/st2common/st2common/models/system/actionchain.py @@ -31,45 +31,45 @@ class Node(object): "name": { "description": "The name of this node.", "type": "string", - "required": True + "required": True, }, "ref": { "type": "string", "description": "Ref of the action to be executed.", - "required": True + "required": True, }, "params": { "type": "object", - "description": ("Parameter for the execution (old name, here for backward " - "compatibility reasons)."), - "default": {} + "description": ( + "Parameter for the execution (old name, here for backward " + "compatibility reasons)." + ), + "default": {}, }, "parameters": { "type": "object", "description": "Parameter for the execution.", - "default": {} + "default": {}, }, "on-success": { "type": "string", "description": "Name of the node to invoke on successful completion of action" - " executed for this node.", - "default": "" + " executed for this node.", + "default": "", }, "on-failure": { "type": "string", "description": "Name of the node to invoke on failure of action executed for this" - " node.", - "default": "" + " node.", + "default": "", }, "publish": { "description": "The variables to publish from the result. Should be of the form" - " name.foo. o1: {{node_name.foo}} will result in creation of a" - " variable o1 which is now available for reference through" - " remainder of the chain as a global variable.", + " name.foo. o1: {{node_name.foo}} will result in creation of a" + " variable o1 which is now available for reference through" + " remainder of the chain as a global variable.", "type": "object", - "patternProperties": { - r"^\w+$": {} - } + "patternProperties": {r"^\w+$": {}}, }, "notify": { "description": "Notification settings for action.", @@ -77,43 +77,49 @@ class Node(object): "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False - } + "additionalProperties": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): - for prop in six.iterkeys(self.schema.get('properties', [])): + for prop in six.iterkeys(self.schema.get("properties", [])): value = kw.get(prop, None) # having '-' in the property name lead to challenges in referencing the property. # At hindsight the schema property should've been on_success rather than on-success. - prop = prop.replace('-', '_') + prop = prop.replace("-", "_") setattr(self, prop, value) def validate(self): - params = getattr(self, 'params', {}) - parameters = getattr(self, 'parameters', {}) + params = getattr(self, "params", {}) + parameters = getattr(self, "parameters", {}) if params and parameters: - msg = ('Either "params" or "parameters" attribute needs to be provided, but not ' - 'both') + msg = ( + 'Either "params" or "parameters" attribute needs to be provided, but not ' + "both" + ) raise ValueError(msg) return self def get_parameters(self): # Note: "params" is old deprecated attribute which will be removed in a future release - params = getattr(self, 'params', {}) - parameters = getattr(self, 'parameters', {}) + params = getattr(self, "params", {}) + parameters = getattr(self, "parameters", {}) return parameters or params def __repr__(self): - return ('' % - (self.name, self.ref, self.on_success, self.on_failure)) + return "" % ( + self.name, + self.ref, + self.on_success, + self.on_failure, + ) class ActionChain(object): @@ -127,31 +133,34 @@ class ActionChain(object): "description": "The chain.", "type": "array", "items": [Node.schema], - "required": True + "required": True, }, "default": { "type": "string", - "description": "name of the action to be executed." + "description": "name of the action to be executed.", }, "vars": { "description": "", "type": "object", - "patternProperties": { - r"^\w+$": {} - } - } + "patternProperties": {r"^\w+$": {}}, + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): - util_schema.validate(instance=kw, schema=self.schema, cls=util_schema.CustomValidator, - use_default=False, allow_default_none=True) - - for prop in six.iterkeys(self.schema.get('properties', [])): + util_schema.validate( + instance=kw, + schema=self.schema, + cls=util_schema.CustomValidator, + use_default=False, + allow_default_none=True, + ) + + for prop in six.iterkeys(self.schema.get("properties", [])): value = kw.get(prop, None) # special handling for chain property to create the Node object - if prop == 'chain': + if prop == "chain": nodes = [] for node in value: ac_node = Node(**node) diff --git a/st2common/st2common/models/system/common.py b/st2common/st2common/models/system/common.py index a56f6701ac..72ad6c3f84 100644 --- a/st2common/st2common/models/system/common.py +++ b/st2common/st2common/models/system/common.py @@ -14,17 +14,17 @@ # limitations under the License. __all__ = [ - 'InvalidReferenceError', - 'InvalidResourceReferenceError', - 'ResourceReference', + "InvalidReferenceError", + "InvalidResourceReferenceError", + "ResourceReference", ] -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class InvalidReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid reference: %s' % (ref) + message = "Invalid reference: %s" % (ref) self.ref = ref self.message = message super(InvalidReferenceError, self).__init__(message) @@ -32,7 +32,7 @@ def __init__(self, ref): class InvalidResourceReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid resource reference: %s' % (ref) + message = "Invalid resource reference: %s" % (ref) self.ref = ref self.message = message super(InvalidResourceReferenceError, self).__init__(message) @@ -42,6 +42,7 @@ class ResourceReference(object): """ Class used for referring to resources which belong to a content pack. """ + def __init__(self, pack=None, name=None): self.pack = self.validate_pack_name(pack=pack) self.name = name @@ -72,8 +73,10 @@ def to_string_reference(pack=None, name=None): pack = ResourceReference.validate_pack_name(pack=pack) return PACK_SEPARATOR.join([pack, name]) else: - raise ValueError('Both pack and name needed for building ref. pack=%s, name=%s' % - (pack, name)) + raise ValueError( + "Both pack and name needed for building ref. pack=%s, name=%s" + % (pack, name) + ) @staticmethod def validate_pack_name(pack): @@ -97,5 +100,8 @@ def get_name(ref): raise InvalidResourceReferenceError(ref=ref) def __repr__(self): - return ('' % - (self.pack, self.name, self.ref)) + return "" % ( + self.pack, + self.name, + self.ref, + ) diff --git a/st2common/st2common/models/system/keyvalue.py b/st2common/st2common/models/system/keyvalue.py index 0bac5949d8..018df95602 100644 --- a/st2common/st2common/models/system/keyvalue.py +++ b/st2common/st2common/models/system/keyvalue.py @@ -17,13 +17,13 @@ from st2common.constants.keyvalue import USER_SEPARATOR __all__ = [ - 'InvalidUserKeyReferenceError', + "InvalidUserKeyReferenceError", ] class InvalidUserKeyReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid resource reference: %s' % (ref) + message = "Invalid resource reference: %s" % (ref) self.ref = ref self.message = message super(InvalidUserKeyReferenceError, self).__init__(message) @@ -38,7 +38,7 @@ class UserKeyReference(object): def __init__(self, user, name): self._user = user self._name = name - self.ref = ('%s%s%s' % (self._user, USER_SEPARATOR, self._name)) + self.ref = "%s%s%s" % (self._user, USER_SEPARATOR, self._name) def __str__(self): return self.ref diff --git a/st2common/st2common/models/system/paramiko_command_action.py b/st2common/st2common/models/system/paramiko_command_action.py index a96183ef9e..685ffeb67c 100644 --- a/st2common/st2common/models/system/paramiko_command_action.py +++ b/st2common/st2common/models/system/paramiko_command_action.py @@ -23,7 +23,7 @@ from st2common.util.shell import quote_unix __all__ = [ - 'ParamikoRemoteCommandAction', + "ParamikoRemoteCommandAction", ] LOG = logging.getLogger(__name__) @@ -32,7 +32,6 @@ class ParamikoRemoteCommandAction(RemoteAction): - def get_full_command_string(self): # Note: We pass -E to sudo because we want to preserve user provided environment variables env_str = self._get_env_vars_export_string() @@ -40,24 +39,25 @@ def get_full_command_string(self): if self.sudo: if env_str: - command = quote_unix('%s && cd %s && %s' % (env_str, cwd, self.command)) + command = quote_unix("%s && cd %s && %s" % (env_str, cwd, self.command)) else: - command = quote_unix('cd %s && %s' % (cwd, self.command)) + command = quote_unix("cd %s && %s" % (cwd, self.command)) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) if self.sudo_password: - command = ('set +o history ; echo -e %s | %s' % - (quote_unix('%s\n' % (self.sudo_password)), command)) + command = "set +o history ; echo -e %s | %s" % ( + quote_unix("%s\n" % (self.sudo_password)), + command, + ) else: if env_str: - command = '%s && cd %s && %s' % (env_str, cwd, - self.command) + command = "%s && cd %s && %s" % (env_str, cwd, self.command) else: - command = 'cd %s && %s' % (cwd, self.command) + command = "cd %s && %s" % (cwd, self.command) - LOG.debug('Command to run on remote host will be: %s', command) + LOG.debug("Command to run on remote host will be: %s", command) return command def _get_common_sudo_arguments(self): @@ -69,7 +69,7 @@ def _get_common_sudo_arguments(self): flags = [] if self.sudo_password: - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS diff --git a/st2common/st2common/models/system/paramiko_script_action.py b/st2common/st2common/models/system/paramiko_script_action.py index a6ff26a751..284e87a708 100644 --- a/st2common/st2common/models/system/paramiko_script_action.py +++ b/st2common/st2common/models/system/paramiko_script_action.py @@ -20,7 +20,7 @@ from st2common.util.shell import quote_unix __all__ = [ - 'ParamikoRemoteScriptAction', + "ParamikoRemoteScriptAction", ] @@ -28,10 +28,10 @@ class ParamikoRemoteScriptAction(RemoteScriptAction): - def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) env_str = self._get_env_vars_export_string() cwd = quote_unix(self.get_cwd()) script_path = quote_unix(self.remote_script) @@ -39,36 +39,46 @@ def _format_command(self): if self.sudo: if script_arguments: if env_str: - command = quote_unix('%s && cd %s && %s %s' % ( - env_str, cwd, script_path, script_arguments)) + command = quote_unix( + "%s && cd %s && %s %s" + % (env_str, cwd, script_path, script_arguments) + ) else: - command = quote_unix('cd %s && %s %s' % ( - cwd, script_path, script_arguments)) + command = quote_unix( + "cd %s && %s %s" % (cwd, script_path, script_arguments) + ) else: if env_str: - command = quote_unix('%s && cd %s && %s' % ( - env_str, cwd, script_path)) + command = quote_unix( + "%s && cd %s && %s" % (env_str, cwd, script_path) + ) else: - command = quote_unix('cd %s && %s' % (cwd, script_path)) + command = quote_unix("cd %s && %s" % (cwd, script_path)) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) if self.sudo_password: - command = ('set +o history ; echo -e %s | %s' % - (quote_unix('%s\n' % (self.sudo_password)), command)) + command = "set +o history ; echo -e %s | %s" % ( + quote_unix("%s\n" % (self.sudo_password)), + command, + ) else: if script_arguments: if env_str: - command = '%s && cd %s && %s %s' % (env_str, cwd, - script_path, script_arguments) + command = "%s && cd %s && %s %s" % ( + env_str, + cwd, + script_path, + script_arguments, + ) else: - command = 'cd %s && %s %s' % (cwd, script_path, script_arguments) + command = "cd %s && %s %s" % (cwd, script_path, script_arguments) else: if env_str: - command = '%s && cd %s && %s' % (env_str, cwd, script_path) + command = "%s && cd %s && %s" % (env_str, cwd, script_path) else: - command = 'cd %s && %s' % (cwd, script_path) + command = "cd %s && %s" % (cwd, script_path) return command @@ -81,7 +91,7 @@ def _get_common_sudo_arguments(self): flags = [] if self.sudo_password: - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS diff --git a/st2common/st2common/models/utils/action_alias_utils.py b/st2common/st2common/models/utils/action_alias_utils.py index 06106a2794..bf6d47c8b4 100644 --- a/st2common/st2common/models/utils/action_alias_utils.py +++ b/st2common/st2common/models/utils/action_alias_utils.py @@ -18,9 +18,15 @@ import re import sys -from sre_parse import ( # pylint: disable=E0611 - parse, AT, AT_BEGINNING, AT_BEGINNING_STRING, - AT_END, AT_END_STRING, BRANCH, SUBPATTERN, +from sre_parse import ( # pylint: disable=E0611 + parse, + AT, + AT_BEGINNING, + AT_BEGINNING_STRING, + AT_END, + AT_END_STRING, + BRANCH, + SUBPATTERN, ) from st2common.util.jinja import render_values @@ -30,11 +36,10 @@ from st2common import log __all__ = [ - 'ActionAliasFormatParser', - - 'extract_parameters_for_action_alias_db', - 'extract_parameters', - 'search_regex_tokens', + "ActionAliasFormatParser", + "extract_parameters_for_action_alias_db", + "extract_parameters", + "search_regex_tokens", ] @@ -48,10 +53,9 @@ class ActionAliasFormatParser(object): - def __init__(self, alias_format=None, param_stream=None): - self._format = alias_format or '' - self._original_param_stream = param_stream or '' + self._format = alias_format or "" + self._original_param_stream = param_stream or "" self._param_stream = self._original_param_stream self._snippets = self.generate_snippets() @@ -76,26 +80,26 @@ def generate_snippets(self): # Formats for keys and values: key is a non-spaced string, # value is anything in quotes or curly braces, or a single word. - snippets['key'] = r'\s*(\S+?)\s*' - snippets['value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)' + snippets["key"] = r"\s*(\S+?)\s*" + snippets["value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)' # Extended value: also matches unquoted text (caution). - snippets['ext_value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)' + snippets["ext_value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)' # Key-value pair: - snippets['pairs'] = r'(?:^|\s+){key}=({value})'.format(**snippets) + snippets["pairs"] = r"(?:^|\s+){key}=({value})".format(**snippets) # End of string: multiple space-separated key-value pairs: - snippets['ending'] = r'.*?(({pairs}\s*)*)$'.format(**snippets) + snippets["ending"] = r".*?(({pairs}\s*)*)$".format(**snippets) # Default value in optional parameters: - snippets['default'] = r'\s*=\s*(?:{ext_value})\s*'.format(**snippets) + snippets["default"] = r"\s*=\s*(?:{ext_value})\s*".format(**snippets) # Optional parameter (has a default value): - snippets['optional'] = '{{' + snippets['key'] + snippets['default'] + '}}' + snippets["optional"] = "{{" + snippets["key"] + snippets["default"] + "}}" # Required parameter (no default value): - snippets['required'] = '{{' + snippets['key'] + '}}' + snippets["required"] = "{{" + snippets["key"] + "}}" return snippets @@ -105,11 +109,13 @@ def match_kv_pairs_at_end(self): # 1. Matching the arbitrary key-value pairs at the end of the command # to support extra parameters (not specified in the format string), # and cutting them from the command string afterwards. - ending_pairs = re.match(self._snippets['ending'], param_stream, re.DOTALL) + ending_pairs = re.match(self._snippets["ending"], param_stream, re.DOTALL) has_ending_pairs = ending_pairs and ending_pairs.group(1) if has_ending_pairs: - kv_pairs = re.findall(self._snippets['pairs'], ending_pairs.group(1), re.DOTALL) - param_stream = param_stream.replace(ending_pairs.group(1), '') + kv_pairs = re.findall( + self._snippets["pairs"], ending_pairs.group(1), re.DOTALL + ) + param_stream = param_stream.replace(ending_pairs.group(1), "") else: kv_pairs = [] param_stream = " %s " % (param_stream) @@ -118,27 +124,36 @@ def match_kv_pairs_at_end(self): def generate_optional_params_regex(self): # 2. Matching optional parameters (with default values). - return re.findall(self._snippets['optional'], self._format, re.DOTALL) + return re.findall(self._snippets["optional"], self._format, re.DOTALL) def transform_format_string_into_regex(self): # 3. Convert the mangled format string into a regex object # Transforming our format string into a regular expression, # substituting {{ ... }} with regex named groups, so that param_stream # matched against this expression yields a dict of params with values. - param_match = r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?' - reg = re.sub(r'(\s*)' + self._snippets['optional'], r'(?:' + param_match + r')?', - self._format) - reg = re.sub(r'(\s*)' + self._snippets['required'], param_match, reg) + param_match = ( + r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?' + ) + reg = re.sub( + r"(\s*)" + self._snippets["optional"], + r"(?:" + param_match + r")?", + self._format, + ) + reg = re.sub(r"(\s*)" + self._snippets["required"], param_match, reg) reg_tokens = parse(reg, flags=re.DOTALL) # Add a beginning anchor if none exists - if not search_regex_tokens(((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens): - reg = r'^\s*' + reg + if not search_regex_tokens( + ((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens + ): + reg = r"^\s*" + reg # Add an ending anchor if none exists - if not search_regex_tokens(((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True): - reg = reg + r'\s*$' + if not search_regex_tokens( + ((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True + ): + reg = reg + r"\s*$" return re.compile(reg, re.DOTALL) @@ -147,8 +162,10 @@ def match_params_in_stream(self, matched_stream): if not matched_stream: # If no match is found we throw since this indicates provided user string (command) # didn't match the provided format string - raise ParseException('Command "%s" doesn\'t match format string "%s"' % - (self._original_param_stream, self._format)) + raise ParseException( + 'Command "%s" doesn\'t match format string "%s"' + % (self._original_param_stream, self._format) + ) # Compiling results from the steps 1-3. if matched_stream: @@ -157,16 +174,16 @@ def match_params_in_stream(self, matched_stream): # Apply optional parameters/add the default parameters for param in self._optional: matched_value = result[param[0]] if matched_stream else None - matched_result = matched_value or ''.join(param[1:]) + matched_result = matched_value or "".join(param[1:]) if matched_result is not None: result[param[0]] = matched_result # Apply given parameters for pair in self._kv_pairs: - result[pair[0]] = ''.join(pair[2:]) + result[pair[0]] = "".join(pair[2:]) if self._format and not (self._param_stream.strip() or any(result.values())): - raise ParseException('No value supplied and no default value found.') + raise ParseException("No value supplied and no default value found.") return result @@ -196,8 +213,9 @@ def get_multiple_extracted_param_value(self): return results -def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_stream, - match_multiple=False): +def extract_parameters_for_action_alias_db( + action_alias_db, format_str, param_stream, match_multiple=False +): """ Extract parameters from the user input based on the provided format string. @@ -208,13 +226,14 @@ def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_st formats = action_alias_db.get_format_strings() if format_str not in formats: - raise ValueError('Format string "%s" is not available on the alias "%s"' % - (format_str, action_alias_db.name)) + raise ValueError( + 'Format string "%s" is not available on the alias "%s"' + % (format_str, action_alias_db.name) + ) result = extract_parameters( - format_str=format_str, - param_stream=param_stream, - match_multiple=match_multiple) + format_str=format_str, param_stream=param_stream, match_multiple=match_multiple + ) return result @@ -226,7 +245,9 @@ def extract_parameters(format_str, param_stream, match_multiple=False): return parser.get_extracted_param_value() -def inject_immutable_parameters(action_alias_db, multiple_execution_parameters, action_context): +def inject_immutable_parameters( + action_alias_db, multiple_execution_parameters, action_context +): """ Inject immutable parameters from the alias definiton on the execution parameters. Jinja expressions will be resolved. @@ -235,26 +256,34 @@ def inject_immutable_parameters(action_alias_db, multiple_execution_parameters, if not immutable_parameters: return multiple_execution_parameters - user = action_context.get('user', None) + user = action_context.get("user", None) context = {} - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE), - kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( - scope=kv_constants.FULL_USER_SCOPE, user=user) + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ), + kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( + scope=kv_constants.FULL_USER_SCOPE, user=user + ), + } } - }) + ) context.update(action_context) rendered_params = render_values(immutable_parameters, context) for exec_params in multiple_execution_parameters: - overriden = [param for param in immutable_parameters.keys() if param in exec_params] + overriden = [ + param for param in immutable_parameters.keys() if param in exec_params + ] if overriden: raise ValueError( "Immutable arguments cannot be overriden: {}".format( - ','.join(overriden))) + ",".join(overriden) + ) + ) exec_params.update(rendered_params) diff --git a/st2common/st2common/models/utils/action_param_utils.py b/st2common/st2common/models/utils/action_param_utils.py index 1ecf6dbbe8..3edbeae6ed 100644 --- a/st2common/st2common/models/utils/action_param_utils.py +++ b/st2common/st2common/models/utils/action_param_utils.py @@ -33,7 +33,7 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): merged_meta = {} # ?? Runner immutable param's meta shouldn't be allowed to be modified by action whatsoever. - if runner_meta and runner_meta.get('immutable', False): + if runner_meta and runner_meta.get("immutable", False): merged_meta = runner_meta for key in all_keys: @@ -42,8 +42,10 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): elif key in runner_meta_keys and key not in action_meta_keys: merged_meta[key] = runner_meta[key] else: - if key in ['immutable']: - merged_meta[key] = runner_meta.get(key, False) or action_meta.get(key, False) + if key in ["immutable"]: + merged_meta[key] = runner_meta.get(key, False) or action_meta.get( + key, False + ) else: merged_meta[key] = action_meta.get(key) return merged_meta @@ -51,12 +53,12 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): def get_params_view(action_db=None, runner_db=None, merged_only=False): if runner_db: - runner_params = fast_deepcopy(getattr(runner_db, 'runner_parameters', {})) or {} + runner_params = fast_deepcopy(getattr(runner_db, "runner_parameters", {})) or {} else: runner_params = {} if action_db: - action_params = fast_deepcopy(getattr(action_db, 'parameters', {})) or {} + action_params = fast_deepcopy(getattr(action_db, "parameters", {})) or {} else: action_params = {} @@ -64,19 +66,22 @@ def get_params_view(action_db=None, runner_db=None, merged_only=False): merged_params = {} for param in parameters: - merged_params[param] = _merge_param_meta_values(action_meta=action_params.get(param), - runner_meta=runner_params.get(param)) + merged_params[param] = _merge_param_meta_values( + action_meta=action_params.get(param), runner_meta=runner_params.get(param) + ) if merged_only: return merged_params def is_required(param_meta): - return param_meta.get('required', False) + return param_meta.get("required", False) def is_immutable(param_meta): - return param_meta.get('immutable', False) + return param_meta.get("immutable", False) - immutable = {param for param in parameters if is_immutable(merged_params.get(param))} + immutable = { + param for param in parameters if is_immutable(merged_params.get(param)) + } required = {param for param in parameters if is_required(merged_params.get(param))} required = required - immutable optional = parameters - required - immutable @@ -89,8 +94,7 @@ def is_immutable(param_meta): def cast_params(action_ref, params, cast_overrides=None): - """ - """ + """""" params = params or {} action_db = action_db_util.get_action_by_ref(action_ref) @@ -98,7 +102,7 @@ def cast_params(action_ref, params, cast_overrides=None): raise ValueError('Action with ref "%s" doesn\'t exist' % (action_ref)) action_parameters_schema = action_db.parameters - runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type["name"]) runner_parameters_schema = runnertype_db.runner_parameters # combine into 1 list of parameter schemas parameters_schema = {} @@ -110,29 +114,37 @@ def cast_params(action_ref, params, cast_overrides=None): for k, v in six.iteritems(params): parameter_schema = parameters_schema.get(k, None) if not parameter_schema: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No schema.', k, v) + LOG.debug("Will skip cast of param[name: %s, value: %s]. No schema.", k, v) continue - parameter_type = parameter_schema.get('type', None) + parameter_type = parameter_schema.get("type", None) if not parameter_type: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No type.', k, v) + LOG.debug("Will skip cast of param[name: %s, value: %s]. No type.", k, v) continue # Pick up cast from teh override and then from the system suppied ones. cast = cast_overrides.get(parameter_type, None) if cast_overrides else None if not cast: cast = get_cast(cast_type=parameter_type) if not cast: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No cast for %s.', k, v, - parameter_type) + LOG.debug( + "Will skip cast of param[name: %s, value: %s]. No cast for %s.", + k, + v, + parameter_type, + ) continue - LOG.debug('Casting param: %s of type %s to type: %s', v, type(v), parameter_type) + LOG.debug( + "Casting param: %s of type %s to type: %s", v, type(v), parameter_type + ) try: params[k] = cast(v) except Exception as e: v_type = type(v).__name__ - msg = ('Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. ' - 'Perhaps the value is of an invalid type?' % - (v, v_type, k, parameter_type, six.text_type(e))) + msg = ( + 'Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. ' + "Perhaps the value is of an invalid type?" + % (v, v_type, k, parameter_type, six.text_type(e)) + ) raise ValueError(msg) return params @@ -145,8 +157,13 @@ def validate_action_parameters(action_ref, inputs): parameters = action_db_util.get_action_parameters_specs(action_ref) # Check required parameters that have no default defined. - required = set([param for param, meta in six.iteritems(parameters) - if meta.get('required', False) and 'default' not in meta]) + required = set( + [ + param + for param, meta in six.iteritems(parameters) + if meta.get("required", False) and "default" not in meta + ] + ) requires = sorted(required.difference(input_set)) diff --git a/st2common/st2common/models/utils/profiling.py b/st2common/st2common/models/utils/profiling.py index c9d26636b0..47add2adc3 100644 --- a/st2common/st2common/models/utils/profiling.py +++ b/st2common/st2common/models/utils/profiling.py @@ -23,10 +23,10 @@ from st2common import log as logging __all__ = [ - 'enable_profiling', - 'disable_profiling', - 'is_enabled', - 'log_query_and_profile_data_for_queryset' + "enable_profiling", + "disable_profiling", + "is_enabled", + "log_query_and_profile_data_for_queryset", ] LOG = logging.getLogger(__name__) @@ -72,13 +72,13 @@ def log_query_and_profile_data_for_queryset(queryset): # Note: Some mongoengine methods don't return queryset (e.g. count) return queryset - query = getattr(queryset, '_query', None) - mongo_query = getattr(queryset, '_mongo_query', query) - ordering = getattr(queryset, '_ordering', None) - limit = getattr(queryset, '_limit', None) - collection = getattr(queryset, '_collection', None) - collection_name = getattr(collection, 'name', None) - only_fields = getattr(queryset, 'only_fields', None) + query = getattr(queryset, "_query", None) + mongo_query = getattr(queryset, "_mongo_query", query) + ordering = getattr(queryset, "_ordering", None) + limit = getattr(queryset, "_limit", None) + collection = getattr(queryset, "_collection", None) + collection_name = getattr(collection, "name", None) + only_fields = getattr(queryset, "only_fields", None) # Note: We need to clone the queryset when using explain because explain advances the cursor # internally which changes the function result @@ -86,42 +86,46 @@ def log_query_and_profile_data_for_queryset(queryset): explain_info = cloned_queryset.explain(format=True) if mongo_query is not None and collection_name is not None: - mongo_shell_query = construct_mongo_shell_query(mongo_query=mongo_query, - collection_name=collection_name, - ordering=ordering, - limit=limit, - only_fields=only_fields) - extra = {'mongo_query': mongo_query, 'mongo_shell_query': mongo_shell_query} - LOG.debug('MongoDB query: %s' % (mongo_shell_query), extra=extra) - LOG.debug('MongoDB explain data: %s' % (explain_info)) + mongo_shell_query = construct_mongo_shell_query( + mongo_query=mongo_query, + collection_name=collection_name, + ordering=ordering, + limit=limit, + only_fields=only_fields, + ) + extra = {"mongo_query": mongo_query, "mongo_shell_query": mongo_shell_query} + LOG.debug("MongoDB query: %s" % (mongo_shell_query), extra=extra) + LOG.debug("MongoDB explain data: %s" % (explain_info)) return queryset -def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit, - only_fields=None): +def construct_mongo_shell_query( + mongo_query, collection_name, ordering, limit, only_fields=None +): result = [] # Select collection - part = 'db.{collection}'.format(collection=collection_name) + part = "db.{collection}".format(collection=collection_name) result.append(part) # Include filters (if any) if mongo_query: filter_predicate = mongo_query else: - filter_predicate = '' + filter_predicate = "" - part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate) + part = "find({filter_predicate})".format(filter_predicate=filter_predicate) # Include only fields (projection) if only_fields: - projection_items = ['\'%s\': 1' % (field) for field in only_fields] - projection = ', '.join(projection_items) - part = 'find({filter_predicate}, {{{projection}}})'.format( - filter_predicate=filter_predicate, projection=projection) + projection_items = ["'%s': 1" % (field) for field in only_fields] + projection = ", ".join(projection_items) + part = "find({filter_predicate}, {{{projection}}})".format( + filter_predicate=filter_predicate, projection=projection + ) else: - part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate) + part = "find({filter_predicate})".format(filter_predicate=filter_predicate) result.append(part) @@ -129,17 +133,18 @@ def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit, if ordering: sort_predicate = [] for field_name, direction in ordering: - sort_predicate.append('{name}: {direction}'.format(name=field_name, - direction=direction)) + sort_predicate.append( + "{name}: {direction}".format(name=field_name, direction=direction) + ) - sort_predicate = ', '.join(sort_predicate) - part = 'sort({{{sort_predicate}}})'.format(sort_predicate=sort_predicate) + sort_predicate = ", ".join(sort_predicate) + part = "sort({{{sort_predicate}}})".format(sort_predicate=sort_predicate) result.append(part) # Include limit info (if any) if limit is not None: - part = 'limit({limit})'.format(limit=limit) + part = "limit({limit})".format(limit=limit) result.append(part) - result = '.'.join(result) + ';' + result = ".".join(result) + ";" return result diff --git a/st2common/st2common/models/utils/sensor_type_utils.py b/st2common/st2common/models/utils/sensor_type_utils.py index f67a65e530..cd4b068db9 100644 --- a/st2common/st2common/models/utils/sensor_type_utils.py +++ b/st2common/st2common/models/utils/sensor_type_utils.py @@ -21,11 +21,7 @@ from st2common.models.db.sensor import SensorTypeDB from st2common.services import triggers as trigger_service -__all__ = [ - 'to_sensor_db_model', - 'get_sensor_entry_point', - 'create_trigger_types' -] +__all__ = ["to_sensor_db_model", "get_sensor_entry_point", "create_trigger_types"] def to_sensor_db_model(sensor_api_model=None): @@ -38,37 +34,40 @@ def to_sensor_db_model(sensor_api_model=None): :rtype: :class:`SensorTypeDB` """ - class_name = getattr(sensor_api_model, 'class_name', None) - pack = getattr(sensor_api_model, 'pack', None) + class_name = getattr(sensor_api_model, "class_name", None) + pack = getattr(sensor_api_model, "pack", None) entry_point = get_sensor_entry_point(sensor_api_model) - artifact_uri = getattr(sensor_api_model, 'artifact_uri', None) - description = getattr(sensor_api_model, 'description', None) - trigger_types = getattr(sensor_api_model, 'trigger_types', []) - poll_interval = getattr(sensor_api_model, 'poll_interval', None) - enabled = getattr(sensor_api_model, 'enabled', True) - metadata_file = getattr(sensor_api_model, 'metadata_file', None) - - poll_interval = getattr(sensor_api_model, 'poll_interval', None) + artifact_uri = getattr(sensor_api_model, "artifact_uri", None) + description = getattr(sensor_api_model, "description", None) + trigger_types = getattr(sensor_api_model, "trigger_types", []) + poll_interval = getattr(sensor_api_model, "poll_interval", None) + enabled = getattr(sensor_api_model, "enabled", True) + metadata_file = getattr(sensor_api_model, "metadata_file", None) + + poll_interval = getattr(sensor_api_model, "poll_interval", None) if poll_interval and (poll_interval < MINIMUM_POLL_INTERVAL): - raise ValueError('Minimum possible poll_interval is %s seconds' % - (MINIMUM_POLL_INTERVAL)) + raise ValueError( + "Minimum possible poll_interval is %s seconds" % (MINIMUM_POLL_INTERVAL) + ) # Add pack and metadata fileto each trigger type item for trigger_type in trigger_types: - trigger_type['pack'] = pack - trigger_type['metadata_file'] = metadata_file + trigger_type["pack"] = pack + trigger_type["metadata_file"] = metadata_file trigger_type_refs = create_trigger_types(trigger_types) - return _create_sensor_type(pack=pack, - name=class_name, - description=description, - artifact_uri=artifact_uri, - entry_point=entry_point, - trigger_types=trigger_type_refs, - poll_interval=poll_interval, - enabled=enabled, - metadata_file=metadata_file) + return _create_sensor_type( + pack=pack, + name=class_name, + description=description, + artifact_uri=artifact_uri, + entry_point=entry_point, + trigger_types=trigger_type_refs, + poll_interval=poll_interval, + enabled=enabled, + metadata_file=metadata_file, + ) def create_trigger_types(trigger_types, metadata_file=None): @@ -87,29 +86,44 @@ def create_trigger_types(trigger_types, metadata_file=None): return trigger_type_refs -def _create_sensor_type(pack=None, name=None, description=None, artifact_uri=None, - entry_point=None, trigger_types=None, poll_interval=10, - enabled=True, metadata_file=None): - - sensor_type = SensorTypeDB(pack=pack, name=name, description=description, - artifact_uri=artifact_uri, entry_point=entry_point, - poll_interval=poll_interval, enabled=enabled, - trigger_types=trigger_types, metadata_file=metadata_file) +def _create_sensor_type( + pack=None, + name=None, + description=None, + artifact_uri=None, + entry_point=None, + trigger_types=None, + poll_interval=10, + enabled=True, + metadata_file=None, +): + + sensor_type = SensorTypeDB( + pack=pack, + name=name, + description=description, + artifact_uri=artifact_uri, + entry_point=entry_point, + poll_interval=poll_interval, + enabled=enabled, + trigger_types=trigger_types, + metadata_file=metadata_file, + ) return sensor_type def get_sensor_entry_point(sensor_api_model): - file_path = getattr(sensor_api_model, 'artifact_uri', None) - class_name = getattr(sensor_api_model, 'class_name', None) - pack = getattr(sensor_api_model, 'pack', None) + file_path = getattr(sensor_api_model, "artifact_uri", None) + class_name = getattr(sensor_api_model, "class_name", None) + pack = getattr(sensor_api_model, "pack", None) if pack == SYSTEM_PACK_NAME: # Special case for sensors which come included with the default installation entry_point = class_name else: - module_path = file_path.split('/%s/' % (pack))[1] - module_path = module_path.replace(os.path.sep, '.') - module_path = module_path.replace('.py', '') - entry_point = '%s.%s' % (module_path, class_name) + module_path = file_path.split("/%s/" % (pack))[1] + module_path = module_path.replace(os.path.sep, ".") + module_path = module_path.replace(".py", "") + entry_point = "%s.%s" % (module_path, class_name) return entry_point diff --git a/st2common/st2common/operators.py b/st2common/st2common/operators.py index fc38d63215..6896e87658 100644 --- a/st2common/st2common/operators.py +++ b/st2common/st2common/operators.py @@ -24,10 +24,10 @@ from st2common.util.payload import PayloadLookup __all__ = [ - 'SEARCH', - 'get_operator', - 'get_allowed_operators', - 'UnrecognizedConditionError', + "SEARCH", + "get_operator", + "get_allowed_operators", + "UnrecognizedConditionError", ] @@ -40,7 +40,7 @@ def get_operator(op): if op in operators: return operators[op] else: - raise Exception('Invalid operator: ' + op) + raise Exception("Invalid operator: " + op) class UnrecognizedConditionError(Exception): @@ -106,35 +106,57 @@ def search(value, criteria_pattern, criteria_condition, check_function): type: "equals" pattern: "Approved" """ - if criteria_condition == 'any': + if criteria_condition == "any": # Any item of the list can match all patterns - rtn = any([ - # Any payload item can match - all([ - # Match all patterns - check_function( - child_criterion_k, child_criterion_v, - PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX)) - for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern) - ]) - for child_payload in value - ]) - elif criteria_condition == 'all': + rtn = any( + [ + # Any payload item can match + all( + [ + # Match all patterns + check_function( + child_criterion_k, + child_criterion_v, + PayloadLookup( + child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX + ), + ) + for child_criterion_k, child_criterion_v in six.iteritems( + criteria_pattern + ) + ] + ) + for child_payload in value + ] + ) + elif criteria_condition == "all": # Every item of the list must match all patterns - rtn = all([ - # All payload items must match - all([ - # Match all patterns - check_function( - child_criterion_k, child_criterion_v, - PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX)) - for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern) - ]) - for child_payload in value - ]) + rtn = all( + [ + # All payload items must match + all( + [ + # Match all patterns + check_function( + child_criterion_k, + child_criterion_v, + PayloadLookup( + child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX + ), + ) + for child_criterion_k, child_criterion_v in six.iteritems( + criteria_pattern + ) + ] + ) + for child_payload in value + ] + ) else: - raise UnrecognizedConditionError("The '%s' search condition is not recognized, only 'any' " - "and 'all' are allowed" % criteria_condition) + raise UnrecognizedConditionError( + "The '%s' search condition is not recognized, only 'any' " + "and 'all' are allowed" % criteria_condition + ) return rtn @@ -298,13 +320,17 @@ def _timediff(diff_target, period_seconds, operator): def timediff_lt(value, criteria_pattern): if criteria_pattern is None: return False - return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=less_than) + return _timediff( + diff_target=value, period_seconds=criteria_pattern, operator=less_than + ) def timediff_gt(value, criteria_pattern): if criteria_pattern is None: return False - return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=greater_than) + return _timediff( + diff_target=value, period_seconds=criteria_pattern, operator=greater_than + ) def exists(value, criteria_pattern): @@ -344,48 +370,48 @@ def ensure_operators_are_strings(value, criteria_pattern): :return: tuple(value, criteria_pattern) """ if isinstance(value, bytes): - value = value.decode('utf-8') + value = value.decode("utf-8") if isinstance(criteria_pattern, bytes): - criteria_pattern = criteria_pattern.decode('utf-8') + criteria_pattern = criteria_pattern.decode("utf-8") return value, criteria_pattern # operator match strings -MATCH_WILDCARD = 'matchwildcard' -MATCH_REGEX = 'matchregex' -REGEX = 'regex' -IREGEX = 'iregex' -EQUALS_SHORT = 'eq' -EQUALS_LONG = 'equals' -NEQUALS_LONG = 'nequals' -NEQUALS_SHORT = 'neq' -IEQUALS_SHORT = 'ieq' -IEQUALS_LONG = 'iequals' -CONTAINS_LONG = 'contains' -ICONTAINS_LONG = 'icontains' -NCONTAINS_LONG = 'ncontains' -INCONTAINS_LONG = 'incontains' -STARTSWITH_LONG = 'startswith' -ISTARTSWITH_LONG = 'istartswith' -ENDSWITH_LONG = 'endswith' -IENDSWITH_LONG = 'iendswith' -LESS_THAN_SHORT = 'lt' -LESS_THAN_LONG = 'lessthan' -GREATER_THAN_SHORT = 'gt' -GREATER_THAN_LONG = 'greaterthan' -TIMEDIFF_LT_SHORT = 'td_lt' -TIMEDIFF_LT_LONG = 'timediff_lt' -TIMEDIFF_GT_SHORT = 'td_gt' -TIMEDIFF_GT_LONG = 'timediff_gt' -KEY_EXISTS = 'exists' -KEY_NOT_EXISTS = 'nexists' -INSIDE_LONG = 'inside' -INSIDE_SHORT = 'in' -NINSIDE_LONG = 'ninside' -NINSIDE_SHORT = 'nin' -SEARCH = 'search' +MATCH_WILDCARD = "matchwildcard" +MATCH_REGEX = "matchregex" +REGEX = "regex" +IREGEX = "iregex" +EQUALS_SHORT = "eq" +EQUALS_LONG = "equals" +NEQUALS_LONG = "nequals" +NEQUALS_SHORT = "neq" +IEQUALS_SHORT = "ieq" +IEQUALS_LONG = "iequals" +CONTAINS_LONG = "contains" +ICONTAINS_LONG = "icontains" +NCONTAINS_LONG = "ncontains" +INCONTAINS_LONG = "incontains" +STARTSWITH_LONG = "startswith" +ISTARTSWITH_LONG = "istartswith" +ENDSWITH_LONG = "endswith" +IENDSWITH_LONG = "iendswith" +LESS_THAN_SHORT = "lt" +LESS_THAN_LONG = "lessthan" +GREATER_THAN_SHORT = "gt" +GREATER_THAN_LONG = "greaterthan" +TIMEDIFF_LT_SHORT = "td_lt" +TIMEDIFF_LT_LONG = "timediff_lt" +TIMEDIFF_GT_SHORT = "td_gt" +TIMEDIFF_GT_LONG = "timediff_gt" +KEY_EXISTS = "exists" +KEY_NOT_EXISTS = "nexists" +INSIDE_LONG = "inside" +INSIDE_SHORT = "in" +NINSIDE_LONG = "ninside" +NINSIDE_SHORT = "nin" +SEARCH = "search" # operator lookups operators = { diff --git a/st2common/st2common/persistence/action.py b/st2common/st2common/persistence/action.py index 0a91fc5cef..1f3d17ee01 100644 --- a/st2common/st2common/persistence/action.py +++ b/st2common/st2common/persistence/action.py @@ -23,12 +23,12 @@ from st2common.persistence.runner import RunnerType __all__ = [ - 'Action', - 'ActionAlias', - 'ActionExecution', - 'ActionExecutionState', - 'LiveAction', - 'RunnerType' + "Action", + "ActionAlias", + "ActionExecution", + "ActionExecutionState", + "LiveAction", + "RunnerType", ] diff --git a/st2common/st2common/persistence/auth.py b/st2common/st2common/persistence/auth.py index f03e3ab4e1..51f0a59ea1 100644 --- a/st2common/st2common/persistence/auth.py +++ b/st2common/st2common/persistence/auth.py @@ -14,9 +14,13 @@ # limitations under the License. from __future__ import absolute_import -from st2common.exceptions.auth import (TokenNotFoundError, ApiKeyNotFoundError, - UserNotFoundError, AmbiguousUserError, - NoNicknameOriginProvidedError) +from st2common.exceptions.auth import ( + TokenNotFoundError, + ApiKeyNotFoundError, + UserNotFoundError, + AmbiguousUserError, + NoNicknameOriginProvidedError, +) from st2common.models.db import MongoDBAccess from st2common.models.db.auth import UserDB, TokenDB, ApiKeyDB from st2common.persistence.base import Access @@ -35,7 +39,7 @@ def get_by_nickname(cls, nickname, origin): if not origin: raise NoNicknameOriginProvidedError() - result = cls.query(**{('nicknames__%s' % origin): nickname}) + result = cls.query(**{("nicknames__%s" % origin): nickname}) if not result.first(): raise UserNotFoundError() @@ -51,7 +55,7 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For User name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) @@ -64,13 +68,15 @@ def _get_impl(cls): @classmethod def add_or_update(cls, model_object, publish=True, validate=True): - if not getattr(model_object, 'user', None): - raise ValueError('User is not provided in the token.') - if not getattr(model_object, 'token', None): - raise ValueError('Token value is not set.') - if not getattr(model_object, 'expiry', None): - raise ValueError('Token expiry is not provided in the token.') - return super(Token, cls).add_or_update(model_object, publish=publish, validate=validate) + if not getattr(model_object, "user", None): + raise ValueError("User is not provided in the token.") + if not getattr(model_object, "token", None): + raise ValueError("Token value is not set.") + if not getattr(model_object, "expiry", None): + raise ValueError("Token expiry is not provided in the token.") + return super(Token, cls).add_or_update( + model_object, publish=publish, validate=validate + ) @classmethod def get(cls, value): @@ -96,7 +102,7 @@ def get(cls, value): result = cls.query(key_hash=value_hash).first() if not result: - raise ApiKeyNotFoundError('ApiKey with key_hash=%s not found.' % value_hash) + raise ApiKeyNotFoundError("ApiKey with key_hash=%s not found." % value_hash) return result @@ -109,4 +115,4 @@ def get_by_key_or_id(cls, value): try: return cls.get_by_id(value) except: - raise ApiKeyNotFoundError('ApiKey with key or id=%s not found.' % value) + raise ApiKeyNotFoundError("ApiKey with key or id=%s not found." % value) diff --git a/st2common/st2common/persistence/base.py b/st2common/st2common/persistence/base.py index ea1325762f..a477defe49 100644 --- a/st2common/st2common/persistence/base.py +++ b/st2common/st2common/persistence/base.py @@ -23,12 +23,7 @@ from st2common.models.system.common import ResourceReference -__all__ = [ - 'Access', - - 'ContentPackResource', - 'StatusBasedResource' -] +__all__ = ["Access", "ContentPackResource", "StatusBasedResource"] LOG = logging.getLogger(__name__) @@ -123,48 +118,60 @@ def aggregate(cls, *args, **kwargs): return cls._get_impl().aggregate(*args, **kwargs) @classmethod - def insert(cls, model_object, publish=True, dispatch_trigger=True, - log_not_unique_error_as_debug=False): + def insert( + cls, + model_object, + publish=True, + dispatch_trigger=True, + log_not_unique_error_as_debug=False, + ): # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used from mongoengine import NotUniqueError if model_object.id: - raise ValueError('id for object %s was unexpected.' % model_object) + raise ValueError("id for object %s was unexpected." % model_object) try: model_object = cls._get_impl().insert(model_object) except NotUniqueError as e: if log_not_unique_error_as_debug: - LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e)) + LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e)) else: - LOG.exception('Conflict while trying to save in DB.') + LOG.exception("Conflict while trying to save in DB.") # On a conflict determine the conflicting object and return its id in # the raised exception. conflict_object = cls._get_by_object(model_object) conflict_id = str(conflict_object.id) if conflict_object else None message = six.text_type(e) - raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id, - model_object=model_object) + raise StackStormDBObjectConflictError( + message=message, conflict_id=conflict_id, model_object=model_object + ) # Publish internal event on the message bus if publish: try: cls.publish_create(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_create_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @classmethod - def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True, - log_not_unique_error_as_debug=False): + def add_or_update( + cls, + model_object, + publish=True, + dispatch_trigger=True, + validate=True, + log_not_unique_error_as_debug=False, + ): # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used from mongoengine import NotUniqueError @@ -174,16 +181,17 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida model_object = cls._get_impl().add_or_update(model_object, validate=True) except NotUniqueError as e: if log_not_unique_error_as_debug: - LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e)) + LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e)) else: - LOG.exception('Conflict while trying to save in DB.') + LOG.exception("Conflict while trying to save in DB.") # On a conflict determine the conflicting object and return its id in # the raised exception. conflict_object = cls._get_by_object(model_object) conflict_id = str(conflict_object.id) if conflict_object else None message = six.text_type(e) - raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id, - model_object=model_object) + raise StackStormDBObjectConflictError( + message=message, conflict_id=conflict_id, model_object=model_object + ) is_update = str(pre_persist_id) == str(model_object.id) @@ -195,7 +203,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida else: cls.publish_create(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: @@ -205,7 +213,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida else: cls.dispatch_create_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @@ -227,14 +235,14 @@ def update(cls, model_object, publish=True, dispatch_trigger=True, **kwargs): try: cls.publish_update(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_update_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @@ -247,14 +255,14 @@ def delete(cls, model_object, publish=True, dispatch_trigger=True): try: cls.publish_delete(model_object) except Exception: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_delete_trigger(model_object) except Exception: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return persisted_object @@ -289,14 +297,18 @@ def dispatch_create_trigger(cls, model_object): """ Dispatch a resource-specific trigger which indicates a new resource has been created. """ - return cls._dispatch_operation_trigger(operation='create', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="create", model_object=model_object + ) @classmethod def dispatch_update_trigger(cls, model_object): """ Dispatch a resource-specific trigger which indicates an existing resource has been updated. """ - return cls._dispatch_operation_trigger(operation='update', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="update", model_object=model_object + ) @classmethod def dispatch_delete_trigger(cls, model_object): @@ -304,14 +316,18 @@ def dispatch_delete_trigger(cls, model_object): Dispatch a resource-specific trigger which indicates an existing resource has been deleted. """ - return cls._dispatch_operation_trigger(operation='delete', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="delete", model_object=model_object + ) @classmethod def _get_trigger_ref_for_operation(cls, operation): trigger_ref = cls.operation_to_trigger_ref_map.get(operation, None) if not trigger_ref: - raise ValueError('Trigger ref not specified for operation: %s' % (operation)) + raise ValueError( + "Trigger ref not specified for operation: %s" % (operation) + ) return trigger_ref @@ -322,11 +338,13 @@ def _dispatch_operation_trigger(cls, operation, model_object): trigger = cls._get_trigger_ref_for_operation(operation=operation) - object_payload = cls.api_model_cls.from_model(model_object, mask_secrets=True).__json__() - payload = { - 'object': object_payload - } - return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload) + object_payload = cls.api_model_cls.from_model( + model_object, mask_secrets=True + ).__json__() + payload = {"object": object_payload} + return cls._dispatch_trigger( + operation=operation, trigger=trigger, payload=payload + ) @classmethod def _dispatch_trigger(cls, operation, trigger, payload): @@ -338,23 +356,23 @@ def _dispatch_trigger(cls, operation, trigger, payload): class ContentPackResource(Access): - @classmethod def get_by_ref(cls, ref): if not ref: return None ref_obj = ResourceReference.from_string_reference(ref=ref) - result = cls.query(name=ref_obj.name, - pack=ref_obj.pack).first() + result = cls.query(name=ref_obj.name, pack=ref_obj.pack).first() return result @classmethod def _get_by_object(cls, object): # For an object with a resourcepack pack.name is unique. - name = getattr(object, 'name', '') - pack = getattr(object, 'pack', '') - return cls.get_by_ref(ResourceReference.to_string_reference(pack=pack, name=name)) + name = getattr(object, "name", "") + pack = getattr(object, "pack", "") + return cls.get_by_ref( + ResourceReference.to_string_reference(pack=pack, name=name) + ) class StatusBasedResource(Access): @@ -372,4 +390,4 @@ def publish_status(cls, model_object): """ publisher = cls._get_publisher() if publisher: - publisher.publish_state(model_object, getattr(model_object, 'status', None)) + publisher.publish_state(model_object, getattr(model_object, "status", None)) diff --git a/st2common/st2common/persistence/cleanup.py b/st2common/st2common/persistence/cleanup.py index 5831a47cca..06c48dec86 100644 --- a/st2common/st2common/persistence/cleanup.py +++ b/st2common/st2common/persistence/cleanup.py @@ -24,11 +24,7 @@ from st2common.script_setup import setup as common_setup from st2common.script_setup import teardown as common_teardown -__all__ = [ - 'db_cleanup', - 'db_cleanup_with_retry', - 'main' -] +__all__ = ["db_cleanup", "db_cleanup_with_retry", "main"] LOG = logging.getLogger(__name__) @@ -42,26 +38,47 @@ def db_cleanup(): return connection -def db_cleanup_with_retry(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, - ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): +def db_cleanup_with_retry( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): """ This method is a retry version of db_cleanup. """ - return db_func_with_retry(db_cleanup_func, - db_name, db_host, db_port, - username=username, password=password, - ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, - ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + return db_func_with_retry( + db_cleanup_func, + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) def setup(argv): - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) def teardown(): @@ -75,5 +92,5 @@ def main(argv): # This script registers actions and rules from content-packs. -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/st2common/st2common/persistence/db_init.py b/st2common/st2common/persistence/db_init.py index 04a2a3a753..678ca71ccd 100644 --- a/st2common/st2common/persistence/db_init.py +++ b/st2common/st2common/persistence/db_init.py @@ -22,9 +22,7 @@ from st2common import log as logging from st2common.models.db import db_setup -__all__ = [ - 'db_setup_with_retry' -] +__all__ = ["db_setup_with_retry"] LOG = logging.getLogger(__name__) @@ -36,9 +34,11 @@ def _retry_if_connection_error(error): # Ideally, a special execption or atleast some exception code. # If this does become an issue look for "Cannot connect to database" at the # start of error msg. - is_connection_error = isinstance(error, mongoengine.connection.MongoEngineConnectionError) + is_connection_error = isinstance( + error, mongoengine.connection.MongoEngineConnectionError + ) if is_connection_error: - LOG.warn('Retry on ConnectionError - %s', error) + LOG.warn("Retry on ConnectionError - %s", error) return is_connection_error @@ -52,25 +52,45 @@ def db_func_with_retry(db_func, *args, **kwargs): # reading of config values however this is lesser code. retrying_obj = retrying.Retrying( retry_on_exception=_retry_if_connection_error, - wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul * 1000, + wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul + * 1000, wait_exponential_max=cfg.CONF.database.connection_retry_backoff_max_s * 1000, - stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000 + stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000, ) return retrying_obj.call(db_func, *args, **kwargs) -def db_setup_with_retry(db_name, db_host, db_port, username=None, password=None, - ensure_indexes=True, ssl=False, ssl_keyfile=None, - ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): +def db_setup_with_retry( + db_name, + db_host, + db_port, + username=None, + password=None, + ensure_indexes=True, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): """ This method is a retry version of db_setup. """ - return db_func_with_retry(db_setup, db_name, db_host, db_port, - username=username, password=password, - ensure_indexes=ensure_indexes, - ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, - ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + return db_func_with_retry( + db_setup, + db_name, + db_host, + db_port, + username=username, + password=password, + ensure_indexes=ensure_indexes, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) diff --git a/st2common/st2common/persistence/execution.py b/st2common/st2common/persistence/execution.py index 6af949786d..2073dda17b 100644 --- a/st2common/st2common/persistence/execution.py +++ b/st2common/st2common/persistence/execution.py @@ -21,8 +21,8 @@ from st2common.persistence.base import Access __all__ = [ - 'ActionExecution', - 'ActionExecutionOutput', + "ActionExecution", + "ActionExecutionOutput", ] diff --git a/st2common/st2common/persistence/execution_queue.py b/st2common/st2common/persistence/execution_queue.py index 2ec5f05924..eaedc22f4c 100644 --- a/st2common/st2common/persistence/execution_queue.py +++ b/st2common/st2common/persistence/execution_queue.py @@ -18,9 +18,7 @@ from st2common.models.db.execution_queue import EXECUTION_QUEUE_ACCESS from st2common.persistence import base as persistence -__all__ = [ - 'ActionExecutionSchedulingQueue' -] +__all__ = ["ActionExecutionSchedulingQueue"] class ActionExecutionSchedulingQueue(persistence.Access): diff --git a/st2common/st2common/persistence/executionstate.py b/st2common/st2common/persistence/executionstate.py index 8e94a714aa..7e2debd138 100644 --- a/st2common/st2common/persistence/executionstate.py +++ b/st2common/st2common/persistence/executionstate.py @@ -19,9 +19,7 @@ from st2common.models.db.executionstate import actionexecstate_access from st2common.persistence import base as persistence -__all__ = [ - 'ActionExecutionState' -] +__all__ = ["ActionExecutionState"] class ActionExecutionState(persistence.Access): @@ -35,5 +33,7 @@ def _get_impl(cls): @classmethod def _get_publisher(cls): if not cls.publisher: - cls.publisher = transport.actionexecutionstate.ActionExecutionStatePublisher() + cls.publisher = ( + transport.actionexecutionstate.ActionExecutionStatePublisher() + ) return cls.publisher diff --git a/st2common/st2common/persistence/keyvalue.py b/st2common/st2common/persistence/keyvalue.py index 634bd72302..10676998f5 100644 --- a/st2common/st2common/persistence/keyvalue.py +++ b/st2common/st2common/persistence/keyvalue.py @@ -34,24 +34,30 @@ class KeyValuePair(Access): publisher = None api_model_cls = KeyValuePairAPI - dispatch_trigger_for_operations = ['create', 'update', 'value_change', 'delete'] + dispatch_trigger_for_operations = ["create", "update", "value_change", "delete"] operation_to_trigger_ref_map = { - 'create': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_CREATE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_CREATE_TRIGGER['pack']), - 'update': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_UPDATE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_UPDATE_TRIGGER['pack']), - 'value_change': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['pack']), - 'delete': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_DELETE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_DELETE_TRIGGER['pack']), + "create": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_CREATE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_CREATE_TRIGGER["pack"], + ), + "update": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_UPDATE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_UPDATE_TRIGGER["pack"], + ), + "value_change": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["pack"], + ), + "delete": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_DELETE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_DELETE_TRIGGER["pack"], + ), } @classmethod - def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True): + def add_or_update( + cls, model_object, publish=True, dispatch_trigger=True, validate=True + ): """ Note: We override add_or_update because we also want to publish high level "value_change" event for this resource. @@ -62,32 +68,36 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida # Not an update existing_model_object = None - model_object = super(KeyValuePair, cls).add_or_update(model_object=model_object, - publish=publish, - dispatch_trigger=dispatch_trigger) + model_object = super(KeyValuePair, cls).add_or_update( + model_object=model_object, + publish=publish, + dispatch_trigger=dispatch_trigger, + ) # Dispatch a value_change event which is specific to this resource if existing_model_object and existing_model_object.value != model_object.value: - cls.dispatch_value_change_trigger(old_model_object=existing_model_object, - new_model_object=model_object) + cls.dispatch_value_change_trigger( + old_model_object=existing_model_object, new_model_object=model_object + ) return model_object @classmethod def dispatch_value_change_trigger(cls, old_model_object, new_model_object): - operation = 'value_change' + operation = "value_change" trigger = cls._get_trigger_ref_for_operation(operation=operation) - old_object_payload = cls.api_model_cls.from_model(old_model_object, - mask_secrets=True).__json__() - new_object_payload = cls.api_model_cls.from_model(new_model_object, - mask_secrets=True).__json__() - payload = { - 'old_object': old_object_payload, - 'new_object': new_object_payload - } + old_object_payload = cls.api_model_cls.from_model( + old_model_object, mask_secrets=True + ).__json__() + new_object_payload = cls.api_model_cls.from_model( + new_model_object, mask_secrets=True + ).__json__() + payload = {"old_object": old_object_payload, "new_object": new_object_payload} - return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload) + return cls._dispatch_trigger( + operation=operation, trigger=trigger, payload=payload + ) @classmethod def get_by_names(cls, names): @@ -124,5 +134,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For KeyValuePair name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/liveaction.py b/st2common/st2common/persistence/liveaction.py index 61b16b1878..aa7551592a 100644 --- a/st2common/st2common/persistence/liveaction.py +++ b/st2common/st2common/persistence/liveaction.py @@ -19,9 +19,7 @@ from st2common.models.db.liveaction import liveaction_access from st2common.persistence import base as persistence -__all__ = [ - 'LiveAction' -] +__all__ = ["LiveAction"] class LiveAction(persistence.StatusBasedResource): diff --git a/st2common/st2common/persistence/marker.py b/st2common/st2common/persistence/marker.py index 1f35bbcdf2..6be08a25ec 100644 --- a/st2common/st2common/persistence/marker.py +++ b/st2common/st2common/persistence/marker.py @@ -19,9 +19,7 @@ from st2common.models.db.marker import DumperMarkerDB from st2common.persistence.base import Access -__all__ = [ - 'Marker' -] +__all__ = ["Marker"] class Marker(Access): diff --git a/st2common/st2common/persistence/pack.py b/st2common/st2common/persistence/pack.py index 5b2ff39102..01ca6b20cb 100644 --- a/st2common/st2common/persistence/pack.py +++ b/st2common/st2common/persistence/pack.py @@ -19,11 +19,7 @@ from st2common.models.db.pack import config_schema_access from st2common.models.db.pack import config_access -__all__ = [ - 'Pack', - 'ConfigSchema', - 'Config' -] +__all__ = ["Pack", "ConfigSchema", "Config"] class Pack(base.Access): diff --git a/st2common/st2common/persistence/policy.py b/st2common/st2common/persistence/policy.py index 468ce07f69..8b6700c194 100644 --- a/st2common/st2common/persistence/policy.py +++ b/st2common/st2common/persistence/policy.py @@ -30,16 +30,20 @@ def _get_impl(cls): def get_by_ref(cls, ref): if ref: ref_obj = PolicyTypeReference.from_string_reference(ref=ref) - result = cls.query(name=ref_obj.name, resource_type=ref_obj.resource_type).first() + result = cls.query( + name=ref_obj.name, resource_type=ref_obj.resource_type + ).first() return result else: return None @classmethod def _get_by_object(cls, object): - name = getattr(object, 'name', '') - resource_type = getattr(object, 'resource_type', '') - ref = PolicyTypeReference.to_string_reference(resource_type=resource_type, name=name) + name = getattr(object, "name", "") + resource_type = getattr(object, "resource_type", "") + ref = PolicyTypeReference.to_string_reference( + resource_type=resource_type, name=name + ) return cls.get_by_ref(ref) diff --git a/st2common/st2common/persistence/rbac.py b/st2common/st2common/persistence/rbac.py index bdac61d888..e14b973aeb 100644 --- a/st2common/st2common/persistence/rbac.py +++ b/st2common/st2common/persistence/rbac.py @@ -20,12 +20,7 @@ from st2common.models.db.rbac import permission_grant_access from st2common.models.db.rbac import group_to_role_mapping_access -__all__ = [ - 'Role', - 'UserRoleAssignment', - 'PermissionGrant', - 'GroupToRoleMapping' -] +__all__ = ["Role", "UserRoleAssignment", "PermissionGrant", "GroupToRoleMapping"] class Role(base.Access): diff --git a/st2common/st2common/persistence/reactor.py b/st2common/st2common/persistence/reactor.py index c060877513..0fa35c6bdf 100644 --- a/st2common/st2common/persistence/reactor.py +++ b/st2common/st2common/persistence/reactor.py @@ -16,12 +16,6 @@ from __future__ import absolute_import from st2common.persistence.rule import Rule from st2common.persistence.sensor import SensorType -from st2common.persistence.trigger import (Trigger, TriggerInstance, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerInstance, TriggerType -__all__ = [ - 'Rule', - 'SensorType', - 'Trigger', - 'TriggerInstance', - 'TriggerType' -] +__all__ = ["Rule", "SensorType", "Trigger", "TriggerInstance", "TriggerType"] diff --git a/st2common/st2common/persistence/rule.py b/st2common/st2common/persistence/rule.py index 741b9d4967..0a64e4bb1f 100644 --- a/st2common/st2common/persistence/rule.py +++ b/st2common/st2common/persistence/rule.py @@ -36,5 +36,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For RuleType name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/runner.py b/st2common/st2common/persistence/runner.py index 77440707f2..63cfa36d9e 100644 --- a/st2common/st2common/persistence/runner.py +++ b/st2common/st2common/persistence/runner.py @@ -28,5 +28,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For RunnerType name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/sensor.py b/st2common/st2common/persistence/sensor.py index 67367c7fc4..1a3a3679da 100644 --- a/st2common/st2common/persistence/sensor.py +++ b/st2common/st2common/persistence/sensor.py @@ -19,9 +19,7 @@ from st2common.models.db.sensor import sensor_type_access from st2common.persistence.base import ContentPackResource -__all__ = [ - 'SensorType' -] +__all__ = ["SensorType"] class SensorType(ContentPackResource): diff --git a/st2common/st2common/persistence/trace.py b/st2common/st2common/persistence/trace.py index 5e7276a1f0..ce5472f2aa 100644 --- a/st2common/st2common/persistence/trace.py +++ b/st2common/st2common/persistence/trace.py @@ -26,14 +26,16 @@ def _get_impl(cls): return cls.impl @classmethod - def push_components(cls, instance, action_executions=None, rules=None, trigger_instances=None): + def push_components( + cls, instance, action_executions=None, rules=None, trigger_instances=None + ): update_kwargs = {} if action_executions: - update_kwargs['push_all__action_executions'] = action_executions + update_kwargs["push_all__action_executions"] = action_executions if rules: - update_kwargs['push_all__rules'] = rules + update_kwargs["push_all__rules"] = rules if trigger_instances: - update_kwargs['push_all__trigger_instances'] = trigger_instances + update_kwargs["push_all__trigger_instances"] = trigger_instances if update_kwargs: return cls.update(instance, **update_kwargs) return instance diff --git a/st2common/st2common/persistence/trigger.py b/st2common/st2common/persistence/trigger.py index 1cdc4ef4ac..3567a15829 100644 --- a/st2common/st2common/persistence/trigger.py +++ b/st2common/st2common/persistence/trigger.py @@ -18,14 +18,14 @@ from st2common import log as logging from st2common import transport from st2common.exceptions.db import StackStormDBObjectNotFoundError -from st2common.models.db.trigger import triggertype_access, trigger_access, triggerinstance_access -from st2common.persistence.base import (Access, ContentPackResource) +from st2common.models.db.trigger import ( + triggertype_access, + trigger_access, + triggerinstance_access, +) +from st2common.persistence.base import Access, ContentPackResource -__all__ = [ - 'TriggerType', - 'Trigger', - 'TriggerInstance' -] +__all__ = ["TriggerType", "Trigger", "TriggerInstance"] LOG = logging.getLogger(__name__) @@ -57,7 +57,7 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru # Found in the innards of mongoengine. # e.g. {'pk': ObjectId('5609e91832ed356d04a93cc0')} delete_query = model_object._object_key - delete_query['ref_count__lte'] = 0 + delete_query["ref_count__lte"] = 0 cls._get_impl().delete_by_query(**delete_query) # Since delete_by_query cannot tell if teh delete actually happened check with a get call @@ -73,14 +73,14 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru try: cls.publish_delete(model_object) except Exception: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if confirmed_delete and dispatch_trigger: try: cls.dispatch_delete_trigger(model_object) except Exception: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object diff --git a/st2common/st2common/persistence/workflow.py b/st2common/st2common/persistence/workflow.py index aa02c320e1..8d993ef4fe 100644 --- a/st2common/st2common/persistence/workflow.py +++ b/st2common/st2common/persistence/workflow.py @@ -21,10 +21,7 @@ from st2common.persistence import base as persistence -__all__ = [ - 'WorkflowExecution', - 'TaskExecution' -] +__all__ = ["WorkflowExecution", "TaskExecution"] class WorkflowExecution(persistence.StatusBasedResource): diff --git a/st2common/st2common/policies/__init__.py b/st2common/st2common/policies/__init__.py index df49fa1f14..ef39e129c9 100644 --- a/st2common/st2common/policies/__init__.py +++ b/st2common/st2common/policies/__init__.py @@ -18,7 +18,4 @@ from st2common.policies.base import ResourcePolicyApplicator -__all__ = [ - 'get_driver', - 'ResourcePolicyApplicator' -] +__all__ = ["get_driver", "ResourcePolicyApplicator"] diff --git a/st2common/st2common/policies/base.py b/st2common/st2common/policies/base.py index 5bfc3fa58e..a22fa2fb42 100644 --- a/st2common/st2common/policies/base.py +++ b/st2common/st2common/policies/base.py @@ -24,10 +24,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'ResourcePolicyApplicator', - 'get_driver' -] +__all__ = ["ResourcePolicyApplicator", "get_driver"] @six.add_metaclass(abc.ABCMeta) @@ -72,9 +69,9 @@ def _get_lock_name(self, values): lock_uid = [] for key, value in six.iteritems(values): - lock_uid.append('%s=%s' % (key, value)) + lock_uid.append("%s=%s" % (key, value)) - lock_uid = ','.join(lock_uid) + lock_uid = ",".join(lock_uid) return lock_uid @@ -88,5 +85,7 @@ def get_driver(policy_ref, policy_type, **parameters): # interested in continue - if (issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith('Base')): + if issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith( + "Base" + ): return obj(policy_ref, policy_type, **parameters) diff --git a/st2common/st2common/policies/concurrency.py b/st2common/st2common/policies/concurrency.py index a453214b72..fcf96467c3 100644 --- a/st2common/st2common/policies/concurrency.py +++ b/st2common/st2common/policies/concurrency.py @@ -18,24 +18,23 @@ from st2common.policies import base from st2common.services import coordination -__all__ = [ - 'BaseConcurrencyApplicator' -] +__all__ = ["BaseConcurrencyApplicator"] class BaseConcurrencyApplicator(base.ResourcePolicyApplicator): - def __init__(self, policy_ref, policy_type, threshold=0, action='delay'): - super(BaseConcurrencyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type) + def __init__(self, policy_ref, policy_type, threshold=0, action="delay"): + super(BaseConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, policy_type=policy_type + ) self.threshold = threshold self.policy_action = action self.coordinator = coordination.get_coordinator(start_heart=True) def _get_status_for_policy_action(self, action): - if action == 'delay': + if action == "delay": status = action_constants.LIVEACTION_STATUS_DELAYED - elif action == 'cancel': + elif action == "cancel": status = action_constants.LIVEACTION_STATUS_CANCELING return status diff --git a/st2common/st2common/rbac/backends/__init__.py b/st2common/st2common/rbac/backends/__init__.py index cf6429c124..bb7ad3d58f 100644 --- a/st2common/st2common/rbac/backends/__init__.py +++ b/st2common/st2common/rbac/backends/__init__.py @@ -22,15 +22,11 @@ from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance', - 'get_rbac_backend' -] +__all__ = ["get_available_backends", "get_backend_instance", "get_rbac_backend"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2common.rbac.backend' +BACKENDS_NAMESPACE = "st2common.rbac.backend" # Cache which maps backed name -> backend class instance # NOTE: We use cache to avoid slow stevedore dynamic filesystem instrospection on every @@ -44,7 +40,9 @@ def get_available_backends(): def get_backend_instance(name, use_cache=True): if name not in BACKENDS_CACHE or not use_cache: - rbac_backend = driver_loader.get_backend_instance(namespace=BACKENDS_NAMESPACE, name=name) + rbac_backend = driver_loader.get_backend_instance( + namespace=BACKENDS_NAMESPACE, name=name + ) BACKENDS_CACHE[name] = rbac_backend rbac_backend = BACKENDS_CACHE[name] diff --git a/st2common/st2common/rbac/backends/base.py b/st2common/st2common/rbac/backends/base.py index 8e2c54c4fd..f9661d0b4b 100644 --- a/st2common/st2common/rbac/backends/base.py +++ b/st2common/st2common/rbac/backends/base.py @@ -23,17 +23,16 @@ from st2common.exceptions.rbac import AccessDeniedError __all__ = [ - 'BaseRBACBackend', - 'BaseRBACPermissionResolver', - 'BaseRBACService', - 'BaseRBACUtils', - 'BaseRBACRemoteGroupToRoleSyncer' + "BaseRBACBackend", + "BaseRBACPermissionResolver", + "BaseRBACService", + "BaseRBACUtils", + "BaseRBACRemoteGroupToRoleSyncer", ] @six.add_metaclass(abc.ABCMeta) class BaseRBACBackend(object): - def get_resolver_for_resource_type(self, resource_type): """ Method which returns PermissionResolver class for the provided resource type. @@ -67,7 +66,6 @@ def get_utils_class(self): @six.add_metaclass(abc.ABCMeta) class BaseRBACPermissionResolver(object): - def user_has_permission(self, user_db, permission_type): """ Method for checking user permissions which are not tied to a particular resource. @@ -177,7 +175,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api): raise NotImplementedError() @staticmethod - def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False): + def assert_user_is_admin_if_user_query_param_is_provided( + user_db, user, require_rbac=False + ): """ Function which asserts that the request user is administator if "user" query parameter is provided and doesn't match the current user. @@ -273,12 +273,12 @@ def get_user_db_from_request(request): """ Retrieve UserDB object from the provided request. """ - auth_context = request.context.get('auth', {}) + auth_context = request.context.get("auth", {}) if not auth_context: return None - user_db = auth_context.get('user', None) + user_db = auth_context.get("user", None) return user_db diff --git a/st2common/st2common/rbac/backends/noop.py b/st2common/st2common/rbac/backends/noop.py index 15ca5a3a75..4d3b8fb127 100644 --- a/st2common/st2common/rbac/backends/noop.py +++ b/st2common/st2common/rbac/backends/noop.py @@ -25,11 +25,11 @@ from st2common.exceptions.rbac import AccessDeniedError __all__ = [ - 'NoOpRBACBackend', - 'NoOpRBACPermissionResolver', - 'NoOpRBACService', - 'NoOpRBACUtils', - 'NoOpRBACRemoteGroupToRoleSyncer' + "NoOpRBACBackend", + "NoOpRBACPermissionResolver", + "NoOpRBACService", + "NoOpRBACUtils", + "NoOpRBACRemoteGroupToRoleSyncer", ] @@ -37,6 +37,7 @@ class NoOpRBACBackend(BaseRBACBackend): """ NoOp RBAC backend. """ + def get_resolver_for_resource_type(self, resource_type): return NoOpRBACPermissionResolver() @@ -79,7 +80,6 @@ def validate_roles_exists(role_names): class NoOpRBACUtils(BaseRBACUtils): - @staticmethod def assert_user_is_admin(user_db): """ @@ -141,7 +141,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api): return True @staticmethod - def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False): + def assert_user_is_admin_if_user_query_param_is_provided( + user_db, user, require_rbac=False + ): """ Function which asserts that the request user is administator if "user" query parameter is provided and doesn't match the current user. diff --git a/st2common/st2common/rbac/migrations.py b/st2common/st2common/rbac/migrations.py index 9e9fc9db18..951bbddf19 100644 --- a/st2common/st2common/rbac/migrations.py +++ b/st2common/st2common/rbac/migrations.py @@ -23,11 +23,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'run_all', - - 'insert_system_roles' -] +__all__ = ["run_all", "insert_system_roles"] def run_all(): @@ -40,7 +36,7 @@ def insert_system_roles(): """ system_roles = SystemRole.get_valid_values() - LOG.debug('Inserting system roles (%s)' % (str(system_roles))) + LOG.debug("Inserting system roles (%s)" % (str(system_roles))) for role_name in system_roles: description = role_name diff --git a/st2common/st2common/rbac/types.py b/st2common/st2common/rbac/types.py index 1c6b0ea352..cceb819d7b 100644 --- a/st2common/st2common/rbac/types.py +++ b/st2common/st2common/rbac/types.py @@ -21,19 +21,16 @@ from st2common.constants.types import ResourceType as SystemResourceType __all__ = [ - 'SystemRole', - 'PermissionType', - 'ResourceType', - - 'RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP', - 'PERMISION_TYPE_TO_DESCRIPTION_MAP', - - 'ALL_PERMISSION_TYPES', - 'GLOBAL_PERMISSION_TYPES', - 'GLOBAL_PACK_PERMISSION_TYPES', - 'LIST_PERMISSION_TYPES', - - 'get_resource_permission_types_with_descriptions' + "SystemRole", + "PermissionType", + "ResourceType", + "RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP", + "PERMISION_TYPE_TO_DESCRIPTION_MAP", + "ALL_PERMISSION_TYPES", + "GLOBAL_PERMISSION_TYPES", + "GLOBAL_PACK_PERMISSION_TYPES", + "LIST_PERMISSION_TYPES", + "get_resource_permission_types_with_descriptions", ] @@ -43,120 +40,120 @@ class PermissionType(Enum): """ # Note: There is no create endpoint for runner types right now - RUNNER_LIST = 'runner_type_list' - RUNNER_VIEW = 'runner_type_view' - RUNNER_MODIFY = 'runner_type_modify' - RUNNER_ALL = 'runner_type_all' + RUNNER_LIST = "runner_type_list" + RUNNER_VIEW = "runner_type_view" + RUNNER_MODIFY = "runner_type_modify" + RUNNER_ALL = "runner_type_all" - PACK_LIST = 'pack_list' - PACK_VIEW = 'pack_view' - PACK_CREATE = 'pack_create' - PACK_MODIFY = 'pack_modify' - PACK_DELETE = 'pack_delete' + PACK_LIST = "pack_list" + PACK_VIEW = "pack_view" + PACK_CREATE = "pack_create" + PACK_MODIFY = "pack_modify" + PACK_DELETE = "pack_delete" # Pack-management specific permissions # Note: Right now those permissions are global and apply to all the packs. # In the future we plan to support globs. - PACK_INSTALL = 'pack_install' - PACK_UNINSTALL = 'pack_uninstall' - PACK_REGISTER = 'pack_register' - PACK_CONFIG = 'pack_config' - PACK_SEARCH = 'pack_search' - PACK_VIEWS_INDEX_HEALTH = 'pack_views_index_health' + PACK_INSTALL = "pack_install" + PACK_UNINSTALL = "pack_uninstall" + PACK_REGISTER = "pack_register" + PACK_CONFIG = "pack_config" + PACK_SEARCH = "pack_search" + PACK_VIEWS_INDEX_HEALTH = "pack_views_index_health" - PACK_ALL = 'pack_all' + PACK_ALL = "pack_all" # Note: Right now we only have read endpoints + update for sensors types - SENSOR_LIST = 'sensor_type_list' - SENSOR_VIEW = 'sensor_type_view' - SENSOR_MODIFY = 'sensor_type_modify' - SENSOR_ALL = 'sensor_type_all' - - ACTION_LIST = 'action_list' - ACTION_VIEW = 'action_view' - ACTION_CREATE = 'action_create' - ACTION_MODIFY = 'action_modify' - ACTION_DELETE = 'action_delete' - ACTION_EXECUTE = 'action_execute' - ACTION_ALL = 'action_all' - - ACTION_ALIAS_LIST = 'action_alias_list' - ACTION_ALIAS_VIEW = 'action_alias_view' - ACTION_ALIAS_CREATE = 'action_alias_create' - ACTION_ALIAS_MODIFY = 'action_alias_modify' - ACTION_ALIAS_MATCH = 'action_alias_match' - ACTION_ALIAS_HELP = 'action_alias_help' - ACTION_ALIAS_DELETE = 'action_alias_delete' - ACTION_ALIAS_ALL = 'action_alias_all' + SENSOR_LIST = "sensor_type_list" + SENSOR_VIEW = "sensor_type_view" + SENSOR_MODIFY = "sensor_type_modify" + SENSOR_ALL = "sensor_type_all" + + ACTION_LIST = "action_list" + ACTION_VIEW = "action_view" + ACTION_CREATE = "action_create" + ACTION_MODIFY = "action_modify" + ACTION_DELETE = "action_delete" + ACTION_EXECUTE = "action_execute" + ACTION_ALL = "action_all" + + ACTION_ALIAS_LIST = "action_alias_list" + ACTION_ALIAS_VIEW = "action_alias_view" + ACTION_ALIAS_CREATE = "action_alias_create" + ACTION_ALIAS_MODIFY = "action_alias_modify" + ACTION_ALIAS_MATCH = "action_alias_match" + ACTION_ALIAS_HELP = "action_alias_help" + ACTION_ALIAS_DELETE = "action_alias_delete" + ACTION_ALIAS_ALL = "action_alias_all" # Note: Execution create is granted with "action_execute" - EXECUTION_LIST = 'execution_list' - EXECUTION_VIEW = 'execution_view' - EXECUTION_RE_RUN = 'execution_rerun' - EXECUTION_STOP = 'execution_stop' - EXECUTION_ALL = 'execution_all' - EXECUTION_VIEWS_FILTERS_LIST = 'execution_views_filters_list' - - RULE_LIST = 'rule_list' - RULE_VIEW = 'rule_view' - RULE_CREATE = 'rule_create' - RULE_MODIFY = 'rule_modify' - RULE_DELETE = 'rule_delete' - RULE_ALL = 'rule_all' - - RULE_ENFORCEMENT_LIST = 'rule_enforcement_list' - RULE_ENFORCEMENT_VIEW = 'rule_enforcement_view' + EXECUTION_LIST = "execution_list" + EXECUTION_VIEW = "execution_view" + EXECUTION_RE_RUN = "execution_rerun" + EXECUTION_STOP = "execution_stop" + EXECUTION_ALL = "execution_all" + EXECUTION_VIEWS_FILTERS_LIST = "execution_views_filters_list" + + RULE_LIST = "rule_list" + RULE_VIEW = "rule_view" + RULE_CREATE = "rule_create" + RULE_MODIFY = "rule_modify" + RULE_DELETE = "rule_delete" + RULE_ALL = "rule_all" + + RULE_ENFORCEMENT_LIST = "rule_enforcement_list" + RULE_ENFORCEMENT_VIEW = "rule_enforcement_view" # TODO - Maybe "datastore_item" / key_value_item ? - KEY_VALUE_VIEW = 'key_value_pair_view' - KEY_VALUE_SET = 'key_value_pair_set' - KEY_VALUE_DELETE = 'key_value_pair_delete' - - WEBHOOK_LIST = 'webhook_list' - WEBHOOK_VIEW = 'webhook_view' - WEBHOOK_CREATE = 'webhook_create' - WEBHOOK_SEND = 'webhook_send' - WEBHOOK_DELETE = 'webhook_delete' - WEBHOOK_ALL = 'webhook_all' - - TIMER_LIST = 'timer_list' - TIMER_VIEW = 'timer_view' - TIMER_ALL = 'timer_all' - - API_KEY_LIST = 'api_key_list' - API_KEY_VIEW = 'api_key_view' - API_KEY_CREATE = 'api_key_create' - API_KEY_MODIFY = 'api_key_modify' - API_KEY_DELETE = 'api_key_delete' - API_KEY_ALL = 'api_key_all' - - TRACE_LIST = 'trace_list' - TRACE_VIEW = 'trace_view' - TRACE_ALL = 'trace_all' + KEY_VALUE_VIEW = "key_value_pair_view" + KEY_VALUE_SET = "key_value_pair_set" + KEY_VALUE_DELETE = "key_value_pair_delete" + + WEBHOOK_LIST = "webhook_list" + WEBHOOK_VIEW = "webhook_view" + WEBHOOK_CREATE = "webhook_create" + WEBHOOK_SEND = "webhook_send" + WEBHOOK_DELETE = "webhook_delete" + WEBHOOK_ALL = "webhook_all" + + TIMER_LIST = "timer_list" + TIMER_VIEW = "timer_view" + TIMER_ALL = "timer_all" + + API_KEY_LIST = "api_key_list" + API_KEY_VIEW = "api_key_view" + API_KEY_CREATE = "api_key_create" + API_KEY_MODIFY = "api_key_modify" + API_KEY_DELETE = "api_key_delete" + API_KEY_ALL = "api_key_all" + + TRACE_LIST = "trace_list" + TRACE_VIEW = "trace_view" + TRACE_ALL = "trace_all" # Note: Trigger permissions types are also used for Timer API endpoint since timer is just # a special type of a trigger - TRIGGER_LIST = 'trigger_list' - TRIGGER_VIEW = 'trigger_view' - TRIGGER_ALL = 'trigger_all' + TRIGGER_LIST = "trigger_list" + TRIGGER_VIEW = "trigger_view" + TRIGGER_ALL = "trigger_all" - POLICY_TYPE_LIST = 'policy_type_list' - POLICY_TYPE_VIEW = 'policy_type_view' - POLICY_TYPE_ALL = 'policy_type_all' + POLICY_TYPE_LIST = "policy_type_list" + POLICY_TYPE_VIEW = "policy_type_view" + POLICY_TYPE_ALL = "policy_type_all" - POLICY_LIST = 'policy_list' - POLICY_VIEW = 'policy_view' - POLICY_CREATE = 'policy_create' - POLICY_MODIFY = 'policy_modify' - POLICY_DELETE = 'policy_delete' - POLICY_ALL = 'policy_all' + POLICY_LIST = "policy_list" + POLICY_VIEW = "policy_view" + POLICY_CREATE = "policy_create" + POLICY_MODIFY = "policy_modify" + POLICY_DELETE = "policy_delete" + POLICY_ALL = "policy_all" - STREAM_VIEW = 'stream_view' + STREAM_VIEW = "stream_view" - INQUIRY_LIST = 'inquiry_list' - INQUIRY_VIEW = 'inquiry_view' - INQUIRY_RESPOND = 'inquiry_respond' - INQUIRY_ALL = 'inquiry_all' + INQUIRY_LIST = "inquiry_list" + INQUIRY_VIEW = "inquiry_view" + INQUIRY_RESPOND = "inquiry_respond" + INQUIRY_ALL = "inquiry_all" @classmethod def get_valid_permissions_for_resource_type(cls, resource_type): @@ -183,10 +180,10 @@ def get_resource_type(cls, permission_type): elif permission_type == PermissionType.EXECUTION_VIEWS_FILTERS_LIST: return ResourceType.EXECUTION - split = permission_type.split('_') + split = permission_type.split("_") assert len(split) >= 2 - return '_'.join(split[:-1]) + return "_".join(split[:-1]) @classmethod def get_permission_name(cls, permission_type): @@ -195,12 +192,12 @@ def get_permission_name(cls, permission_type): :rtype: ``str`` """ - split = permission_type.split('_') + split = permission_type.split("_") assert len(split) >= 2 # Special case for PACK_VIEWS_INDEX_HEALTH if permission_type == PermissionType.PACK_VIEWS_INDEX_HEALTH: - split = permission_type.split('_', 1) + split = permission_type.split("_", 1) return split[1] return split[-1] @@ -224,14 +221,16 @@ def get_permission_type(cls, resource_type, permission_name): """ # Special case for sensor type (sensor_type -> sensor) if resource_type == ResourceType.SENSOR: - resource_type = 'sensor' + resource_type = "sensor" - permission_enum = '%s_%s' % (resource_type.upper(), permission_name.upper()) + permission_enum = "%s_%s" % (resource_type.upper(), permission_name.upper()) result = getattr(cls, permission_enum, None) if not result: - raise ValueError('Unsupported permission type for type "%s" and name "%s"' % - (resource_type, permission_name)) + raise ValueError( + 'Unsupported permission type for type "%s" and name "%s"' + % (resource_type, permission_name) + ) return result @@ -240,6 +239,7 @@ class ResourceType(Enum): """ Resource types on which permissions can be granted. """ + RUNNER = SystemResourceType.RUNNER_TYPE PACK = SystemResourceType.PACK @@ -266,9 +266,10 @@ class SystemRole(Enum): """ Default system roles which can't be manipulated (modified or removed). """ - SYSTEM_ADMIN = 'system_admin' # Special role which can't be revoked. - ADMIN = 'admin' - OBSERVER = 'observer' + + SYSTEM_ADMIN = "system_admin" # Special role which can't be revoked. + ADMIN = "admin" + OBSERVER = "observer" # Maps a list of available permission types for each resource @@ -292,35 +293,31 @@ class SystemRole(Enum): PermissionType.PACK_SEARCH, PermissionType.PACK_VIEWS_INDEX_HEALTH, PermissionType.PACK_ALL, - PermissionType.SENSOR_VIEW, PermissionType.SENSOR_MODIFY, PermissionType.SENSOR_ALL, - PermissionType.ACTION_VIEW, PermissionType.ACTION_CREATE, PermissionType.ACTION_MODIFY, PermissionType.ACTION_DELETE, PermissionType.ACTION_EXECUTE, PermissionType.ACTION_ALL, - PermissionType.ACTION_ALIAS_VIEW, PermissionType.ACTION_ALIAS_CREATE, PermissionType.ACTION_ALIAS_MODIFY, PermissionType.ACTION_ALIAS_DELETE, PermissionType.ACTION_ALIAS_ALL, - PermissionType.RULE_VIEW, PermissionType.RULE_CREATE, PermissionType.RULE_MODIFY, PermissionType.RULE_DELETE, - PermissionType.RULE_ALL + PermissionType.RULE_ALL, ], ResourceType.SENSOR: [ PermissionType.SENSOR_LIST, PermissionType.SENSOR_VIEW, PermissionType.SENSOR_MODIFY, - PermissionType.SENSOR_ALL + PermissionType.SENSOR_ALL, ], ResourceType.ACTION: [ PermissionType.ACTION_LIST, @@ -329,7 +326,7 @@ class SystemRole(Enum): PermissionType.ACTION_MODIFY, PermissionType.ACTION_DELETE, PermissionType.ACTION_EXECUTE, - PermissionType.ACTION_ALL + PermissionType.ACTION_ALL, ], ResourceType.ACTION_ALIAS: [ PermissionType.ACTION_ALIAS_LIST, @@ -339,7 +336,7 @@ class SystemRole(Enum): PermissionType.ACTION_ALIAS_MATCH, PermissionType.ACTION_ALIAS_HELP, PermissionType.ACTION_ALIAS_DELETE, - PermissionType.ACTION_ALIAS_ALL + PermissionType.ACTION_ALIAS_ALL, ], ResourceType.RULE: [ PermissionType.RULE_LIST, @@ -347,7 +344,7 @@ class SystemRole(Enum): PermissionType.RULE_CREATE, PermissionType.RULE_MODIFY, PermissionType.RULE_DELETE, - PermissionType.RULE_ALL + PermissionType.RULE_ALL, ], ResourceType.RULE_ENFORCEMENT: [ PermissionType.RULE_ENFORCEMENT_LIST, @@ -364,7 +361,7 @@ class SystemRole(Enum): ResourceType.KEY_VALUE_PAIR: [ PermissionType.KEY_VALUE_VIEW, PermissionType.KEY_VALUE_SET, - PermissionType.KEY_VALUE_DELETE + PermissionType.KEY_VALUE_DELETE, ], ResourceType.WEBHOOK: [ PermissionType.WEBHOOK_LIST, @@ -372,12 +369,12 @@ class SystemRole(Enum): PermissionType.WEBHOOK_CREATE, PermissionType.WEBHOOK_SEND, PermissionType.WEBHOOK_DELETE, - PermissionType.WEBHOOK_ALL + PermissionType.WEBHOOK_ALL, ], ResourceType.TIMER: [ PermissionType.TIMER_LIST, PermissionType.TIMER_VIEW, - PermissionType.TIMER_ALL + PermissionType.TIMER_ALL, ], ResourceType.API_KEY: [ PermissionType.API_KEY_LIST, @@ -385,17 +382,17 @@ class SystemRole(Enum): PermissionType.API_KEY_CREATE, PermissionType.API_KEY_MODIFY, PermissionType.API_KEY_DELETE, - PermissionType.API_KEY_ALL + PermissionType.API_KEY_ALL, ], ResourceType.TRACE: [ PermissionType.TRACE_LIST, PermissionType.TRACE_VIEW, - PermissionType.TRACE_ALL + PermissionType.TRACE_ALL, ], ResourceType.TRIGGER: [ PermissionType.TRIGGER_LIST, PermissionType.TRIGGER_VIEW, - PermissionType.TRIGGER_ALL + PermissionType.TRIGGER_ALL, ], ResourceType.POLICY_TYPE: [ PermissionType.POLICY_TYPE_LIST, @@ -415,13 +412,16 @@ class SystemRole(Enum): PermissionType.INQUIRY_VIEW, PermissionType.INQUIRY_RESPOND, PermissionType.INQUIRY_ALL, - ] + ], } ALL_PERMISSION_TYPES = list(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP.values()) ALL_PERMISSION_TYPES = list(itertools.chain(*ALL_PERMISSION_TYPES)) -LIST_PERMISSION_TYPES = [permission_type for permission_type in ALL_PERMISSION_TYPES if - permission_type.endswith('_list')] +LIST_PERMISSION_TYPES = [ + permission_type + for permission_type in ALL_PERMISSION_TYPES + if permission_type.endswith("_list") +] # List of global permissions (ones which don't apply to a specific resource) GLOBAL_PERMISSION_TYPES = [ @@ -433,169 +433,198 @@ class SystemRole(Enum): PermissionType.PACK_CONFIG, PermissionType.PACK_SEARCH, PermissionType.PACK_VIEWS_INDEX_HEALTH, - # Action alias global permission types PermissionType.ACTION_ALIAS_MATCH, PermissionType.ACTION_ALIAS_HELP, - # API key global permission types PermissionType.API_KEY_CREATE, - # Policy global permission types PermissionType.POLICY_CREATE, - # Execution PermissionType.EXECUTION_VIEWS_FILTERS_LIST, - # Stream PermissionType.STREAM_VIEW, - # Inquiry PermissionType.INQUIRY_LIST, PermissionType.INQUIRY_RESPOND, - PermissionType.INQUIRY_VIEW - + PermissionType.INQUIRY_VIEW, ] + LIST_PERMISSION_TYPES -GLOBAL_PACK_PERMISSION_TYPES = [permission_type for permission_type in GLOBAL_PERMISSION_TYPES if - permission_type.startswith('pack_')] +GLOBAL_PACK_PERMISSION_TYPES = [ + permission_type + for permission_type in GLOBAL_PERMISSION_TYPES + if permission_type.startswith("pack_") +] # Maps a permission type to the corresponding description PERMISION_TYPE_TO_DESCRIPTION_MAP = { - PermissionType.PACK_LIST: 'Ability to list (view all) packs.', - PermissionType.PACK_VIEW: 'Ability to view a pack.', - PermissionType.PACK_CREATE: 'Ability to create a new pack.', - PermissionType.PACK_MODIFY: 'Ability to modify (update) an existing pack.', - PermissionType.PACK_DELETE: 'Ability to delete an existing pack.', - PermissionType.PACK_INSTALL: 'Ability to install packs.', - PermissionType.PACK_UNINSTALL: 'Ability to uninstall packs.', - PermissionType.PACK_REGISTER: 'Ability to register packs and corresponding resources.', - PermissionType.PACK_CONFIG: 'Ability to configure a pack.', - PermissionType.PACK_SEARCH: 'Ability to query registry and search packs.', - PermissionType.PACK_VIEWS_INDEX_HEALTH: 'Ability to query health of pack registries.', - PermissionType.PACK_ALL: ('Ability to perform all the supported operations on a particular ' - 'pack.'), - - PermissionType.SENSOR_LIST: 'Ability to list (view all) sensors.', - PermissionType.SENSOR_VIEW: 'Ability to view a sensor', - PermissionType.SENSOR_MODIFY: ('Ability to modify (update) an existing sensor. Also implies ' - '"sensor_type_view" permission.'), - PermissionType.SENSOR_ALL: ('Ability to perform all the supported operations on a particular ' - 'sensor.'), - - PermissionType.ACTION_LIST: 'Ability to list (view all) actions.', - PermissionType.ACTION_VIEW: 'Ability to view an action.', - PermissionType.ACTION_CREATE: ('Ability to create a new action. Also implies "action_view" ' - 'permission.'), - PermissionType.ACTION_MODIFY: ('Ability to modify (update) an existing action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_DELETE: ('Ability to delete an existing action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_EXECUTE: ('Ability to execute (run) an action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_ALL: ('Ability to perform all the supported operations on a particular ' - 'action.'), - - PermissionType.ACTION_ALIAS_LIST: 'Ability to list (view all) action aliases.', - PermissionType.ACTION_ALIAS_VIEW: 'Ability to view an action alias.', - PermissionType.ACTION_ALIAS_CREATE: ('Ability to create a new action alias. Also implies' - ' "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_MODIFY: ('Ability to modify (update) an existing action alias. ' - 'Also implies "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_MATCH: ('Ability to use action alias match API endpoint.'), - PermissionType.ACTION_ALIAS_HELP: ('Ability to use action alias help API endpoint.'), - PermissionType.ACTION_ALIAS_DELETE: ('Ability to delete an existing action alias. Also ' - 'implies "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_ALL: ('Ability to perform all the supported operations on a ' - 'particular action alias.'), - - PermissionType.EXECUTION_LIST: 'Ability to list (view all) executions.', - PermissionType.EXECUTION_VIEW: 'Ability to view an execution.', - PermissionType.EXECUTION_RE_RUN: 'Ability to create a new action.', - PermissionType.EXECUTION_STOP: 'Ability to stop (cancel) a running execution.', - PermissionType.EXECUTION_ALL: ('Ability to perform all the supported operations on a ' - 'particular execution.'), - PermissionType.EXECUTION_VIEWS_FILTERS_LIST: ('Ability view all the distinct execution ' - 'filters.'), - - PermissionType.RULE_LIST: 'Ability to list (view all) rules.', - PermissionType.RULE_VIEW: 'Ability to view a rule.', - PermissionType.RULE_CREATE: ('Ability to create a new rule. Also implies "rule_view" ' - 'permission'), - PermissionType.RULE_MODIFY: ('Ability to modify (update) an existing rule. Also implies ' - '"rule_view" permission.'), - PermissionType.RULE_DELETE: ('Ability to delete an existing rule. Also implies "rule_view" ' - 'permission.'), - PermissionType.RULE_ALL: ('Ability to perform all the supported operations on a particular ' - 'rule.'), - - PermissionType.RULE_ENFORCEMENT_LIST: 'Ability to list (view all) rule enforcements.', - PermissionType.RULE_ENFORCEMENT_VIEW: 'Ability to view a rule enforcement.', - - PermissionType.RUNNER_LIST: 'Ability to list (view all) runners.', - PermissionType.RUNNER_VIEW: 'Ability to view a runner.', - PermissionType.RUNNER_MODIFY: ('Ability to modify (update) an existing runner. Also implies ' - '"runner_type_view" permission.'), - PermissionType.RUNNER_ALL: ('Ability to perform all the supported operations on a particular ' - 'runner.'), - - PermissionType.WEBHOOK_LIST: 'Ability to list (view all) webhooks.', - PermissionType.WEBHOOK_VIEW: ('Ability to view a webhook.'), - PermissionType.WEBHOOK_CREATE: ('Ability to create a new webhook.'), - PermissionType.WEBHOOK_SEND: ('Ability to send / POST data to an existing webhook.'), - PermissionType.WEBHOOK_DELETE: ('Ability to delete an existing webhook.'), - PermissionType.WEBHOOK_ALL: ('Ability to perform all the supported operations on a particular ' - 'webhook.'), - - PermissionType.TIMER_LIST: 'Ability to list (view all) timers.', - PermissionType.TIMER_VIEW: ('Ability to view a timer.'), - PermissionType.TIMER_ALL: ('Ability to perform all the supported operations on timers'), - - PermissionType.API_KEY_LIST: 'Ability to list (view all) API keys.', - PermissionType.API_KEY_VIEW: ('Ability to view an API Key.'), - PermissionType.API_KEY_CREATE: ('Ability to create a new API Key.'), - PermissionType.API_KEY_MODIFY: ('Ability to modify (update) an existing API key. Also implies ' - '"api_key_view" permission.'), - PermissionType.API_KEY_DELETE: ('Ability to delete an existing API Keys.'), - PermissionType.API_KEY_ALL: ('Ability to perform all the supported operations on an API Key.'), - - PermissionType.KEY_VALUE_VIEW: ('Ability to view Key-Value Pairs.'), - PermissionType.KEY_VALUE_SET: ('Ability to set a Key-Value Pair.'), - PermissionType.KEY_VALUE_DELETE: ('Ability to delete an existing Key-Value Pair.'), - - PermissionType.TRACE_LIST: ('Ability to list (view all) traces.'), - PermissionType.TRACE_VIEW: ('Ability to view a trace.'), - PermissionType.TRACE_ALL: ('Ability to perform all the supported operations on traces.'), - - PermissionType.TRIGGER_LIST: ('Ability to list (view all) triggers.'), - PermissionType.TRIGGER_VIEW: ('Ability to view a trigger.'), - PermissionType.TRIGGER_ALL: ('Ability to perform all the supported operations on triggers.'), - - PermissionType.POLICY_TYPE_LIST: ('Ability to list (view all) policy types.'), - PermissionType.POLICY_TYPE_VIEW: ('Ability to view a policy types.'), - PermissionType.POLICY_TYPE_ALL: ('Ability to perform all the supported operations on policy' - ' types.'), - - PermissionType.POLICY_LIST: 'Ability to list (view all) policies.', - PermissionType.POLICY_VIEW: ('Ability to view a policy.'), - PermissionType.POLICY_CREATE: ('Ability to create a new policy.'), - PermissionType.POLICY_MODIFY: ('Ability to modify an existing policy.'), - PermissionType.POLICY_DELETE: ('Ability to delete an existing policy.'), - PermissionType.POLICY_ALL: ('Ability to perform all the supported operations on a particular ' - 'policy.'), - - PermissionType.STREAM_VIEW: ('Ability to view / listen to the events on the stream API ' - 'endpoint.'), - - PermissionType.INQUIRY_LIST: 'Ability to list existing Inquiries', - PermissionType.INQUIRY_VIEW: 'Ability to view an existing Inquiry. Also implies ' - '"inquiry_respond" permission.', - PermissionType.INQUIRY_RESPOND: 'Ability to respond to an existing Inquiry (in general - user ' - 'still needs access per specific inquiry parameters). Also ' - 'implies "inquiry_view" permission.', - PermissionType.INQUIRY_ALL: ('Ability to perform all supported operations on a particular ' - 'Inquiry.') + PermissionType.PACK_LIST: "Ability to list (view all) packs.", + PermissionType.PACK_VIEW: "Ability to view a pack.", + PermissionType.PACK_CREATE: "Ability to create a new pack.", + PermissionType.PACK_MODIFY: "Ability to modify (update) an existing pack.", + PermissionType.PACK_DELETE: "Ability to delete an existing pack.", + PermissionType.PACK_INSTALL: "Ability to install packs.", + PermissionType.PACK_UNINSTALL: "Ability to uninstall packs.", + PermissionType.PACK_REGISTER: "Ability to register packs and corresponding resources.", + PermissionType.PACK_CONFIG: "Ability to configure a pack.", + PermissionType.PACK_SEARCH: "Ability to query registry and search packs.", + PermissionType.PACK_VIEWS_INDEX_HEALTH: "Ability to query health of pack registries.", + PermissionType.PACK_ALL: ( + "Ability to perform all the supported operations on a particular " "pack." + ), + PermissionType.SENSOR_LIST: "Ability to list (view all) sensors.", + PermissionType.SENSOR_VIEW: "Ability to view a sensor", + PermissionType.SENSOR_MODIFY: ( + "Ability to modify (update) an existing sensor. Also implies " + '"sensor_type_view" permission.' + ), + PermissionType.SENSOR_ALL: ( + "Ability to perform all the supported operations on a particular " "sensor." + ), + PermissionType.ACTION_LIST: "Ability to list (view all) actions.", + PermissionType.ACTION_VIEW: "Ability to view an action.", + PermissionType.ACTION_CREATE: ( + 'Ability to create a new action. Also implies "action_view" ' "permission." + ), + PermissionType.ACTION_MODIFY: ( + "Ability to modify (update) an existing action. Also implies " + '"action_view" permission.' + ), + PermissionType.ACTION_DELETE: ( + "Ability to delete an existing action. Also implies " + '"action_view" permission.' + ), + PermissionType.ACTION_EXECUTE: ( + "Ability to execute (run) an action. Also implies " '"action_view" permission.' + ), + PermissionType.ACTION_ALL: ( + "Ability to perform all the supported operations on a particular " "action." + ), + PermissionType.ACTION_ALIAS_LIST: "Ability to list (view all) action aliases.", + PermissionType.ACTION_ALIAS_VIEW: "Ability to view an action alias.", + PermissionType.ACTION_ALIAS_CREATE: ( + "Ability to create a new action alias. Also implies" + ' "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_MODIFY: ( + "Ability to modify (update) an existing action alias. " + 'Also implies "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_MATCH: ( + "Ability to use action alias match API endpoint." + ), + PermissionType.ACTION_ALIAS_HELP: ( + "Ability to use action alias help API endpoint." + ), + PermissionType.ACTION_ALIAS_DELETE: ( + "Ability to delete an existing action alias. Also " + 'implies "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_ALL: ( + "Ability to perform all the supported operations on a " + "particular action alias." + ), + PermissionType.EXECUTION_LIST: "Ability to list (view all) executions.", + PermissionType.EXECUTION_VIEW: "Ability to view an execution.", + PermissionType.EXECUTION_RE_RUN: "Ability to create a new action.", + PermissionType.EXECUTION_STOP: "Ability to stop (cancel) a running execution.", + PermissionType.EXECUTION_ALL: ( + "Ability to perform all the supported operations on a " "particular execution." + ), + PermissionType.EXECUTION_VIEWS_FILTERS_LIST: ( + "Ability view all the distinct execution " "filters." + ), + PermissionType.RULE_LIST: "Ability to list (view all) rules.", + PermissionType.RULE_VIEW: "Ability to view a rule.", + PermissionType.RULE_CREATE: ( + 'Ability to create a new rule. Also implies "rule_view" ' "permission" + ), + PermissionType.RULE_MODIFY: ( + "Ability to modify (update) an existing rule. Also implies " + '"rule_view" permission.' + ), + PermissionType.RULE_DELETE: ( + 'Ability to delete an existing rule. Also implies "rule_view" ' "permission." + ), + PermissionType.RULE_ALL: ( + "Ability to perform all the supported operations on a particular " "rule." + ), + PermissionType.RULE_ENFORCEMENT_LIST: "Ability to list (view all) rule enforcements.", + PermissionType.RULE_ENFORCEMENT_VIEW: "Ability to view a rule enforcement.", + PermissionType.RUNNER_LIST: "Ability to list (view all) runners.", + PermissionType.RUNNER_VIEW: "Ability to view a runner.", + PermissionType.RUNNER_MODIFY: ( + "Ability to modify (update) an existing runner. Also implies " + '"runner_type_view" permission.' + ), + PermissionType.RUNNER_ALL: ( + "Ability to perform all the supported operations on a particular " "runner." + ), + PermissionType.WEBHOOK_LIST: "Ability to list (view all) webhooks.", + PermissionType.WEBHOOK_VIEW: ("Ability to view a webhook."), + PermissionType.WEBHOOK_CREATE: ("Ability to create a new webhook."), + PermissionType.WEBHOOK_SEND: ( + "Ability to send / POST data to an existing webhook." + ), + PermissionType.WEBHOOK_DELETE: ("Ability to delete an existing webhook."), + PermissionType.WEBHOOK_ALL: ( + "Ability to perform all the supported operations on a particular " "webhook." + ), + PermissionType.TIMER_LIST: "Ability to list (view all) timers.", + PermissionType.TIMER_VIEW: ("Ability to view a timer."), + PermissionType.TIMER_ALL: ( + "Ability to perform all the supported operations on timers" + ), + PermissionType.API_KEY_LIST: "Ability to list (view all) API keys.", + PermissionType.API_KEY_VIEW: ("Ability to view an API Key."), + PermissionType.API_KEY_CREATE: ("Ability to create a new API Key."), + PermissionType.API_KEY_MODIFY: ( + "Ability to modify (update) an existing API key. Also implies " + '"api_key_view" permission.' + ), + PermissionType.API_KEY_DELETE: ("Ability to delete an existing API Keys."), + PermissionType.API_KEY_ALL: ( + "Ability to perform all the supported operations on an API Key." + ), + PermissionType.KEY_VALUE_VIEW: ("Ability to view Key-Value Pairs."), + PermissionType.KEY_VALUE_SET: ("Ability to set a Key-Value Pair."), + PermissionType.KEY_VALUE_DELETE: ("Ability to delete an existing Key-Value Pair."), + PermissionType.TRACE_LIST: ("Ability to list (view all) traces."), + PermissionType.TRACE_VIEW: ("Ability to view a trace."), + PermissionType.TRACE_ALL: ( + "Ability to perform all the supported operations on traces." + ), + PermissionType.TRIGGER_LIST: ("Ability to list (view all) triggers."), + PermissionType.TRIGGER_VIEW: ("Ability to view a trigger."), + PermissionType.TRIGGER_ALL: ( + "Ability to perform all the supported operations on triggers." + ), + PermissionType.POLICY_TYPE_LIST: ("Ability to list (view all) policy types."), + PermissionType.POLICY_TYPE_VIEW: ("Ability to view a policy types."), + PermissionType.POLICY_TYPE_ALL: ( + "Ability to perform all the supported operations on policy" " types." + ), + PermissionType.POLICY_LIST: "Ability to list (view all) policies.", + PermissionType.POLICY_VIEW: ("Ability to view a policy."), + PermissionType.POLICY_CREATE: ("Ability to create a new policy."), + PermissionType.POLICY_MODIFY: ("Ability to modify an existing policy."), + PermissionType.POLICY_DELETE: ("Ability to delete an existing policy."), + PermissionType.POLICY_ALL: ( + "Ability to perform all the supported operations on a particular " "policy." + ), + PermissionType.STREAM_VIEW: ( + "Ability to view / listen to the events on the stream API " "endpoint." + ), + PermissionType.INQUIRY_LIST: "Ability to list existing Inquiries", + PermissionType.INQUIRY_VIEW: "Ability to view an existing Inquiry. Also implies " + '"inquiry_respond" permission.', + PermissionType.INQUIRY_RESPOND: "Ability to respond to an existing Inquiry (in general - user " + "still needs access per specific inquiry parameters). Also " + 'implies "inquiry_view" permission.', + PermissionType.INQUIRY_ALL: ( + "Ability to perform all supported operations on a particular " "Inquiry." + ), } @@ -607,10 +636,13 @@ def get_resource_permission_types_with_descriptions(): """ result = {} - for resource_type, permission_types in six.iteritems(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP): + for resource_type, permission_types in six.iteritems( + RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP + ): result[resource_type] = {} for permission_type in permission_types: - result[resource_type][permission_type] = \ - PERMISION_TYPE_TO_DESCRIPTION_MAP[permission_type] + result[resource_type][permission_type] = PERMISION_TYPE_TO_DESCRIPTION_MAP[ + permission_type + ] return result diff --git a/st2common/st2common/router.py b/st2common/st2common/router.py index 29b34031b4..47ef009b98 100644 --- a/st2common/st2common/router.py +++ b/st2common/st2common/router.py @@ -43,15 +43,12 @@ from st2common.util.http import parse_content_type_header __all__ = [ - 'Router', - - 'Response', - - 'NotFoundException', - - 'abort', - 'abort_unauthorized', - 'exc' + "Router", + "Response", + "NotFoundException", + "abort", + "abort_unauthorized", + "exc", ] LOG = logging.getLogger(__name__) @@ -63,24 +60,24 @@ def op_resolver(op_id): :rtype: ``tuple`` """ - module_name, func_name = op_id.split(':', 1) - controller_name = func_name.split('.')[0] + module_name, func_name = op_id.split(":", 1) + controller_name = func_name.split(".")[0] __import__(module_name) module = sys.modules[module_name] controller_instance = getattr(module, controller_name) - method_callable = functools.reduce(getattr, func_name.split('.'), module) + method_callable = functools.reduce(getattr, func_name.split("."), module) return controller_instance, method_callable -def abort(status_code=exc.HTTPInternalServerError.code, message='Unhandled exception'): +def abort(status_code=exc.HTTPInternalServerError.code, message="Unhandled exception"): raise exc.status_map[status_code](message) def abort_unauthorized(msg=None): - raise exc.HTTPUnauthorized('Unauthorized - %s' % msg if msg else 'Unauthorized') + raise exc.HTTPUnauthorized("Unauthorized - %s" % msg if msg else "Unauthorized") def extend_with_default(validator_class): @@ -92,12 +89,16 @@ def set_defaults(validator, properties, instance, schema): instance.setdefault(property, subschema["default"]) for error in validate_properties( - validator, properties, instance, schema, + validator, + properties, + instance, + schema, ): yield error return jsonschema.validators.extend( - validator_class, {"properties": set_defaults}, + validator_class, + {"properties": set_defaults}, ) @@ -109,7 +110,8 @@ def set_additional_check(validator, properties, instance, schema): yield error return jsonschema.validators.extend( - validator_class, {"x-additional-check": set_additional_check}, + validator_class, + {"x-additional-check": set_additional_check}, ) @@ -126,7 +128,8 @@ def set_type_draft4(validator, types, instance, schema): yield error return jsonschema.validators.extend( - validator_class, {"type": set_type_draft4}, + validator_class, + {"type": set_type_draft4}, ) @@ -141,27 +144,40 @@ class NotFoundException(Exception): class Response(webob.Response): - def __init__(self, body=None, status=None, headerlist=None, app_iter=None, content_type=None, - *args, **kwargs): + def __init__( + self, + body=None, + status=None, + headerlist=None, + app_iter=None, + content_type=None, + *args, + **kwargs, + ): # Do some sanity checking, and turn json_body into an actual body - if app_iter is None and body is None and ('json_body' in kwargs or 'json' in kwargs): - if 'json_body' in kwargs: - json_body = kwargs.pop('json_body') + if ( + app_iter is None + and body is None + and ("json_body" in kwargs or "json" in kwargs) + ): + if "json_body" in kwargs: + json_body = kwargs.pop("json_body") else: - json_body = kwargs.pop('json') - body = json_encode(json_body).encode('UTF-8') + json_body = kwargs.pop("json") + body = json_encode(json_body).encode("UTF-8") if content_type is None: - content_type = 'application/json' + content_type = "application/json" - super(Response, self).__init__(body, status, headerlist, app_iter, content_type, - *args, **kwargs) + super(Response, self).__init__( + body, status, headerlist, app_iter, content_type, *args, **kwargs + ) def _json_body__get(self): return super(Response, self)._json_body__get() def _json_body__set(self, value): - self.body = json_encode(value).encode('UTF-8') + self.body = json_encode(value).encode("UTF-8") def _json_body__del(self): return super(Response, self)._json_body__del() @@ -182,44 +198,51 @@ def __init__(self, arguments=None, debug=False, auth=True, is_gunicorn=True): self.routes = routes.Mapper() def add_spec(self, spec, transforms): - info = spec.get('info', {}) - LOG.debug('Adding API: %s %s', info.get('title', 'untitled'), info.get('version', '0.0.0')) + info = spec.get("info", {}) + LOG.debug( + "Adding API: %s %s", + info.get("title", "untitled"), + info.get("version", "0.0.0"), + ) self.spec = spec - self.spec_resolver = jsonschema.RefResolver('', self.spec) + self.spec_resolver = jsonschema.RefResolver("", self.spec) validate(copy.deepcopy(self.spec)) for filter in transforms: - for (path, methods) in six.iteritems(spec['paths']): + for (path, methods) in six.iteritems(spec["paths"]): if not re.search(filter, path): continue for (method, endpoint) in six.iteritems(methods): - conditions = { - 'method': [method.upper()] - } + conditions = {"method": [method.upper()]} connect_kw = {} - if 'x-requirements' in endpoint: - connect_kw['requirements'] = endpoint['x-requirements'] + if "x-requirements" in endpoint: + connect_kw["requirements"] = endpoint["x-requirements"] - m = self.routes.submapper(_api_path=path, _api_method=method, - conditions=conditions) + m = self.routes.submapper( + _api_path=path, _api_method=method, conditions=conditions + ) for transform in transforms[filter]: m.connect(None, re.sub(filter, transform, path), **connect_kw) - module_name = endpoint['operationId'].split(':', 1)[0] + module_name = endpoint["operationId"].split(":", 1)[0] __import__(module_name) for route in sorted(self.routes.matchlist, key=lambda r: r.routepath): - LOG.debug('Route registered: %+6s %s', route.conditions['method'][0], route.routepath) + LOG.debug( + "Route registered: %+6s %s", + route.conditions["method"][0], + route.routepath, + ) def match(self, req): path = url_unquote(req.path) LOG.debug("Match path: %s", path) - if len(path) > 1 and path.endswith('/'): + if len(path) > 1 and path.endswith("/"): path = path[:-1] match = self.routes.match(path, req.environ) @@ -235,9 +258,9 @@ def match(self, req): path_vars = dict(path_vars) - path = path_vars.pop('_api_path') - method = path_vars.pop('_api_method') - endpoint = self.spec['paths'][path][method] + path = path_vars.pop("_api_path") + method = path_vars.pop("_api_method") + endpoint = self.spec["paths"][path][method] return endpoint, path_vars @@ -256,127 +279,140 @@ def __call__(self, req): LOG.debug("Parsed endpoint: %s", endpoint) LOG.debug("Parsed path_vars: %s", path_vars) - context = copy.copy(getattr(self, 'mock_context', {})) + context = copy.copy(getattr(self, "mock_context", {})) cookie_token = None # Handle security - if 'security' in endpoint: - security = endpoint.get('security') + if "security" in endpoint: + security = endpoint.get("security") else: - security = self.spec.get('security', []) + security = self.spec.get("security", []) if self.auth and security: try: - security_definitions = self.spec.get('securityDefinitions', {}) + security_definitions = self.spec.get("securityDefinitions", {}) for statement in security: declaration, options = statement.copy().popitem() definition = security_definitions[declaration] - if definition['type'] == 'apiKey': - if definition['in'] == 'header': - token = req.headers.get(definition['name']) - elif definition['in'] == 'query': - token = req.GET.get(definition['name']) - elif definition['in'] == 'cookie': - token = req.cookies.get(definition['name']) + if definition["type"] == "apiKey": + if definition["in"] == "header": + token = req.headers.get(definition["name"]) + elif definition["in"] == "query": + token = req.GET.get(definition["name"]) + elif definition["in"] == "cookie": + token = req.cookies.get(definition["name"]) else: token = None if token: - _, auth_func = op_resolver(definition['x-operationId']) + _, auth_func = op_resolver(definition["x-operationId"]) auth_resp = auth_func(token) # Include information on how user authenticated inside the context - if 'auth-token' in definition['name'].lower(): - auth_method = 'authentication token' - elif 'api-key' in definition['name'].lower(): - auth_method = 'API key' - - context['user'] = User.get_by_name(auth_resp.user) - context['auth_info'] = { - 'method': auth_method, - 'location': definition['in'] + if "auth-token" in definition["name"].lower(): + auth_method = "authentication token" + elif "api-key" in definition["name"].lower(): + auth_method = "API key" + + context["user"] = User.get_by_name(auth_resp.user) + context["auth_info"] = { + "method": auth_method, + "location": definition["in"], } # Also include token expiration time when authenticated via auth token - if 'auth-token' in definition['name'].lower(): - context['auth_info']['token_expire'] = auth_resp.expiry - - if 'x-set-cookie' in definition: - max_age = auth_resp.expiry - date_utils.get_datetime_utc_now() - cookie_token = cookies.make_cookie(definition['x-set-cookie'], - token, - max_age=max_age, - httponly=True) + if "auth-token" in definition["name"].lower(): + context["auth_info"]["token_expire"] = auth_resp.expiry + + if "x-set-cookie" in definition: + max_age = ( + auth_resp.expiry - date_utils.get_datetime_utc_now() + ) + cookie_token = cookies.make_cookie( + definition["x-set-cookie"], + token, + max_age=max_age, + httponly=True, + ) break - if 'user' not in context: - raise auth_exc.NoAuthSourceProvidedError('One of Token or API key required.') - except (auth_exc.NoAuthSourceProvidedError, - auth_exc.MultipleAuthSourcesError) as e: + if "user" not in context: + raise auth_exc.NoAuthSourceProvidedError( + "One of Token or API key required." + ) + except ( + auth_exc.NoAuthSourceProvidedError, + auth_exc.MultipleAuthSourcesError, + ) as e: LOG.error(six.text_type(e)) return abort_unauthorized(six.text_type(e)) except auth_exc.TokenNotProvidedError as e: - LOG.exception('Token is not provided.') + LOG.exception("Token is not provided.") return abort_unauthorized(six.text_type(e)) except auth_exc.TokenNotFoundError as e: - LOG.exception('Token is not found.') + LOG.exception("Token is not found.") return abort_unauthorized(six.text_type(e)) except auth_exc.TokenExpiredError as e: - LOG.exception('Token has expired.') + LOG.exception("Token has expired.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyNotProvidedError as e: - LOG.exception('API key is not provided.') + LOG.exception("API key is not provided.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyNotFoundError as e: - LOG.exception('API key is not found.') + LOG.exception("API key is not found.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyDisabledError as e: - LOG.exception('API key is disabled.') + LOG.exception("API key is disabled.") return abort_unauthorized(six.text_type(e)) if cfg.CONF.rbac.enable: - user_db = context['user'] + user_db = context["user"] - permission_type = endpoint.get('x-permissions', None) + permission_type = endpoint.get("x-permissions", None) if permission_type: rbac_backend = get_rbac_backend() - resolver = rbac_backend.get_resolver_for_permission_type(permission_type) - has_permission = resolver.user_has_permission(user_db, permission_type) + resolver = rbac_backend.get_resolver_for_permission_type( + permission_type + ) + has_permission = resolver.user_has_permission( + user_db, permission_type + ) if not has_permission: - raise rbac_exc.ResourceTypeAccessDeniedError(user_db, - permission_type) + raise rbac_exc.ResourceTypeAccessDeniedError( + user_db, permission_type + ) # Collect parameters kw = {} - for param in endpoint.get('parameters', []) + endpoint.get('x-parameters', []): - name = param['name'] - argument_name = param.get('x-as', None) or name - source = param['in'] - default = param.get('default', None) + for param in endpoint.get("parameters", []) + endpoint.get("x-parameters", []): + name = param["name"] + argument_name = param.get("x-as", None) or name + source = param["in"] + default = param.get("default", None) # Collecting params from different sources - if source == 'query': + if source == "query": kw[argument_name] = req.GET.get(name, default) - elif source == 'path': + elif source == "path": kw[argument_name] = path_vars[name] - elif source == 'header': + elif source == "header": kw[argument_name] = req.headers.get(name, default) - elif source == 'formData': + elif source == "formData": kw[argument_name] = req.POST.get(name, default) - elif source == 'environ': + elif source == "environ": kw[argument_name] = req.environ.get(name.upper(), default) - elif source == 'context': + elif source == "context": kw[argument_name] = context.get(name, default) - elif source == 'request': + elif source == "request": kw[argument_name] = getattr(req, name) - elif source == 'body': - content_type = req.headers.get('Content-Type', 'application/json') + elif source == "body": + content_type = req.headers.get("Content-Type", "application/json") content_type = parse_content_type_header(content_type=content_type)[0] - schema = param['schema'] + schema = param["schema"] # NOTE: HACK: Workaround for eventlet wsgi server which sets Content-Type to # text/plain if Content-Type is not provided in the request. @@ -384,65 +420,76 @@ def __call__(self, req): # expect application/json so we explicitly set it to that # if not provided (set to text/plain by the base http server) and if it's not # /v1/workflows/inspection API endpoints. - if not self.is_gunicorn and content_type == 'text/plain': - operation_id = endpoint['operationId'] + if not self.is_gunicorn and content_type == "text/plain": + operation_id = endpoint["operationId"] - if ('workflow_inspection_controller' not in operation_id): - content_type = 'application/json' + if "workflow_inspection_controller" not in operation_id: + content_type = "application/json" # Note: We also want to perform validation if no body is explicitly provided - in a # lot of POST, PUT scenarios, body is mandatory - if not req.body and content_type == 'application/json': - req.body = b'{}' + if not req.body and content_type == "application/json": + req.body = b"{}" try: - if content_type == 'application/json': + if content_type == "application/json": data = req.json - elif content_type == 'text/plain': + elif content_type == "text/plain": data = req.body - elif content_type in ['application/x-www-form-urlencoded', - 'multipart/form-data']: + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: data = urlparse.parse_qs(req.body) else: - raise ValueError('Unsupported Content-Type: "%s"' % (content_type)) + raise ValueError( + 'Unsupported Content-Type: "%s"' % (content_type) + ) except Exception as e: - detail = 'Failed to parse request body: %s' % six.text_type(e) + detail = "Failed to parse request body: %s" % six.text_type(e) raise exc.HTTPBadRequest(detail=detail) # Special case for Python 3 - if six.PY3 and content_type == 'text/plain' and isinstance(data, six.binary_type): + if ( + six.PY3 + and content_type == "text/plain" + and isinstance(data, six.binary_type) + ): # Convert bytes to text type (string / unicode) - data = data.decode('utf-8') + data = data.decode("utf-8") try: CustomValidator(schema, resolver=self.spec_resolver).validate(data) except (jsonschema.ValidationError, ValueError) as e: - raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)), - comment=traceback.format_exc()) + raise exc.HTTPBadRequest( + detail=getattr(e, "message", six.text_type(e)), + comment=traceback.format_exc(), + ) - if content_type == 'text/plain': + if content_type == "text/plain": kw[argument_name] = data else: + class Body(object): def __init__(self, **entries): self.__dict__.update(entries) - ref = schema.get('$ref', None) + ref = schema.get("$ref", None) if ref: with self.spec_resolver.resolving(ref) as resolved: schema = resolved - if 'x-api-model' in schema: - input_type = schema.get('type', []) - _, Model = op_resolver(schema['x-api-model']) + if "x-api-model" in schema: + input_type = schema.get("type", []) + _, Model = op_resolver(schema["x-api-model"]) if input_type and not isinstance(input_type, (list, tuple)): input_type = [input_type] # root attribute is not an object, we need to use wrapper attribute to # make it work with **kwarg expansion - if input_type and 'array' in input_type: - data = {'data': data} + if input_type and "array" in input_type: + data = {"data": data} instance = self._get_model_instance(model_cls=Model, data=data) @@ -451,143 +498,178 @@ def __init__(self, **entries): try: instance = instance.validate() except (jsonschema.ValidationError, ValueError) as e: - raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)), - comment=traceback.format_exc()) + raise exc.HTTPBadRequest( + detail=getattr(e, "message", six.text_type(e)), + comment=traceback.format_exc(), + ) else: - LOG.debug('Missing x-api-model definition for %s, using generic Body ' - 'model.' % (endpoint['operationId'])) + LOG.debug( + "Missing x-api-model definition for %s, using generic Body " + "model." % (endpoint["operationId"]) + ) model = Body instance = self._get_model_instance(model_cls=model, data=data) kw[argument_name] = instance # Making sure all required params are present - required = param.get('required', False) + required = param.get("required", False) if required and kw[argument_name] is None: detail = 'Required parameter "%s" is missing' % name raise exc.HTTPBadRequest(detail=detail) # Validating and casting param types - param_type = param.get('type', None) + param_type = param.get("type", None) if kw[argument_name] is not None: - if param_type == 'boolean': - positive = ('true', '1', 'yes', 'y') - negative = ('false', '0', 'no', 'n') + if param_type == "boolean": + positive = ("true", "1", "yes", "y") + negative = ("false", "0", "no", "n") if str(kw[argument_name]).lower() not in positive + negative: detail = 'Parameter "%s" is not of type boolean' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = str(kw[argument_name]).lower() in positive - elif param_type == 'integer': - regex = r'^-?[0-9]+$' + elif param_type == "integer": + regex = r"^-?[0-9]+$" if not re.search(regex, str(kw[argument_name])): detail = 'Parameter "%s" is not of type integer' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = int(kw[argument_name]) - elif param_type == 'number': - regex = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$' + elif param_type == "number": + regex = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$" if not re.search(regex, str(kw[argument_name])): detail = 'Parameter "%s" is not of type float' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = float(kw[argument_name]) - elif param_type == 'array' and param.get('items', {}).get('type', None) == 'string': + elif ( + param_type == "array" + and param.get("items", {}).get("type", None) == "string" + ): if kw[argument_name] is None: kw[argument_name] = [] elif isinstance(kw[argument_name], (list, tuple)): # argument is already an array pass else: - kw[argument_name] = kw[argument_name].split(',') + kw[argument_name] = kw[argument_name].split(",") # Call the controller try: - controller_instance, func = op_resolver(endpoint['operationId']) + controller_instance, func = op_resolver(endpoint["operationId"]) except Exception as e: - LOG.exception('Failed to load controller for operation "%s": %s' % - (endpoint['operationId'], six.text_type(e))) + LOG.exception( + 'Failed to load controller for operation "%s": %s' + % (endpoint["operationId"], six.text_type(e)) + ) raise e try: resp = func(**kw) except DataStoreKeyNotFoundError as e: - LOG.warning('Failed to call controller function "%s" for operation "%s": %s' % - (func.__name__, endpoint['operationId'], six.text_type(e))) + LOG.warning( + 'Failed to call controller function "%s" for operation "%s": %s' + % (func.__name__, endpoint["operationId"], six.text_type(e)) + ) raise e except Exception as e: - LOG.exception('Failed to call controller function "%s" for operation "%s": %s' % - (func.__name__, endpoint['operationId'], six.text_type(e))) + LOG.exception( + 'Failed to call controller function "%s" for operation "%s": %s' + % (func.__name__, endpoint["operationId"], six.text_type(e)) + ) raise e # Handle response if resp is None: resp = Response() - if not hasattr(resp, '__call__'): + if not hasattr(resp, "__call__"): resp = Response(json=resp) - operation_id = endpoint['operationId'] + operation_id = endpoint["operationId"] # Process the response removing attributes based on the exclude_attribute and # include_attributes query param filter values (if specified) - include_attributes = kw.get('include_attributes', None) - exclude_attributes = kw.get('exclude_attributes', None) - has_include_or_exclude_attributes = bool(include_attributes) or bool(exclude_attributes) + include_attributes = kw.get("include_attributes", None) + exclude_attributes = kw.get("exclude_attributes", None) + has_include_or_exclude_attributes = bool(include_attributes) or bool( + exclude_attributes + ) # NOTE: We do NOT want to process stream controller response - is_streamming_controller = endpoint.get('x-is-streaming-endpoint', - bool('st2stream' in operation_id)) - - if not is_streamming_controller and resp.body and has_include_or_exclude_attributes: + is_streamming_controller = endpoint.get( + "x-is-streaming-endpoint", bool("st2stream" in operation_id) + ) + + if ( + not is_streamming_controller + and resp.body + and has_include_or_exclude_attributes + ): # NOTE: We need to check for response.body attribute since resp.json throws if JSON # response is not available - mandatory_include_fields = getattr(controller_instance, - 'mandatory_include_fields_response', []) - data = self._process_response(data=resp.json, - mandatory_include_fields=mandatory_include_fields, - include_attributes=include_attributes, - exclude_attributes=exclude_attributes) + mandatory_include_fields = getattr( + controller_instance, "mandatory_include_fields_response", [] + ) + data = self._process_response( + data=resp.json, + mandatory_include_fields=mandatory_include_fields, + include_attributes=include_attributes, + exclude_attributes=exclude_attributes, + ) resp.json = data - responses = endpoint.get('responses', {}) + responses = endpoint.get("responses", {}) response_spec = responses.get(str(resp.status_code), None) - default_response_spec = responses.get('default', None) + default_response_spec = responses.get("default", None) if not response_spec and default_response_spec: - LOG.debug('No custom response spec found for endpoint "%s", using a default one' % - (endpoint['operationId'])) - response_spec_name = 'default' + LOG.debug( + 'No custom response spec found for endpoint "%s", using a default one' + % (endpoint["operationId"]) + ) + response_spec_name = "default" else: response_spec_name = str(resp.status_code) response_spec = response_spec or default_response_spec - if response_spec and 'schema' in response_spec and not has_include_or_exclude_attributes: + if ( + response_spec + and "schema" in response_spec + and not has_include_or_exclude_attributes + ): # NOTE: We don't perform response validation when include or exclude attributes are # provided because this means partial response which likely won't pass the validation - LOG.debug('Using response spec "%s" for endpoint %s and status code %s' % - (response_spec_name, endpoint['operationId'], resp.status_code)) + LOG.debug( + 'Using response spec "%s" for endpoint %s and status code %s' + % (response_spec_name, endpoint["operationId"], resp.status_code) + ) try: - validator = CustomValidator(response_spec['schema'], resolver=self.spec_resolver) + validator = CustomValidator( + response_spec["schema"], resolver=self.spec_resolver + ) - response_type = response_spec['schema'].get('type', 'json') - if response_type == 'string': + response_type = response_spec["schema"].get("type", "json") + if response_type == "string": validator.validate(resp.text) else: validator.validate(resp.json) except (jsonschema.ValidationError, ValueError): - LOG.exception('Response validation failed.') - resp.headers.add('Warning', '199 OpenAPI "Response validation failed"') + LOG.exception("Response validation failed.") + resp.headers.add("Warning", '199 OpenAPI "Response validation failed"') else: - LOG.debug('No response spec found for endpoint "%s"' % (endpoint['operationId'])) + LOG.debug( + 'No response spec found for endpoint "%s"' % (endpoint["operationId"]) + ) if cookie_token: - resp.headerlist.append(('Set-Cookie', cookie_token)) + resp.headerlist.append(("Set-Cookie", cookie_token)) return resp @@ -604,17 +686,24 @@ def _get_model_instance(self, model_cls, data): instance = model_cls(**data) except TypeError as e: # Throw a more user-friendly exception when input data is not an object - if 'type object argument after ** must be a mapping, not' in six.text_type(e): + if "type object argument after ** must be a mapping, not" in six.text_type( + e + ): type_string = get_json_type_for_python_value(data) - msg = ('Input body needs to be an object, got: %s' % (type_string)) + msg = "Input body needs to be an object, got: %s" % (type_string) raise ValueError(msg) raise e return instance - def _process_response(self, data, mandatory_include_fields=None, include_attributes=None, - exclude_attributes=None): + def _process_response( + self, + data, + mandatory_include_fields=None, + include_attributes=None, + exclude_attributes=None, + ): """ Process controller response data such as removing attributes based on the values of exclude_attributes and include_attributes query param filters and similar. @@ -628,8 +717,10 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu # NOTE: include_attributes and exclude_attributes are mutually exclusive if include_attributes and exclude_attributes: - msg = ('exclude_attributes and include_attributes arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_attributes and include_attributes arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) # Common case - filters are not provided @@ -637,16 +728,20 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu return data # Skip processing of error responses - if isinstance(data, dict) and data.get('faultstring', None): + if isinstance(data, dict) and data.get("faultstring", None): return data # We only care about the first part of the field name since deep filtering happens inside # MongoDB. Deep filtering here would also be quite expensive and waste of CPU cycles. - cleaned_include_attributes = [attribute.split('.')[0] for attribute in include_attributes] + cleaned_include_attributes = [ + attribute.split(".")[0] for attribute in include_attributes + ] # Add in mandatory fields which always need to be present in the response (primary keys) cleaned_include_attributes += mandatory_include_fields - cleaned_exclude_attributes = [attribute.split('.')[0] for attribute in exclude_attributes] + cleaned_exclude_attributes = [ + attribute.split(".")[0] for attribute in exclude_attributes + ] # NOTE: Since those parameters are mutually exclusive we could perform more efficient # filtering when just exclude_attributes is provided. Instead of creating a new dict, we @@ -675,6 +770,6 @@ def process_item(item): # get_one response result = process_item(data) else: - raise ValueError('Unsupported type: %s' % (type(data))) + raise ValueError("Unsupported type: %s" % (type(data))) return result diff --git a/st2common/st2common/runners/__init__.py b/st2common/st2common/runners/__init__.py index bcccaaf48d..d6468f78e2 100644 --- a/st2common/st2common/runners/__init__.py +++ b/st2common/st2common/runners/__init__.py @@ -19,14 +19,9 @@ from st2common.util import driver_loader -__all__ = [ - 'BACKENDS_NAMESPACE', +__all__ = ["BACKENDS_NAMESPACE", "get_available_backends", "get_backend_driver"] - 'get_available_backends', - 'get_backend_driver' -] - -BACKENDS_NAMESPACE = 'st2common.runners.runner' +BACKENDS_NAMESPACE = "st2common.runners.runner" def get_available_backends(): diff --git a/st2common/st2common/runners/base.py b/st2common/st2common/runners/base.py index 6a9656b9a1..a5692b66e6 100644 --- a/st2common/st2common/runners/base.py +++ b/st2common/st2common/runners/base.py @@ -42,45 +42,43 @@ subprocess = concurrency.get_subprocess_module() __all__ = [ - 'ActionRunner', - 'AsyncActionRunner', - 'PollingAsyncActionRunner', - 'GitWorktreeActionRunner', - 'PollingAsyncActionRunner', - 'ShellRunnerMixin', - - 'get_runner_module', - - 'get_runner', - 'get_metadata', + "ActionRunner", + "AsyncActionRunner", + "PollingAsyncActionRunner", + "GitWorktreeActionRunner", + "PollingAsyncActionRunner", + "ShellRunnerMixin", + "get_runner_module", + "get_runner", + "get_metadata", ] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters -RUNNER_COMMAND = 'cmd' -RUNNER_CONTENT_VERSION = 'content_version' -RUNNER_DEBUG = 'debug' +RUNNER_COMMAND = "cmd" +RUNNER_CONTENT_VERSION = "content_version" +RUNNER_DEBUG = "debug" def get_runner(name, config=None): """ Load the module and return an instance of the runner. """ - LOG.debug('Runner loading Python module: %s', name) + LOG.debug("Runner loading Python module: %s", name) module = get_runner_module(name=name) - LOG.debug('Instance of runner module: %s', module) + LOG.debug("Instance of runner module: %s", module) if config: - runner_kwargs = {'config': config} + runner_kwargs = {"config": config} else: runner_kwargs = {} runner = module.get_runner(**runner_kwargs) - LOG.debug('Instance of runner: %s', runner) + LOG.debug("Instance of runner: %s", runner) return runner @@ -95,19 +93,21 @@ def get_runner_module(name): try: module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False) except NoMatches: - name = name.replace('_', '-') + name = name.replace("_", "-") try: module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False) except Exception as e: available_runners = get_available_plugins(namespace=RUNNERS_NAMESPACE) - available_runners = ', '.join(available_runners) - msg = ('Failed to find runner %s. Make sure that the runner is available and installed ' - 'in StackStorm virtual environment. Available runners are: %s' % - (name, available_runners)) + available_runners = ", ".join(available_runners) + msg = ( + "Failed to find runner %s. Make sure that the runner is available and installed " + "in StackStorm virtual environment. Available runners are: %s" + % (name, available_runners) + ) LOG.exception(msg) - raise exc.ActionRunnerCreateError('%s\n\n%s' % (msg, six.text_type(e))) + raise exc.ActionRunnerCreateError("%s\n\n%s" % (msg, six.text_type(e))) return module @@ -120,9 +120,9 @@ def get_metadata(package_name): """ import pkg_resources - file_path = pkg_resources.resource_filename(package_name, 'runner.yaml') + file_path = pkg_resources.resource_filename(package_name, "runner.yaml") - with open(file_path, 'r') as fp: + with open(file_path, "r") as fp: content = fp.read() metadata = yaml.safe_load(content) @@ -158,14 +158,14 @@ def __init__(self, runner_id): def pre_run(self): # Handle runner "enabled" attribute - runner_enabled = getattr(self.runner_type, 'enabled', True) - runner_name = getattr(self.runner_type, 'name', 'unknown') + runner_enabled = getattr(self.runner_type, "enabled", True) + runner_name = getattr(self.runner_type, "name", "unknown") if not runner_enabled: msg = 'Runner "%s" has been disabled by the administrator.' % runner_name raise ValueError(msg) - runner_parameters = getattr(self, 'runner_parameters', {}) or {} + runner_parameters = getattr(self, "runner_parameters", {}) or {} self._debug = runner_parameters.get(RUNNER_DEBUG, False) # Run will need to take an action argument @@ -175,18 +175,20 @@ def run(self, action_parameters): raise NotImplementedError() def pause(self): - runner_name = getattr(self.runner_type, 'name', 'unknown') - raise NotImplementedError('Pause is not supported for runner %s.' % runner_name) + runner_name = getattr(self.runner_type, "name", "unknown") + raise NotImplementedError("Pause is not supported for runner %s." % runner_name) def resume(self): - runner_name = getattr(self.runner_type, 'name', 'unknown') - raise NotImplementedError('Resume is not supported for runner %s.' % runner_name) + runner_name = getattr(self.runner_type, "name", "unknown") + raise NotImplementedError( + "Resume is not supported for runner %s." % runner_name + ) def cancel(self): return ( action_constants.LIVEACTION_STATUS_CANCELED, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def post_run(self, status, result): @@ -213,8 +215,8 @@ def get_user(self): :rtype: ``str`` """ - context = getattr(self, 'context', {}) or {} - user = context.get('user', cfg.CONF.system_user.user) + context = getattr(self, "context", {}) or {} + user = context.get("user", cfg.CONF.system_user.user) return user @@ -228,18 +230,18 @@ def _get_common_action_env_variables(self): :rtype: ``dict`` """ result = {} - result['ST2_ACTION_PACK_NAME'] = self.get_pack_ref() - result['ST2_ACTION_EXECUTION_ID'] = str(self.execution_id) - result['ST2_ACTION_API_URL'] = get_full_public_api_url() + result["ST2_ACTION_PACK_NAME"] = self.get_pack_ref() + result["ST2_ACTION_EXECUTION_ID"] = str(self.execution_id) + result["ST2_ACTION_API_URL"] = get_full_public_api_url() if self.auth_token: - result['ST2_ACTION_AUTH_TOKEN'] = self.auth_token.token + result["ST2_ACTION_AUTH_TOKEN"] = self.auth_token.token return result def __str__(self): - attrs = ', '.join(['%s=%s' % (k, v) for k, v in six.iteritems(self.__dict__)]) - return '%s@%s(%s)' % (self.__class__.__name__, str(id(self)), attrs) + attrs = ", ".join(["%s=%s" % (k, v) for k, v in six.iteritems(self.__dict__)]) + return "%s@%s(%s)" % (self.__class__.__name__, str(id(self)), attrs) @six.add_metaclass(abc.ABCMeta) @@ -248,7 +250,6 @@ class AsyncActionRunner(ActionRunner): class PollingAsyncActionRunner(AsyncActionRunner): - @classmethod def is_polling_enabled(cls): return True @@ -264,7 +265,7 @@ class GitWorktreeActionRunner(ActionRunner): This revision is specified using "content_version" runner parameter. """ - WORKTREE_DIRECTORY_PREFIX = 'st2-git-worktree-' + WORKTREE_DIRECTORY_PREFIX = "st2-git-worktree-" def __init__(self, runner_id): super(GitWorktreeActionRunner, self).__init__(runner_id=runner_id) @@ -284,11 +285,13 @@ def pre_run(self): # Override entry_point so it points to git worktree directory pack_name = self.get_pack_name() - entry_point = self._get_entry_point_for_worktree_path(pack_name=pack_name, - entry_point=self.entry_point, - worktree_path=self.git_worktree_path) + entry_point = self._get_entry_point_for_worktree_path( + pack_name=pack_name, + entry_point=self.entry_point, + worktree_path=self.git_worktree_path, + ) - assert(entry_point.startswith(self.git_worktree_path)) + assert entry_point.startswith(self.git_worktree_path) self.entry_point = entry_point @@ -298,9 +301,11 @@ def post_run(self, status, result): # Remove git worktree directories (if used and available) if self.git_worktree_path and self.git_worktree_revision: pack_name = self.get_pack_name() - self.cleanup_git_worktree(worktree_path=self.git_worktree_path, - content_version=self.git_worktree_revision, - pack_name=pack_name) + self.cleanup_git_worktree( + worktree_path=self.git_worktree_path, + content_version=self.git_worktree_revision, + pack_name=pack_name, + ) def create_git_worktree(self, content_version): """ @@ -318,51 +323,59 @@ def create_git_worktree(self, content_version): self.git_worktree_path = worktree_path extra = { - 'pack_name': pack_name, - 'pack_directory': pack_directory, - 'content_version': content_version, - 'worktree_path': worktree_path + "pack_name": pack_name, + "pack_directory": pack_directory, + "content_version": content_version, + "worktree_path": worktree_path, } if not os.path.isdir(pack_directory): - msg = ('Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t ' - 'exist.' % (pack_name, pack_directory)) + msg = ( + 'Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t ' + "exist." % (pack_name, pack_directory) + ) raise ValueError(msg) args = [ - 'git', - '-C', + "git", + "-C", pack_directory, - 'worktree', - 'add', + "worktree", + "add", worktree_path, - content_version + content_version, ] cmd = list2cmdline(args) - LOG.debug('Creating git worktree for pack "%s", content version "%s" and execution ' - 'id "%s" in "%s"' % (pack_name, content_version, self.execution_id, - worktree_path), extra=extra) - LOG.debug('Command: %s' % (cmd)) - exit_code, stdout, stderr, timed_out = run_command(cmd=cmd, - cwd=pack_directory, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True) + LOG.debug( + 'Creating git worktree for pack "%s", content version "%s" and execution ' + 'id "%s" in "%s"' + % (pack_name, content_version, self.execution_id, worktree_path), + extra=extra, + ) + LOG.debug("Command: %s" % (cmd)) + exit_code, stdout, stderr, timed_out = run_command( + cmd=cmd, + cwd=pack_directory, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) if exit_code != 0: - self._handle_git_worktree_error(pack_name=pack_name, pack_directory=pack_directory, - content_version=content_version, - exit_code=exit_code, stdout=stdout, stderr=stderr) + self._handle_git_worktree_error( + pack_name=pack_name, + pack_directory=pack_directory, + content_version=content_version, + exit_code=exit_code, + stdout=stdout, + stderr=stderr, + ) else: LOG.debug('Git worktree created in "%s"' % (worktree_path), extra=extra) # Make sure system / action runner user can access that directory - args = [ - 'chmod', - '777', - worktree_path - ] + args = ["chmod", "777", worktree_path] cmd = list2cmdline(args) run_command(cmd=cmd, shell=True) @@ -375,15 +388,19 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version): :rtype: ``bool`` """ # Safety check to make sure we don't remove something outside /tmp - assert(worktree_path.startswith('/tmp')) - assert(worktree_path.startswith('/tmp/%s' % (self.WORKTREE_DIRECTORY_PREFIX))) + assert worktree_path.startswith("/tmp") + assert worktree_path.startswith("/tmp/%s" % (self.WORKTREE_DIRECTORY_PREFIX)) if self._debug: - LOG.debug('Not removing git worktree "%s" because debug mode is enabled' % - (worktree_path)) + LOG.debug( + 'Not removing git worktree "%s" because debug mode is enabled' + % (worktree_path) + ) else: - LOG.debug('Removing git worktree "%s" for pack "%s" and content version "%s"' % - (worktree_path, pack_name, content_version)) + LOG.debug( + 'Removing git worktree "%s" for pack "%s" and content version "%s"' + % (worktree_path, pack_name, content_version) + ) try: shutil.rmtree(worktree_path, ignore_errors=True) @@ -392,36 +409,43 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version): return True - def _handle_git_worktree_error(self, pack_name, pack_directory, content_version, exit_code, - stdout, stderr): + def _handle_git_worktree_error( + self, pack_name, pack_directory, content_version, exit_code, stdout, stderr + ): """ Handle "git worktree" related errors and throw a more user-friendly exception. """ error_prefix = 'Failed to create git worktree for pack "%s": ' % (pack_name) if isinstance(stdout, six.binary_type): - stdout = stdout.decode('utf-8') + stdout = stdout.decode("utf-8") if isinstance(stderr, six.binary_type): - stderr = stderr.decode('utf-8') + stderr = stderr.decode("utf-8") # 1. Installed version of git which doesn't support worktree command if "git: 'worktree' is not a git command." in stderr: - msg = ('Installed git version doesn\'t support git worktree command. ' - 'To be able to utilize this functionality you need to use git ' - '>= 2.5.0.') + msg = ( + "Installed git version doesn't support git worktree command. " + "To be able to utilize this functionality you need to use git " + ">= 2.5.0." + ) raise ValueError(error_prefix + msg) # 2. Provided pack directory is not a git repository if "Not a git repository" in stderr: - msg = ('Pack directory "%s" is not a git repository. To utilize this functionality, ' - 'pack directory needs to be a git repository.' % (pack_directory)) + msg = ( + 'Pack directory "%s" is not a git repository. To utilize this functionality, ' + "pack directory needs to be a git repository." % (pack_directory) + ) raise ValueError(error_prefix + msg) # 3. Invalid revision provided if "invalid reference" in stderr: - msg = ('Invalid content_version "%s" provided. Make sure that git repository is up ' - 'to date and contains that revision.' % (content_version)) + msg = ( + 'Invalid content_version "%s" provided. Make sure that git repository is up ' + "to date and contains that revision." % (content_version) + ) raise ValueError(error_prefix + msg) def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_path): @@ -433,10 +457,10 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa """ pack_base_path = get_pack_base_path(pack_name=pack_name) - new_entry_point = entry_point.replace(pack_base_path, '') + new_entry_point = entry_point.replace(pack_base_path, "") # Remove leading slash (if any) - if new_entry_point.startswith('/'): + if new_entry_point.startswith("/"): new_entry_point = new_entry_point[1:] new_entry_point = os.path.join(worktree_path, new_entry_point) @@ -444,7 +468,7 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa # Check to prevent directory traversal common_prefix = os.path.commonprefix([worktree_path, new_entry_point]) if common_prefix != worktree_path: - raise ValueError('entry_point is not located inside the pack directory') + raise ValueError("entry_point is not located inside the pack directory") return new_entry_point @@ -483,11 +507,11 @@ def _get_script_args(self, action_parameters): is_script_run_as_cmd = self.runner_parameters.get(RUNNER_COMMAND, None) - pos_args = '' + pos_args = "" named_args = {} if is_script_run_as_cmd: - pos_args = self.runner_parameters.get(RUNNER_COMMAND, '') + pos_args = self.runner_parameters.get(RUNNER_COMMAND, "") named_args = action_parameters else: pos_args, named_args = action_utils.get_args(action_parameters, self.action) diff --git a/st2common/st2common/runners/base_action.py b/st2common/st2common/runners/base_action.py index bc915d2b4f..244a4235c9 100644 --- a/st2common/st2common/runners/base_action.py +++ b/st2common/st2common/runners/base_action.py @@ -21,9 +21,7 @@ from st2common.runners.utils import get_logger_for_python_runner_action from st2common.runners.utils import PackConfigDict -__all__ = [ - 'Action' -] +__all__ = ["Action"] @six.add_metaclass(abc.ABCMeta) @@ -45,16 +43,17 @@ def __init__(self, config=None, action_service=None): self.config = config or {} self.action_service = action_service - if action_service and getattr(action_service, '_action_wrapper', None): - log_level = getattr(action_service._action_wrapper, '_log_level', 'debug') - pack_name = getattr(action_service._action_wrapper, '_pack', 'unknown') + if action_service and getattr(action_service, "_action_wrapper", None): + log_level = getattr(action_service._action_wrapper, "_log_level", "debug") + pack_name = getattr(action_service._action_wrapper, "_pack", "unknown") else: - log_level = 'debug' - pack_name = 'unknown' + log_level = "debug" + pack_name = "unknown" self.config = PackConfigDict(pack_name, self.config) - self.logger = get_logger_for_python_runner_action(action_name=self.__class__.__name__, - log_level=log_level) + self.logger = get_logger_for_python_runner_action( + action_name=self.__class__.__name__, log_level=log_level + ) @abc.abstractmethod def run(self, **kwargs): diff --git a/st2common/st2common/runners/parallel_ssh.py b/st2common/st2common/runners/parallel_ssh.py index 28f8756415..c41175c02c 100644 --- a/st2common/st2common/runners/parallel_ssh.py +++ b/st2common/st2common/runners/parallel_ssh.py @@ -35,13 +35,26 @@ class ParallelSSHClient(object): - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] - CONNECT_ERROR = 'Cannot connect to host.' - - def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_material=None, port=22, - bastion_host=None, concurrency=10, raise_on_any_error=False, connect=True, - passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None, - sudo_password=False): + KEYS_TO_TRANSFORM = ["stdout", "stderr"] + CONNECT_ERROR = "Cannot connect to host." + + def __init__( + self, + hosts, + user=None, + password=None, + pkey_file=None, + pkey_material=None, + port=22, + bastion_host=None, + concurrency=10, + raise_on_any_error=False, + connect=True, + passphrase=None, + handle_stdout_line_func=None, + handle_stderr_line_func=None, + sudo_password=False, + ): """ :param handle_stdout_line_func: Callback function which is called dynamically each time a new stdout line is received. @@ -65,7 +78,7 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia self._sudo_password = sudo_password if not hosts: - raise Exception('Need an non-empty list of hosts to talk to.') + raise Exception("Need an non-empty list of hosts to talk to.") self._pool = concurrency_lib.get_green_pool_class()(concurrency) self._hosts_client = {} @@ -74,8 +87,8 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia if connect: connect_results = self.connect(raise_on_any_error=raise_on_any_error) - extra = {'_connect_results': connect_results} - LOG.debug('Connect to hosts complete.', extra=extra) + extra = {"_connect_results": connect_results} + LOG.debug("Connect to hosts complete.", extra=extra) def connect(self, raise_on_any_error=False): """ @@ -92,17 +105,28 @@ def connect(self, raise_on_any_error=False): for host in self._hosts: while not concurrency_lib.is_green_pool_free(self._pool): concurrency_lib.sleep(self._scan_interval) - self._pool.spawn(self._connect, host=host, results=results, - raise_on_any_error=raise_on_any_error) + self._pool.spawn( + self._connect, + host=host, + results=results, + raise_on_any_error=raise_on_any_error, + ) concurrency_lib.green_pool_wait_all(self._pool) if self._successful_connects < 1: # We definitely have to raise an exception in this case. - LOG.error('Unable to connect to any of the hosts.', - extra={'connect_results': results}) - msg = ('Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s' % - (self._hosts, json.dumps(results, indent=2))) + LOG.error( + "Unable to connect to any of the hosts.", + extra={"connect_results": results}, + ) + msg = ( + "Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s" + % ( + self._hosts, + json.dumps(results, indent=2), + ) + ) raise NoHostsConnectedToException(msg) return results @@ -124,10 +148,7 @@ def run(self, cmd, timeout=None): :rtype: ``dict`` of ``str`` to ``dict`` """ - options = { - 'cmd': cmd, - 'timeout': timeout - } + options = {"cmd": cmd, "timeout": timeout} results = self._execute_in_pool(self._run_command, **options) return results @@ -152,13 +173,13 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): """ if not os.path.exists(local_path): - raise Exception('Local path %s does not exist.' % local_path) + raise Exception("Local path %s does not exist." % local_path) options = { - 'local_path': local_path, - 'remote_path': remote_path, - 'mode': mode, - 'mirror_local_mode': mirror_local_mode + "local_path": local_path, + "remote_path": remote_path, + "mode": mode, + "mirror_local_mode": mirror_local_mode, } return self._execute_in_pool(self._put_files, **options) @@ -173,9 +194,7 @@ def mkdir(self, path): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path - } + options = {"path": path} return self._execute_in_pool(self._mkdir, **options) def delete_file(self, path): @@ -188,9 +207,7 @@ def delete_file(self, path): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path - } + options = {"path": path} return self._execute_in_pool(self._delete_file, **options) def delete_dir(self, path, force=False, timeout=None): @@ -203,10 +220,7 @@ def delete_dir(self, path, force=False, timeout=None): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path, - 'force': force - } + options = {"path": path, "force": force} return self._execute_in_pool(self._delete_dir, **options) def close(self): @@ -218,7 +232,7 @@ def close(self): try: self._hosts_client[host].close() except: - LOG.exception('Failed shutting down SSH connection to host: %s', host) + LOG.exception("Failed shutting down SSH connection to host: %s", host) def _execute_in_pool(self, execute_method, **kwargs): results = {} @@ -237,36 +251,41 @@ def _execute_in_pool(self, execute_method, **kwargs): def _connect(self, host, results, raise_on_any_error=False): (hostname, port) = self._get_host_port_info(host) - extra = {'host': host, 'port': port, 'user': self._ssh_user} + extra = {"host": host, "port": port, "user": self._ssh_user} if self._ssh_password: - extra['password'] = '' + extra["password"] = "" elif self._ssh_key_file: - extra['key_file_path'] = self._ssh_key_file + extra["key_file_path"] = self._ssh_key_file else: - extra['private_key'] = '' - - LOG.debug('Connecting to host.', extra=extra) - - client = ParamikoSSHClient(hostname=hostname, port=port, - username=self._ssh_user, - password=self._ssh_password, - bastion_host=self._bastion_host, - key_files=self._ssh_key_file, - key_material=self._ssh_key_material, - passphrase=self._passphrase, - handle_stdout_line_func=self._handle_stdout_line_func, - handle_stderr_line_func=self._handle_stderr_line_func) + extra["private_key"] = "" + + LOG.debug("Connecting to host.", extra=extra) + + client = ParamikoSSHClient( + hostname=hostname, + port=port, + username=self._ssh_user, + password=self._ssh_password, + bastion_host=self._bastion_host, + key_files=self._ssh_key_file, + key_material=self._ssh_key_material, + passphrase=self._passphrase, + handle_stdout_line_func=self._handle_stdout_line_func, + handle_stderr_line_func=self._handle_stderr_line_func, + ) try: client.connect() except SSHException as ex: LOG.exception(ex) if raise_on_any_error: raise - error_dict = self._generate_error_result(exc=ex, message='Connection error.') + error_dict = self._generate_error_result( + exc=ex, message="Connection error." + ) self._bad_hosts[hostname] = error_dict results[hostname] = error_dict except Exception as ex: - error = 'Failed connecting to host %s.' % hostname + error = "Failed connecting to host %s." % hostname LOG.exception(error) if raise_on_any_error: raise @@ -276,16 +295,19 @@ def _connect(self, host, results, raise_on_any_error=False): else: self._successful_connects += 1 self._hosts_client[hostname] = client - results[hostname] = {'message': 'Connected to host.'} + results[hostname] = {"message": "Connected to host."} def _run_command(self, host, cmd, results, timeout=None): try: - LOG.debug('Running command: %s on host: %s.', cmd, host) + LOG.debug("Running command: %s on host: %s.", cmd, host) client = self._hosts_client[host] - (stdout, stderr, exit_code) = client.run(cmd, timeout=timeout, - call_line_handler_func=True) + (stdout, stderr, exit_code) = client.run( + cmd, timeout=timeout, call_line_handler_func=True + ) - result = self._handle_command_result(stdout=stdout, stderr=stderr, exit_code=exit_code) + result = self._handle_command_result( + stdout=stdout, stderr=stderr, exit_code=exit_code + ) results[host] = result except Exception as ex: cmd = self._sanitize_command_string(cmd=cmd) @@ -293,20 +315,24 @@ def _run_command(self, host, cmd, results, timeout=None): LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) - def _put_files(self, local_path, remote_path, host, results, mode=None, - mirror_local_mode=False): + def _put_files( + self, local_path, remote_path, host, results, mode=None, mirror_local_mode=False + ): try: - LOG.debug('Copying file to host: %s' % host) + LOG.debug("Copying file to host: %s" % host) if os.path.isdir(local_path): result = self._hosts_client[host].put_dir(local_path, remote_path) else: - result = self._hosts_client[host].put(local_path, remote_path, - mirror_local_mode=mirror_local_mode, - mode=mode) - LOG.debug('Result of copy: %s' % result) + result = self._hosts_client[host].put( + local_path, + remote_path, + mirror_local_mode=mirror_local_mode, + mode=mode, + ) + LOG.debug("Result of copy: %s" % result) results[host] = result except Exception as ex: - error = 'Failed sending file(s) in path %s to host %s' % (local_path, host) + error = "Failed sending file(s) in path %s to host %s" % (local_path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) @@ -324,16 +350,18 @@ def _delete_file(self, host, path, results): result = self._hosts_client[host].delete_file(path) results[host] = result except Exception as ex: - error = 'Failed deleting file %s on host %s.' % (path, host) + error = "Failed deleting file %s on host %s." % (path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) def _delete_dir(self, host, path, results, force=False, timeout=None): try: - result = self._hosts_client[host].delete_dir(path, force=force, timeout=timeout) + result = self._hosts_client[host].delete_dir( + path, force=force, timeout=timeout + ) results[host] = result except Exception as ex: - error = 'Failed deleting dir %s on host %s.' % (path, host) + error = "Failed deleting dir %s on host %s." % (path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) @@ -347,20 +375,27 @@ def _get_host_port_info(self, host_str): def _handle_command_result(self, stdout, stderr, exit_code): # Detect if user provided an invalid sudo password or sudo is not configured for that user if self._sudo_password: - if re.search(r'sudo: \d+ incorrect password attempts', stderr): - match = re.search(r'\[sudo\] password for (.+?)\:', stderr) + if re.search(r"sudo: \d+ incorrect password attempts", stderr): + match = re.search(r"\[sudo\] password for (.+?)\:", stderr) if match: username = match.groups()[0] else: - username = 'unknown' + username = "unknown" - error = ('Invalid sudo password provided or sudo is not configured for this user ' - '(%s)' % (username)) + error = ( + "Invalid sudo password provided or sudo is not configured for this user " + "(%s)" % (username) + ) raise ValueError(error) - is_succeeded = (exit_code == 0) - result_dict = {'stdout': stdout, 'stderr': stderr, 'return_code': exit_code, - 'succeeded': is_succeeded, 'failed': not is_succeeded} + is_succeeded = exit_code == 0 + result_dict = { + "stdout": stdout, + "stderr": stderr, + "return_code": exit_code, + "succeeded": is_succeeded, + "failed": not is_succeeded, + } result = jsonify.json_loads(result_dict, ParallelSSHClient.KEYS_TO_TRANSFORM) return result @@ -375,8 +410,11 @@ def _sanitize_command_string(cmd): if not cmd: return cmd - result = re.sub(r'ST2_ACTION_AUTH_TOKEN=(.+?)\s+?', 'ST2_ACTION_AUTH_TOKEN=%s ' % - (MASKED_ATTRIBUTE_VALUE), cmd) + result = re.sub( + r"ST2_ACTION_AUTH_TOKEN=(.+?)\s+?", + "ST2_ACTION_AUTH_TOKEN=%s " % (MASKED_ATTRIBUTE_VALUE), + cmd, + ) return result @staticmethod @@ -388,8 +426,8 @@ def _generate_error_result(exc, message): :param message: Error message which will be prefixed to the exception exception message. :type message: ``str`` """ - exc_message = getattr(exc, 'message', str(exc)) - error_message = '%s %s' % (message, exc_message) + exc_message = getattr(exc, "message", str(exc)) + error_message = "%s %s" % (message, exc_message) traceback_message = traceback.format_exc() if isinstance(exc, SSHCommandTimeoutError): @@ -399,21 +437,24 @@ def _generate_error_result(exc, message): timeout = False return_code = 255 - stdout = getattr(exc, 'stdout', None) or '' - stderr = getattr(exc, 'stderr', None) or '' + stdout = getattr(exc, "stdout", None) or "" + stderr = getattr(exc, "stderr", None) or "" error_dict = { - 'failed': True, - 'succeeded': False, - 'timeout': timeout, - 'return_code': return_code, - 'stdout': stdout, - 'stderr': stderr, - 'error': error_message, - 'traceback': traceback_message, + "failed": True, + "succeeded": False, + "timeout": timeout, + "return_code": return_code, + "stdout": stdout, + "stderr": stderr, + "error": error_message, + "traceback": traceback_message, } return error_dict def __repr__(self): - return ('' % - (repr(self._hosts), self._ssh_user, id(self))) + return "" % ( + repr(self._hosts), + self._ssh_user, + id(self), + ) diff --git a/st2common/st2common/runners/paramiko_ssh.py b/st2common/st2common/runners/paramiko_ssh.py index c42c4eb89f..7530a532d9 100644 --- a/st2common/st2common/runners/paramiko_ssh.py +++ b/st2common/st2common/runners/paramiko_ssh.py @@ -35,14 +35,13 @@ from st2common.util.misc import strip_shell_chars from st2common.util.misc import sanitize_output from st2common.util.shell import quote_unix -from st2common.constants.runners import DEFAULT_SSH_PORT, REMOTE_RUNNER_PRIVATE_KEY_HEADER +from st2common.constants.runners import ( + DEFAULT_SSH_PORT, + REMOTE_RUNNER_PRIVATE_KEY_HEADER, +) from st2common.util import concurrency -__all__ = [ - 'ParamikoSSHClient', - - 'SSHCommandTimeoutError' -] +__all__ = ["ParamikoSSHClient", "SSHCommandTimeoutError"] class SSHCommandTimeoutError(Exception): @@ -63,13 +62,21 @@ def __init__(self, cmd, timeout, ssh_connect_timeout, stdout=None, stderr=None): self.ssh_connect_timeout = ssh_connect_timeout self.stdout = stdout self.stderr = stderr - self.message = ('Command didn\'t finish in %s seconds or the SSH connection ' - 'did not succeed in %s seconds' % (timeout, ssh_connect_timeout)) + self.message = ( + "Command didn't finish in %s seconds or the SSH connection " + "did not succeed in %s seconds" % (timeout, ssh_connect_timeout) + ) super(SSHCommandTimeoutError, self).__init__(self.message) def __repr__(self): - return ('' % - (self.cmd, self.timeout, self.ssh_connect_timeout)) + return ( + '' + % ( + self.cmd, + self.timeout, + self.ssh_connect_timeout, + ) + ) def __str__(self): return self.message @@ -86,9 +93,20 @@ class ParamikoSSHClient(object): # How long to sleep while waiting for command to finish to prevent busy waiting SLEEP_DELAY = 0.2 - def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None, - bastion_host=None, key_files=None, key_material=None, timeout=None, - passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None): + def __init__( + self, + hostname, + port=DEFAULT_SSH_PORT, + username=None, + password=None, + bastion_host=None, + key_files=None, + key_material=None, + timeout=None, + passphrase=None, + handle_stdout_line_func=None, + handle_stderr_line_func=None, + ): """ Authentication is always attempted in the following order: @@ -114,8 +132,7 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None self._handle_stderr_line_func = handle_stderr_line_func self.ssh_config_file = os.path.expanduser( - cfg.CONF.ssh_runner.ssh_config_file_path or - '~/.ssh/config' + cfg.CONF.ssh_runner.ssh_config_file_path or "~/.ssh/config" ) if self.timeout and int(self.ssh_connect_timeout) > int(self.timeout) - 2: @@ -140,14 +157,16 @@ def connect(self): :rtype: ``bool`` """ if self.bastion_host: - self.logger.debug('Bastion host specified, connecting') + self.logger.debug("Bastion host specified, connecting") self.bastion_client = self._connect(host=self.bastion_host) transport = self.bastion_client.get_transport() real_addr = (self.hostname, self.port) # fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour # see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad - local_addr = ('', 0) - self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr) + local_addr = ("", 0) + self.bastion_socket = transport.open_channel( + "direct-tcpip", real_addr, local_addr + ) self.client = self._connect(host=self.hostname, socket=self.bastion_socket) return True @@ -173,17 +192,24 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): """ if not local_path or not remote_path: - raise Exception('Need both local_path and remote_path. local: %s, remote: %s' % - local_path, remote_path) + raise Exception( + "Need both local_path and remote_path. local: %s, remote: %s" + % local_path, + remote_path, + ) local_path = quote_unix(local_path) remote_path = quote_unix(remote_path) - extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode, - '_mirror_local_mode': mirror_local_mode} - self.logger.debug('Uploading file', extra=extra) + extra = { + "_local_path": local_path, + "_remote_path": remote_path, + "_mode": mode, + "_mirror_local_mode": mirror_local_mode, + } + self.logger.debug("Uploading file", extra=extra) if not os.path.exists(local_path): - raise Exception('Path %s does not exist locally.' % local_path) + raise Exception("Path %s does not exist locally." % local_path) rattrs = self.sftp.put(local_path, remote_path) @@ -199,7 +225,7 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): remote_mode = rattrs.st_mode # Only bitshift if we actually got an remote_mode if remote_mode is not None: - remote_mode = (remote_mode & 0o7777) + remote_mode = remote_mode & 0o7777 if local_mode != remote_mode: self.sftp.chmod(remote_path, local_mode) @@ -225,9 +251,13 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): :rtype: ``list`` of ``str`` """ - extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode, - '_mirror_local_mode': mirror_local_mode} - self.logger.debug('Uploading dir', extra=extra) + extra = { + "_local_path": local_path, + "_remote_path": remote_path, + "_mode": mode, + "_mirror_local_mode": mirror_local_mode, + } + self.logger.debug("Uploading dir", extra=extra) if os.path.basename(local_path): strip = os.path.dirname(local_path) @@ -237,10 +267,10 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): remote_paths = [] for context, dirs, files in os.walk(local_path): - rcontext = context.replace(strip, '', 1) + rcontext = context.replace(strip, "", 1) # normalize pathname separators with POSIX separator - rcontext = rcontext.replace(os.sep, '/') - rcontext = rcontext.lstrip('/') + rcontext = rcontext.replace(os.sep, "/") + rcontext = rcontext.lstrip("/") rcontext = posixpath.join(remote_path, rcontext) if not self.exists(rcontext): @@ -255,8 +285,12 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): local_path = os.path.join(context, f) n = posixpath.join(rcontext, f) # Note that quote_unix is done by put anyways. - p = self.put(local_path=local_path, remote_path=n, - mirror_local_mode=mirror_local_mode, mode=mode) + p = self.put( + local_path=local_path, + remote_path=n, + mirror_local_mode=mirror_local_mode, + mode=mode, + ) remote_paths.append(p) return remote_paths @@ -290,8 +324,8 @@ def mkdir(self, dir_path): """ dir_path = quote_unix(dir_path) - extra = {'_dir_path': dir_path} - self.logger.debug('mkdir', extra=extra) + extra = {"_dir_path": dir_path} + self.logger.debug("mkdir", extra=extra) return self.sftp.mkdir(dir_path) def delete_file(self, path): @@ -307,8 +341,8 @@ def delete_file(self, path): """ path = quote_unix(path) - extra = {'_path': path} - self.logger.debug('Deleting file', extra=extra) + extra = {"_path": path} + self.logger.debug("Deleting file", extra=extra) self.sftp.unlink(path) return True @@ -331,15 +365,15 @@ def delete_dir(self, path, force=False, timeout=None): """ path = quote_unix(path) - extra = {'_path': path} + extra = {"_path": path} if force: - command = 'rm -rf %s' % path - extra['_command'] = command - extra['_force'] = force - self.logger.debug('Deleting dir', extra=extra) + command = "rm -rf %s" % path + extra["_command"] = command + extra["_force"] = force + self.logger.debug("Deleting dir", extra=extra) return self.run(command, timeout=timeout) - self.logger.debug('Deleting dir', extra=extra) + self.logger.debug("Deleting dir", extra=extra) return self.sftp.rmdir(path) def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): @@ -359,8 +393,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): if quote: cmd = quote_unix(cmd) - extra = {'_cmd': cmd} - self.logger.info('Executing command', extra=extra) + extra = {"_cmd": cmd} + self.logger.info("Executing command", extra=extra) # Use the system default buffer size bufsize = -1 @@ -369,7 +403,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): chan = transport.open_session() start_time = time.time() - if cmd.startswith('sudo'): + if cmd.startswith("sudo"): # Note that fabric does this as well. If you set pty, stdout and stderr # streams will be combined into one. # NOTE: If pty is used, every new line character \n will be converted to \r\n which @@ -386,7 +420,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): # Create a stdin file and immediately close it to prevent any # interactive script from hanging the process. - stdin = chan.makefile('wb', bufsize) + stdin = chan.makefile("wb", bufsize) stdin.close() # Receive all the output @@ -400,12 +434,14 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): exit_status_ready = chan.exit_status_ready() if exit_status_ready: - stdout_data = self._consume_stdout(chan=chan, - call_line_handler_func=call_line_handler_func) + stdout_data = self._consume_stdout( + chan=chan, call_line_handler_func=call_line_handler_func + ) stdout_data = stdout_data.getvalue() - stderr_data = self._consume_stderr(chan=chan, - call_line_handler_func=call_line_handler_func) + stderr_data = self._consume_stderr( + chan=chan, call_line_handler_func=call_line_handler_func + ) stderr_data = stderr_data.getvalue() stdout.write(stdout_data) @@ -413,7 +449,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): while not exit_status_ready: current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if timeout and (elapsed_time > timeout): # TODO: Is this the right way to clean up? @@ -421,16 +457,22 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty) stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty) - raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, - ssh_connect_timeout=self.ssh_connect_timeout, - stdout=stdout, stderr=stderr) - - stdout_data = self._consume_stdout(chan=chan, - call_line_handler_func=call_line_handler_func) + raise SSHCommandTimeoutError( + cmd=cmd, + timeout=timeout, + ssh_connect_timeout=self.ssh_connect_timeout, + stdout=stdout, + stderr=stderr, + ) + + stdout_data = self._consume_stdout( + chan=chan, call_line_handler_func=call_line_handler_func + ) stdout_data = stdout_data.getvalue() - stderr_data = self._consume_stderr(chan=chan, - call_line_handler_func=call_line_handler_func) + stderr_data = self._consume_stderr( + chan=chan, call_line_handler_func=call_line_handler_func + ) stderr_data = stderr_data.getvalue() stdout.write(stdout_data) @@ -453,8 +495,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty) stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty) - extra = {'_status': status, '_stdout': stdout, '_stderr': stderr} - self.logger.debug('Command finished', extra=extra) + extra = {"_status": status, "_stdout": stdout, "_stderr": stderr} + self.logger.debug("Command finished", extra=extra) return [stdout, stderr, status] @@ -499,7 +541,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False): data = chan.recv(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -512,7 +554,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False): data = chan.recv(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -520,14 +562,14 @@ def _consume_stdout(self, chan, call_line_handler_func=False): if self._handle_stdout_line_func and call_line_handler_func: data = strip_shell_chars(stdout.getvalue()) - lines = data.split('\n') + lines = data.split("\n") lines = [line for line in lines if line] for line in lines: # Note: If this function performs network operating no sleep is # needed, otherwise if a long blocking operating is performed, # sleep is recommended to yield and prevent from busy looping - self._handle_stdout_line_func(line=line + '\n') + self._handle_stdout_line_func(line=line + "\n") stdout.seek(0) @@ -545,7 +587,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False): data = chan.recv_stderr(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -558,7 +600,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False): data = chan.recv_stderr(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -566,14 +608,14 @@ def _consume_stderr(self, chan, call_line_handler_func=False): if self._handle_stderr_line_func and call_line_handler_func: data = strip_shell_chars(stderr.getvalue()) - lines = data.split('\n') + lines = data.split("\n") lines = [line for line in lines if line] for line in lines: # Note: If this function performs network operating no sleep is # needed, otherwise if a long blocking operating is performed, # sleep is recommended to yield and prevent from busy looping - self._handle_stderr_line_func(line=line + '\n') + self._handle_stderr_line_func(line=line + "\n") stderr.seek(0) @@ -581,9 +623,9 @@ def _consume_stderr(self, chan, call_line_handler_func=False): def _get_decoded_data(self, data): try: - return data.decode('utf-8') + return data.decode("utf-8") except: - self.logger.exception('Non UTF-8 character found in data: %s', data) + self.logger.exception("Non UTF-8 character found in data: %s", data) raise def _get_pkey_object(self, key_material, passphrase): @@ -604,13 +646,17 @@ def _get_pkey_object(self, key_material, passphrase): # exception letting the user know we expect the contents a not a path. # Note: We do it here and not up the stack to avoid false positives. contains_header = REMOTE_RUNNER_PRIVATE_KEY_HEADER in key_material.lower() - if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1): - msg = ('"private_key" parameter needs to contain private key data / content and not ' - 'a path') + if not contains_header and ( + key_material.count("/") >= 1 or key_material.count("\\") >= 1 + ): + msg = ( + '"private_key" parameter needs to contain private key data / content and not ' + "a path" + ) elif passphrase: - msg = 'Invalid passphrase or invalid/unsupported key type' + msg = "Invalid passphrase or invalid/unsupported key type" else: - msg = 'Invalid or unsupported key type' + msg = "Invalid or unsupported key type" raise paramiko.ssh_exception.SSHException(msg) @@ -636,19 +682,23 @@ def _connect(self, host, socket=None): :rtype: :class:`paramiko.SSHClient` """ - conninfo = {'hostname': host, - 'allow_agent': False, - 'look_for_keys': False, - 'timeout': self.ssh_connect_timeout} + conninfo = { + "hostname": host, + "allow_agent": False, + "look_for_keys": False, + "timeout": self.ssh_connect_timeout, + } ssh_config_file_info = {} if cfg.CONF.ssh_runner.use_ssh_config: ssh_config_file_info = self._get_ssh_config_for_host(host) - ssh_config_username = ssh_config_file_info.get('user', None) - ssh_config_port = ssh_config_file_info.get('port', None) + ssh_config_username = ssh_config_file_info.get("user", None) + ssh_config_port = ssh_config_file_info.get("port", None) - self.username = (self.username or ssh_config_username or cfg.CONF.system_user.user) + self.username = ( + self.username or ssh_config_username or cfg.CONF.system_user.user + ) # If a custom non-default port is provided in the SSH config file we use that over the # default port value provided via runner parameter @@ -660,78 +710,92 @@ def _connect(self, host, socket=None): # If both key file and key material are provided as action parameters, # throw an error informing user only one is required. if self.key_files and self.key_material: - msg = ('key_files and key_material arguments are mutually exclusive. Supply only one.') + msg = "key_files and key_material arguments are mutually exclusive. Supply only one." raise ValueError(msg) # If neither key material nor password is provided, only then we look at key file and decide # if we want to use the user supplied one or the one in SSH config. if not self.key_material and not self.password: - self.key_files = (self.key_files or ssh_config_file_info.get('identityfile', None) or - cfg.CONF.system_user.ssh_key_file) + self.key_files = ( + self.key_files + or ssh_config_file_info.get("identityfile", None) + or cfg.CONF.system_user.ssh_key_file + ) if self.passphrase and not (self.key_files or self.key_material): - raise ValueError('passphrase should accompany private key material') + raise ValueError("passphrase should accompany private key material") credentials_provided = self.password or self.key_files or self.key_material if not credentials_provided: - msg = ('Either password or key file location or key material should be supplied ' + - 'for action. You can also add an entry for host %s in SSH config file %s.' % - (host, self.ssh_config_file)) + msg = ( + "Either password or key file location or key material should be supplied " + + "for action. You can also add an entry for host %s in SSH config file %s." + % (host, self.ssh_config_file) + ) raise ValueError(msg) - conninfo['username'] = self.username - conninfo['port'] = self.port + conninfo["username"] = self.username + conninfo["port"] = self.port if self.password: - conninfo['password'] = self.password + conninfo["password"] = self.password if self.key_files: - conninfo['key_filename'] = self.key_files + conninfo["key_filename"] = self.key_files passphrase_reqd = self._is_key_file_needs_passphrase(self.key_files) if passphrase_reqd and not self.passphrase: - msg = ('Private key file %s is passphrase protected. Supply a passphrase.' % - self.key_files) + msg = ( + "Private key file %s is passphrase protected. Supply a passphrase." + % self.key_files + ) raise paramiko.ssh_exception.PasswordRequiredException(msg) if self.passphrase: # Optional passphrase for unlocking the private key - conninfo['password'] = self.passphrase + conninfo["password"] = self.passphrase if self.key_material: - conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material, - passphrase=self.passphrase) + conninfo["pkey"] = self._get_pkey_object( + key_material=self.key_material, passphrase=self.passphrase + ) if not self.password and not (self.key_files or self.key_material): - conninfo['allow_agent'] = True - conninfo['look_for_keys'] = True - - extra = {'_hostname': host, '_port': self.port, - '_username': self.username, '_timeout': self.ssh_connect_timeout} - self.logger.debug('Connecting to server', extra=extra) - - self.socket = socket or ssh_config_file_info.get('sock', None) + conninfo["allow_agent"] = True + conninfo["look_for_keys"] = True + + extra = { + "_hostname": host, + "_port": self.port, + "_username": self.username, + "_timeout": self.ssh_connect_timeout, + } + self.logger.debug("Connecting to server", extra=extra) + + self.socket = socket or ssh_config_file_info.get("sock", None) if self.socket: - conninfo['sock'] = socket + conninfo["sock"] = socket client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - extra = {'_conninfo': conninfo} - self.logger.debug('Connection info', extra=extra) + extra = {"_conninfo": conninfo} + self.logger.debug("Connection info", extra=extra) try: client.connect(**conninfo) except SSHException as e: paramiko_msg = six.text_type(e) - if conninfo.get('password', None): - conninfo['password'] = '' + if conninfo.get("password", None): + conninfo["password"] = "" - msg = ('Error connecting to host %s ' % host + - 'with connection parameters %s.' % conninfo + - 'Paramiko error: %s.' % paramiko_msg) + msg = ( + "Error connecting to host %s " % host + + "with connection parameters %s." % conninfo + + "Paramiko error: %s." % paramiko_msg + ) raise SSHException(msg) return client @@ -744,25 +808,29 @@ def _get_ssh_config_for_host(self, host): with open(self.ssh_config_file) as f: ssh_config_parser.parse(f) except IOError as e: - raise Exception('Error accessing ssh config file %s. Code: %s Reason %s' % - (self.ssh_config_file, e.errno, e.strerror)) + raise Exception( + "Error accessing ssh config file %s. Code: %s Reason %s" + % (self.ssh_config_file, e.errno, e.strerror) + ) ssh_config = ssh_config_parser.lookup(host) - self.logger.info('Parsed SSH config file contents: %s', ssh_config) + self.logger.info("Parsed SSH config file contents: %s", ssh_config) if ssh_config: - for k in ('hostname', 'user', 'port'): + for k in ("hostname", "user", "port"): if k in ssh_config: ssh_config_info[k] = ssh_config[k] - if 'identityfile' in ssh_config: - key_file = ssh_config['identityfile'] + if "identityfile" in ssh_config: + key_file = ssh_config["identityfile"] if type(key_file) is list: key_file = key_file[0] - ssh_config_info['identityfile'] = key_file + ssh_config_info["identityfile"] = key_file - if 'proxycommand' in ssh_config: - ssh_config_info['sock'] = paramiko.ProxyCommand(ssh_config['proxycommand']) + if "proxycommand" in ssh_config: + ssh_config_info["sock"] = paramiko.ProxyCommand( + ssh_config["proxycommand"] + ) return ssh_config_info @@ -779,5 +847,9 @@ def _is_key_file_needs_passphrase(file): return False def __repr__(self): - return ('' % - (self.hostname, self.port, self.username, id(self))) + return "" % ( + self.hostname, + self.port, + self.username, + id(self), + ) diff --git a/st2common/st2common/runners/paramiko_ssh_runner.py b/st2common/st2common/runners/paramiko_ssh_runner.py index 6c1ab053e9..f41882935f 100644 --- a/st2common/st2common/runners/paramiko_ssh_runner.py +++ b/st2common/st2common/runners/paramiko_ssh_runner.py @@ -29,34 +29,31 @@ from st2common.exceptions.actionrunner import ActionRunnerPreRunError from st2common.services.action import store_execution_output_data -__all__ = [ - 'BaseParallelSSHRunner' -] +__all__ = ["BaseParallelSSHRunner"] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_HOSTS = 'hosts' -RUNNER_USERNAME = 'username' -RUNNER_PASSWORD = 'password' -RUNNER_PRIVATE_KEY = 'private_key' -RUNNER_PARALLEL = 'parallel' -RUNNER_SUDO = 'sudo' -RUNNER_SUDO_PASSWORD = 'sudo_password' -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_REMOTE_DIR = 'dir' -RUNNER_COMMAND = 'cmd' -RUNNER_CWD = 'cwd' -RUNNER_ENV = 'env' -RUNNER_KWARG_OP = 'kwarg_op' -RUNNER_TIMEOUT = 'timeout' -RUNNER_SSH_PORT = 'port' -RUNNER_BASTION_HOST = 'bastion_host' -RUNNER_PASSPHRASE = 'passphrase' +RUNNER_HOSTS = "hosts" +RUNNER_USERNAME = "username" +RUNNER_PASSWORD = "password" +RUNNER_PRIVATE_KEY = "private_key" +RUNNER_PARALLEL = "parallel" +RUNNER_SUDO = "sudo" +RUNNER_SUDO_PASSWORD = "sudo_password" +RUNNER_ON_BEHALF_USER = "user" +RUNNER_REMOTE_DIR = "dir" +RUNNER_COMMAND = "cmd" +RUNNER_CWD = "cwd" +RUNNER_ENV = "env" +RUNNER_KWARG_OP = "kwarg_op" +RUNNER_TIMEOUT = "timeout" +RUNNER_SSH_PORT = "port" +RUNNER_BASTION_HOST = "bastion_host" +RUNNER_PASSPHRASE = "passphrase" class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin): - def __init__(self, runner_id): super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id) self._hosts = None @@ -68,7 +65,7 @@ def __init__(self, runner_id): self._password = None self._private_key = None self._passphrase = None - self._kwarg_op = '--' + self._kwarg_op = "--" self._cwd = None self._env = None self._ssh_port = None @@ -83,13 +80,16 @@ def __init__(self, runner_id): def pre_run(self): super(BaseParallelSSHRunner, self).pre_run() - LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', - self.liveaction_id) - hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',') + LOG.debug( + 'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', + self.liveaction_id, + ) + hosts = self.runner_parameters.get(RUNNER_HOSTS, "").split(",") self._hosts = [h.strip() for h in hosts if len(h) > 0] if len(self._hosts) < 1: - raise ActionRunnerPreRunError('No hosts specified to run action for action %s.' - % self.liveaction_id) + raise ActionRunnerPreRunError( + "No hosts specified to run action for action %s." % self.liveaction_id + ) self._username = self.runner_parameters.get(RUNNER_USERNAME, None) self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None) @@ -103,85 +103,105 @@ def pre_run(self): self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None) if self.context: - self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user) + self._on_behalf_user = self.context.get( + RUNNER_ON_BEHALF_USER, self._on_behalf_user + ) self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) - self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--') - self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, - REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT) + self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, "--") + self._timeout = self.runner_parameters.get( + RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT + ) self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None) - LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', - self.runner_id, self.liveaction_id) + LOG.info( + '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', + self.runner_id, + self.liveaction_id, + ) concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1 if concurrency > self._max_concurrency: - LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency) + LOG.debug("Limiting parallel SSH concurrency to %d.", concurrency) concurrency = self._max_concurrency client_kwargs = { - 'hosts': self._hosts, - 'user': self._username, - 'port': self._ssh_port, - 'concurrency': concurrency, - 'bastion_host': self._bastion_host, - 'raise_on_any_error': False, - 'connect': True + "hosts": self._hosts, + "user": self._username, + "port": self._ssh_port, + "concurrency": concurrency, + "bastion_host": self._bastion_host, + "raise_on_any_error": False, + "connect": True, } def make_store_stdout_line_func(execution_db, action_db): def store_stdout_line(line): if cfg.CONF.actionrunner.stream_output: - store_execution_output_data(execution_db=execution_db, action_db=action_db, - data=line, output_type='stdout') + store_execution_output_data( + execution_db=execution_db, + action_db=action_db, + data=line, + output_type="stdout", + ) return store_stdout_line def make_store_stderr_line_func(execution_db, action_db): def store_stderr_line(line): if cfg.CONF.actionrunner.stream_output: - store_execution_output_data(execution_db=execution_db, action_db=action_db, - data=line, output_type='stderr') + store_execution_output_data( + execution_db=execution_db, + action_db=action_db, + data=line, + output_type="stderr", + ) return store_stderr_line - handle_stdout_line_func = make_store_stdout_line_func(execution_db=self.execution, - action_db=self.action) - handle_stderr_line_func = make_store_stderr_line_func(execution_db=self.execution, - action_db=self.action) + handle_stdout_line_func = make_store_stdout_line_func( + execution_db=self.execution, action_db=self.action + ) + handle_stderr_line_func = make_store_stderr_line_func( + execution_db=self.execution, action_db=self.action + ) if len(self._hosts) == 1: # We only support streaming output when running action on one host. That is because # the action output is tied to a particulat execution. User can still achieve output # streaming for multiple hosts by running one execution per host. - client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func - client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func + client_kwargs["handle_stdout_line_func"] = handle_stdout_line_func + client_kwargs["handle_stderr_line_func"] = handle_stderr_line_func else: - LOG.debug('Real-time action output streaming is disabled, because action is running ' - 'on more than one host') + LOG.debug( + "Real-time action output streaming is disabled, because action is running " + "on more than one host" + ) if self._password: - client_kwargs['password'] = self._password + client_kwargs["password"] = self._password elif self._private_key: # Determine if the private_key is a path to the key file or the raw key material - is_key_material = self._is_private_key_material(private_key=self._private_key) + is_key_material = self._is_private_key_material( + private_key=self._private_key + ) if is_key_material: # Raw key material - client_kwargs['pkey_material'] = self._private_key + client_kwargs["pkey_material"] = self._private_key else: # Assume it's a path to the key file, verify the file exists - client_kwargs['pkey_file'] = self._private_key + client_kwargs["pkey_file"] = self._private_key if self._passphrase: - client_kwargs['passphrase'] = self._passphrase + client_kwargs["passphrase"] = self._passphrase else: # Default to stanley key file specified in the config - client_kwargs['pkey_file'] = self._ssh_key_file + client_kwargs["pkey_file"] = self._ssh_key_file if self._sudo_password: - client_kwargs['sudo_password'] = True + client_kwargs["sudo_password"] = True self._parallel_ssh_client = ParallelSSHClient(**client_kwargs) @@ -213,21 +233,22 @@ def _get_env_vars(self): @staticmethod def _get_result_status(result, allow_partial_failure): - if 'error' in result and 'traceback' in result: + if "error" in result and "traceback" in result: # Assume this is a global failure where the result dictionary doesn't contain entry # per host timeout = False - success = result.get('succeeded', False) - status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success, - timeout=timeout) + success = result.get("succeeded", False) + status = BaseParallelSSHRunner._get_status_for_success_and_timeout( + success=success, timeout=timeout + ) return status success = not allow_partial_failure timeout = True for r in six.itervalues(result): - r_succeess = r.get('succeeded', False) if r else False - r_timeout = r.get('timeout', False) if r else False + r_succeess = r.get("succeeded", False) if r else False + r_timeout = r.get("timeout", False) if r else False timeout &= r_timeout @@ -240,8 +261,9 @@ def _get_result_status(result, allow_partial_failure): if not success: break - status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success, - timeout=timeout) + status = BaseParallelSSHRunner._get_status_for_success_and_timeout( + success=success, timeout=timeout + ) return status diff --git a/st2common/st2common/runners/utils.py b/st2common/st2common/runners/utils.py index 82f1a3477c..70f7139f3f 100644 --- a/st2common/st2common/runners/utils.py +++ b/st2common/st2common/runners/utils.py @@ -27,14 +27,11 @@ __all__ = [ - 'PackConfigDict', - - 'get_logger_for_python_runner_action', - 'get_action_class_instance', - - 'make_read_and_store_stream_func', - - 'invoke_post_run', + "PackConfigDict", + "get_logger_for_python_runner_action", + "get_action_class_instance", + "make_read_and_store_stream_func", + "invoke_post_run", ] LOG = logging.getLogger(__name__) @@ -61,6 +58,7 @@ class PackConfigDict(dict): This class throws a user-friendly exception in case user tries to access config item which doesn't exist in the dict. """ + def __init__(self, pack_name, *args): super(PackConfigDict, self).__init__(*args) self._pack_name = pack_name @@ -72,8 +70,8 @@ def __getitem__(self, key): # Note: We use late import to avoid performance overhead from oslo_config import cfg - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, self._pack_name + '.yaml') + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, self._pack_name + ".yaml") msg = CONFIG_MISSING_ITEM_ERROR % (self._pack_name, key, config_path) raise ValueError(msg) @@ -83,11 +81,11 @@ def __setitem__(self, key, value): super(PackConfigDict, self).__setitem__(key, value) -def get_logger_for_python_runner_action(action_name, log_level='debug'): +def get_logger_for_python_runner_action(action_name, log_level="debug"): """ Set up a logger which logs all the messages with level DEBUG and above to stderr. """ - logger_name = 'actions.python.%s' % (action_name) + logger_name = "actions.python.%s" % (action_name) if logger_name not in LOGGERS: level_name = log_level.upper() @@ -97,7 +95,7 @@ def get_logger_for_python_runner_action(action_name, log_level='debug'): console = stdlib_logging.StreamHandler() console.setLevel(log_level_constant) - formatter = stdlib_logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + formatter = stdlib_logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") console.setFormatter(formatter) logger.addHandler(console) logger.setLevel(log_level_constant) @@ -123,8 +121,8 @@ def get_action_class_instance(action_cls, config=None, action_service=None): :type action_service: :class:`ActionService` """ kwargs = {} - kwargs['config'] = config - kwargs['action_service'] = action_service + kwargs["config"] = config + kwargs["action_service"] = action_service # Note: This is done for backward compatibility reasons. We first try to pass # "action_service" argument to the action class constructor, but if that doesn't work (e.g. old @@ -133,13 +131,15 @@ def get_action_class_instance(action_cls, config=None, action_service=None): try: action_instance = action_cls(**kwargs) except TypeError as e: - if 'unexpected keyword argument \'action_service\'' not in six.text_type(e): + if "unexpected keyword argument 'action_service'" not in six.text_type(e): raise e - LOG.debug('Action class (%s) constructor doesn\'t take "action_service" argument, ' - 'falling back to late assignment...' % (action_cls.__class__.__name__)) + LOG.debug( + 'Action class (%s) constructor doesn\'t take "action_service" argument, ' + "falling back to late assignment..." % (action_cls.__class__.__name__) + ) - action_service = kwargs.pop('action_service', None) + action_service = kwargs.pop("action_service", None) action_instance = action_cls(**kwargs) action_instance.action_service = action_service @@ -166,7 +166,7 @@ def read_and_store_stream(stream, buff): break if isinstance(line, six.binary_type): - line = line.decode('utf-8') + line = line.decode("utf-8") buff.write(line) @@ -175,7 +175,9 @@ def read_and_store_stream(stream, buff): continue if cfg.CONF.actionrunner.stream_output: - store_data_func(execution_db=execution_db, action_db=action_db, data=line) + store_data_func( + execution_db=execution_db, action_db=action_db, data=line + ) except RuntimeError: # process was terminated abruptly pass @@ -193,31 +195,40 @@ def invoke_post_run(liveaction_db, action_db=None): from st2common.util import action_db as action_db_utils from st2common.content import utils as content_utils - LOG.info('Invoking post run for action execution %s.', liveaction_db.id) + LOG.info("Invoking post run for action execution %s.", liveaction_db.id) # Identify action and runner. if not action_db: action_db = action_db_utils.get_action_by_ref(liveaction_db.action) if not action_db: - LOG.error('Unable to invoke post run. Action %s no longer exists.', liveaction_db.action) + LOG.error( + "Unable to invoke post run. Action %s no longer exists.", + liveaction_db.action, + ) return - LOG.info('Action execution %s runs %s of runner type %s.', - liveaction_db.id, action_db.name, action_db.runner_type['name']) + LOG.info( + "Action execution %s runs %s of runner type %s.", + liveaction_db.id, + action_db.name, + action_db.runner_type["name"], + ) # Get instance of the action runner and related configuration. - runner_type_db = action_db_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_db_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) runner = runners.get_runner(name=runner_type_db.name) entry_point = content_utils.get_entry_point_abs_path( - pack=action_db.pack, - entry_point=action_db.entry_point) + pack=action_db.pack, entry_point=action_db.entry_point + ) libs_dir_path = content_utils.get_action_libs_abs_path( - pack=action_db.pack, - entry_point=action_db.entry_point) + pack=action_db.pack, entry_point=action_db.entry_point + ) # Configure the action runner. runner.runner_type_db = runner_type_db @@ -226,8 +237,8 @@ def invoke_post_run(liveaction_db, action_db=None): runner.liveaction = liveaction_db runner.liveaction_id = str(liveaction_db.id) runner.entry_point = entry_point - runner.context = getattr(liveaction_db, 'context', dict()) - runner.callback = getattr(liveaction_db, 'callback', dict()) + runner.context = getattr(liveaction_db, "context", dict()) + runner.callback = getattr(liveaction_db, "callback", dict()) runner.libs_dir_path = libs_dir_path # Invoke the post_run method. diff --git a/st2common/st2common/script_setup.py b/st2common/st2common/script_setup.py index 03be4b4427..0abb7e8269 100644 --- a/st2common/st2common/script_setup.py +++ b/st2common/st2common/script_setup.py @@ -32,13 +32,7 @@ from st2common.logging.filters import LogLevelFilter from st2common.transport.bootstrap_utils import register_exchanges_with_retry -__all__ = [ - 'setup', - 'teardown', - - 'db_setup', - 'db_teardown' -] +__all__ = ["setup", "teardown", "db_setup", "db_teardown"] LOG = logging.getLogger(__name__) @@ -47,11 +41,15 @@ def register_common_cli_options(): """ Register common CLI options. """ - cfg.CONF.register_cli_opt(cfg.BoolOpt('verbose', short='v', default=False)) + cfg.CONF.register_cli_opt(cfg.BoolOpt("verbose", short="v", default=False)) -def setup(config, setup_db=True, register_mq_exchanges=True, - register_internal_trigger_types=False): +def setup( + config, + setup_db=True, + register_mq_exchanges=True, + register_internal_trigger_types=False, +): """ Common setup function. @@ -76,7 +74,9 @@ def setup(config, setup_db=True, register_mq_exchanges=True, # Set up logging log_level = stdlib_logging.DEBUG - stdlib_logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s', level=log_level) + stdlib_logging.basicConfig( + format="%(asctime)s %(levelname)s [-] %(message)s", level=log_level + ) if not cfg.CONF.verbose: # Note: We still want to print things at the following log levels: INFO, ERROR, CRITICAL diff --git a/st2common/st2common/service_setup.py b/st2common/st2common/service_setup.py index 14cd708cca..bd01d205e7 100644 --- a/st2common/st2common/service_setup.py +++ b/st2common/st2common/service_setup.py @@ -53,22 +53,29 @@ __all__ = [ - 'setup', - 'teardown', - - 'db_setup', - 'db_teardown', - - 'register_service_in_service_registry' + "setup", + "teardown", + "db_setup", + "db_teardown", + "register_service_in_service_registry", ] LOG = logging.getLogger(__name__) -def setup(service, config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=True, register_runners=True, service_registry=False, - capabilities=None, config_args=None): +def setup( + service, + config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=True, + register_runners=True, + service_registry=False, + capabilities=None, + config_args=None, +): """ Common setup function. @@ -99,29 +106,38 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, else: config.parse_args() - version = '%s.%s.%s' % (sys.version_info[0], sys.version_info[1], sys.version_info[2]) - LOG.debug('Using Python: %s (%s)' % (version, sys.executable)) + version = "%s.%s.%s" % ( + sys.version_info[0], + sys.version_info[1], + sys.version_info[2], + ) + LOG.debug("Using Python: %s (%s)" % (version, sys.executable)) config_file_paths = cfg.CONF.config_file config_file_paths = [os.path.abspath(path) for path in config_file_paths] - LOG.debug('Using config files: %s', ','.join(config_file_paths)) + LOG.debug("Using config files: %s", ",".join(config_file_paths)) # Setup logging. logging_config_path = config.get_logging_config_path() logging_config_path = os.path.abspath(logging_config_path) - LOG.debug('Using logging config: %s', logging_config_path) + LOG.debug("Using logging config: %s", logging_config_path) - is_debug_enabled = (cfg.CONF.debug or cfg.CONF.system.debug) + is_debug_enabled = cfg.CONF.debug or cfg.CONF.system.debug try: - logging.setup(logging_config_path, redirect_stderr=cfg.CONF.log.redirect_stderr, - excludes=cfg.CONF.log.excludes) + logging.setup( + logging_config_path, + redirect_stderr=cfg.CONF.log.redirect_stderr, + excludes=cfg.CONF.log.excludes, + ) except KeyError as e: tb_msg = traceback.format_exc() - if 'log.setLevel' in tb_msg: - msg = 'Invalid log level selected. Log level names need to be all uppercase.' - msg += '\n\n' + getattr(e, 'message', six.text_type(e)) + if "log.setLevel" in tb_msg: + msg = ( + "Invalid log level selected. Log level names need to be all uppercase." + ) + msg += "\n\n" + getattr(e, "message", six.text_type(e)) raise KeyError(msg) else: raise e @@ -134,10 +150,14 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, # duplicate "AUDIT" messages in production deployments where default service log level is # set to "INFO" and we already log messages with level AUDIT to a special dedicated log # file. - ignore_audit_log_messages = (handler.level >= stdlib_logging.INFO and - handler.level < stdlib_logging.AUDIT) + ignore_audit_log_messages = ( + handler.level >= stdlib_logging.INFO + and handler.level < stdlib_logging.AUDIT + ) if not is_debug_enabled and ignore_audit_log_messages: - LOG.debug('Excluding log messages with level "AUDIT" for handler "%s"' % (handler)) + LOG.debug( + 'Excluding log messages with level "AUDIT" for handler "%s"' % (handler) + ) handler.addFilter(LogLevelFilter(log_levels=exclude_log_levels)) if not is_debug_enabled: @@ -184,8 +204,9 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, # Register service in the service registry if cfg.CONF.coordination.service_registry and service_registry: # NOTE: It's important that we pass start_heart=True to start the hearbeat process - register_service_in_service_registry(service=service, capabilities=capabilities, - start_heart=True) + register_service_in_service_registry( + service=service, capabilities=capabilities, start_heart=True + ) if sys.version_info[0] == 2: LOG.warning(PYTHON2_DEPRECATION) @@ -220,7 +241,7 @@ def register_service_in_service_registry(service, capabilities=None, start_heart # 1. Create a group with the name of the service if not isinstance(service, six.binary_type): - group_id = service.encode('utf-8') + group_id = service.encode("utf-8") else: group_id = service @@ -231,10 +252,12 @@ def register_service_in_service_registry(service, capabilities=None, start_heart # Include common capabilities such as hostname and process ID proc_info = system_info.get_process_info() - capabilities['hostname'] = proc_info['hostname'] - capabilities['pid'] = proc_info['pid'] + capabilities["hostname"] = proc_info["hostname"] + capabilities["pid"] = proc_info["pid"] # 1. Join the group as a member - LOG.debug('Joining service registry group "%s" as member_id "%s" with capabilities "%s"' % - (group_id, member_id, capabilities)) + LOG.debug( + 'Joining service registry group "%s" as member_id "%s" with capabilities "%s"' + % (group_id, member_id, capabilities) + ) return coordinator.join_group(group_id, capabilities=capabilities).get() diff --git a/st2common/st2common/services/access.py b/st2common/st2common/services/access.py index 72f7f192bb..9d88c39c42 100644 --- a/st2common/st2common/services/access.py +++ b/st2common/st2common/services/access.py @@ -27,15 +27,14 @@ from st2common.persistence.auth import Token, User from st2common import log as logging -__all__ = [ - 'create_token', - 'delete_token' -] +__all__ = ["create_token", "delete_token"] LOG = logging.getLogger(__name__) -def create_token(username, ttl=None, metadata=None, add_missing_user=True, service=False): +def create_token( + username, ttl=None, metadata=None, add_missing_user=True, service=False +): """ :param username: Username of the user to create the token for. If the account for this user doesn't exist yet it will be created. @@ -57,8 +56,10 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi if ttl: # Note: We allow arbitrary large TTLs for service tokens. if not service and ttl > cfg.CONF.auth.token_ttl: - msg = ('TTL specified %s is greater than max allowed %s.' % (ttl, - cfg.CONF.auth.token_ttl)) + msg = "TTL specified %s is greater than max allowed %s." % ( + ttl, + cfg.CONF.auth.token_ttl, + ) raise TTLTooLargeException(msg) else: ttl = cfg.CONF.auth.token_ttl @@ -71,22 +72,27 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi user_db = UserDB(name=username) User.add_or_update(user_db) - extra = {'username': username, 'user': user_db} + extra = {"username": username, "user": user_db} LOG.audit('Registered new user "%s".' % (username), extra=extra) else: raise UserNotFoundError() token = uuid.uuid4().hex expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) - token = TokenDB(user=username, token=token, expiry=expiry, metadata=metadata, service=service) + token = TokenDB( + user=username, token=token, expiry=expiry, metadata=metadata, service=service + ) Token.add_or_update(token) - username_string = username if username else 'an anonymous user' + username_string = username if username else "an anonymous user" token_expire_string = isotime.format(expiry, offset=False) - extra = {'username': username, 'token_expiration': token_expire_string} + extra = {"username": username, "token_expiration": token_expire_string} - LOG.audit('Access granted to "%s" with the token set to expire at "%s".' % - (username_string, token_expire_string), extra=extra) + LOG.audit( + 'Access granted to "%s" with the token set to expire at "%s".' + % (username_string, token_expire_string), + extra=extra, + ) return token diff --git a/st2common/st2common/services/action.py b/st2common/st2common/services/action.py index c7e7495d69..46e44800cc 100644 --- a/st2common/st2common/services/action.py +++ b/st2common/st2common/services/action.py @@ -34,15 +34,13 @@ __all__ = [ - 'request', - 'create_request', - 'publish_request', - 'is_action_canceled_or_canceling', - - 'request_pause', - 'request_resume', - - 'store_execution_output_data', + "request", + "create_request", + "publish_request", + "is_action_canceled_or_canceling", + "request_pause", + "request_resume", + "store_execution_output_data", ] LOG = logging.getLogger(__name__) @@ -51,7 +49,7 @@ def _get_immutable_params(parameters): if not parameters: return [] - return [k for k, v in six.iteritems(parameters) if v.get('immutable', False)] + return [k for k, v in six.iteritems(parameters) if v.get("immutable", False)] def create_request(liveaction, action_db=None, runnertype_db=None): @@ -77,10 +75,10 @@ def create_request(liveaction, action_db=None, runnertype_db=None): # action can be invoked by a system user and so we want to use the user context # from the original workflow action. parent_context = executions.get_parent_context(liveaction) or {} - parent_user = parent_context.get('user', None) + parent_user = parent_context.get("user", None) if parent_user: - liveaction.context['user'] = parent_user + liveaction.context["user"] = parent_user # Validate action if not action_db: @@ -89,31 +87,44 @@ def create_request(liveaction, action_db=None, runnertype_db=None): if not action_db: raise ValueError('Action "%s" cannot be found.' % liveaction.action) if not action_db.enabled: - raise ValueError('Unable to execute. Action "%s" is disabled.' % liveaction.action) + raise ValueError( + 'Unable to execute. Action "%s" is disabled.' % liveaction.action + ) if not runnertype_db: - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) - if not hasattr(liveaction, 'parameters'): + if not hasattr(liveaction, "parameters"): liveaction.parameters = dict() # For consistency add pack to the context here in addition to RunnerContainer.dispatch() method - liveaction.context['pack'] = action_db.pack + liveaction.context["pack"] = action_db.pack # Validate action parameters. schema = util_schema.get_schema_for_action_parameters(action_db, runnertype_db) validator = util_schema.get_validator() - util_schema.validate(liveaction.parameters, schema, validator, use_default=True, - allow_default_none=True) + util_schema.validate( + liveaction.parameters, + schema, + validator, + use_default=True, + allow_default_none=True, + ) # validate that no immutable params are being overriden. Although possible to # ignore the override it is safer to inform the user to avoid surprises. immutables = _get_immutable_params(action_db.parameters) immutables.extend(_get_immutable_params(runnertype_db.runner_parameters)) - overridden_immutables = [p for p in six.iterkeys(liveaction.parameters) if p in immutables] + overridden_immutables = [ + p for p in six.iterkeys(liveaction.parameters) if p in immutables + ] if len(overridden_immutables) > 0: - raise ValueError('Override of immutable parameter(s) %s is unsupported.' - % str(overridden_immutables)) + raise ValueError( + "Override of immutable parameter(s) %s is unsupported." + % str(overridden_immutables) + ) # Set notification settings for action. # XXX: There are cases when we don't want notifications to be sent for a particular @@ -140,17 +151,24 @@ def create_request(liveaction, action_db=None, runnertype_db=None): _cleanup_liveaction(liveaction) raise trace_exc.TraceNotFoundException(six.text_type(e)) - execution = executions.create_execution_object(liveaction=liveaction, action_db=action_db, - runnertype_db=runnertype_db, publish=False) + execution = executions.create_execution_object( + liveaction=liveaction, + action_db=action_db, + runnertype_db=runnertype_db, + publish=False, + ) if trace_db: trace_service.add_or_update_given_trace_db( trace_db=trace_db, action_executions=[ - trace_service.get_trace_component_for_action_execution(execution, liveaction) - ]) + trace_service.get_trace_component_for_action_execution( + execution, liveaction + ) + ], + ) - get_driver().inc_counter('action.executions.%s' % (liveaction.status)) + get_driver().inc_counter("action.executions.%s" % (liveaction.status)) return liveaction, execution @@ -170,8 +188,11 @@ def publish_request(liveaction, execution): # TODO: This results in two queries, optimize it # extra = {'liveaction_db': liveaction, 'execution_db': execution} extra = {} - LOG.audit('Action execution requested. LiveAction.id=%s, ActionExecution.id=%s' % - (liveaction.id, execution.id), extra=extra) + LOG.audit( + "Action execution requested. LiveAction.id=%s, ActionExecution.id=%s" + % (liveaction.id, execution.id), + extra=extra, + ) return liveaction, execution @@ -190,33 +211,34 @@ def update_status(liveaction, new_status, result=None, publish=True): old_status = liveaction.status updates = { - 'liveaction_id': liveaction.id, - 'status': new_status, - 'result': result, - 'publish': False + "liveaction_id": liveaction.id, + "status": new_status, + "result": result, + "publish": False, } if new_status in action_constants.LIVEACTION_COMPLETED_STATES: - updates['end_timestamp'] = date_utils.get_datetime_utc_now() + updates["end_timestamp"] = date_utils.get_datetime_utc_now() liveaction = action_utils.update_liveaction_status(**updates) action_execution = executions.update_execution(liveaction) - msg = ('The status of action execution is changed from %s to %s. ' - '' % (old_status, - new_status, liveaction.id, action_execution.id)) + msg = ( + "The status of action execution is changed from %s to %s. " + "" + % (old_status, new_status, liveaction.id, action_execution.id) + ) - extra = { - 'action_execution_db': action_execution, - 'liveaction_db': liveaction - } + extra = {"action_execution_db": action_execution, "liveaction_db": liveaction} LOG.audit(msg, extra=extra) LOG.info(msg) # Invoke post run if liveaction status is completed or paused. - if (new_status in action_constants.LIVEACTION_COMPLETED_STATES or - new_status == action_constants.LIVEACTION_STATUS_PAUSED): + if ( + new_status in action_constants.LIVEACTION_COMPLETED_STATES + or new_status == action_constants.LIVEACTION_STATUS_PAUSED + ): runners_utils.invoke_post_run(liveaction) if publish: @@ -227,14 +249,18 @@ def update_status(liveaction, new_status, result=None, publish=True): def is_action_canceled_or_canceling(liveaction_id): liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) - return liveaction_db.status in [action_constants.LIVEACTION_STATUS_CANCELED, - action_constants.LIVEACTION_STATUS_CANCELING] + return liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_CANCELED, + action_constants.LIVEACTION_STATUS_CANCELING, + ] def is_action_paused_or_pausing(liveaction_id): liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) - return liveaction_db.status in [action_constants.LIVEACTION_STATUS_PAUSED, - action_constants.LIVEACTION_STATUS_PAUSING] + return liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PAUSING, + ] def request_cancellation(liveaction, requester): @@ -250,18 +276,17 @@ def request_cancellation(liveaction, requester): if liveaction.status not in action_constants.LIVEACTION_CANCELABLE_STATES: raise Exception( 'Unable to cancel liveaction "%s" because it is already in a ' - 'completed state.' % liveaction.id + "completed state." % liveaction.id ) - result = { - 'message': 'Action canceled by user.', - 'user': requester - } + result = {"message": "Action canceled by user.", "user": requester} # Run cancelation sequence for liveaction that is in running state or # if the liveaction is operating under a workflow. - if ('parent' in liveaction.context or - liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING): + if ( + "parent" in liveaction.context + or liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING + ): status = action_constants.LIVEACTION_STATUS_CANCELING else: status = action_constants.LIVEACTION_STATUS_CANCELED @@ -286,17 +311,19 @@ def request_pause(liveaction, requester): if not action_db: raise ValueError( 'Unable to pause liveaction "%s" because the action "%s" ' - 'is not found.' % (liveaction.id, liveaction.action) + "is not found." % (liveaction.id, liveaction.action) ) - if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: raise runner_exc.InvalidActionRunnerOperationError( 'Unable to pause liveaction "%s" because it is not supported by the ' - '"%s" runner.' % (liveaction.id, action_db.runner_type['name']) + '"%s" runner.' % (liveaction.id, action_db.runner_type["name"]) ) - if (liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING or - liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED): + if ( + liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING + or liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED + ): execution = ActionExecution.get(liveaction__id=str(liveaction.id)) return (liveaction, execution) @@ -326,18 +353,18 @@ def request_resume(liveaction, requester): if not action_db: raise ValueError( 'Unable to resume liveaction "%s" because the action "%s" ' - 'is not found.' % (liveaction.id, liveaction.action) + "is not found." % (liveaction.id, liveaction.action) ) - if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: raise runner_exc.InvalidActionRunnerOperationError( 'Unable to resume liveaction "%s" because it is not supported by the ' - '"%s" runner.' % (liveaction.id, action_db.runner_type['name']) + '"%s" runner.' % (liveaction.id, action_db.runner_type["name"]) ) running_states = [ action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] if liveaction.status in running_states: @@ -367,13 +394,13 @@ def get_parent_liveaction(liveaction_db): :rtype: LiveActionDB """ - parent = liveaction_db.context.get('parent') + parent = liveaction_db.context.get("parent") if not parent: return None - parent_execution_db = ActionExecution.get(id=parent['execution_id']) - parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction['id']) + parent_execution_db = ActionExecution.get(id=parent["execution_id"]) + parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction["id"]) return parent_liveaction_db @@ -409,7 +436,11 @@ def get_root_liveaction(liveaction_db): parent_liveaction_db = get_parent_liveaction(liveaction_db) - return get_root_liveaction(parent_liveaction_db) if parent_liveaction_db else liveaction_db + return ( + get_root_liveaction(parent_liveaction_db) + if parent_liveaction_db + else liveaction_db + ) def get_root_execution(execution_db): @@ -425,36 +456,48 @@ def get_root_execution(execution_db): parent_execution_db = get_parent_execution(execution_db) - return get_root_execution(parent_execution_db) if parent_execution_db else execution_db + return ( + get_root_execution(parent_execution_db) if parent_execution_db else execution_db + ) -def store_execution_output_data(execution_db, action_db, data, output_type='output', - timestamp=None): +def store_execution_output_data( + execution_db, action_db, data, output_type="output", timestamp=None +): """ Store output from an execution as a new document in the collection. """ execution_id = str(execution_db.id) if action_db is None: - action_ref = execution_db.action.get('ref', 'unknown') - runner_ref = execution_db.action.get('runner_type', 'unknown') + action_ref = execution_db.action.get("ref", "unknown") + runner_ref = execution_db.action.get("runner_type", "unknown") else: action_ref = action_db.ref - runner_ref = getattr(action_db, 'runner_type', {}).get('name', 'unknown') + runner_ref = getattr(action_db, "runner_type", {}).get("name", "unknown") return store_execution_output_data_ex( - execution_id, action_ref, runner_ref, data, - output_type=output_type, timestamp=timestamp + execution_id, + action_ref, + runner_ref, + data, + output_type=output_type, + timestamp=timestamp, ) -def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, output_type='output', - timestamp=None): +def store_execution_output_data_ex( + execution_id, action_ref, runner_ref, data, output_type="output", timestamp=None +): timestamp = timestamp or date_utils.get_datetime_utc_now() output_db = ActionExecutionOutputDB( - execution_id=execution_id, action_ref=action_ref, runner_ref=runner_ref, - timestamp=timestamp, output_type=output_type, data=data + execution_id=execution_id, + action_ref=action_ref, + runner_ref=runner_ref, + timestamp=timestamp, + output_type=output_type, + data=data, ) output_db = ActionExecutionOutput.add_or_update( @@ -467,29 +510,29 @@ def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, o def is_children_active(liveaction_id): execution_db = ActionExecution.get(liveaction__id=str(liveaction_id)) - if execution_db.runner['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if execution_db.runner["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: return False children_execution_dbs = ActionExecution.query(parent=str(execution_db.id)) - inactive_statuses = ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED, action_constants.LIVEACTION_STATUS_PENDING] - ) + inactive_statuses = action_constants.LIVEACTION_COMPLETED_STATES + [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PENDING, + ] completed = [ child_exec_db.status in inactive_statuses for child_exec_db in children_execution_dbs ] - return (not all(completed)) + return not all(completed) def _cleanup_liveaction(liveaction): try: LiveAction.delete(liveaction) except: - LOG.exception('Failed cleaning up LiveAction: %s.', liveaction) + LOG.exception("Failed cleaning up LiveAction: %s.", liveaction) pass diff --git a/st2common/st2common/services/config.py b/st2common/st2common/services/config.py index f23d91ee9c..bef8f483dd 100644 --- a/st2common/st2common/services/config.py +++ b/st2common/st2common/services/config.py @@ -28,13 +28,15 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'set_datastore_value_for_config_key', + "set_datastore_value_for_config_key", ] LOG = logging.getLogger(__name__) -def set_datastore_value_for_config_key(pack_name, key_name, value, secret=False, user=None): +def set_datastore_value_for_config_key( + pack_name, key_name, value, secret=False, user=None +): """ Set config value in the datastore. diff --git a/st2common/st2common/services/coordination.py b/st2common/st2common/services/coordination.py index 42556ea084..1a068632ab 100644 --- a/st2common/st2common/services/coordination.py +++ b/st2common/st2common/services/coordination.py @@ -31,19 +31,17 @@ COORDINATOR = None __all__ = [ - 'configured', - - 'get_coordinator', - 'get_coordinator_if_set', - 'get_member_id', - - 'coordinator_setup', - 'coordinator_teardown' + "configured", + "get_coordinator", + "get_coordinator_if_set", + "get_member_id", + "coordinator_setup", + "coordinator_teardown", ] class NoOpLock(locking.Lock): - def __init__(self, name='noop'): + def __init__(self, name="noop"): super(NoOpLock, self).__init__(name=name) def acquire(self, blocking=True): @@ -61,6 +59,7 @@ class NoOpAsyncResult(object): In most scenarios, tooz library returns an async result, a future and this class wrapper is here to correctly mimic tooz API and behavior. """ + def __init__(self, result=None): self._result = result @@ -108,7 +107,7 @@ def stand_down_group_leader(group_id): @classmethod def create_group(cls, group_id): - cls.groups[group_id] = {'members': {}} + cls.groups[group_id] = {"members": {}} return NoOpAsyncResult() @classmethod @@ -116,17 +115,17 @@ def get_groups(cls): return NoOpAsyncResult(result=cls.groups.keys()) @classmethod - def join_group(cls, group_id, capabilities=''): + def join_group(cls, group_id, capabilities=""): member_id = get_member_id() - cls.groups[group_id]['members'][member_id] = {'capabilities': capabilities} + cls.groups[group_id]["members"][member_id] = {"capabilities": capabilities} return NoOpAsyncResult() @classmethod def leave_group(cls, group_id): member_id = get_member_id() - del cls.groups[group_id]['members'][member_id] + del cls.groups[group_id]["members"][member_id] return NoOpAsyncResult() @classmethod @@ -137,15 +136,15 @@ def delete_group(cls, group_id): @classmethod def get_members(cls, group_id): try: - member_ids = cls.groups[group_id]['members'].keys() + member_ids = cls.groups[group_id]["members"].keys() except KeyError: - raise GroupNotCreated('Group doesnt exist') + raise GroupNotCreated("Group doesnt exist") return NoOpAsyncResult(result=member_ids) @classmethod def get_member_capabilities(cls, group_id, member_id): - member_capabiliteis = cls.groups[group_id]['members'][member_id]['capabilities'] + member_capabiliteis = cls.groups[group_id]["members"][member_id]["capabilities"] return NoOpAsyncResult(result=member_capabiliteis) @staticmethod @@ -158,7 +157,7 @@ def get_leader(group_id): @staticmethod def get_lock(name): - return NoOpLock(name='noop') + return NoOpLock(name="noop") def configured(): @@ -168,8 +167,10 @@ def configured(): :rtype: ``bool`` """ backend_configured = cfg.CONF.coordination.url is not None - mock_backend = backend_configured and (cfg.CONF.coordination.url.startswith('zake') or - cfg.CONF.coordination.url.startswith('file')) + mock_backend = backend_configured and ( + cfg.CONF.coordination.url.startswith("zake") + or cfg.CONF.coordination.url.startswith("file") + ) return backend_configured and not mock_backend @@ -189,7 +190,9 @@ def coordinator_setup(start_heart=True): member_id = get_member_id() if url: - coordinator = coordination.get_coordinator(url, member_id, lock_timeout=lock_timeout) + coordinator = coordination.get_coordinator( + url, member_id, lock_timeout=lock_timeout + ) else: # Use a no-op backend # Note: We don't use tooz to obtain a reference since for this to work we would need to @@ -217,17 +220,21 @@ def get_coordinator(start_heart=True, use_cache=True): global COORDINATOR if not configured(): - LOG.warn('Coordination backend is not configured. Code paths which use coordination ' - 'service will use best effort approach and race conditions are possible.') + LOG.warn( + "Coordination backend is not configured. Code paths which use coordination " + "service will use best effort approach and race conditions are possible." + ) if not use_cache: return coordinator_setup(start_heart=start_heart) if not COORDINATOR: COORDINATOR = coordinator_setup(start_heart=start_heart) - LOG.debug('Initializing and caching new coordinator instance: %s' % (str(COORDINATOR))) + LOG.debug( + "Initializing and caching new coordinator instance: %s" % (str(COORDINATOR)) + ) else: - LOG.debug('Using cached coordinator instance: %s' % (str(COORDINATOR))) + LOG.debug("Using cached coordinator instance: %s" % (str(COORDINATOR))) return COORDINATOR @@ -247,5 +254,5 @@ def get_member_id(): :rtype: ``bytes`` """ proc_info = system_info.get_process_info() - member_id = six.b('%s_%d' % (proc_info['hostname'], proc_info['pid'])) + member_id = six.b("%s_%d" % (proc_info["hostname"], proc_info["pid"])) return member_id diff --git a/st2common/st2common/services/datastore.py b/st2common/st2common/services/datastore.py index 986ffd0d03..9655499e49 100644 --- a/st2common/st2common/services/datastore.py +++ b/st2common/st2common/services/datastore.py @@ -24,11 +24,7 @@ from st2common.util.date import get_datetime_utc_now from st2common.constants.keyvalue import DATASTORE_KEY_SEPARATOR, SYSTEM_SCOPE -__all__ = [ - 'BaseDatastoreService', - 'ActionDatastoreService', - 'SensorDatastoreService' -] +__all__ = ["BaseDatastoreService", "ActionDatastoreService", "SensorDatastoreService"] class BaseDatastoreService(object): @@ -63,7 +59,7 @@ def get_user_info(self): """ client = self.get_api_client() - self._logger.debug('Retrieving user information') + self._logger.debug("Retrieving user information") result = client.get_user_info() return result @@ -85,7 +81,7 @@ def list_values(self, local=True, prefix=None): :rtype: ``list`` of :class:`KeyValuePair` """ client = self.get_api_client() - self._logger.debug('Retrieving all the values from the datastore') + self._logger.debug("Retrieving all the values from the datastore") key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) kvps = client.keys.get_all(prefix=key_prefix) @@ -113,21 +109,19 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): :rtype: ``str`` or ``None`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) client = self.get_api_client() - self._logger.debug('Retrieving value from the datastore (name=%s)', name) + self._logger.debug("Retrieving value from the datastore (name=%s)", name) try: - params = {'decrypt': str(decrypt).lower(), 'scope': scope} + params = {"decrypt": str(decrypt).lower(), "scope": scope} kvp = client.keys.get_by_id(id=name, params=params) except Exception as e: self._logger.exception( - 'Exception retrieving value from datastore (name=%s): %s', - name, - e + "Exception retrieving value from datastore (name=%s): %s", name, e ) return None @@ -136,7 +130,9 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): return None - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): """ Set a value for the provided key. @@ -165,14 +161,14 @@ def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encry :rtype: ``bool`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) value = str(value) client = self.get_api_client() - self._logger.debug('Setting value in the datastore (name=%s)', name) + self._logger.debug("Setting value in the datastore (name=%s)", name) instance = KeyValuePair() instance.id = name @@ -208,7 +204,7 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): :rtype: ``bool`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) @@ -218,16 +214,14 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): instance.id = name instance.name = name - self._logger.debug('Deleting value from the datastore (name=%s)', name) + self._logger.debug("Deleting value from the datastore (name=%s)", name) try: - params = {'scope': scope} + params = {"scope": scope} client.keys.delete(instance=instance, params=params) except Exception as e: self._logger.exception( - 'Exception deleting value from datastore (name=%s): %s', - name, - e + "Exception deleting value from datastore (name=%s): %s", name, e ) return False @@ -237,7 +231,7 @@ def get_api_client(self): """ Retrieve API client instance. """ - raise NotImplementedError('get_api_client() not implemented') + raise NotImplementedError("get_api_client() not implemented") def _get_full_key_name(self, name, local): """ @@ -282,7 +276,7 @@ def _get_key_name_with_prefix(self, name): return full_name def _get_datastore_key_prefix(self): - prefix = '%s.%s' % (self._pack_name, self._class_name) + prefix = "%s.%s" % (self._pack_name, self._class_name) return prefix @@ -299,8 +293,9 @@ def __init__(self, logger, pack_name, class_name, auth_token): :param auth_token: Auth token used to authenticate with StackStorm API. :type auth_token: ``str`` """ - super(ActionDatastoreService, self).__init__(logger=logger, pack_name=pack_name, - class_name=class_name) + super(ActionDatastoreService, self).__init__( + logger=logger, pack_name=pack_name, class_name=class_name + ) self._auth_token = auth_token self._client = None @@ -310,7 +305,7 @@ def get_api_client(self): Retrieve API client instance. """ if not self._client: - self._logger.debug('Creating new Client object.') + self._logger.debug("Creating new Client object.") api_url = get_full_public_api_url() client = Client(api_url=api_url, token=self._auth_token) @@ -330,8 +325,9 @@ class SensorDatastoreService(BaseDatastoreService): """ def __init__(self, logger, pack_name, class_name, api_username): - super(SensorDatastoreService, self).__init__(logger=logger, pack_name=pack_name, - class_name=class_name) + super(SensorDatastoreService, self).__init__( + logger=logger, pack_name=pack_name, class_name=class_name + ) self._api_username = api_username self._token_expire = get_datetime_utc_now() @@ -344,12 +340,15 @@ def get_api_client(self): if not self._client or token_expire: # Note: Late import to avoid high import cost (time wise) from st2common.services.access import create_token - self._logger.debug('Creating new Client object.') + + self._logger.debug("Creating new Client object.") ttl = cfg.CONF.auth.service_token_ttl api_url = get_full_public_api_url() - temporary_token = create_token(username=self._api_username, ttl=ttl, service=True) + temporary_token = create_token( + username=self._api_username, ttl=ttl, service=True + ) self._client = Client(api_url=api_url, token=temporary_token.token) self._token_expire = get_datetime_utc_now() + timedelta(seconds=ttl) diff --git a/st2common/st2common/services/executions.py b/st2common/st2common/services/executions.py index e259977bdc..51447796b0 100644 --- a/st2common/st2common/services/executions.py +++ b/st2common/st2common/services/executions.py @@ -51,13 +51,13 @@ __all__ = [ - 'create_execution_object', - 'update_execution', - 'abandon_execution_if_incomplete', - 'is_execution_canceled', - 'AscendingSortedDescendantView', - 'DFSDescendantView', - 'get_descendants' + "create_execution_object", + "update_execution", + "abandon_execution_if_incomplete", + "is_execution_canceled", + "AscendingSortedDescendantView", + "DFSDescendantView", + "get_descendants", ] LOG = logging.getLogger(__name__) @@ -66,13 +66,13 @@ # into a ActionExecution compatible dictionary. # Those attributes are LiveAction specific and are therefore stored in a "liveaction" key LIVEACTION_ATTRIBUTES = [ - 'id', - 'callback', - 'action', - 'action_is_workflow', - 'runner_info', - 'parameters', - 'notify' + "id", + "callback", + "action", + "action_is_workflow", + "runner_info", + "parameters", + "notify", ] @@ -80,11 +80,11 @@ def _decompose_liveaction(liveaction_db): """ Splits the liveaction into an ActionExecution compatible dict. """ - decomposed = {'liveaction': {}} + decomposed = {"liveaction": {}} liveaction_api = vars(LiveActionAPI.from_model(liveaction_db)) for k in liveaction_api.keys(): if k in LIVEACTION_ATTRIBUTES: - decomposed['liveaction'][k] = liveaction_api[k] + decomposed["liveaction"][k] = liveaction_api[k] else: decomposed[k] = getattr(liveaction_db, k) return decomposed @@ -94,49 +94,53 @@ def _create_execution_log_entry(status): """ Create execution log entry object for the provided execution status. """ - return { - 'timestamp': date_utils.get_datetime_utc_now(), - 'status': status - } + return {"timestamp": date_utils.get_datetime_utc_now(), "status": status} -def create_execution_object(liveaction, action_db=None, runnertype_db=None, publish=True): +def create_execution_object( + liveaction, action_db=None, runnertype_db=None, publish=True +): if not action_db: action_db = action_utils.get_action_by_ref(liveaction.action) if not runnertype_db: - runnertype_db = RunnerType.get_by_name(action_db.runner_type['name']) + runnertype_db = RunnerType.get_by_name(action_db.runner_type["name"]) attrs = { - 'action': vars(ActionAPI.from_model(action_db)), - 'parameters': liveaction['parameters'], - 'runner': vars(RunnerTypeAPI.from_model(runnertype_db)) + "action": vars(ActionAPI.from_model(action_db)), + "parameters": liveaction["parameters"], + "runner": vars(RunnerTypeAPI.from_model(runnertype_db)), } attrs.update(_decompose_liveaction(liveaction)) - if 'rule' in liveaction.context: - rule = reference.get_model_from_ref(Rule, liveaction.context.get('rule', {})) - attrs['rule'] = vars(RuleAPI.from_model(rule)) + if "rule" in liveaction.context: + rule = reference.get_model_from_ref(Rule, liveaction.context.get("rule", {})) + attrs["rule"] = vars(RuleAPI.from_model(rule)) - if 'trigger_instance' in liveaction.context: - trigger_instance_id = liveaction.context.get('trigger_instance', {}) - trigger_instance_id = trigger_instance_id.get('id', None) + if "trigger_instance" in liveaction.context: + trigger_instance_id = liveaction.context.get("trigger_instance", {}) + trigger_instance_id = trigger_instance_id.get("id", None) trigger_instance = TriggerInstance.get_by_id(trigger_instance_id) - trigger = reference.get_model_by_resource_ref(db_api=Trigger, - ref=trigger_instance.trigger) - trigger_type = reference.get_model_by_resource_ref(db_api=TriggerType, - ref=trigger.type) + trigger = reference.get_model_by_resource_ref( + db_api=Trigger, ref=trigger_instance.trigger + ) + trigger_type = reference.get_model_by_resource_ref( + db_api=TriggerType, ref=trigger.type + ) trigger_instance = reference.get_model_from_ref( - TriggerInstance, liveaction.context.get('trigger_instance', {})) - attrs['trigger_instance'] = vars(TriggerInstanceAPI.from_model(trigger_instance)) - attrs['trigger'] = vars(TriggerAPI.from_model(trigger)) - attrs['trigger_type'] = vars(TriggerTypeAPI.from_model(trigger_type)) + TriggerInstance, liveaction.context.get("trigger_instance", {}) + ) + attrs["trigger_instance"] = vars( + TriggerInstanceAPI.from_model(trigger_instance) + ) + attrs["trigger"] = vars(TriggerAPI.from_model(trigger)) + attrs["trigger_type"] = vars(TriggerTypeAPI.from_model(trigger_type)) parent = _get_parent_execution(liveaction) if parent: - attrs['parent'] = str(parent.id) + attrs["parent"] = str(parent.id) - attrs['log'] = [_create_execution_log_entry(liveaction['status'])] + attrs["log"] = [_create_execution_log_entry(liveaction["status"])] # TODO: This object initialization takes 20-30or so ms execution = ActionExecutionDB(**attrs) @@ -146,24 +150,30 @@ def create_execution_object(liveaction, action_db=None, runnertype_db=None, publ # NOTE: User input data is already validate as part of the API request, # other data is set by us. Skipping validation here makes operation 10%-30% faster - execution = ActionExecution.add_or_update(execution, publish=publish, validate=False) + execution = ActionExecution.add_or_update( + execution, publish=publish, validate=False + ) if parent and str(execution.id) not in parent.children: values = {} - values['push__children'] = str(execution.id) + values["push__children"] = str(execution.id) ActionExecution.update(parent, **values) return execution def _get_parent_execution(child_liveaction_db): - parent_execution_id = child_liveaction_db.context.get('parent', {}).get('execution_id', None) + parent_execution_id = child_liveaction_db.context.get("parent", {}).get( + "execution_id", None + ) if parent_execution_id: try: return ActionExecution.get_by_id(parent_execution_id) except: - LOG.exception('No valid execution object found in db for id: %s' % parent_execution_id) + LOG.exception( + "No valid execution object found in db for id: %s" % parent_execution_id + ) return None return None @@ -180,12 +190,12 @@ def update_execution(liveaction_db, publish=True): kw = {} for k, v in six.iteritems(decomposed): - kw['set__' + k] = v + kw["set__" + k] = v if liveaction_db.status != execution.status: # Note: If the status changes we store this transition in the "log" attribute of action # execution - kw['push__log'] = _create_execution_log_entry(liveaction_db.status) + kw["push__log"] = _create_execution_log_entry(liveaction_db.status) execution = ActionExecution.update(execution, publish=publish, **kw) return execution @@ -201,19 +211,25 @@ def abandon_execution_if_incomplete(liveaction_id, publish=True): # No need to abandon and already complete action if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES: - raise ValueError('LiveAction %s already in a completed state %s.' % - (liveaction_id, liveaction_db.status)) + raise ValueError( + "LiveAction %s already in a completed state %s." + % (liveaction_id, liveaction_db.status) + ) # Update status to reflect execution being abandoned. liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_ABANDONED, liveaction_db=liveaction_db, - result={}) + result={}, + ) execution_db = update_execution(liveaction_db, publish=publish) - LOG.info('Marked execution %s as %s.', execution_db.id, - action_constants.LIVEACTION_STATUS_ABANDONED) + LOG.info( + "Marked execution %s as %s.", + execution_db.id, + action_constants.LIVEACTION_STATUS_ABANDONED, + ) # Invoke post run on the action to execute post run operations such as callback. runners_utils.invoke_post_run(liveaction_db) @@ -236,10 +252,10 @@ def get_parent_context(liveaction_db): :return: If found the parent context else None. :rtype: dict """ - context = getattr(liveaction_db, 'context', None) + context = getattr(liveaction_db, "context", None) if not context: return None - return context.get('parent', None) + return context.get("parent", None) class AscendingSortedDescendantView(object): @@ -267,8 +283,8 @@ def result(self): DESCENDANT_VIEWS = { - 'sorted': AscendingSortedDescendantView, - 'default': DFSDescendantView + "sorted": AscendingSortedDescendantView, + "default": DFSDescendantView, } @@ -278,9 +294,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None): the supplied actionexecution_id. """ descendants = DESCENDANT_VIEWS.get(result_fmt, DFSDescendantView)() - children = ActionExecution.query(parent=actionexecution_id, - **{'order_by': ['start_timestamp']}) - LOG.debug('Found %s children for id %s.', len(children), actionexecution_id) + children = ActionExecution.query( + parent=actionexecution_id, **{"order_by": ["start_timestamp"]} + ) + LOG.debug("Found %s children for id %s.", len(children), actionexecution_id) current_level = [(child, 1) for child in children] while current_level: @@ -291,8 +308,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None): continue if level != -1 and level == descendant_depth: continue - children = ActionExecution.query(parent=parent_id, **{'order_by': ['start_timestamp']}) - LOG.debug('Found %s children for id %s.', len(children), parent_id) + children = ActionExecution.query( + parent=parent_id, **{"order_by": ["start_timestamp"]} + ) + LOG.debug("Found %s children for id %s.", len(children), parent_id) # prepend for DFS for idx in range(len(children)): current_level.insert(idx, (children[idx], level + 1)) diff --git a/st2common/st2common/services/inquiry.py b/st2common/st2common/services/inquiry.py index 5b511b3a97..09be3cc8f1 100644 --- a/st2common/st2common/services/inquiry.py +++ b/st2common/st2common/services/inquiry.py @@ -40,9 +40,11 @@ def check_inquiry(inquiry): - LOG.debug('Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id)) + LOG.debug( + 'Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id) + ) - if inquiry.runner.get('name') != 'inquirer': + if inquiry.runner.get("name") != "inquirer": raise inquiry_exceptions.InvalidInquiryInstance(str(inquiry.id)) LOG.debug('Checking if the inquiry "%s" has timed out.' % str(inquiry.id)) @@ -69,7 +71,7 @@ def check_permission(inquiry, requester): users_passed = False # Determine role-level permissions - roles = getattr(inquiry, 'roles', []) + roles = getattr(inquiry, "roles", []) if not roles: # No roles definition so we treat it as a pass @@ -79,14 +81,16 @@ def check_permission(inquiry, requester): rbac_utils = get_rbac_backend().get_utils_class() user_has_role = rbac_utils.user_has_role(user_db, role) - LOG.debug('Checking user %s is in role %s - %s' % (user_db, role, user_has_role)) + LOG.debug( + "Checking user %s is in role %s - %s" % (user_db, role, user_has_role) + ) if user_has_role: roles_passed = True break # Determine user-level permissions - users = getattr(inquiry, 'users', []) + users = getattr(inquiry, "users", []) if not users or user_db.name in users: users_passed = True @@ -98,7 +102,7 @@ def check_permission(inquiry, requester): def validate_response(inquiry, response): schema = inquiry.schema - LOG.debug('Validating inquiry response: %s against schema: %s' % (response, schema)) + LOG.debug("Validating inquiry response: %s against schema: %s" % (response, schema)) try: schema_utils.validate( @@ -106,12 +110,14 @@ def validate_response(inquiry, response): schema=schema, cls=schema_utils.CustomValidator, use_default=True, - allow_default_none=True + allow_default_none=True, ) except Exception as e: msg = 'Response for inquiry "%s" did not pass schema validation.' LOG.exception(msg % str(inquiry.id)) - raise inquiry_exceptions.InvalidInquiryResponse(str(inquiry.id), six.text_type(e)) + raise inquiry_exceptions.InvalidInquiryResponse( + str(inquiry.id), six.text_type(e) + ) def respond(inquiry, response, requester=None): @@ -120,14 +126,14 @@ def respond(inquiry, response, requester=None): requester = cfg.CONF.system_user.user # Retrieve the liveaction from the database. - liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get('id')) + liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get("id")) # Resume the parent workflow first. If the action execution for the inquiry is updated first, # it triggers handling of the action execution completion which will interact with the paused # parent workflow. The resuming logic that is executed here will then race with the completion # of the inquiry action execution, which will randomly result in the parent workflow stuck in # paused state. - if liveaction_db.context.get('parent'): + if liveaction_db.context.get("parent"): LOG.debug('Resuming workflow parent(s) for inquiry "%s".' % str(inquiry.id)) # For action execution under Action Chain workflows, request the entire @@ -136,7 +142,9 @@ def respond(inquiry, response, requester=None): # there is no other paused branches, the conductor will resume the rest of the workflow. resume_target = ( action_service.get_parent_liveaction(liveaction_db) - if workflow_service.is_action_execution_under_workflow_context(liveaction_db) + if workflow_service.is_action_execution_under_workflow_context( + liveaction_db + ) else action_service.get_root_liveaction(liveaction_db) ) @@ -147,14 +155,14 @@ def respond(inquiry, response, requester=None): LOG.debug('Updating response for inquiry "%s".' % str(inquiry.id)) result = copy.deepcopy(inquiry.result) - result['response'] = response + result["response"] = response liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_SUCCEEDED, end_timestamp=date_utils.get_datetime_utc_now(), runner_info=sys_info_utils.get_process_info(), result=result, - liveaction_id=str(liveaction_db.id) + liveaction_id=str(liveaction_db.id), ) # Sync the liveaction with the corresponding action execution. @@ -164,7 +172,7 @@ def respond(inquiry, response, requester=None): LOG.debug('Invoking post run for inquiry "%s".' % str(inquiry.id)) runner_container = container.get_runner_container() action_db = action_utils.get_action_by_ref(liveaction_db.action) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) runner = runner_container._get_runner(runnertype_db, action_db, liveaction_db) runner.post_run(status=action_constants.LIVEACTION_STATUS_SUCCEEDED, result=result) diff --git a/st2common/st2common/services/keyvalues.py b/st2common/st2common/services/keyvalues.py index 722603eee5..d38f28ca93 100644 --- a/st2common/st2common/services/keyvalues.py +++ b/st2common/st2common/services/keyvalues.py @@ -28,11 +28,10 @@ from st2common.persistence.keyvalue import KeyValuePair __all__ = [ - 'get_kvp_for_name', - 'get_values_for_names', - - 'KeyValueLookup', - 'UserKeyValueLookup' + "get_kvp_for_name", + "get_values_for_names", + "KeyValueLookup", + "UserKeyValueLookup", ] LOG = logging.getLogger(__name__) @@ -81,17 +80,17 @@ def get_key_name(self): :rtype: ``str`` """ key_name_parts = [DATASTORE_PARENT_SCOPE, self.scope] - key_name = self._key_prefix.split(':', 1) + key_name = self._key_prefix.split(":", 1) if len(key_name) == 1: key_name = key_name[0] elif len(key_name) >= 2: key_name = key_name[1] else: - key_name = '' + key_name = "" key_name_parts.append(key_name) - key_name = '.'.join(key_name_parts) + key_name = ".".join(key_name_parts) return key_name @@ -99,7 +98,9 @@ class KeyValueLookup(BaseKeyValueLookup): scope = SYSTEM_SCOPE - def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE): + def __init__( + self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE + ): if not scope: scope = FULL_SYSTEM_SCOPE @@ -107,7 +108,7 @@ def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_S scope = FULL_SYSTEM_SCOPE self._prefix = prefix - self._key_prefix = key_prefix or '' + self._key_prefix = key_prefix or "" self._value_cache = cache or {} self._scope = scope @@ -129,7 +130,7 @@ def __getattr__(self, name): def _get(self, name): # get the value for this key and save in value_cache if self._key_prefix: - key = '%s.%s' % (self._key_prefix, name) + key = "%s.%s" % (self._key_prefix, name) else: key = name @@ -144,12 +145,16 @@ def _get(self, name): # the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja, # will expect to do a dictionary style lookup for key_base and key_value as subsequent # calls. Saving the value in cache avoids extra DB calls. - return KeyValueLookup(prefix=self._prefix, key_prefix=key, cache=self._value_cache, - scope=self._scope) + return KeyValueLookup( + prefix=self._prefix, + key_prefix=key, + cache=self._value_cache, + scope=self._scope, + ) def _get_kv(self, key): scope = self._scope - LOG.debug('Lookup system kv: scope: %s and key: %s', scope, key) + LOG.debug("Lookup system kv: scope: %s and key: %s", scope, key) try: kvp = KeyValuePair.get_by_scope_and_name(scope=scope, name=key) @@ -157,15 +162,17 @@ def _get_kv(self, key): kvp = None if kvp: - LOG.debug('Got value %s from datastore.', kvp.value) - return kvp.value if kvp else '' + LOG.debug("Got value %s from datastore.", kvp.value) + return kvp.value if kvp else "" class UserKeyValueLookup(BaseKeyValueLookup): scope = USER_SCOPE - def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE): + def __init__( + self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE + ): if not scope: scope = FULL_USER_SCOPE @@ -173,7 +180,7 @@ def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_US scope = FULL_USER_SCOPE self._prefix = prefix - self._key_prefix = key_prefix or '' + self._key_prefix = key_prefix or "" self._value_cache = cache or {} self._user = user self._scope = scope @@ -190,7 +197,7 @@ def __getattr__(self, name): def _get(self, name): # get the value for this key and save in value_cache if self._key_prefix: - key = '%s.%s' % (self._key_prefix, name) + key = "%s.%s" % (self._key_prefix, name) else: key = UserKeyReference(name=name, user=self._user).ref @@ -205,8 +212,13 @@ def _get(self, name): # the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja, # will expect to do a dictionary style lookup for key_base and key_value as subsequent # calls. Saving the value in cache avoids extra DB calls. - return UserKeyValueLookup(prefix=self._prefix, user=self._user, key_prefix=key, - cache=self._value_cache, scope=self._scope) + return UserKeyValueLookup( + prefix=self._prefix, + user=self._user, + key_prefix=key, + cache=self._value_cache, + scope=self._scope, + ) def _get_kv(self, key): scope = self._scope @@ -216,7 +228,7 @@ def _get_kv(self, key): except StackStormDBObjectNotFoundError: kvp = None - return kvp.value if kvp else '' + return kvp.value if kvp else "" def get_key_reference(scope, name, user=None): @@ -232,12 +244,15 @@ def get_key_reference(scope, name, user=None): :rtype: ``str`` """ - if (scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE): + if scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE: return name - elif (scope == USER_SCOPE or scope == FULL_USER_SCOPE): + elif scope == USER_SCOPE or scope == FULL_USER_SCOPE: if not user: - raise InvalidUserException('A valid user must be specified for user key ref.') + raise InvalidUserException( + "A valid user must be specified for user key ref." + ) return UserKeyReference(name=name, user=user).ref else: - raise InvalidScopeException('Scope "%s" is not valid. Allowed scopes are %s.' % - (scope, ALLOWED_SCOPES)) + raise InvalidScopeException( + 'Scope "%s" is not valid. Allowed scopes are %s.' % (scope, ALLOWED_SCOPES) + ) diff --git a/st2common/st2common/services/packs.py b/st2common/st2common/services/packs.py index 7088b5f368..9f2794ed78 100644 --- a/st2common/st2common/services/packs.py +++ b/st2common/st2common/services/packs.py @@ -27,21 +27,15 @@ from six.moves import range __all__ = [ - 'get_pack_by_ref', - 'fetch_pack_index', - 'get_pack_from_index', - 'search_pack_index' + "get_pack_by_ref", + "fetch_pack_index", + "get_pack_from_index", + "search_pack_index", ] -EXCLUDE_FIELDS = [ - "repo_url", - "email" -] +EXCLUDE_FIELDS = ["repo_url", "email"] -SEARCH_PRIORITY = [ - "name", - "keywords" -] +SEARCH_PRIORITY = ["name", "keywords"] LOG = logging.getLogger(__name__) @@ -55,7 +49,7 @@ def _build_index_list(index_url): index_urls = cfg.CONF.content.index_url[::-1] elif isinstance(index_url, str): index_urls = [index_url] - elif hasattr(index_url, '__iter__'): + elif hasattr(index_url, "__iter__"): index_urls = index_url else: raise TypeError('"index_url" should either be a string or an iterable object.') @@ -73,23 +67,23 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None): verify = True if proxy_config: - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) - ca_bundle_path = proxy_config.get('proxy_ca_bundle_path', None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) + ca_bundle_path = proxy_config.get("proxy_ca_bundle_path", None) if https_proxy: - proxies_dict['https'] = https_proxy + proxies_dict["https"] = https_proxy verify = ca_bundle_path or True if http_proxy: - proxies_dict['http'] = http_proxy + proxies_dict["http"] = http_proxy for index_url in index_urls: index_status = { - 'url': index_url, - 'packs': 0, - 'message': None, - 'error': None, + "url": index_url, + "packs": 0, + "message": None, + "error": None, } index_json = None @@ -98,32 +92,32 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None): request.raise_for_status() index_json = request.json() except ValueError as e: - index_status['error'] = 'malformed' - index_status['message'] = repr(e) + index_status["error"] = "malformed" + index_status["message"] = repr(e) except requests.exceptions.RequestException as e: - index_status['error'] = 'unresponsive' - index_status['message'] = repr(e) + index_status["error"] = "unresponsive" + index_status["message"] = repr(e) except Exception as e: - index_status['error'] = 'other errors' - index_status['message'] = repr(e) + index_status["error"] = "other errors" + index_status["message"] = repr(e) if index_json == {}: - index_status['error'] = 'empty' - index_status['message'] = 'The index URL returned an empty object.' + index_status["error"] = "empty" + index_status["message"] = "The index URL returned an empty object." elif type(index_json) is list: - index_status['error'] = 'malformed' - index_status['message'] = 'Expected an index object, got a list instead.' - elif index_json and 'packs' not in index_json: - index_status['error'] = 'malformed' - index_status['message'] = 'Index object is missing "packs" attribute.' + index_status["error"] = "malformed" + index_status["message"] = "Expected an index object, got a list instead." + elif index_json and "packs" not in index_json: + index_status["error"] = "malformed" + index_status["message"] = 'Index object is missing "packs" attribute.' - if index_status['error']: + if index_status["error"]: logger.error("Index parsing error: %s" % json.dumps(index_status, indent=4)) else: # TODO: Notify on a duplicate pack aka pack being overwritten from a different index - packs_data = index_json['packs'] - index_status['message'] = 'Success.' - index_status['packs'] = len(packs_data) + packs_data = index_json["packs"] + index_status["message"] = "Success." + index_status["packs"] = len(packs_data) index.update(packs_data) status.append(index_status) @@ -147,8 +141,9 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi logger = logger or LOG index_urls = _build_index_list(index_url) - index, status = _fetch_and_compile_index(index_urls=index_urls, logger=logger, - proxy_config=proxy_config) + index, status = _fetch_and_compile_index( + index_urls=index_urls, logger=logger, proxy_config=proxy_config + ) # If one of the indexes on the list is unresponsive, we do not throw # immediately. The only case where an exception is raised is when no @@ -156,11 +151,14 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi # This behavior allows for mirrors / backups and handling connection # or network issues in one of the indexes. if not index and not allow_empty: - raise ValueError("No results from the %s: tried %s.\nStatus: %s" % ( - ("index" if len(index_urls) == 1 else "indexes"), - ", ".join(index_urls), - json.dumps(status, indent=4) - )) + raise ValueError( + "No results from the %s: tried %s.\nStatus: %s" + % ( + ("index" if len(index_urls) == 1 else "indexes"), + ", ".join(index_urls), + json.dumps(status, indent=4), + ) + ) return (index, status) @@ -177,13 +175,15 @@ def get_pack_from_index(pack, proxy_config=None): return index.get(pack) -def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, proxy_config=None): +def search_pack_index( + query, exclude=None, priority=None, case_sensitive=True, proxy_config=None +): """ Search the pack index by query. Returns a list of matches for a query. """ if not query: - raise ValueError('Query must be specified.') + raise ValueError("Query must be specified.") if not exclude: exclude = EXCLUDE_FIELDS @@ -198,7 +198,7 @@ def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, p matches = [[] for i in range(len(priority) + 1)] for pack in six.itervalues(index): for key, value in six.iteritems(pack): - if not hasattr(value, '__contains__'): + if not hasattr(value, "__contains__"): value = str(value) if not case_sensitive: diff --git a/st2common/st2common/services/policies.py b/st2common/st2common/services/policies.py index 50ba28f304..46e24ce290 100644 --- a/st2common/st2common/services/policies.py +++ b/st2common/st2common/services/policies.py @@ -25,13 +25,10 @@ def has_policies(lv_ac_db, policy_types=None): - query_params = { - 'resource_ref': lv_ac_db.action, - 'enabled': True - } + query_params = {"resource_ref": lv_ac_db.action, "enabled": True} if policy_types: - query_params['policy_type__in'] = policy_types + query_params["policy_type__in"] = policy_types policy_dbs = pc_db_access.Policy.query(**query_params) @@ -42,11 +39,19 @@ def apply_pre_run_policies(lv_ac_db): LOG.debug('Applying pre-run policies for liveaction "%s".' % str(lv_ac_db.id)) policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True) - LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action)) + LOG.debug( + 'Identified %s policies for the action "%s".' + % (len(policy_dbs), lv_ac_db.action) + ) for policy_db in policy_dbs: - LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type)) - driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + LOG.debug( + 'Getting driver for policy "%s" (%s).' + % (policy_db.ref, policy_db.policy_type) + ) + driver = engine.get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) try: message = 'Applying policy "%s" (%s) for liveaction "%s".' @@ -54,7 +59,9 @@ def apply_pre_run_policies(lv_ac_db): lv_ac_db = driver.apply_before(lv_ac_db) except: message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".' - LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))) + LOG.exception( + message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)) + ) if lv_ac_db.status == ac_const.LIVEACTION_STATUS_DELAYED: break @@ -66,11 +73,19 @@ def apply_post_run_policies(lv_ac_db): LOG.debug('Applying post run policies for liveaction "%s".' % str(lv_ac_db.id)) policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True) - LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action)) + LOG.debug( + 'Identified %s policies for the action "%s".' + % (len(policy_dbs), lv_ac_db.action) + ) for policy_db in policy_dbs: - LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type)) - driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + LOG.debug( + 'Getting driver for policy "%s" (%s).' + % (policy_db.ref, policy_db.policy_type) + ) + driver = engine.get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) try: message = 'Applying policy "%s" (%s) for liveaction "%s".' @@ -78,6 +93,8 @@ def apply_post_run_policies(lv_ac_db): lv_ac_db = driver.apply_after(lv_ac_db) except: message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".' - LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))) + LOG.exception( + message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)) + ) return lv_ac_db diff --git a/st2common/st2common/services/queries.py b/st2common/st2common/services/queries.py index e6d769e365..20c7a0c990 100644 --- a/st2common/st2common/services/queries.py +++ b/st2common/st2common/services/queries.py @@ -25,13 +25,15 @@ def setup_query(liveaction_id, runnertype_db, query_context): - if not getattr(runnertype_db, 'query_module', None): - raise Exception('The runner "%s" does not have a query module.' % runnertype_db.name) + if not getattr(runnertype_db, "query_module", None): + raise Exception( + 'The runner "%s" does not have a query module.' % runnertype_db.name + ) state_db = ActionExecutionStateDB( execution_id=liveaction_id, query_module=runnertype_db.query_module, - query_context=query_context + query_context=query_context, ) ActionExecutionState.add_or_update(state_db) diff --git a/st2common/st2common/services/rules.py b/st2common/st2common/services/rules.py index d9be718e27..ebb8083433 100644 --- a/st2common/st2common/services/rules.py +++ b/st2common/st2common/services/rules.py @@ -22,10 +22,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'get_rules_given_trigger', - 'get_rules_with_trigger_ref' -] +__all__ = ["get_rules_given_trigger", "get_rules_with_trigger_ref"] def get_rules_given_trigger(trigger): @@ -34,13 +31,15 @@ def get_rules_given_trigger(trigger): return get_rules_with_trigger_ref(trigger_ref=trigger) if isinstance(trigger, dict): - trigger_ref = trigger.get('ref', None) + trigger_ref = trigger.get("ref", None) if trigger_ref: return get_rules_with_trigger_ref(trigger_ref=trigger_ref) else: - raise ValueError('Trigger dict %s is missing ``ref``.' % trigger) + raise ValueError("Trigger dict %s is missing ``ref``." % trigger) - raise ValueError('Unknown type %s for trigger. Cannot do rule lookups.' % type(trigger)) + raise ValueError( + "Unknown type %s for trigger. Cannot do rule lookups." % type(trigger) + ) def get_rules_with_trigger_ref(trigger_ref=None, enabled=True): @@ -56,5 +55,5 @@ def get_rules_with_trigger_ref(trigger_ref=None, enabled=True): if not trigger_ref: return None - LOG.debug('Querying rules with trigger %s', trigger_ref) + LOG.debug("Querying rules with trigger %s", trigger_ref) return Rule.query(trigger=trigger_ref, enabled=enabled) diff --git a/st2common/st2common/services/sensor_watcher.py b/st2common/st2common/services/sensor_watcher.py index 0105ba46d6..1c54881663 100644 --- a/st2common/st2common/services/sensor_watcher.py +++ b/st2common/st2common/services/sensor_watcher.py @@ -32,9 +32,9 @@ class SensorWatcher(ConsumerMixin): - - def __init__(self, create_handler, update_handler, delete_handler, - queue_suffix=None): + def __init__( + self, create_handler, update_handler, delete_handler, queue_suffix=None + ): """ :param create_handler: Function which is called on SensorDB create event. :type create_handler: ``callable`` @@ -57,34 +57,41 @@ def __init__(self, create_handler, update_handler, delete_handler, self._handlers = { publishers.CREATE_RK: create_handler, publishers.UPDATE_RK: update_handler, - publishers.DELETE_RK: delete_handler + publishers.DELETE_RK: delete_handler, } def get_consumers(self, Consumer, channel): - consumers = [Consumer(queues=[self._sensor_watcher_q], - accept=['pickle'], - callbacks=[self.process_task])] + consumers = [ + Consumer( + queues=[self._sensor_watcher_q], + accept=["pickle"], + callbacks=[self.process_task], + ) + ] return consumers def process_task(self, body, message): - LOG.debug('process_task') - LOG.debug(' body: %s', body) - LOG.debug(' message.properties: %s', message.properties) - LOG.debug(' message.delivery_info: %s', message.delivery_info) + LOG.debug("process_task") + LOG.debug(" body: %s", body) + LOG.debug(" message.properties: %s", message.properties) + LOG.debug(" message.delivery_info: %s", message.delivery_info) - routing_key = message.delivery_info.get('routing_key', '') + routing_key = message.delivery_info.get("routing_key", "") handler = self._handlers.get(routing_key, None) try: if not handler: - LOG.info('Skipping message %s as no handler was found.', message) + LOG.info("Skipping message %s as no handler was found.", message) return try: handler(body) except Exception as e: - LOG.exception('Handling failed. Message body: %s. Exception: %s', - body, six.text_type(e)) + LOG.exception( + "Handling failed. Message body: %s. Exception: %s", + body, + six.text_type(e), + ) finally: message.ack() @@ -93,11 +100,11 @@ def start(self): self.connection = transport_utils.get_connection() self._updates_thread = concurrency.spawn(self.run) except: - LOG.exception('Failed to start sensor_watcher.') + LOG.exception("Failed to start sensor_watcher.") self.connection.release() def stop(self): - LOG.debug('Shutting down sensor watcher.') + LOG.debug("Shutting down sensor watcher.") try: if self._updates_thread: self._updates_thread = concurrency.kill(self._updates_thread) @@ -108,15 +115,19 @@ def stop(self): try: bound_sensor_watch_q.delete() except: - LOG.error('Unable to delete sensor watcher queue: %s', self._sensor_watcher_q) + LOG.error( + "Unable to delete sensor watcher queue: %s", + self._sensor_watcher_q, + ) finally: if self.connection: self.connection.release() @staticmethod def _get_queue(queue_suffix): - queue_name = queue_utils.get_queue_name(queue_name_base='st2.sensor.watch', - queue_name_suffix=queue_suffix, - add_random_uuid_to_suffix=True - ) - return reactor.get_sensor_cud_queue(queue_name, routing_key='#') + queue_name = queue_utils.get_queue_name( + queue_name_base="st2.sensor.watch", + queue_name_suffix=queue_suffix, + add_random_uuid_to_suffix=True, + ) + return reactor.get_sensor_cud_queue(queue_name, routing_key="#") diff --git a/st2common/st2common/services/trace.py b/st2common/st2common/services/trace.py index 3eb92bd2f1..4dadef0964 100644 --- a/st2common/st2common/services/trace.py +++ b/st2common/st2common/services/trace.py @@ -32,22 +32,24 @@ LOG = logging.getLogger(__name__) __all__ = [ - 'get_trace_db_by_action_execution', - 'get_trace_db_by_rule', - 'get_trace_db_by_trigger_instance', - 'get_trace', - 'add_or_update_given_trace_context', - 'add_or_update_given_trace_db', - 'get_trace_component_for_action_execution', - 'get_trace_component_for_rule', - 'get_trace_component_for_trigger_instance' + "get_trace_db_by_action_execution", + "get_trace_db_by_rule", + "get_trace_db_by_trigger_instance", + "get_trace", + "add_or_update_given_trace_context", + "add_or_update_given_trace_db", + "get_trace_component_for_action_execution", + "get_trace_component_for_rule", + "get_trace_component_for_trigger_instance", ] ACTION_SENSOR_TRIGGER_REF = ResourceReference.to_string_reference( - pack=ACTION_SENSOR_TRIGGER['pack'], name=ACTION_SENSOR_TRIGGER['name']) + pack=ACTION_SENSOR_TRIGGER["pack"], name=ACTION_SENSOR_TRIGGER["name"] +) NOTIFY_TRIGGER_REF = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER['pack'], name=NOTIFY_TRIGGER['name']) + pack=NOTIFY_TRIGGER["pack"], name=NOTIFY_TRIGGER["name"] +) def _get_valid_trace_context(trace_context): @@ -74,14 +76,17 @@ def _get_single_trace_by_component(**component_filter): return None elif len(traces) > 1: raise UniqueTraceNotFoundException( - 'More than 1 trace matching %s found.' % component_filter) + "More than 1 trace matching %s found." % component_filter + ) return traces[0] def get_trace_db_by_action_execution(action_execution=None, action_execution_id=None): if action_execution: action_execution_id = str(action_execution.id) - return _get_single_trace_by_component(action_executions__object_id=action_execution_id) + return _get_single_trace_by_component( + action_executions__object_id=action_execution_id + ) def get_trace_db_by_rule(rule=None, rule_id=None): @@ -94,7 +99,9 @@ def get_trace_db_by_rule(rule=None, rule_id=None): def get_trace_db_by_trigger_instance(trigger_instance=None, trigger_instance_id=None): if trigger_instance: trigger_instance_id = str(trigger_instance.id) - return _get_single_trace_by_component(trigger_instances__object_id=trigger_instance_id) + return _get_single_trace_by_component( + trigger_instances__object_id=trigger_instance_id + ) def get_trace(trace_context, ignore_trace_tag=False): @@ -111,16 +118,20 @@ def get_trace(trace_context, ignore_trace_tag=False): trace_context = _get_valid_trace_context(trace_context) if not trace_context.id_ and not trace_context.trace_tag: - raise ValueError('Atleast one of id_ or trace_tag should be specified.') + raise ValueError("Atleast one of id_ or trace_tag should be specified.") if trace_context.id_: try: return Trace.get_by_id(trace_context.id_) except (ValidationError, ValueError): - LOG.warning('Database lookup for Trace with id="%s" failed.', - trace_context.id_, exc_info=True) + LOG.warning( + 'Database lookup for Trace with id="%s" failed.', + trace_context.id_, + exc_info=True, + ) raise StackStormDBObjectNotFoundError( - 'Unable to find Trace with id="%s"' % trace_context.id_) + 'Unable to find Trace with id="%s"' % trace_context.id_ + ) if ignore_trace_tag: return None @@ -130,7 +141,8 @@ def get_trace(trace_context, ignore_trace_tag=False): # Assume this method only handles 1 trace. if len(traces) > 1: raise UniqueTraceNotFoundException( - 'More than 1 Trace matching %s found.' % trace_context.trace_tag) + "More than 1 Trace matching %s found." % trace_context.trace_tag + ) return traces[0] @@ -168,14 +180,17 @@ def get_trace_db_by_live_action(liveaction): # This cover case for child execution of a workflow. parent_context = executions.get_parent_context(liveaction_db=liveaction) if not trace_context and parent_context: - parent_execution_id = parent_context.get('execution_id', None) + parent_execution_id = parent_context.get("execution_id", None) if parent_execution_id: # go straight to a trace_db. If there is a parent execution then that must # be associated with a Trace. - trace_db = get_trace_db_by_action_execution(action_execution_id=parent_execution_id) + trace_db = get_trace_db_by_action_execution( + action_execution_id=parent_execution_id + ) if not trace_db: - raise StackStormDBObjectNotFoundError('No trace found for execution %s' % - parent_execution_id) + raise StackStormDBObjectNotFoundError( + "No trace found for execution %s" % parent_execution_id + ) return (created, trace_db) # 3. Check if the action_execution associated with liveaction leads to a trace_db execution = ActionExecution.get(liveaction__id=str(liveaction.id)) @@ -184,13 +199,14 @@ def get_trace_db_by_live_action(liveaction): # 4. No trace_db found, therefore create one. This typically happens # when execution is run by hand. if not trace_db: - trace_db = TraceDB(trace_tag='execution-%s' % str(liveaction.id)) + trace_db = TraceDB(trace_tag="execution-%s" % str(liveaction.id)) created = True return (created, trace_db) -def add_or_update_given_trace_context(trace_context, action_executions=None, rules=None, - trigger_instances=None): +def add_or_update_given_trace_context( + trace_context, action_executions=None, rules=None, trigger_instances=None +): """ Will update an existing Trace or add a new Trace. This method will only look for exact Trace as identified by the trace_context. Even if the trace_context contain a trace_tag @@ -222,14 +238,17 @@ def add_or_update_given_trace_context(trace_context, action_executions=None, rul # since trace_db is None need to end up with a valid trace_context trace_context = _get_valid_trace_context(trace_context) trace_db = TraceDB(trace_tag=trace_context.trace_tag) - return add_or_update_given_trace_db(trace_db=trace_db, - action_executions=action_executions, - rules=rules, - trigger_instances=trigger_instances) + return add_or_update_given_trace_db( + trace_db=trace_db, + action_executions=action_executions, + rules=rules, + trigger_instances=trigger_instances, + ) -def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, - trigger_instances=None): +def add_or_update_given_trace_db( + trace_db, action_executions=None, rules=None, trigger_instances=None +): """ Will update an existing Trace. @@ -251,12 +270,14 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, :rtype: ``TraceDB`` """ if trace_db is None: - raise ValueError('trace_db should be non-None.') + raise ValueError("trace_db should be non-None.") if not action_executions: action_executions = [] - action_executions = [_to_trace_component_db(component=action_execution) - for action_execution in action_executions] + action_executions = [ + _to_trace_component_db(component=action_execution) + for action_execution in action_executions + ] if not rules: rules = [] @@ -264,16 +285,20 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, if not trigger_instances: trigger_instances = [] - trigger_instances = [_to_trace_component_db(component=trigger_instance) - for trigger_instance in trigger_instances] + trigger_instances = [ + _to_trace_component_db(component=trigger_instance) + for trigger_instance in trigger_instances + ] # If an id exists then this is an update and we do not want to perform # an upsert so use push_components which will use the push operator. if trace_db.id: - return Trace.push_components(trace_db, - action_executions=action_executions, - rules=rules, - trigger_instances=trigger_instances) + return Trace.push_components( + trace_db, + action_executions=action_executions, + rules=rules, + trigger_instances=trigger_instances, + ) trace_db.action_executions = action_executions trace_db.rules = rules @@ -295,23 +320,25 @@ def get_trace_component_for_action_execution(action_execution_db, liveaction_db) :rtype: ``dict`` """ if not action_execution_db: - raise ValueError('action_execution_db expected.') + raise ValueError("action_execution_db expected.") trace_component = { - 'id': str(action_execution_db.id), - 'ref': str(action_execution_db.action.get('ref', '')) + "id": str(action_execution_db.id), + "ref": str(action_execution_db.action.get("ref", "")), } caused_by = {} parent_context = executions.get_parent_context(liveaction_db=liveaction_db) if liveaction_db and parent_context: - caused_by['type'] = 'action_execution' - caused_by['id'] = liveaction_db.context['parent'].get('execution_id', None) + caused_by["type"] = "action_execution" + caused_by["id"] = liveaction_db.context["parent"].get("execution_id", None) elif action_execution_db.rule and action_execution_db.trigger_instance: # Once RuleEnforcement is available that can be used instead. - caused_by['type'] = 'rule' - caused_by['id'] = '%s:%s' % (action_execution_db.rule['id'], - action_execution_db.trigger_instance['id']) + caused_by["type"] = "rule" + caused_by["id"] = "%s:%s" % ( + action_execution_db.rule["id"], + action_execution_db.trigger_instance["id"], + ) - trace_component['caused_by'] = caused_by + trace_component["caused_by"] = caused_by return trace_component @@ -328,13 +355,13 @@ def get_trace_component_for_rule(rule_db, trigger_instance_db): :rtype: ``dict`` """ trace_component = {} - trace_component = {'id': str(rule_db.id), 'ref': rule_db.ref} + trace_component = {"id": str(rule_db.id), "ref": rule_db.ref} caused_by = {} if trigger_instance_db: # Once RuleEnforcement is available that can be used instead. - caused_by['type'] = 'trigger_instance' - caused_by['id'] = str(trigger_instance_db.id) - trace_component['caused_by'] = caused_by + caused_by["type"] = "trigger_instance" + caused_by["id"] = str(trigger_instance_db.id) + trace_component["caused_by"] = caused_by return trace_component @@ -349,18 +376,20 @@ def get_trace_component_for_trigger_instance(trigger_instance_db): """ trace_component = {} trace_component = { - 'id': str(trigger_instance_db.id), - 'ref': trigger_instance_db.trigger + "id": str(trigger_instance_db.id), + "ref": trigger_instance_db.trigger, } caused_by = {} # Special handling for ACTION_SENSOR_TRIGGER and NOTIFY_TRIGGER where we # know how to maintain the links. - if trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF or \ - trigger_instance_db.trigger == NOTIFY_TRIGGER_REF: - caused_by['type'] = 'action_execution' + if ( + trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF + or trigger_instance_db.trigger == NOTIFY_TRIGGER_REF + ): + caused_by["type"] = "action_execution" # For both action trigger and notidy trigger execution_id is stored in the payload. - caused_by['id'] = trigger_instance_db.payload['execution_id'] - trace_component['caused_by'] = caused_by + caused_by["id"] = trigger_instance_db.payload["execution_id"] + trace_component["caused_by"] = caused_by return trace_component @@ -376,10 +405,12 @@ def _to_trace_component_db(component): """ if not isinstance(component, (six.string_types, dict)): print(type(component)) - raise ValueError('Expected component to be str or dict') + raise ValueError("Expected component to be str or dict") - object_id = component if isinstance(component, six.string_types) else component['id'] - ref = component.get('ref', '') if isinstance(component, dict) else '' - caused_by = component.get('caused_by', {}) if isinstance(component, dict) else {} + object_id = ( + component if isinstance(component, six.string_types) else component["id"] + ) + ref = component.get("ref", "") if isinstance(component, dict) else "" + caused_by = component.get("caused_by", {}) if isinstance(component, dict) else {} return TraceComponentDB(object_id=object_id, ref=ref, caused_by=caused_by) diff --git a/st2common/st2common/services/trigger_dispatcher.py b/st2common/st2common/services/trigger_dispatcher.py index 6843a1eb74..6343a555b9 100644 --- a/st2common/st2common/services/trigger_dispatcher.py +++ b/st2common/st2common/services/trigger_dispatcher.py @@ -23,9 +23,7 @@ from st2common.transport.reactor import TriggerDispatcher from st2common.validators.api.reactor import validate_trigger_payload -__all__ = [ - 'TriggerDispatcherService' -] +__all__ = ["TriggerDispatcherService"] class TriggerDispatcherService(object): @@ -37,7 +35,9 @@ def __init__(self, logger): self._logger = logger self._dispatcher = TriggerDispatcher(self._logger) - def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False): + def dispatch( + self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False + ): """ Method which dispatches the trigger. @@ -56,12 +56,19 @@ def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_er """ # empty strings trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None - self._logger.debug('Added trace_context %s to trigger %s.', trace_context, trigger) - return self.dispatch_with_context(trigger, payload=payload, trace_context=trace_context, - throw_on_validation_error=throw_on_validation_error) - - def dispatch_with_context(self, trigger, payload=None, trace_context=None, - throw_on_validation_error=False): + self._logger.debug( + "Added trace_context %s to trigger %s.", trace_context, trigger + ) + return self.dispatch_with_context( + trigger, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=throw_on_validation_error, + ) + + def dispatch_with_context( + self, trigger, payload=None, trace_context=None, throw_on_validation_error=False + ): """ Method which dispatches the trigger. @@ -81,18 +88,25 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None, # Note: We perform validation even if it's disabled in the config so we can at least warn # the user if validation fals (but not throw if it's disabled) try: - validate_trigger_payload(trigger_type_ref=trigger, payload=payload, - throw_on_inexistent_trigger=True) + validate_trigger_payload( + trigger_type_ref=trigger, + payload=payload, + throw_on_inexistent_trigger=True, + ) except (ValidationError, ValueError, Exception) as e: - self._logger.warn('Failed to validate payload (%s) for trigger "%s": %s' % - (str(payload), trigger, six.text_type(e))) + self._logger.warn( + 'Failed to validate payload (%s) for trigger "%s": %s' + % (str(payload), trigger, six.text_type(e)) + ) # If validation is disabled, still dispatch a trigger even if it failed validation # This condition prevents unexpected restriction. if cfg.CONF.system.validate_trigger_payload: - msg = ('Trigger payload validation failed and validation is enabled, not ' - 'dispatching a trigger "%s" (%s): %s' % (trigger, str(payload), - six.text_type(e))) + msg = ( + "Trigger payload validation failed and validation is enabled, not " + 'dispatching a trigger "%s" (%s): %s' + % (trigger, str(payload), six.text_type(e)) + ) if throw_on_validation_error: raise ValueError(msg) @@ -100,5 +114,7 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None, self._logger.warn(msg) return None - self._logger.debug('Dispatching trigger %s with payload %s.', trigger, payload) - return self._dispatcher.dispatch(trigger, payload=payload, trace_context=trace_context) + self._logger.debug("Dispatching trigger %s with payload %s.", trigger, payload) + return self._dispatcher.dispatch( + trigger, payload=payload, trace_context=trace_context + ) diff --git a/st2common/st2common/services/triggers.py b/st2common/st2common/services/triggers.py index 6448aa2533..bbdce26b81 100644 --- a/st2common/st2common/services/triggers.py +++ b/st2common/st2common/services/triggers.py @@ -23,25 +23,22 @@ from st2common.exceptions.triggers import TriggerDoesNotExistException from st2common.exceptions.db import StackStormDBObjectNotFoundError from st2common.exceptions.db import StackStormDBObjectConflictError -from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI) +from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI from st2common.models.system.common import ResourceReference -from st2common.persistence.trigger import (Trigger, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerType __all__ = [ - 'add_trigger_models', - - 'get_trigger_db_by_ref', - 'get_trigger_db_by_id', - 'get_trigger_db_by_uid', - 'get_trigger_db_by_ref_or_dict', - 'get_trigger_db_given_type_and_params', - 'get_trigger_type_db', - - 'create_trigger_db', - 'create_trigger_type_db', - - 'create_or_update_trigger_db', - 'create_or_update_trigger_type_db' + "add_trigger_models", + "get_trigger_db_by_ref", + "get_trigger_db_by_id", + "get_trigger_db_by_uid", + "get_trigger_db_by_ref_or_dict", + "get_trigger_db_given_type_and_params", + "get_trigger_type_db", + "create_trigger_db", + "create_trigger_type_db", + "create_or_update_trigger_db", + "create_or_update_trigger_type_db", ] LOG = logging.getLogger(__name__) @@ -50,8 +47,7 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): try: parameters = parameters or {} - trigger_dbs = Trigger.query(type=type, - parameters=parameters) + trigger_dbs = Trigger.query(type=type, parameters=parameters) trigger_db = trigger_dbs[0] if len(trigger_dbs) > 0 else None @@ -59,23 +55,24 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): # pymongo and mongoengine # Work around for cron-timer when in some scenarios finding an object fails when Python # value types are unicode :/ - is_cron_trigger = (type == CRON_TIMER_TRIGGER_REF) + is_cron_trigger = type == CRON_TIMER_TRIGGER_REF has_parameters = bool(parameters) if not trigger_db and six.PY2 and is_cron_trigger and has_parameters: non_unicode_literal_parameters = {} for key, value in six.iteritems(parameters): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(value, six.text_type): # We only encode unicode to str - value = value.encode('utf-8') + value = value.encode("utf-8") non_unicode_literal_parameters[key] = value parameters = non_unicode_literal_parameters - trigger_dbs = Trigger.query(type=type, - parameters=non_unicode_literal_parameters).no_cache() + trigger_dbs = Trigger.query( + type=type, parameters=non_unicode_literal_parameters + ).no_cache() # Note: We need to directly access the object, using len or accessing the query set # twice won't work - there seems to bug a bug with cursor where accessing it twice @@ -93,8 +90,14 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): return trigger_db except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for type="%s" parameters="%s" resulted ' + - 'in exception : %s.', type, parameters, e, exc_info=True) + LOG.debug( + 'Database lookup for type="%s" parameters="%s" resulted ' + + "in exception : %s.", + type, + parameters, + e, + exc_info=True, + ) return None @@ -109,26 +112,30 @@ def get_trigger_db_by_ref_or_dict(trigger): else: # If id / uid is available we try to look up Trigger by id. This way we can avoid bug in # pymongo / mongoengine related to "parameters" dictionary lookups - trigger_id = trigger.get('id', None) - trigger_uid = trigger.get('uid', None) + trigger_id = trigger.get("id", None) + trigger_uid = trigger.get("uid", None) # TODO: Remove parameters dictionary look up when we can confirm each trigger dictionary # passed to this method always contains id or uid if trigger_id: - LOG.debug('Looking up TriggerDB by id: %s', trigger_id) + LOG.debug("Looking up TriggerDB by id: %s", trigger_id) trigger_db = get_trigger_db_by_id(id=trigger_id) elif trigger_uid: - LOG.debug('Looking up TriggerDB by uid: %s', trigger_uid) + LOG.debug("Looking up TriggerDB by uid: %s", trigger_uid) trigger_db = get_trigger_db_by_uid(uid=trigger_uid) else: # Last resort - look it up by parameters - trigger_type = trigger.get('type', None) - parameters = trigger.get('parameters', {}) - - LOG.debug('Looking up TriggerDB by type and parameters: type=%s, parameters=%s', - trigger_type, parameters) - trigger_db = get_trigger_db_given_type_and_params(type=trigger_type, - parameters=parameters) + trigger_type = trigger.get("type", None) + parameters = trigger.get("parameters", {}) + + LOG.debug( + "Looking up TriggerDB by type and parameters: type=%s, parameters=%s", + trigger_type, + parameters, + ) + trigger_db = get_trigger_db_given_type_and_params( + type=trigger_type, parameters=parameters + ) return trigger_db @@ -145,8 +152,12 @@ def get_trigger_db_by_id(id): try: return Trigger.get_by_id(id) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for id="%s" resulted in exception : %s.', - id, e, exc_info=True) + LOG.debug( + 'Database lookup for id="%s" resulted in exception : %s.', + id, + e, + exc_info=True, + ) return None @@ -163,8 +174,12 @@ def get_trigger_db_by_uid(uid): try: return Trigger.get_by_uid(uid) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for uid="%s" resulted in exception : %s.', - uid, e, exc_info=True) + LOG.debug( + 'Database lookup for uid="%s" resulted in exception : %s.', + uid, + e, + exc_info=True, + ) return None @@ -181,8 +196,12 @@ def get_trigger_db_by_ref(ref): try: return Trigger.get_by_ref(ref) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None @@ -192,16 +211,17 @@ def _get_trigger_db(trigger): # XXX: Do not make this method public. if isinstance(trigger, dict): - name = trigger.get('name', None) - pack = trigger.get('pack', None) + name = trigger.get("name", None) + pack = trigger.get("pack", None) if name and pack: ref = ResourceReference.to_string_reference(name=name, pack=pack) return get_trigger_db_by_ref(ref) - return get_trigger_db_given_type_and_params(type=trigger['type'], - parameters=trigger.get('parameters', {})) + return get_trigger_db_given_type_and_params( + type=trigger["type"], parameters=trigger.get("parameters", {}) + ) else: - raise Exception('Unrecognized object') + raise Exception("Unrecognized object") def get_trigger_type_db(ref): @@ -216,8 +236,12 @@ def get_trigger_type_db(ref): try: return TriggerType.get_by_ref(ref) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None @@ -225,22 +249,23 @@ def get_trigger_type_db(ref): def _get_trigger_dict_given_rule(rule): trigger = rule.trigger trigger_dict = {} - triggertype_ref = ResourceReference.from_string_reference(trigger.get('type')) - trigger_dict['pack'] = trigger_dict.get('pack', triggertype_ref.pack) - trigger_dict['type'] = triggertype_ref.ref - trigger_dict['parameters'] = rule.trigger.get('parameters', {}) + triggertype_ref = ResourceReference.from_string_reference(trigger.get("type")) + trigger_dict["pack"] = trigger_dict.get("pack", triggertype_ref.pack) + trigger_dict["type"] = triggertype_ref.ref + trigger_dict["parameters"] = rule.trigger.get("parameters", {}) return trigger_dict def create_trigger_db(trigger_api): # TODO: This is used only in trigger API controller. We should get rid of this. - trigger_ref = ResourceReference.to_string_reference(name=trigger_api.name, - pack=trigger_api.pack) + trigger_ref = ResourceReference.to_string_reference( + name=trigger_api.name, pack=trigger_api.pack + ) trigger_db = get_trigger_db_by_ref(trigger_ref) if not trigger_db: trigger_db = TriggerAPI.to_model(trigger_api) - LOG.debug('Verified trigger and formulated TriggerDB=%s', trigger_db) + LOG.debug("Verified trigger and formulated TriggerDB=%s", trigger_db) trigger_db = Trigger.add_or_update(trigger_db) return trigger_db @@ -269,15 +294,16 @@ def create_or_update_trigger_db(trigger, log_not_unique_error_as_debug=False): if is_update: trigger_db.id = existing_trigger_db.id - trigger_db = Trigger.add_or_update(trigger_db, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + trigger_db = Trigger.add_or_update( + trigger_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) - extra = {'trigger_db': trigger_db} + extra = {"trigger_db": trigger_db} if is_update: - LOG.audit('Trigger updated. Trigger.id=%s' % (trigger_db.id), extra=extra) + LOG.audit("Trigger updated. Trigger.id=%s" % (trigger_db.id), extra=extra) else: - LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra) + LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra) return trigger_db @@ -288,10 +314,11 @@ def create_trigger_db_from_rule(rule): # For simple triggertypes (triggertype with no parameters), we create a trigger when # registering triggertype. So if we hit the case that there is no trigger in db but # parameters is empty, then this case is a run time error. - if not trigger_dict.get('parameters', {}) and not existing_trigger_db: + if not trigger_dict.get("parameters", {}) and not existing_trigger_db: raise TriggerDoesNotExistException( - 'A simple trigger should have been created when registering ' - 'triggertype. Cannot create trigger: %s.' % (trigger_dict)) + "A simple trigger should have been created when registering " + "triggertype. Cannot create trigger: %s." % (trigger_dict) + ) if not existing_trigger_db: trigger_db = create_or_update_trigger_db(trigger_dict) @@ -316,7 +343,7 @@ def increment_trigger_ref_count(rule_api): trigger_dict = _get_trigger_dict_given_rule(rule_api) # Special reference counting for trigger with parameters. - if trigger_dict.get('parameters', None): + if trigger_dict.get("parameters", None): trigger_db = _get_trigger_db(trigger_dict) Trigger.update(trigger_db, inc__ref_count=1) @@ -326,7 +353,7 @@ def cleanup_trigger_db_for_rule(rule_db): existing_trigger_db = get_trigger_db_by_ref(rule_db.trigger) if not existing_trigger_db or not existing_trigger_db.parameters: # nothing to be done here so moving on. - LOG.debug('ref_count decrement for %s not required.', existing_trigger_db) + LOG.debug("ref_count decrement for %s not required.", existing_trigger_db) return Trigger.update(existing_trigger_db, dec__ref_count=1) Trigger.delete_if_unreferenced(existing_trigger_db) @@ -350,15 +377,17 @@ def create_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False): """ trigger_type_api = TriggerTypeAPI(**trigger_type) trigger_type_api.validate() - ref = ResourceReference.to_string_reference(name=trigger_type_api.name, - pack=trigger_type_api.pack) + ref = ResourceReference.to_string_reference( + name=trigger_type_api.name, pack=trigger_type_api.pack + ) trigger_type_db = get_trigger_type_db(ref) if not trigger_type_db: trigger_type_db = TriggerTypeAPI.to_model(trigger_type_api) - LOG.debug('verified trigger and formulated TriggerDB=%s', trigger_type_db) - trigger_type_db = TriggerType.add_or_update(trigger_type_db, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + LOG.debug("verified trigger and formulated TriggerDB=%s", trigger_type_db) + trigger_type_db = TriggerType.add_or_update( + trigger_type_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) return trigger_type_db @@ -378,16 +407,21 @@ def create_shadow_trigger(trigger_type_db, log_not_unique_error_as_debug=False): trigger_type_ref = trigger_type_db.get_reference().ref if trigger_type_db.parameters_schema: - LOG.debug('Skip shadow trigger for TriggerType with parameters %s.', trigger_type_ref) + LOG.debug( + "Skip shadow trigger for TriggerType with parameters %s.", trigger_type_ref + ) return None - trigger = {'name': trigger_type_db.name, - 'pack': trigger_type_db.pack, - 'type': trigger_type_ref, - 'parameters': {}} + trigger = { + "name": trigger_type_db.name, + "pack": trigger_type_db.pack, + "type": trigger_type_ref, + "parameters": {}, + } - return create_or_update_trigger_db(trigger, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + return create_or_update_trigger_db( + trigger, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False): @@ -412,8 +446,9 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_api.validate() trigger_type_api = TriggerTypeAPI.to_model(trigger_type_api) - ref = ResourceReference.to_string_reference(name=trigger_type_api.name, - pack=trigger_type_api.pack) + ref = ResourceReference.to_string_reference( + name=trigger_type_api.name, pack=trigger_type_api.pack + ) existing_trigger_type_db = get_trigger_type_db(ref) if existing_trigger_type_db: @@ -425,8 +460,10 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_api.id = existing_trigger_type_db.id try: - trigger_type_db = TriggerType.add_or_update(trigger_type_api, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + trigger_type_db = TriggerType.add_or_update( + trigger_type_api, + log_not_unique_error_as_debug=log_not_unique_error_as_debug, + ) except StackStormDBObjectConflictError: # Operation is idempotent and trigger could have already been created by # another process. Ignore object already exists because it simply means @@ -434,26 +471,37 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_db = get_trigger_type_db(ref) is_update = True - extra = {'trigger_type_db': trigger_type_db} + extra = {"trigger_type_db": trigger_type_db} if is_update: - LOG.audit('TriggerType updated. TriggerType.id=%s' % (trigger_type_db.id), extra=extra) + LOG.audit( + "TriggerType updated. TriggerType.id=%s" % (trigger_type_db.id), extra=extra + ) else: - LOG.audit('TriggerType created. TriggerType.id=%s' % (trigger_type_db.id), extra=extra) + LOG.audit( + "TriggerType created. TriggerType.id=%s" % (trigger_type_db.id), extra=extra + ) return trigger_type_db -def _create_trigger_type(pack, name, description=None, payload_schema=None, - parameters_schema=None, tags=None, metadata_file=None): +def _create_trigger_type( + pack, + name, + description=None, + payload_schema=None, + parameters_schema=None, + tags=None, + metadata_file=None, +): trigger_type = { - 'name': name, - 'pack': pack, - 'description': description, - 'payload_schema': payload_schema, - 'parameters_schema': parameters_schema, - 'tags': tags, - 'metadata_file': metadata_file + "name": name, + "pack": pack, + "description": description, + "payload_schema": payload_schema, + "parameters_schema": parameters_schema, + "tags": tags, + "metadata_file": metadata_file, } return create_or_update_trigger_type_db(trigger_type=trigger_type) @@ -464,11 +512,12 @@ def _validate_trigger_type(trigger_type): XXX: We need validator objects that define the required and optional fields. For now, manually check them. """ - required_fields = ['name'] + required_fields = ["name"] for field in required_fields: if field not in trigger_type: - raise TriggerTypeRegistrationException('Invalid trigger type. Missing field "%s"' % - (field)) + raise TriggerTypeRegistrationException( + 'Invalid trigger type. Missing field "%s"' % (field) + ) def _create_trigger(trigger_type): @@ -476,37 +525,46 @@ def _create_trigger(trigger_type): :param trigger_type: TriggerType db object. :type trigger_type: :class:`TriggerTypeDB` """ - if hasattr(trigger_type, 'parameters_schema') and not trigger_type['parameters_schema']: + if ( + hasattr(trigger_type, "parameters_schema") + and not trigger_type["parameters_schema"] + ): trigger_dict = { - 'name': trigger_type.name, - 'pack': trigger_type.pack, - 'type': trigger_type.get_reference().ref + "name": trigger_type.name, + "pack": trigger_type.pack, + "type": trigger_type.get_reference().ref, } try: return create_or_update_trigger_db(trigger=trigger_dict) except: - LOG.exception('Validation failed for Trigger=%s.', trigger_dict) + LOG.exception("Validation failed for Trigger=%s.", trigger_dict) raise TriggerTypeRegistrationException( - 'Unable to create Trigger for TriggerType=%s.' % trigger_type.name) + "Unable to create Trigger for TriggerType=%s." % trigger_type.name + ) else: - LOG.debug('Won\'t create Trigger object as TriggerType %s expects ' + - 'parameters.', trigger_type) + LOG.debug( + "Won't create Trigger object as TriggerType %s expects " + "parameters.", + trigger_type, + ) return None def _add_trigger_models(trigger_type): - pack = trigger_type['pack'] - description = trigger_type['description'] if 'description' in trigger_type else '' - payload_schema = trigger_type['payload_schema'] if 'payload_schema' in trigger_type else {} - parameters_schema = trigger_type['parameters_schema'] \ - if 'parameters_schema' in trigger_type else {} - tags = trigger_type.get('tags', []) - metadata_file = trigger_type.get('metadata_file', None) + pack = trigger_type["pack"] + description = trigger_type["description"] if "description" in trigger_type else "" + payload_schema = ( + trigger_type["payload_schema"] if "payload_schema" in trigger_type else {} + ) + parameters_schema = ( + trigger_type["parameters_schema"] if "parameters_schema" in trigger_type else {} + ) + tags = trigger_type.get("tags", []) + metadata_file = trigger_type.get("metadata_file", None) trigger_type = _create_trigger_type( pack=pack, - name=trigger_type['name'], + name=trigger_type["name"], description=description, payload_schema=payload_schema, parameters_schema=parameters_schema, @@ -526,8 +584,13 @@ def add_trigger_models(trigger_types): :rtype: ``list`` of ``tuple`` (trigger_type, trigger) """ - [r for r in (_validate_trigger_type(trigger_type) - for trigger_type in trigger_types) if r is not None] + [ + r + for r in ( + _validate_trigger_type(trigger_type) for trigger_type in trigger_types + ) + if r is not None + ] result = [] for trigger_type in trigger_types: diff --git a/st2common/st2common/services/triggerwatcher.py b/st2common/st2common/services/triggerwatcher.py index 4830c349f4..b82a46043a 100644 --- a/st2common/st2common/services/triggerwatcher.py +++ b/st2common/st2common/services/triggerwatcher.py @@ -33,8 +33,15 @@ class TriggerWatcher(ConsumerMixin): sleep_interval = 0 # sleep to co-operatively yield after processing each message - def __init__(self, create_handler, update_handler, delete_handler, - trigger_types=None, queue_suffix=None, exclusive=False): + def __init__( + self, + create_handler, + update_handler, + delete_handler, + trigger_types=None, + queue_suffix=None, + exclusive=False, + ): """ :param create_handler: Function which is called on TriggerDB create event. :type create_handler: ``callable`` @@ -69,39 +76,49 @@ def __init__(self, create_handler, update_handler, delete_handler, self._handlers = { publishers.CREATE_RK: create_handler, publishers.UPDATE_RK: update_handler, - publishers.DELETE_RK: delete_handler + publishers.DELETE_RK: delete_handler, } def get_consumers(self, Consumer, channel): - return [Consumer(queues=[self._trigger_watch_q], - accept=['pickle'], - callbacks=[self.process_task])] + return [ + Consumer( + queues=[self._trigger_watch_q], + accept=["pickle"], + callbacks=[self.process_task], + ) + ] def process_task(self, body, message): - LOG.debug('process_task') - LOG.debug(' body: %s', body) - LOG.debug(' message.properties: %s', message.properties) - LOG.debug(' message.delivery_info: %s', message.delivery_info) + LOG.debug("process_task") + LOG.debug(" body: %s", body) + LOG.debug(" message.properties: %s", message.properties) + LOG.debug(" message.delivery_info: %s", message.delivery_info) - routing_key = message.delivery_info.get('routing_key', '') + routing_key = message.delivery_info.get("routing_key", "") handler = self._handlers.get(routing_key, None) try: if not handler: - LOG.debug('Skipping message %s as no handler was found.', message) + LOG.debug("Skipping message %s as no handler was found.", message) return - trigger_type = getattr(body, 'type', None) + trigger_type = getattr(body, "type", None) if self._trigger_types and trigger_type not in self._trigger_types: - LOG.debug('Skipping message %s since trigger_type doesn\'t match (type=%s)', - message, trigger_type) + LOG.debug( + "Skipping message %s since trigger_type doesn't match (type=%s)", + message, + trigger_type, + ) return try: handler(body) except Exception as e: - LOG.exception('Handling failed. Message body: %s. Exception: %s', - body, six.text_type(e)) + LOG.exception( + "Handling failed. Message body: %s. Exception: %s", + body, + six.text_type(e), + ) finally: message.ack() @@ -113,7 +130,7 @@ def start(self): self._updates_thread = concurrency.spawn(self.run) self._load_thread = concurrency.spawn(self._load_triggers_from_db) except: - LOG.exception('Failed to start watcher.') + LOG.exception("Failed to start watcher.") self.connection.release() def stop(self): @@ -128,8 +145,9 @@ def stop(self): # waiting for a message on the queue. def on_consume_end(self, connection, channel): - super(TriggerWatcher, self).on_consume_end(connection=connection, - channel=channel) + super(TriggerWatcher, self).on_consume_end( + connection=connection, channel=channel + ) concurrency.sleep(seconds=self.sleep_interval) def on_iteration(self): @@ -139,13 +157,16 @@ def on_iteration(self): def _load_triggers_from_db(self): for trigger_type in self._trigger_types: for trigger in Trigger.query(type=trigger_type): - LOG.debug('Found existing trigger: %s in db.' % trigger) + LOG.debug("Found existing trigger: %s in db." % trigger) self._handlers[publishers.CREATE_RK](trigger) @staticmethod def _get_queue(queue_suffix, exclusive): - queue_name = queue_utils.get_queue_name(queue_name_base='st2.trigger.watch', - queue_name_suffix=queue_suffix, - add_random_uuid_to_suffix=True - ) - return reactor.get_trigger_cud_queue(queue_name, routing_key='#', exclusive=exclusive) + queue_name = queue_utils.get_queue_name( + queue_name_base="st2.trigger.watch", + queue_name_suffix=queue_suffix, + add_random_uuid_to_suffix=True, + ) + return reactor.get_trigger_cud_queue( + queue_name, routing_key="#", exclusive=exclusive + ) diff --git a/st2common/st2common/services/workflows.py b/st2common/st2common/services/workflows.py index db64c681b1..16ddcb5a02 100644 --- a/st2common/st2common/services/workflows.py +++ b/st2common/st2common/services/workflows.py @@ -54,59 +54,61 @@ LOG = logging.getLogger(__name__) LOG_FUNCTIONS = { - 'audit': LOG.audit, - 'debug': LOG.debug, - 'info': LOG.info, - 'warning': LOG.warning, - 'error': LOG.error, - 'critical': LOG.critical, + "audit": LOG.audit, + "debug": LOG.debug, + "info": LOG.info, + "warning": LOG.warning, + "error": LOG.error, + "critical": LOG.critical, } -def update_progress(wf_ex_db, message, severity='info', log=True, stream=True): +def update_progress(wf_ex_db, message, severity="info", log=True, stream=True): if not wf_ex_db: return if log and severity in LOG_FUNCTIONS: - LOG_FUNCTIONS[severity]('[%s] %s', wf_ex_db.context['st2']['action_execution_id'], message) + LOG_FUNCTIONS[severity]( + "[%s] %s", wf_ex_db.context["st2"]["action_execution_id"], message + ) if stream: ac_svc.store_execution_output_data_ex( - wf_ex_db.context['st2']['action_execution_id'], - wf_ex_db.context['st2']['action'], - wf_ex_db.context['st2']['runner'], - '%s\n' % message, + wf_ex_db.context["st2"]["action_execution_id"], + wf_ex_db.context["st2"]["action"], + wf_ex_db.context["st2"]["runner"], + "%s\n" % message, ) def is_action_execution_under_workflow_context(ac_ex_db): # The action execution is executed under the context of a workflow # if it contains the orquesta key in its context dictionary. - return ac_ex_db.context and 'orquesta' in ac_ex_db.context + return ac_ex_db.context and "orquesta" in ac_ex_db.context def format_inspection_result(result): errors = [] categories = { - 'contents': 'content', - 'context': 'context', - 'expressions': 'expression', - 'semantics': 'semantic', - 'syntax': 'syntax' + "contents": "content", + "context": "context", + "expressions": "expression", + "semantics": "semantic", + "syntax": "syntax", } # For context and expression errors, rename the attribute from type to language. - for category in ['context', 'expressions']: + for category in ["context", "expressions"]: for entry in result.get(category, []): - if 'language' not in entry: - entry['language'] = entry['type'] - del entry['type'] + if "language" not in entry: + entry["language"] = entry["type"] + del entry["type"] # For all categories, put the category value in the type attribute. for category, entries in six.iteritems(result): for entry in entries: - entry['type'] = categories[category] + entry["type"] = categories[category] errors.append(entry) return errors @@ -121,7 +123,7 @@ def inspect(wf_spec, st2_ctx, raise_exception=True): errors += inspect_task_contents(wf_spec) # Sort the list of errors by type and path. - errors = sorted(errors, key=lambda e: (e['type'], e['schema_path'])) + errors = sorted(errors, key=lambda e: (e["type"], e["schema_path"])) if errors and raise_exception: raise orquesta_exc.WorkflowInspectionError(errors) @@ -131,10 +133,10 @@ def inspect(wf_spec, st2_ctx, raise_exception=True): def inspect_task_contents(wf_spec): result = [] - spec_path = 'tasks' - schema_path = 'properties.tasks.patternProperties.^\\w+$' - action_schema_path = schema_path + '.properties.action' - action_input_schema_path = schema_path + '.properties.input' + spec_path = "tasks" + schema_path = "properties.tasks.patternProperties.^\\w+$" + action_schema_path = schema_path + ".properties.action" + action_input_schema_path = schema_path + ".properties.input" def is_action_an_expression(action): if isinstance(action, six.string_types): @@ -143,9 +145,9 @@ def is_action_an_expression(action): return True for task_name, task_spec in six.iteritems(wf_spec.tasks): - action_ref = getattr(task_spec, 'action', None) - action_spec_path = spec_path + '.' + task_name + '.action' - action_input_spec_path = spec_path + '.' + task_name + '.input' + action_ref = getattr(task_spec, "action", None) + action_spec_path = spec_path + "." + task_name + ".action" + action_input_spec_path = spec_path + "." + task_name + ".input" # Move on if action is empty or an expression. if not action_ref or is_action_an_expression(action_ref): @@ -154,10 +156,11 @@ def is_action_an_expression(action): # Check that the format of the action is a valid resource reference. if not sys_models.ResourceReference.is_resource_reference(action_ref): entry = { - 'type': 'content', - 'message': 'The action reference "%s" is not formatted correctly.' % action_ref, - 'spec_path': action_spec_path, - 'schema_path': action_schema_path + "type": "content", + "message": 'The action reference "%s" is not formatted correctly.' + % action_ref, + "spec_path": action_spec_path, + "schema_path": action_schema_path, } result.append(entry) @@ -166,31 +169,37 @@ def is_action_an_expression(action): # Check that the action is registered in the database. if not action_utils.get_action_by_ref(ref=action_ref): entry = { - 'type': 'content', - 'message': 'The action "%s" is not registered in the database.' % action_ref, - 'spec_path': action_spec_path, - 'schema_path': action_schema_path + "type": "content", + "message": 'The action "%s" is not registered in the database.' + % action_ref, + "spec_path": action_spec_path, + "schema_path": action_schema_path, } result.append(entry) continue # Check the action parameters. - params = getattr(task_spec, 'input', None) or {} + params = getattr(task_spec, "input", None) or {} if params and not isinstance(params, dict): continue - requires, unexpected = action_param_utils.validate_action_parameters(action_ref, params) + requires, unexpected = action_param_utils.validate_action_parameters( + action_ref, params + ) for param in requires: - message = 'Action "%s" is missing required input "%s".' % (action_ref, param) + message = 'Action "%s" is missing required input "%s".' % ( + action_ref, + param, + ) entry = { - 'type': 'content', - 'message': message, - 'spec_path': action_input_spec_path, - 'schema_path': action_input_schema_path + "type": "content", + "message": message, + "spec_path": action_input_spec_path, + "schema_path": action_input_schema_path, } result.append(entry) @@ -199,10 +208,10 @@ def is_action_an_expression(action): message = 'Action "%s" has unexpected input "%s".' % (action_ref, param) entry = { - 'type': 'content', - 'message': message, - 'spec_path': action_input_spec_path + '.' + param, - 'schema_path': action_input_schema_path + '.patternProperties.^\\w+$' + "type": "content", + "message": message, + "spec_path": action_input_spec_path + "." + param, + "schema_path": action_input_schema_path + ".patternProperties.^\\w+$", } result.append(entry) @@ -211,35 +220,35 @@ def is_action_an_expression(action): def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): - LOG.info('[%s] Processing action execution request for workflow.', str(ac_ex_db.id)) + LOG.info("[%s] Processing action execution request for workflow.", str(ac_ex_db.id)) # Load workflow definition into workflow spec model. - spec_module = specs_loader.get_spec_module('native') + spec_module = specs_loader.get_spec_module("native") wf_spec = spec_module.instantiate(wf_def) # Inspect the workflow spec. inspect(wf_spec, st2_ctx, raise_exception=True) # Identify the action to execute. - action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action['ref']) + action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action["ref"]) if not action_db: - error = 'Unable to find action "%s".' % ac_ex_db.action['ref'] + error = 'Unable to find action "%s".' % ac_ex_db.action["ref"] raise ac_exc.InvalidActionReferencedException(error) # Identify the runner for the action. - runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) # Render action execution parameters. runner_params, action_params = param_utils.render_final_params( runner_type_db.runner_parameters, action_db.parameters, ac_ex_db.parameters, - ac_ex_db.context + ac_ex_db.context, ) # Instantiate the workflow conductor. - conductor_params = {'inputs': action_params, 'context': st2_ctx} + conductor_params = {"inputs": action_params, "context": st2_ctx} conductor = conducting.WorkflowConductor(wf_spec, **conductor_params) # Serialize the conductor which initializes some internal values. @@ -248,33 +257,32 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): # Create a record for workflow execution. wf_ex_db = wf_db_models.WorkflowExecutionDB( action_execution=str(ac_ex_db.id), - spec=data['spec'], - graph=data['graph'], - input=data['input'], - context=data['context'], - state=data['state'], - status=data['state']['status'], - output=data['output'], - errors=data['errors'] + spec=data["spec"], + graph=data["graph"], + input=data["input"], + context=data["context"], + state=data["state"], + status=data["state"]["status"], + output=data["output"], + errors=data["errors"], ) # Inspect that the list of tasks in the notify parameter exist in the workflow spec. - if runner_params.get('notify'): - invalid_tasks = list(set(runner_params.get('notify')) - set(wf_spec.tasks.keys())) + if runner_params.get("notify"): + invalid_tasks = list( + set(runner_params.get("notify")) - set(wf_spec.tasks.keys()) + ) if invalid_tasks: raise wf_exc.WorkflowExecutionException( - 'The following tasks in the notify parameter do not exist ' - 'in the workflow definition: %s.' % ', '.join(invalid_tasks) + "The following tasks in the notify parameter do not exist " + "in the workflow definition: %s." % ", ".join(invalid_tasks) ) # Write notify instruction to record. if notify_cfg: # Set up the notify instruction in the workflow execution record. - wf_ex_db.notify = { - 'config': notify_cfg, - 'tasks': runner_params.get('notify') - } + wf_ex_db.notify = {"config": notify_cfg, "tasks": runner_params.get("notify")} # Insert new record into the database and do not publish to the message bus yet. wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False) @@ -286,12 +294,12 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): # Set the initial workflow status to requested. conductor.request_workflow_status(statuses.REQUESTED) data = conductor.serialize() - wf_ex_db.state = data['state'] - wf_ex_db.status = data['state']['status'] + wf_ex_db.state = data["state"] + wf_ex_db.status = data["state"]["status"] # Put the ID of the workflow execution record in the context. - wf_ex_db.context['st2']['workflow_execution_id'] = str(wf_ex_db.id) - wf_ex_db.state['contexts'][0]['st2']['workflow_execution_id'] = str(wf_ex_db.id) + wf_ex_db.context["st2"]["workflow_execution_id"] = str(wf_ex_db.id) + wf_ex_db.state["contexts"][0]["st2"]["workflow_execution_id"] = str(wf_ex_db.id) # Update the workflow execution record. wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) @@ -308,15 +316,17 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_pause(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing pause request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing pause request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -343,7 +353,7 @@ def request_pause(ac_ex_db): wf_ex_db.state = conductor.workflow_state.serialize() wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) - LOG.info('[%s] Completed processing pause request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Completed processing pause request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -351,15 +361,17 @@ def request_pause(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_resume(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing resume request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing resume request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -375,7 +387,9 @@ def request_resume(ac_ex_db): raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id)) if wf_ex_db.status in statuses.RUNNING_STATUSES: - msg = '[%s] Workflow execution "%s" is not resumed because it is already active.' + msg = ( + '[%s] Workflow execution "%s" is not resumed because it is already active.' + ) LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id)) return @@ -385,7 +399,9 @@ def request_resume(ac_ex_db): raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id)) if conductor.get_workflow_status() in statuses.RUNNING_STATUSES: - msg = '[%s] Workflow execution "%s" is not resumed because it is already active.' + msg = ( + '[%s] Workflow execution "%s" is not resumed because it is already active.' + ) LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id)) return @@ -400,7 +416,7 @@ def request_resume(ac_ex_db): # Publish status change. wf_db_access.WorkflowExecution.publish_status(wf_ex_db) - LOG.info('[%s] Completed processing resume request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Completed processing resume request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -408,15 +424,17 @@ def request_resume(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_cancellation(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing cancelation request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing cancelation request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -446,13 +464,16 @@ def request_cancellation(ac_ex_db): # Cascade the cancellation up to the root of the workflow. root_ac_ex_db = ac_svc.get_root_execution(ac_ex_db) - if root_ac_ex_db != ac_ex_db and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES: - LOG.info('[%s] Cascading cancelation request to parent workflow.', wf_ac_ex_id) - root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction['id']) + if ( + root_ac_ex_db != ac_ex_db + and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES + ): + LOG.info("[%s] Cascading cancelation request to parent workflow.", wf_ac_ex_id) + root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction["id"]) ac_svc.request_cancellation(root_lv_ac_db, None) - LOG.debug('[%s] %s', wf_ac_ex_id, conductor.serialize()) - LOG.info('[%s] Completed processing cancelation request for workflow.', wf_ac_ex_id) + LOG.debug("[%s] %s", wf_ac_ex_id, conductor.serialize()) + LOG.info("[%s] Completed processing cancelation request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -460,20 +481,22 @@ def request_cancellation(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_rerun(ac_ex_db, st2_ctx, options=None): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing rerun request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing rerun request for workflow.", wf_ac_ex_id) - wf_ex_id = st2_ctx.get('workflow_execution_id') + wf_ex_id = st2_ctx.get("workflow_execution_id") if not wf_ex_id: - msg = 'Unable to rerun workflow execution because workflow_execution_id is not provided.' + msg = "Unable to rerun workflow execution because workflow_execution_id is not provided." raise wf_exc.WorkflowExecutionRerunException(msg) try: @@ -487,8 +510,8 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id) wf_ex_db.action_execution = wf_ac_ex_id - wf_ex_db.context['st2'] = st2_ctx['st2'] - wf_ex_db.context['parent'] = st2_ctx['parent'] + wf_ex_db.context["st2"] = st2_ctx["st2"] + wf_ex_db.context["parent"] = st2_ctx["parent"] conductor = deserialize_conductor(wf_ex_db) try: @@ -497,26 +520,29 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): if options: task_requests = [] - task_names = options.get('tasks', []) - task_resets = options.get('reset', []) + task_names = options.get("tasks", []) + task_resets = options.get("reset", []) for task_name in task_names: reset_items = task_name in task_resets - task_state_entries = conductor.workflow_state.get_tasks(task_id=task_name) + task_state_entries = conductor.workflow_state.get_tasks( + task_id=task_name + ) if not task_state_entries: problems.append(task_name) continue for _, task_state_entry in task_state_entries: - route = task_state_entry['route'] + route = task_state_entry["route"] req = orquesta_reqs.TaskRerunRequest.new( - task_name, route, reset_items=reset_items) + task_name, route, reset_items=reset_items + ) task_requests.append(req) if problems: - msg = 'Unable to rerun workflow because one or more tasks is not found: %s' - raise Exception(msg % ','.join(problems)) + msg = "Unable to rerun workflow because one or more tasks is not found: %s" + raise Exception(msg % ",".join(problems)) conductor.request_workflow_rerun(task_requests=task_requests) except Exception as e: @@ -527,10 +553,10 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id) data = conductor.serialize() - wf_ex_db.status = data['state']['status'] - wf_ex_db.spec = data['spec'] - wf_ex_db.graph = data['graph'] - wf_ex_db.state = data['state'] + wf_ex_db.status = data["state"]["status"] + wf_ex_db.spec = data["spec"] + wf_ex_db.graph = data["graph"] + wf_ex_db.state = data["state"] wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) @@ -542,12 +568,12 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): - task_id = task_ex_req['id'] - task_route = task_ex_req['route'] - task_spec = task_ex_req['spec'] - task_ctx = task_ex_req['ctx'] - task_actions = task_ex_req['actions'] - task_delay = task_ex_req.get('delay') + task_id = task_ex_req["id"] + task_route = task_ex_req["route"] + task_spec = task_ex_req["spec"] + task_ctx = task_ex_req["ctx"] + task_actions = task_ex_req["actions"] + task_delay = task_ex_req.get("delay") msg = 'Processing task execution request for task "%s", route "%s".' update_progress(wf_ex_db, msg % (task_id, str(task_route)), stream=False) @@ -557,11 +583,14 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): workflow_execution=str(wf_ex_db.id), task_id=task_id, task_route=task_route, - order_by=['-start_timestamp'] + order_by=["-start_timestamp"], ) - if (len(task_ex_dbs) > 0 and task_ex_dbs[0].itemized and - task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING): + if ( + len(task_ex_dbs) > 0 + and task_ex_dbs[0].itemized + and task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING + ): task_ex_db = task_ex_dbs[0] task_ex_id = str(task_ex_db.id) msg = 'Task execution "%s" retrieved for task "%s", route "%s".' @@ -576,15 +605,15 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): task_spec=task_spec.serialize(), delay=task_delay, itemized=task_spec.has_items(), - items_count=task_ex_req.get('items_count'), - items_concurrency=task_ex_req.get('concurrency'), + items_count=task_ex_req.get("items_count"), + items_concurrency=task_ex_req.get("concurrency"), context=task_ctx, - status=statuses.REQUESTED + status=statuses.REQUESTED, ) # Prepare the result format for itemized task execution. if task_ex_db.itemized: - task_ex_db.result = {'items': [None] * task_ex_db.items_count} + task_ex_db.result = {"items": [None] * task_ex_db.items_count} # Insert new record into the database. task_ex_db = wf_db_access.TaskExecution.insert(task_ex_db, publish=False) @@ -627,26 +656,35 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): # Request action execution for each actions in the task request. for ac_ex_req in task_actions: - ac_ex_delay = eval_action_execution_delay(task_ex_req, ac_ex_req, task_ex_db.itemized) - request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay) + ac_ex_delay = eval_action_execution_delay( + task_ex_req, ac_ex_req, task_ex_db.itemized + ) + request_action_execution( + wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay + ) task_ex_db = wf_db_access.TaskExecution.get_by_id(str(task_ex_db.id)) except Exception as e: msg = 'Failed action execution(s) for task "%s", route "%s".' msg = msg % (task_id, str(task_route)) LOG.exception(msg) - msg = '%s %s: %s' % (msg, type(e).__name__, six.text_type(e)) - update_progress(wf_ex_db, msg, severity='error', log=False) - msg = '%s: %s' % (type(e).__name__, six.text_type(e)) - error = {'type': 'error', 'message': msg, 'task_id': task_id, 'route': task_route} - update_task_execution(str(task_ex_db.id), statuses.FAILED, {'errors': [error]}) + msg = "%s %s: %s" % (msg, type(e).__name__, six.text_type(e)) + update_progress(wf_ex_db, msg, severity="error", log=False) + msg = "%s: %s" % (type(e).__name__, six.text_type(e)) + error = { + "type": "error", + "message": msg, + "task_id": task_id, + "route": task_route, + } + update_task_execution(str(task_ex_db.id), statuses.FAILED, {"errors": [error]}) raise e return task_ex_db def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): - task_ex_delay = task_ex_req.get('delay') - items_concurrency = task_ex_req.get('concurrency') + task_ex_delay = task_ex_req.get("delay") + items_concurrency = task_ex_req.get("concurrency") # If there is a task delay and not with items, return the delay value. if task_ex_delay and not itemized: @@ -658,7 +696,7 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): # If there is a task delay and task has items with concurrency, # return the delay value up if item id is less than the concurrency value. - if task_ex_delay and itemized and ac_ex_req['item_id'] < items_concurrency: + if task_ex_delay and itemized and ac_ex_req["item_id"] < items_concurrency: return task_ex_delay return None @@ -667,20 +705,22 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=None): - action_ref = ac_ex_req['action'] - action_input = ac_ex_req['input'] - item_id = ac_ex_req.get('item_id') + action_ref = ac_ex_req["action"] + action_input = ac_ex_req["input"] + item_id = ac_ex_req.get("item_id") # If the task is with items and item_id is not provided, raise exception. if task_ex_db.itemized and item_id is None: - msg = 'Unable to request action execution. Identifier for the item is not provided.' + msg = "Unable to request action execution. Identifier for the item is not provided." raise Exception(msg) # Identify the action to execute. @@ -691,40 +731,40 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non raise ac_exc.InvalidActionReferencedException(error) # Identify the runner for the action. - runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) # Identify action pack name - pack_name = action_ref.split('.')[0] if action_ref else st2_ctx.get('pack') + pack_name = action_ref.split(".")[0] if action_ref else st2_ctx.get("pack") # Set context for the action execution. ac_ex_ctx = { - 'pack': pack_name, - 'user': st2_ctx.get('user'), - 'parent': st2_ctx, - 'orquesta': { - 'workflow_execution_id': str(wf_ex_db.id), - 'task_execution_id': str(task_ex_db.id), - 'task_name': task_ex_db.task_name, - 'task_id': task_ex_db.task_id, - 'task_route': task_ex_db.task_route - } + "pack": pack_name, + "user": st2_ctx.get("user"), + "parent": st2_ctx, + "orquesta": { + "workflow_execution_id": str(wf_ex_db.id), + "task_execution_id": str(task_ex_db.id), + "task_name": task_ex_db.task_name, + "task_id": task_ex_db.task_id, + "task_route": task_ex_db.task_route, + }, } - if st2_ctx.get('api_user'): - ac_ex_ctx['api_user'] = st2_ctx.get('api_user') + if st2_ctx.get("api_user"): + ac_ex_ctx["api_user"] = st2_ctx.get("api_user") - if st2_ctx.get('source_channel'): - ac_ex_ctx['source_channel'] = st2_ctx.get('source_channel') + if st2_ctx.get("source_channel"): + ac_ex_ctx["source_channel"] = st2_ctx.get("source_channel") if item_id is not None: - ac_ex_ctx['orquesta']['item_id'] = item_id + ac_ex_ctx["orquesta"]["item_id"] = item_id # Render action execution parameters and setup action execution object. ac_ex_params = param_utils.render_live_params( runner_type_db.runner_parameters or {}, action_db.parameters or {}, action_input or {}, - ac_ex_ctx + ac_ex_ctx, ) # The delay spec is in seconds and scheduler expects milliseconds. @@ -738,13 +778,19 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non task_execution=str(task_ex_db.id), delay=delay, context=ac_ex_ctx, - parameters=ac_ex_params + parameters=ac_ex_params, ) # Set notification if instructed. - if (wf_ex_db.notify and wf_ex_db.notify.get('config') and - wf_ex_db.notify.get('tasks') and task_ex_db.task_name in wf_ex_db.notify['tasks']): - lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(wf_ex_db.notify['config']) + if ( + wf_ex_db.notify + and wf_ex_db.notify.get("config") + and wf_ex_db.notify.get("tasks") + and task_ex_db.task_name in wf_ex_db.notify["tasks"] + ): + lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model( + wf_ex_db.notify["config"] + ) # Set the task execution to running first otherwise a race can occur # where the action execution finishes first and the completion handler @@ -765,13 +811,13 @@ def handle_action_execution_pending(ac_ex_db): # Check that the action execution is paused. if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PENDING: raise Exception( - 'Unable to handle pending of action execution. The action execution ' + "Unable to handle pending of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -780,14 +826,14 @@ def handle_action_execution_pending(ac_ex_db): msg = 'Handling pending of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context) # Update task flow in the workflow execution. - ac_ex_ctx = ac_ex_db.context.get('orquesta') + ac_ex_ctx = ac_ex_db.context.get("orquesta") update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True) @@ -795,13 +841,13 @@ def handle_action_execution_pause(ac_ex_db): # Check that the action execution is paused. if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PAUSED: raise Exception( - 'Unable to handle pause of action execution. The action execution ' + "Unable to handle pause of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -814,27 +860,27 @@ def handle_action_execution_pause(ac_ex_db): msg = 'Handling pause of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context) # Update task flow in the workflow execution. - ac_ex_ctx = ac_ex_db.context.get('orquesta') + ac_ex_ctx = ac_ex_db.context.get("orquesta") update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True) def handle_action_execution_resume(ac_ex_db): - if 'orquesta' not in ac_ex_db.context: + if "orquesta" not in ac_ex_db.context: raise Exception( - 'Unable to handle resume of action execution. The action execution ' - '%s is not an orquesta workflow task.' % str(ac_ex_db.id) + "Unable to handle resume of action execution. The action execution " + "%s is not an orquesta workflow task." % str(ac_ex_db.id) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -843,7 +889,7 @@ def handle_action_execution_resume(ac_ex_db): msg = 'Handling resume of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution to running. @@ -854,18 +900,22 @@ def handle_action_execution_resume(ac_ex_db): # If action execution has a parent, cascade status change upstream and do not publish # the status change because we do not want to trigger resume of other peer subworkflows. - if 'parent' in ac_ex_db.context: - parent_ac_ex_id = ac_ex_db.context['parent']['execution_id'] + if "parent" in ac_ex_db.context: + parent_ac_ex_id = ac_ex_db.context["parent"]["execution_id"] parent_ac_ex_db = ex_db_access.ActionExecution.get_by_id(parent_ac_ex_id) if parent_ac_ex_db.status == ac_const.LIVEACTION_STATUS_PAUSED: action_utils.update_liveaction_status( - liveaction_id=parent_ac_ex_db.liveaction['id'], + liveaction_id=parent_ac_ex_db.liveaction["id"], status=ac_const.LIVEACTION_STATUS_RUNNING, - publish=False) + publish=False, + ) # If there are grand parents, handle the resume of the parent action execution. - if 'orquesta' in parent_ac_ex_db.context and 'parent' in parent_ac_ex_db.context: + if ( + "orquesta" in parent_ac_ex_db.context + and "parent" in parent_ac_ex_db.context + ): handle_action_execution_resume(parent_ac_ex_db) @@ -873,18 +923,19 @@ def handle_action_execution_resume(ac_ex_db): retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def handle_action_execution_completion(ac_ex_db): # Check that the action execution is completed. if ac_ex_db.status not in ac_const.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Unable to handle completion of action execution. The action execution ' + "Unable to handle completion of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Acquire lock before write operations. with coord_svc.get_coordinator(start_heart=True).get_lock(wf_ex_id): @@ -894,9 +945,12 @@ def handle_action_execution_completion(ac_ex_db): msg = ( 'Handling completion of action execution "%s" ' - 'in status "%s" for task "%s", route "%s".' % ( - str(ac_ex_db.id), ac_ex_db.status, task_ex_db.task_id, - str(task_ex_db.task_route) + 'in status "%s" for task "%s", route "%s".' + % ( + str(ac_ex_db.id), + ac_ex_db.status, + task_ex_db.task_id, + str(task_ex_db.task_route), ) ) update_progress(wf_ex_db, msg) @@ -907,14 +961,16 @@ def handle_action_execution_completion(ac_ex_db): resume_task_execution(task_ex_id) # Update task execution if completed. - update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context) + update_task_execution( + task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context + ) # Update task flow in the workflow execution. update_task_state( task_ex_id, ac_ex_db.status, ac_ex_result=ac_ex_db.result, - ac_ex_ctx=ac_ex_db.context.get('orquesta') + ac_ex_ctx=ac_ex_db.context.get("orquesta"), ) # Request the next set of tasks if workflow execution is not complete. @@ -926,13 +982,13 @@ def handle_action_execution_completion(ac_ex_db): def deserialize_conductor(wf_ex_db): data = { - 'spec': wf_ex_db.spec, - 'graph': wf_ex_db.graph, - 'input': wf_ex_db.input, - 'context': wf_ex_db.context, - 'state': wf_ex_db.state, - 'output': wf_ex_db.output, - 'errors': wf_ex_db.errors + "spec": wf_ex_db.spec, + "graph": wf_ex_db.graph, + "input": wf_ex_db.input, + "context": wf_ex_db.context, + "state": wf_ex_db.state, + "output": wf_ex_db.output, + "errors": wf_ex_db.errors, } return conducting.WorkflowConductor.deserialize(data) @@ -948,18 +1004,22 @@ def refresh_conductor(wf_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) -def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True): + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) +def update_task_state( + task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True +): # Return if action execution status is not in the list of statuses to process. - statuses_to_process = ( - copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) + - [ac_const.LIVEACTION_STATUS_PAUSED, ac_const.LIVEACTION_STATUS_PENDING] - ) + statuses_to_process = copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) + [ + ac_const.LIVEACTION_STATUS_PAUSED, + ac_const.LIVEACTION_STATUS_PENDING, + ] if ac_ex_status not in statuses_to_process: return @@ -973,22 +1033,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), task_ex_db.status) update_progress(wf_ex_db, msg, stream=False) - if not ac_ex_ctx or 'item_id' not in ac_ex_ctx or ac_ex_ctx['item_id'] < 0: + if not ac_ex_ctx or "item_id" not in ac_ex_ctx or ac_ex_ctx["item_id"] < 0: ac_ex_event = events.ActionExecutionEvent(ac_ex_status, result=ac_ex_result) else: accumulated_result = [ - item.get('result') if item else None - for item in task_ex_db.result['items'] + item.get("result") if item else None for item in task_ex_db.result["items"] ] ac_ex_event = events.TaskItemActionExecutionEvent( - ac_ex_ctx['item_id'], + ac_ex_ctx["item_id"], ac_ex_status, result=ac_ex_result, - accumulated_result=accumulated_result + accumulated_result=accumulated_result, ) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) conductor.update_task_state(task_ex_db.task_id, task_ex_db.task_route, ac_ex_event) # Update workflow execution and related liveaction and action execution. @@ -997,19 +1056,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non conductor, update_lv_ac_on_statuses=statuses_to_process, pub_lv_ac=publish, - pub_ac_ex=publish + pub_ac_ex=publish, ) @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_next_tasks(wf_ex_db, task_ex_id=None): iteration = 0 @@ -1018,7 +1079,9 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): # If workflow is in requested status, set it to running. if conductor.get_workflow_status() in [statuses.REQUESTED, statuses.SCHEDULED]: - update_progress(wf_ex_db, 'Requesting conductor to start running workflow execution.') + update_progress( + wf_ex_db, "Requesting conductor to start running workflow execution." + ) conductor.request_workflow_status(statuses.RUNNING) # Identify the list of next set of tasks. Don't pass the task id to the conductor @@ -1028,93 +1091,104 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): msg = 'Identifying next set (iter %s) of tasks after completion of task "%s", route "%s".' msg = msg % (str(iteration), task_ex_db.task_id, str(task_ex_db.task_route)) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() else: msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".' msg = msg % (str(iteration), conductor.get_workflow_status()) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() # If there is no new tasks, update execution records to handle possible completion. if not next_tasks: # Update workflow execution and related liveaction and action execution. - update_progress(wf_ex_db, 'No tasks identified to execute next.') - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "No tasks identified to execute next.") + update_progress(wf_ex_db, "\n", log=False) update_execution_records(wf_ex_db, conductor) if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES: msg = 'The workflow execution is completed with status "%s".' update_progress(wf_ex_db, msg % conductor.get_workflow_status()) - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "\n", log=False) # Iterate while there are next tasks identified for processing. In the case for # task with no action execution defined, the task execution will complete # immediately with a new set of tasks available. while next_tasks: - msg = 'Identified the following set of tasks to execute next: %s' - tasks_list = ', '.join(["%s (route %s)" % (t['id'], str(t['route'])) for t in next_tasks]) + msg = "Identified the following set of tasks to execute next: %s" + tasks_list = ", ".join( + ["%s (route %s)" % (t["id"], str(t["route"])) for t in next_tasks] + ) update_progress(wf_ex_db, msg % tasks_list) # Mark the tasks as running in the task flow before actual task execution. for task in next_tasks: msg = 'Mark task "%s", route "%s", in conductor as running.' - update_progress(wf_ex_db, msg % (task['id'], str(task['route'])), stream=False) + update_progress( + wf_ex_db, msg % (task["id"], str(task["route"])), stream=False + ) # If task has items and items list is empty, then actions under the next task is empty # and will not be processed in the for loop below. Handle this use case separately and # mark it as running in the conductor. The task will be completed automatically when # it is requested for task execution. - if task['spec'].has_items() and 'items_count' in task and task['items_count'] == 0: + if ( + task["spec"].has_items() + and "items_count" in task + and task["items_count"] == 0 + ): ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) # If task contains multiple action execution (i.e. with items), # then mark each item individually. - for action in task['actions']: - if 'item_id' not in action or action['item_id'] is None: + for action in task["actions"]: + if "item_id" not in action or action["item_id"] is None: ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING) else: - msg = 'Mark task "%s", route "%s", item "%s" in conductor as running.' - msg = msg % (task['id'], str(task['route']), action['item_id']) + msg = ( + 'Mark task "%s", route "%s", item "%s" in conductor as running.' + ) + msg = msg % (task["id"], str(task["route"]), action["item_id"]) update_progress(wf_ex_db, msg) ac_ex_event = events.TaskItemActionExecutionEvent( - action['item_id'], - statuses.RUNNING + action["item_id"], statuses.RUNNING ) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) # Update workflow execution and related liveaction and action execution. - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) update_execution_records(wf_ex_db, conductor) # Request task execution for the tasks. for task in next_tasks: try: msg = 'Requesting execution for task "%s", route "%s".' - update_progress(wf_ex_db, msg % (task['id'], str(task['route']))) + update_progress(wf_ex_db, msg % (task["id"], str(task["route"]))) # Pass down appropriate st2 context to the task and action execution(s). - root_st2_ctx = wf_ex_db.context.get('st2', {}) + root_st2_ctx = wf_ex_db.context.get("st2", {}) st2_ctx = { - 'execution_id': wf_ex_db.action_execution, - 'user': root_st2_ctx.get('user'), - 'pack': root_st2_ctx.get('pack') + "execution_id": wf_ex_db.action_execution, + "user": root_st2_ctx.get("user"), + "pack": root_st2_ctx.get("pack"), } - if root_st2_ctx.get('api_user'): - st2_ctx['api_user'] = root_st2_ctx.get('api_user') + if root_st2_ctx.get("api_user"): + st2_ctx["api_user"] = root_st2_ctx.get("api_user") - if root_st2_ctx.get('source_channel'): - st2_ctx['source_channel'] = root_st2_ctx.get('source_channel') + if root_st2_ctx.get("source_channel"): + st2_ctx["source_channel"] = root_st2_ctx.get("source_channel") # Request the task execution. request_task_execution(wf_ex_db, st2_ctx, task) except Exception as e: msg = 'Failed task execution for task "%s", route "%s".' - msg = msg % (task['id'], str(task['route'])) - update_progress(wf_ex_db, '%s %s' % (msg, str(e)), severity='error', log=False) + msg = msg % (task["id"], str(task["route"])) + update_progress( + wf_ex_db, "%s %s" % (msg, str(e)), severity="error", log=False + ) LOG.exception(msg) fail_workflow_execution(str(wf_ex_db.id), e, task=task) return @@ -1125,25 +1199,30 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".' msg = msg % (str(iteration), conductor.get_workflow_status()) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() if not next_tasks: - update_progress(wf_ex_db, 'No tasks identified to execute next.') - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "No tasks identified to execute next.") + update_progress(wf_ex_db, "\n", log=False) @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None): - if ac_ex_status not in statuses.COMPLETED_STATUSES + [statuses.PAUSED, statuses.PENDING]: + if ac_ex_status not in statuses.COMPLETED_STATUSES + [ + statuses.PAUSED, + statuses.PENDING, + ]: return task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) @@ -1153,31 +1232,43 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx if not task_ex_db.itemized or (task_ex_db.itemized and task_ex_db.items_count == 0): if ac_ex_status != task_ex_db.status: msg = 'Updating task execution "%s" for task "%s" from status "%s" to "%s".' - msg = msg % (task_ex_id, task_ex_db.task_id, task_ex_db.status, ac_ex_status) + msg = msg % ( + task_ex_id, + task_ex_db.task_id, + task_ex_db.status, + ac_ex_status, + ) update_progress(wf_ex_db, msg) task_ex_db.status = ac_ex_status task_ex_db.result = ac_ex_result if ac_ex_result else task_ex_db.result elif task_ex_db.itemized and ac_ex_ctx: - if 'orquesta' not in ac_ex_ctx or 'item_id' not in ac_ex_ctx['orquesta']: - msg = 'Context information for the item is not provided. %s' % str(ac_ex_ctx) - update_progress(wf_ex_db, msg, severity='error', log=False) + if "orquesta" not in ac_ex_ctx or "item_id" not in ac_ex_ctx["orquesta"]: + msg = "Context information for the item is not provided. %s" % str( + ac_ex_ctx + ) + update_progress(wf_ex_db, msg, severity="error", log=False) raise Exception(msg) - item_id = ac_ex_ctx['orquesta']['item_id'] + item_id = ac_ex_ctx["orquesta"]["item_id"] msg = 'Processing action execution for task "%s", route "%s", item "%s".' msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), item_id) - update_progress(wf_ex_db, msg, severity='debug') + update_progress(wf_ex_db, msg, severity="debug") - task_ex_db.result['items'][item_id] = {'status': ac_ex_status, 'result': ac_ex_result} + task_ex_db.result["items"][item_id] = { + "status": ac_ex_status, + "result": ac_ex_result, + } item_statuses = [ - item.get('status', statuses.UNSET) if item else statuses.UNSET - for item in task_ex_db.result['items'] + item.get("status", statuses.UNSET) if item else statuses.UNSET + for item in task_ex_db.result["items"] ] - task_completed = all([status in statuses.COMPLETED_STATUSES for status in item_statuses]) + task_completed = all( + [status in statuses.COMPLETED_STATUSES for status in item_statuses] + ) if task_completed: new_task_status = ( @@ -1187,11 +1278,15 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx ) msg = 'Updating task execution from status "%s" to "%s".' - update_progress(wf_ex_db, msg % (task_ex_db.status, new_task_status), severity='debug') + update_progress( + wf_ex_db, msg % (task_ex_db.status, new_task_status), severity="debug" + ) task_ex_db.status = new_task_status else: - msg = 'Task execution is not complete because not all items are complete: %s' - update_progress(wf_ex_db, msg % ', '.join(item_statuses), severity='debug') + msg = ( + "Task execution is not complete because not all items are complete: %s" + ) + update_progress(wf_ex_db, msg % ", ".join(item_statuses), severity="debug") if task_ex_db.status in statuses.COMPLETED_STATUSES: task_ex_db.end_timestamp = date_utils.get_datetime_utc_now() @@ -1202,19 +1297,23 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def resume_task_execution(task_ex_id): # Update task execution to running. task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(task_ex_db.workflow_execution) msg = 'Updating task execution from status "%s" to "%s".' - update_progress(wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity='debug') + update_progress( + wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity="debug" + ) task_ex_db.status = statuses.RUNNING # Write update to the database. @@ -1224,17 +1323,21 @@ def resume_task_execution(task_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def update_workflow_execution(wf_ex_id): conductor, wf_ex_db = refresh_conductor(wf_ex_id) # There is nothing to update if workflow execution is not completed or paused. - if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [statuses.PAUSED]: + if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [ + statuses.PAUSED + ]: # Update workflow execution and related liveaction and action execution. update_execution_records(wf_ex_db, conductor) @@ -1242,12 +1345,14 @@ def update_workflow_execution(wf_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def resume_workflow_execution(wf_ex_id, task_ex_id): # Update workflow execution to running. conductor, wf_ex_db = refresh_conductor(wf_ex_id) @@ -1265,12 +1370,14 @@ def resume_workflow_execution(wf_ex_id, task_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def fail_workflow_execution(wf_ex_id, exception, task=None): conductor, wf_ex_db = refresh_conductor(wf_ex_id) @@ -1278,7 +1385,7 @@ def fail_workflow_execution(wf_ex_id, exception, task=None): conductor.request_workflow_status(statuses.FAILED) if task is not None and isinstance(task, dict): - conductor.log_error(exception, task_id=task.get('id'), route=task.get('route')) + conductor.log_error(exception, task_id=task.get("id"), route=task.get("route")) else: conductor.log_error(exception) @@ -1286,8 +1393,14 @@ def fail_workflow_execution(wf_ex_id, exception, task=None): update_execution_records(wf_ex_db, conductor) -def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, - pub_wf_ex=False, pub_lv_ac=True, pub_ac_ex=True): +def update_execution_records( + wf_ex_db, + conductor, + update_lv_ac_on_statuses=None, + pub_wf_ex=False, + pub_lv_ac=True, + pub_ac_ex=True, +): # If the workflow execution is completed, then render the workflow output. if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES: conductor.render_workflow_output() @@ -1295,7 +1408,7 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, # Determine if workflow status has changed. wf_old_status = wf_ex_db.status wf_ex_db.status = conductor.get_workflow_status() - status_changed = (wf_old_status != wf_ex_db.status) + status_changed = wf_old_status != wf_ex_db.status if status_changed: msg = 'Updating workflow execution from status "%s" to "%s".' @@ -1314,53 +1427,58 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=pub_wf_ex) # Return if workflow execution status is not specified in update_lv_ac_on_statuses. - if (isinstance(update_lv_ac_on_statuses, list) and - wf_ex_db.status not in update_lv_ac_on_statuses): + if ( + isinstance(update_lv_ac_on_statuses, list) + and wf_ex_db.status not in update_lv_ac_on_statuses + ): return # Update the corresponding liveaction and action execution for the workflow. wf_ac_ex_db = ex_db_access.ActionExecution.get_by_id(wf_ex_db.action_execution) - wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction['id']) + wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction["id"]) # Gather result for liveaction and action execution. - result = {'output': wf_ex_db.output or None} + result = {"output": wf_ex_db.output or None} if wf_ex_db.status in statuses.ABENDED_STATUSES: - result['errors'] = wf_ex_db.errors + result["errors"] = wf_ex_db.errors if wf_ex_db.errors: - msg = 'Workflow execution completed with errors.' - update_progress(wf_ex_db, msg, severity='error') + msg = "Workflow execution completed with errors." + update_progress(wf_ex_db, msg, severity="error") for wf_ex_error in wf_ex_db.errors: - update_progress(wf_ex_db, wf_ex_error, severity='error') + update_progress(wf_ex_db, wf_ex_error, severity="error") # Sync update with corresponding liveaction and action execution. if pub_lv_ac or pub_ac_ex: - pub_lv_ac = (wf_lv_ac_db.status != wf_ex_db.status) + pub_lv_ac = wf_lv_ac_db.status != wf_ex_db.status pub_ac_ex = pub_lv_ac if wf_lv_ac_db.status != wf_ex_db.status: - kwargs = {'severity': 'debug', 'stream': False} + kwargs = {"severity": "debug", "stream": False} msg = 'Updating workflow liveaction from status "%s" to "%s".' update_progress(wf_ex_db, msg % (wf_lv_ac_db.status, wf_ex_db.status), **kwargs) - msg = 'Workflow liveaction status change %s be published.' - update_progress(wf_ex_db, msg % 'will' if pub_lv_ac else 'will not', **kwargs) - msg = 'Workflow action execution status change %s be published.' - update_progress(wf_ex_db, msg % 'will' if pub_ac_ex else 'will not', **kwargs) + msg = "Workflow liveaction status change %s be published." + update_progress(wf_ex_db, msg % "will" if pub_lv_ac else "will not", **kwargs) + msg = "Workflow action execution status change %s be published." + update_progress(wf_ex_db, msg % "will" if pub_ac_ex else "will not", **kwargs) wf_lv_ac_db = action_utils.update_liveaction_status( status=wf_ex_db.status, result=result, end_timestamp=wf_ex_db.end_timestamp, liveaction_db=wf_lv_ac_db, - publish=pub_lv_ac) + publish=pub_lv_ac, + ) ex_svc.update_execution(wf_lv_ac_db, publish=pub_ac_ex) # Invoke post run on the liveaction for the workflow execution. if status_changed and wf_lv_ac_db.status in ac_const.LIVEACTION_COMPLETED_STATES: - update_progress(wf_ex_db, 'Workflow action execution is completed and invoking post run.') + update_progress( + wf_ex_db, "Workflow action execution is completed and invoking post run." + ) runners_utils.invoke_post_run(wf_lv_ac_db) @@ -1376,36 +1494,40 @@ def identify_orphaned_workflows(): # does not necessary means it is the max idle time. The use of workflow_executions_idled_ttl # to filter is to reduce the number of action executions that need to be evaluated. query_filters = { - 'runner__name': 'orquesta', - 'status': ac_const.LIVEACTION_STATUS_RUNNING, - 'start_timestamp__lte': expiry_dt + "runner__name": "orquesta", + "status": ac_const.LIVEACTION_STATUS_RUNNING, + "start_timestamp__lte": expiry_dt, } ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters) for ac_ex_db in ac_ex_dbs: # Figure out the runtime for the action execution. status_change_logs = sorted( - [log for log in ac_ex_db.log if log['status'] == ac_const.LIVEACTION_STATUS_RUNNING], - key=lambda x: x['timestamp'], - reverse=True + [ + log + for log in ac_ex_db.log + if log["status"] == ac_const.LIVEACTION_STATUS_RUNNING + ], + key=lambda x: x["timestamp"], + reverse=True, ) if len(status_change_logs) <= 0: continue - runtime = (utc_now_dt - status_change_logs[0]['timestamp']).total_seconds() + runtime = (utc_now_dt - status_change_logs[0]["timestamp"]).total_seconds() # Fetch the task executions for the workflow execution. # Ensure that the root action execution is not being selected. - wf_ex_id = ac_ex_db.context['workflow_execution'] + wf_ex_id = ac_ex_db.context["workflow_execution"] wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) - query_filters = {'workflow_execution': wf_ex_id, 'id__ne': ac_ex_db.id} + query_filters = {"workflow_execution": wf_ex_id, "id__ne": ac_ex_db.id} tk_ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters) # The workflow execution is orphaned if there are # no task executions and runtime passed expiry. if len(tk_ac_ex_dbs) <= 0 and runtime > gc_max_idle: - msg = 'The action execution is orphaned and will be canceled by the garbage collector.' + msg = "The action execution is orphaned and will be canceled by the garbage collector." update_progress(wf_ex_db, msg) orphaned.append(ac_ex_db) continue @@ -1415,7 +1537,8 @@ def identify_orphaned_workflows(): has_active_tasks = len([t for t in tk_ac_ex_dbs if t.end_timestamp is None]) > 0 completed_tasks = [ - t for t in tk_ac_ex_dbs + t + for t in tk_ac_ex_dbs if t.end_timestamp is not None and t.end_timestamp <= expiry_dt ] @@ -1423,11 +1546,16 @@ def identify_orphaned_workflows(): most_recent_completed_task_expired = ( completed_tasks[-1].end_timestamp <= expiry_dt - if len(completed_tasks) > 0 else False + if len(completed_tasks) > 0 + else False ) - if len(tk_ac_ex_dbs) > 0 and not has_active_tasks and most_recent_completed_task_expired: - msg = 'The action execution is orphaned and will be canceled by the garbage collector.' + if ( + len(tk_ac_ex_dbs) > 0 + and not has_active_tasks + and most_recent_completed_task_expired + ): + msg = "The action execution is orphaned and will be canceled by the garbage collector." update_progress(wf_ex_db, msg) orphaned.append(ac_ex_db) continue diff --git a/st2common/st2common/signal_handlers.py b/st2common/st2common/signal_handlers.py index 0fc2766175..bd785403f4 100644 --- a/st2common/st2common/signal_handlers.py +++ b/st2common/st2common/signal_handlers.py @@ -26,7 +26,7 @@ from st2common.logging.misc import reopen_log_files __all__ = [ - 'register_common_signal_handlers', + "register_common_signal_handlers", ] diff --git a/st2common/st2common/stream/listener.py b/st2common/st2common/stream/listener.py index 6edbef1750..347c4cfc75 100644 --- a/st2common/st2common/stream/listener.py +++ b/st2common/st2common/stream/listener.py @@ -33,11 +33,10 @@ from st2common import log as logging __all__ = [ - 'StreamListener', - 'ExecutionOutputListener', - - 'get_listener', - 'get_listener_if_set' + "StreamListener", + "ExecutionOutputListener", + "get_listener", + "get_listener_if_set", ] LOG = logging.getLogger(__name__) @@ -49,23 +48,24 @@ class BaseListener(ConsumerMixin): - def __init__(self, connection): self.connection = connection self.queues = [] self._stopped = False def get_consumers(self, consumer, channel): - raise NotImplementedError('get_consumers() is not implemented') + raise NotImplementedError("get_consumers() is not implemented") def processor(self, model=None): def process(body, message): meta = message.delivery_info - event_name = '%s__%s' % (meta.get('exchange'), meta.get('routing_key')) + event_name = "%s__%s" % (meta.get("exchange"), meta.get("routing_key")) try: if model: - body = model.from_model(body, mask_secrets=cfg.CONF.api.mask_secrets) + body = model.from_model( + body, mask_secrets=cfg.CONF.api.mask_secrets + ) self.emit(event_name, body) finally: @@ -78,10 +78,17 @@ def emit(self, event, body): for queue in self.queues: queue.put(pack) - def generator(self, events=None, action_refs=None, execution_ids=None, - end_event=None, end_statuses=None, end_execution_id=None): + def generator( + self, + events=None, + action_refs=None, + execution_ids=None, + end_event=None, + end_statuses=None, + end_execution_id=None, + ): queue = eventlet.Queue() - queue.put('') + queue.put("") self.queues.append(queue) try: stop = False @@ -95,16 +102,19 @@ def generator(self, events=None, action_refs=None, execution_ids=None, event_name, body = message # check to see if this is the last message to send. if event_name == end_event: - if body is not None and \ - body.status in end_statuses and \ - end_execution_id is not None and \ - body.id == end_execution_id: + if ( + body is not None + and body.status in end_statuses + and end_execution_id is not None + and body.id == end_execution_id + ): stop = True # TODO: We now do late filtering, but this could also be performed on the # message bus level if we modified our exchange layout and utilize routing keys # Filter on event name - include_event = self._should_include_event(event_names_whitelist=events, - event_name=event_name) + include_event = self._should_include_event( + event_names_whitelist=events, event_name=event_name + ) if not include_event: LOG.debug('Skipping event "%s"' % (event_name)) continue @@ -112,14 +122,18 @@ def generator(self, events=None, action_refs=None, execution_ids=None, # Filter on action ref action_ref = self._get_action_ref_for_body(body=body) if action_refs and action_ref not in action_refs: - LOG.debug('Skipping event "%s" with action_ref "%s"' % (event_name, - action_ref)) + LOG.debug( + 'Skipping event "%s" with action_ref "%s"' + % (event_name, action_ref) + ) continue # Filter on execution id execution_id = self._get_execution_id_for_body(body=body) if execution_ids and execution_id not in execution_ids: - LOG.debug('Skipping event "%s" with execution_id "%s"' % (event_name, - execution_id)) + LOG.debug( + 'Skipping event "%s" with execution_id "%s"' + % (event_name, execution_id) + ) continue yield message @@ -154,7 +168,7 @@ def _get_action_ref_for_body(self, body): action_ref = None if isinstance(body, ActionExecutionAPI): - action_ref = body.action.get('ref', None) if body.action else None + action_ref = body.action.get("ref", None) if body.action else None elif isinstance(body, LiveActionAPI): action_ref = body.action elif isinstance(body, (ActionExecutionOutputAPI)): @@ -187,21 +201,26 @@ class StreamListener(BaseListener): def get_consumers(self, consumer, channel): return [ - consumer(queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor()]), - - consumer(queues=[STREAM_EXECUTION_ALL_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionAPI)]), - - consumer(queues=[STREAM_LIVEACTION_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(LiveActionAPI)]), - - consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionOutputAPI)]) + consumer( + queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor()], + ), + consumer( + queues=[STREAM_EXECUTION_ALL_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionAPI)], + ), + consumer( + queues=[STREAM_LIVEACTION_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(LiveActionAPI)], + ), + consumer( + queues=[STREAM_EXECUTION_OUTPUT_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionOutputAPI)], + ), ] @@ -214,13 +233,16 @@ class ExecutionOutputListener(BaseListener): def get_consumers(self, consumer, channel): return [ - consumer(queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionAPI)]), - - consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionOutputAPI)]) + consumer( + queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionAPI)], + ), + consumer( + queues=[STREAM_EXECUTION_OUTPUT_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionOutputAPI)], + ), ] @@ -235,29 +257,29 @@ def get_listener(name): global _stream_listener global _execution_output_listener - if name == 'stream': + if name == "stream": if not _stream_listener: with transport_utils.get_connection() as conn: _stream_listener = StreamListener(conn) eventlet.spawn_n(listen, _stream_listener) return _stream_listener - elif name == 'execution_output': + elif name == "execution_output": if not _execution_output_listener: with transport_utils.get_connection() as conn: _execution_output_listener = ExecutionOutputListener(conn) eventlet.spawn_n(listen, _execution_output_listener) return _execution_output_listener else: - raise ValueError('Invalid listener name: %s' % (name)) + raise ValueError("Invalid listener name: %s" % (name)) def get_listener_if_set(name): global _stream_listener global _execution_output_listener - if name == 'stream': + if name == "stream": return _stream_listener - elif name == 'execution_output': + elif name == "execution_output": return _execution_output_listener else: - raise ValueError('Invalid listener name: %s' % (name)) + raise ValueError("Invalid listener name: %s" % (name)) diff --git a/st2common/st2common/transport/__init__.py b/st2common/st2common/transport/__init__.py index cc384c878e..632c08dc0e 100644 --- a/st2common/st2common/transport/__init__.py +++ b/st2common/st2common/transport/__init__.py @@ -21,12 +21,12 @@ # TODO(manas) : Exchanges, Queues and RoutingKey design discussion pending. __all__ = [ - 'liveaction', - 'actionexecutionstate', - 'execution', - 'workflow', - 'publishers', - 'reactor', - 'utils', - 'connection_retry_wrapper' + "liveaction", + "actionexecutionstate", + "execution", + "workflow", + "publishers", + "reactor", + "utils", + "connection_retry_wrapper", ] diff --git a/st2common/st2common/transport/actionexecutionstate.py b/st2common/st2common/transport/actionexecutionstate.py index 268bffe0fc..46fe095fbf 100644 --- a/st2common/st2common/transport/actionexecutionstate.py +++ b/st2common/st2common/transport/actionexecutionstate.py @@ -21,18 +21,16 @@ from st2common.transport import publishers -__all__ = [ - 'ActionExecutionStatePublisher' -] +__all__ = ["ActionExecutionStatePublisher"] -ACTIONEXECUTIONSTATE_XCHG = Exchange('st2.actionexecutionstate', - type='topic') +ACTIONEXECUTIONSTATE_XCHG = Exchange("st2.actionexecutionstate", type="topic") class ActionExecutionStatePublisher(publishers.CUDPublisher): - def __init__(self): - super(ActionExecutionStatePublisher, self).__init__(exchange=ACTIONEXECUTIONSTATE_XCHG) + super(ActionExecutionStatePublisher, self).__init__( + exchange=ACTIONEXECUTIONSTATE_XCHG + ) def get_queue(name, routing_key): diff --git a/st2common/st2common/transport/announcement.py b/st2common/st2common/transport/announcement.py index 84c8bf27a7..e79506c608 100644 --- a/st2common/st2common/transport/announcement.py +++ b/st2common/st2common/transport/announcement.py @@ -22,17 +22,12 @@ from st2common.models.api.trace import TraceContext from st2common.transport import publishers -__all__ = [ - 'AnnouncementPublisher', - 'AnnouncementDispatcher', - - 'get_queue' -] +__all__ = ["AnnouncementPublisher", "AnnouncementDispatcher", "get_queue"] LOG = logging.getLogger(__name__) # Exchange for Announcements -ANNOUNCEMENT_XCHG = Exchange('st2.announcement', type='topic') +ANNOUNCEMENT_XCHG = Exchange("st2.announcement", type="topic") class AnnouncementPublisher(object): @@ -68,16 +63,19 @@ def dispatch(self, routing_key, payload, trace_context=None): assert isinstance(payload, (type(None), dict)) assert isinstance(trace_context, (type(None), dict, TraceContext)) - payload = { - 'payload': payload, - TRACE_CONTEXT: trace_context - } + payload = {"payload": payload, TRACE_CONTEXT: trace_context} - self._logger.debug('Dispatching announcement (routing_key=%s,payload=%s)', - routing_key, payload) + self._logger.debug( + "Dispatching announcement (routing_key=%s,payload=%s)", routing_key, payload + ) self._publisher.publish(payload=payload, routing_key=routing_key) -def get_queue(name=None, routing_key='#', exclusive=False, auto_delete=False): - return Queue(name, ANNOUNCEMENT_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) +def get_queue(name=None, routing_key="#", exclusive=False, auto_delete=False): + return Queue( + name, + ANNOUNCEMENT_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) diff --git a/st2common/st2common/transport/bootstrap.py b/st2common/st2common/transport/bootstrap.py index 4c75072fe9..20d9277fae 100644 --- a/st2common/st2common/transport/bootstrap.py +++ b/st2common/st2common/transport/bootstrap.py @@ -24,8 +24,9 @@ def _setup(): config.parse_args() # 2. setup logging. - logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s', - level=logging.DEBUG) + logging.basicConfig( + format="%(asctime)s %(levelname)s [-] %(message)s", level=logging.DEBUG + ) def main(): @@ -34,5 +35,5 @@ def main(): # The scripts sets up Exchanges in RabbitMQ. -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/st2common/transport/bootstrap_utils.py b/st2common/st2common/transport/bootstrap_utils.py index d787adc493..2eea9ad64b 100644 --- a/st2common/st2common/transport/bootstrap_utils.py +++ b/st2common/st2common/transport/bootstrap_utils.py @@ -50,15 +50,14 @@ from st2common.transport.queues import WORKFLOW_EXECUTION_WORK_QUEUE from st2common.transport.queues import WORKFLOW_EXECUTION_RESUME_QUEUE -LOG = logging.getLogger('st2common.transport.bootstrap') +LOG = logging.getLogger("st2common.transport.bootstrap") __all__ = [ - 'register_exchanges', - 'register_exchanges_with_retry', - 'register_kombu_serializers', - - 'EXCHANGES', - 'QUEUES' + "register_exchanges", + "register_exchanges_with_retry", + "register_kombu_serializers", + "EXCHANGES", + "QUEUES", ] # List of exchanges which are pre-declared on service set up. @@ -72,7 +71,7 @@ TRIGGER_INSTANCE_XCHG, SENSOR_CUD_XCHG, WORKFLOW_EXECUTION_XCHG, - WORKFLOW_EXECUTION_STATUS_MGMT_XCHG + WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, ] # List of queues which are pre-declared on service startup. @@ -85,41 +84,40 @@ NOTIFIER_ACTIONUPDATE_WORK_QUEUE, RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE, RULESENGINE_WORK_QUEUE, - STREAM_ANNOUNCEMENT_WORK_QUEUE, STREAM_EXECUTION_ALL_WORK_QUEUE, STREAM_LIVEACTION_WORK_QUEUE, STREAM_EXECUTION_OUTPUT_QUEUE, - WORKFLOW_EXECUTION_WORK_QUEUE, WORKFLOW_EXECUTION_RESUME_QUEUE, - # Those queues are dynamically / late created on some class init but we still need to # pre-declare them for redis Kombu backend to work. - reactor.get_trigger_cud_queue(name='st2.preinit', routing_key='init'), - reactor.get_sensor_cud_queue(name='st2.preinit', routing_key='init') + reactor.get_trigger_cud_queue(name="st2.preinit", routing_key="init"), + reactor.get_sensor_cud_queue(name="st2.preinit", routing_key="init"), ] def _do_register_exchange(exchange, connection, channel, retry_wrapper): try: kwargs = { - 'exchange': exchange.name, - 'type': exchange.type, - 'durable': exchange.durable, - 'auto_delete': exchange.auto_delete, - 'arguments': exchange.arguments, - 'nowait': False, - 'passive': False + "exchange": exchange.name, + "type": exchange.type, + "durable": exchange.durable, + "auto_delete": exchange.auto_delete, + "arguments": exchange.arguments, + "nowait": False, + "passive": False, } # Use the retry wrapper to increase resiliency in recoverable errors. - retry_wrapper.ensured(connection=connection, - obj=channel, - to_ensure_func=channel.exchange_declare, - **kwargs) - LOG.debug('Registered exchange %s (%s).' % (exchange.name, str(kwargs))) + retry_wrapper.ensured( + connection=connection, + obj=channel, + to_ensure_func=channel.exchange_declare, + **kwargs, + ) + LOG.debug("Registered exchange %s (%s)." % (exchange.name, str(kwargs))) except Exception: - LOG.exception('Failed to register exchange: %s.', exchange.name) + LOG.exception("Failed to register exchange: %s.", exchange.name) def _do_predeclare_queue(channel, queue): @@ -132,23 +130,31 @@ def _do_predeclare_queue(channel, queue): bound_queue.declare(nowait=False) LOG.debug('Predeclared queue for exchange "%s"' % (queue.exchange.name)) except Exception: - LOG.exception('Failed to predeclare queue for exchange "%s"' % (queue.exchange.name)) + LOG.exception( + 'Failed to predeclare queue for exchange "%s"' % (queue.exchange.name) + ) return bound_queue def register_exchanges(): - LOG.debug('Registering exchanges...') + LOG.debug("Registering exchanges...") connection_urls = transport_utils.get_messaging_urls() with transport_utils.get_connection() as conn: # Use ConnectionRetryWrapper to deal with rmq clustering etc. - retry_wrapper = ConnectionRetryWrapper(cluster_size=len(connection_urls), logger=LOG) + retry_wrapper = ConnectionRetryWrapper( + cluster_size=len(connection_urls), logger=LOG + ) def wrapped_register_exchanges(connection, channel): for exchange in EXCHANGES: - _do_register_exchange(exchange=exchange, connection=connection, channel=channel, - retry_wrapper=retry_wrapper) + _do_register_exchange( + exchange=exchange, + connection=connection, + channel=channel, + retry_wrapper=retry_wrapper, + ) retry_wrapper.run(connection=conn, wrapped_callback=wrapped_register_exchanges) @@ -166,7 +172,7 @@ def retry_if_io_error(exception): retrying_obj = retrying.Retrying( retry_on_exception=retry_if_io_error, wait_fixed=cfg.CONF.messaging.connection_retry_wait, - stop_max_attempt_number=cfg.CONF.messaging.connection_retries + stop_max_attempt_number=cfg.CONF.messaging.connection_retries, ) return retrying_obj.call(register_exchanges) @@ -181,24 +187,33 @@ def register_kombu_serializers(): https://github.com/celery/kombu/blob/3.0/kombu/utils/encoding.py#L47 """ + def pickle_dumps(obj, dumper=pickle.dumps): return dumper(obj, protocol=pickle_protocol) if six.PY3: + def str_to_bytes(s): if isinstance(s, str): - return s.encode('utf-8') + return s.encode("utf-8") return s def unpickle(s): return pickle_loads(str_to_bytes(s)) + else: - def str_to_bytes(s): # noqa - if isinstance(s, unicode): # noqa # pylint: disable=E0602 - return s.encode('utf-8') + + def str_to_bytes(s): # noqa + if isinstance(s, unicode): # noqa # pylint: disable=E0602 + return s.encode("utf-8") return s + unpickle = pickle_loads # noqa - register('pickle', pickle_dumps, unpickle, - content_type='application/x-python-serialize', - content_encoding='binary') + register( + "pickle", + pickle_dumps, + unpickle, + content_type="application/x-python-serialize", + content_encoding="binary", + ) diff --git a/st2common/st2common/transport/connection_retry_wrapper.py b/st2common/st2common/transport/connection_retry_wrapper.py index d0c906fff6..492aa24f32 100644 --- a/st2common/st2common/transport/connection_retry_wrapper.py +++ b/st2common/st2common/transport/connection_retry_wrapper.py @@ -19,7 +19,7 @@ from st2common.util import concurrency -__all__ = ['ConnectionRetryWrapper', 'ClusterRetryContext'] +__all__ = ["ConnectionRetryWrapper", "ClusterRetryContext"] class ClusterRetryContext(object): @@ -27,6 +27,7 @@ class ClusterRetryContext(object): Stores retry context for cluster retries. It makes certain assumptions on how cluster_size and retry should be determined. """ + def __init__(self, cluster_size): # No of nodes in a cluster self.cluster_size = cluster_size @@ -101,6 +102,7 @@ def wrapped_callback(connection, channel): retry_wrapper.run(connection=connection, wrapped_callback=wrapped_callback) """ + def __init__(self, cluster_size, logger, ensure_max_retries=3): self._retry_context = ClusterRetryContext(cluster_size=cluster_size) self._logger = logger @@ -109,7 +111,7 @@ def __init__(self, cluster_size, logger, ensure_max_retries=3): self._ensure_max_retries = ensure_max_retries def errback(self, exc, interval): - self._logger.error('Rabbitmq connection error: %s', exc.message) + self._logger.error("Rabbitmq connection error: %s", exc.message) def run(self, connection, wrapped_callback): """ @@ -141,8 +143,10 @@ def run(self, connection, wrapped_callback): raise # -1, 0 and 1+ are handled properly by eventlet.sleep - self._logger.debug('Received RabbitMQ server error, sleeping for %s seconds ' - 'before retrying: %s' % (wait, six.text_type(e))) + self._logger.debug( + "Received RabbitMQ server error, sleeping for %s seconds " + "before retrying: %s" % (wait, six.text_type(e)) + ) concurrency.sleep(wait) connection.close() @@ -154,22 +158,28 @@ def run(self, connection, wrapped_callback): def log_error_on_conn_failure(exc, interval): self._logger.debug( - 'Failed to re-establish connection to RabbitMQ server, ' - 'retrying in %s seconds: %s' % (interval, six.text_type(exc)) + "Failed to re-establish connection to RabbitMQ server, " + "retrying in %s seconds: %s" % (interval, six.text_type(exc)) ) try: # NOTE: This function blocks and tries to restablish a connection for # indefinetly if "max_retries" argument is not specified - connection.ensure_connection(max_retries=self._ensure_max_retries, - errback=log_error_on_conn_failure) + connection.ensure_connection( + max_retries=self._ensure_max_retries, + errback=log_error_on_conn_failure, + ) except Exception: - self._logger.exception('Connections to RabbitMQ cannot be re-established: %s', - six.text_type(e)) + self._logger.exception( + "Connections to RabbitMQ cannot be re-established: %s", + six.text_type(e), + ) raise except Exception as e: - self._logger.exception('Connections to RabbitMQ cannot be re-established: %s', - six.text_type(e)) + self._logger.exception( + "Connections to RabbitMQ cannot be re-established: %s", + six.text_type(e), + ) # Not being able to publish a message could be a significant issue for an app. raise finally: @@ -177,7 +187,7 @@ def log_error_on_conn_failure(exc, interval): try: channel.close() except Exception: - self._logger.warning('Error closing channel.', exc_info=True) + self._logger.warning("Error closing channel.", exc_info=True) def ensured(self, connection, obj, to_ensure_func, **kwargs): """ @@ -191,7 +201,6 @@ def ensured(self, connection, obj, to_ensure_func, **kwargs): :type obj: Must support mixin kombu.abstract.MaybeChannelBound """ ensuring_func = connection.ensure( - obj, to_ensure_func, - errback=self.errback, - max_retries=3) + obj, to_ensure_func, errback=self.errback, max_retries=3 + ) ensuring_func(**kwargs) diff --git a/st2common/st2common/transport/consumers.py b/st2common/st2common/transport/consumers.py index 7f626f72a4..dd2f47cb55 100644 --- a/st2common/st2common/transport/consumers.py +++ b/st2common/st2common/transport/consumers.py @@ -25,12 +25,11 @@ from st2common.util import concurrency __all__ = [ - 'QueueConsumer', - 'StagedQueueConsumer', - 'ActionsQueueConsumer', - - 'MessageHandler', - 'StagedMessageHandler' + "QueueConsumer", + "StagedQueueConsumer", + "ActionsQueueConsumer", + "MessageHandler", + "StagedMessageHandler", ] LOG = logging.getLogger(__name__) @@ -47,7 +46,9 @@ def shutdown(self): self._dispatcher.shutdown() def get_consumers(self, Consumer, channel): - consumer = Consumer(queues=self._queues, accept=['pickle'], callbacks=[self.process]) + consumer = Consumer( + queues=self._queues, accept=["pickle"], callbacks=[self.process] + ) # use prefetch_count=1 for fair dispatch. This way workers that finish an item get the next # task and the work does not get queued behind any single large item. @@ -58,11 +59,15 @@ def get_consumers(self, Consumer, channel): def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) self._dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -71,7 +76,9 @@ def _process_message(self, body): try: self._handler.process(body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) class StagedQueueConsumer(QueueConsumer): @@ -82,11 +89,15 @@ class StagedQueueConsumer(QueueConsumer): def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) response = self._handler.pre_ack_process(body) self._dispatcher.dispatch(self._process_message, response) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -110,17 +121,21 @@ def __init__(self, connection, queues, handler): workflows_pool_size = cfg.CONF.actionrunner.workflows_pool_size actions_pool_size = cfg.CONF.actionrunner.actions_pool_size - self._workflows_dispatcher = BufferedDispatcher(dispatch_pool_size=workflows_pool_size, - name='workflows-dispatcher') - self._actions_dispatcher = BufferedDispatcher(dispatch_pool_size=actions_pool_size, - name='actions-dispatcher') + self._workflows_dispatcher = BufferedDispatcher( + dispatch_pool_size=workflows_pool_size, name="workflows-dispatcher" + ) + self._actions_dispatcher = BufferedDispatcher( + dispatch_pool_size=actions_pool_size, name="actions-dispatcher" + ) def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) - action_is_workflow = getattr(body, 'action_is_workflow', False) + action_is_workflow = getattr(body, "action_is_workflow", False) if action_is_workflow: # Use workflow dispatcher queue dispatcher = self._workflows_dispatcher @@ -131,7 +146,9 @@ def process(self, body, message): LOG.debug('Using BufferedDispatcher pool: "%s"', str(dispatcher)) dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -149,11 +166,15 @@ class VariableMessageQueueConsumer(QueueConsumer): def process(self, body, message): try: if not self._handler.message_types.get(type(body)): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) self._dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -164,12 +185,13 @@ class MessageHandler(object): message_type = None def __init__(self, connection, queues): - self._queue_consumer = self.get_queue_consumer(connection=connection, - queues=queues) + self._queue_consumer = self.get_queue_consumer( + connection=connection, queues=queues + ) self._consumer_thread = None def start(self, wait=False): - LOG.info('Starting %s...', self.__class__.__name__) + LOG.info("Starting %s...", self.__class__.__name__) self._consumer_thread = concurrency.spawn(self._queue_consumer.run) if wait: @@ -179,7 +201,7 @@ def wait(self): self._consumer_thread.wait() def shutdown(self): - LOG.info('Shutting down %s...', self.__class__.__name__) + LOG.info("Shutting down %s...", self.__class__.__name__) self._queue_consumer.shutdown() @abc.abstractmethod @@ -224,4 +246,6 @@ class VariableMessageHandler(MessageHandler): """ def get_queue_consumer(self, connection, queues): - return VariableMessageQueueConsumer(connection=connection, queues=queues, handler=self) + return VariableMessageQueueConsumer( + connection=connection, queues=queues, handler=self + ) diff --git a/st2common/st2common/transport/execution.py b/st2common/st2common/transport/execution.py index e35279ac71..5d2880fd6f 100644 --- a/st2common/st2common/transport/execution.py +++ b/st2common/st2common/transport/execution.py @@ -20,15 +20,14 @@ from st2common.transport import publishers __all__ = [ - 'ActionExecutionPublisher', - 'ActionExecutionOutputPublisher', - - 'get_queue', - 'get_output_queue' + "ActionExecutionPublisher", + "ActionExecutionOutputPublisher", + "get_queue", + "get_output_queue", ] -EXECUTION_XCHG = Exchange('st2.execution', type='topic') -EXECUTION_OUTPUT_XCHG = Exchange('st2.execution.output', type='topic') +EXECUTION_XCHG = Exchange("st2.execution", type="topic") +EXECUTION_OUTPUT_XCHG = Exchange("st2.execution.output", type="topic") class ActionExecutionPublisher(publishers.CUDPublisher): @@ -38,14 +37,26 @@ def __init__(self): class ActionExecutionOutputPublisher(publishers.CUDPublisher): def __init__(self): - super(ActionExecutionOutputPublisher, self).__init__(exchange=EXECUTION_OUTPUT_XCHG) + super(ActionExecutionOutputPublisher, self).__init__( + exchange=EXECUTION_OUTPUT_XCHG + ) def get_queue(name=None, routing_key=None, exclusive=False, auto_delete=False): - return Queue(name, EXECUTION_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) + return Queue( + name, + EXECUTION_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) def get_output_queue(name=None, routing_key=None, exclusive=False, auto_delete=False): - return Queue(name, EXECUTION_OUTPUT_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) + return Queue( + name, + EXECUTION_OUTPUT_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) diff --git a/st2common/st2common/transport/liveaction.py b/st2common/st2common/transport/liveaction.py index 97dd08400b..670c5ebb2e 100644 --- a/st2common/st2common/transport/liveaction.py +++ b/st2common/st2common/transport/liveaction.py @@ -21,23 +21,19 @@ from st2common.transport import publishers -__all__ = [ - 'LiveActionPublisher', +__all__ = ["LiveActionPublisher", "get_queue", "get_status_management_queue"] - 'get_queue', - 'get_status_management_queue' -] - -LIVEACTION_XCHG = Exchange('st2.liveaction', type='topic') -LIVEACTION_STATUS_MGMT_XCHG = Exchange('st2.liveaction.status', type='topic') +LIVEACTION_XCHG = Exchange("st2.liveaction", type="topic") +LIVEACTION_STATUS_MGMT_XCHG = Exchange("st2.liveaction.status", type="topic") class LiveActionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin): - def __init__(self): publishers.CUDPublisher.__init__(self, exchange=LIVEACTION_XCHG) - publishers.StatePublisherMixin.__init__(self, exchange=LIVEACTION_STATUS_MGMT_XCHG) + publishers.StatePublisherMixin.__init__( + self, exchange=LIVEACTION_STATUS_MGMT_XCHG + ) def get_queue(name, routing_key): diff --git a/st2common/st2common/transport/publishers.py b/st2common/st2common/transport/publishers.py index 7942fdfffe..202220acb1 100644 --- a/st2common/st2common/transport/publishers.py +++ b/st2common/st2common/transport/publishers.py @@ -25,16 +25,16 @@ from st2common.transport.connection_retry_wrapper import ConnectionRetryWrapper __all__ = [ - 'PoolPublisher', - 'SharedPoolPublishers', - 'CUDPublisher', - 'StatePublisherMixin' + "PoolPublisher", + "SharedPoolPublishers", + "CUDPublisher", + "StatePublisherMixin", ] -ANY_RK = '*' -CREATE_RK = 'create' -UPDATE_RK = 'update' -DELETE_RK = 'delete' +ANY_RK = "*" +CREATE_RK = "create" +UPDATE_RK = "update" +DELETE_RK = "delete" LOG = logging.getLogger(__name__) @@ -47,19 +47,21 @@ def __init__(self, urls=None): :type urls: ``list`` """ urls = urls or transport_utils.get_messaging_urls() - connection = transport_utils.get_connection(urls=urls, - connection_kwargs={'failover_strategy': - 'round-robin'}) + connection = transport_utils.get_connection( + urls=urls, connection_kwargs={"failover_strategy": "round-robin"} + ) self.pool = connection.Pool(limit=10) self.cluster_size = len(urls) def errback(self, exc, interval): - LOG.error('Rabbitmq connection error: %s', exc.message, exc_info=False) + LOG.error("Rabbitmq connection error: %s", exc.message, exc_info=False) - def publish(self, payload, exchange, routing_key=''): - with Timer(key='amqp.pool_publisher.publish_with_retries.' + exchange.name): + def publish(self, payload, exchange, routing_key=""): + with Timer(key="amqp.pool_publisher.publish_with_retries." + exchange.name): with self.pool.acquire(block=True) as connection: - retry_wrapper = ConnectionRetryWrapper(cluster_size=self.cluster_size, logger=LOG) + retry_wrapper = ConnectionRetryWrapper( + cluster_size=self.cluster_size, logger=LOG + ) def do_publish(connection, channel): # ProducerPool ends up creating it own ConnectionPool which ends up @@ -68,18 +70,18 @@ def do_publish(connection, channel): # Producer for each publish. producer = Producer(channel) kwargs = { - 'body': payload, - 'exchange': exchange, - 'routing_key': routing_key, - 'serializer': 'pickle', - 'content_encoding': 'utf-8' + "body": payload, + "exchange": exchange, + "routing_key": routing_key, + "serializer": "pickle", + "content_encoding": "utf-8", } retry_wrapper.ensured( connection=connection, obj=producer, to_ensure_func=producer.publish, - **kwargs + **kwargs, ) retry_wrapper.run(connection=connection, wrapped_callback=do_publish) @@ -91,6 +93,7 @@ class SharedPoolPublishers(object): server is usually the same. This sharing allows from the same PoolPublisher to be reused for publishing purposes. Sharing publishers leads to shared connections. """ + shared_publishers = {} def get_publisher(self, urls): @@ -99,7 +102,7 @@ def get_publisher(self, urls): # ordering in supplied list. urls_copy = copy.copy(urls) urls_copy.sort() - publisher_key = ''.join(urls_copy) + publisher_key = "".join(urls_copy) publisher = self.shared_publishers.get(publisher_key, None) if not publisher: # Use original urls here to preserve order. @@ -115,15 +118,15 @@ def __init__(self, exchange): self._exchange = exchange def publish_create(self, payload): - with Timer(key='amqp.publish.create'): + with Timer(key="amqp.publish.create"): self._publisher.publish(payload, self._exchange, CREATE_RK) def publish_update(self, payload): - with Timer(key='amqp.publish.update'): + with Timer(key="amqp.publish.update"): self._publisher.publish(payload, self._exchange, UPDATE_RK) def publish_delete(self, payload): - with Timer(key='amqp.publish.delete'): + with Timer(key="amqp.publish.delete"): self._publisher.publish(payload, self._exchange, DELETE_RK) @@ -135,6 +138,6 @@ def __init__(self, exchange): def publish_state(self, payload, state): if not state: - raise Exception('Unable to publish unassigned state.') - with Timer(key='amqp.publish.state'): + raise Exception("Unable to publish unassigned state.") + with Timer(key="amqp.publish.state"): self._state_publisher.publish(payload, self._state_exchange, state) diff --git a/st2common/st2common/transport/queues.py b/st2common/st2common/transport/queues.py index faf6d27fbf..f6f9bcb4ef 100644 --- a/st2common/st2common/transport/queues.py +++ b/st2common/st2common/transport/queues.py @@ -34,120 +34,109 @@ from st2common.transport import workflow __all__ = [ - 'ACTIONSCHEDULER_REQUEST_QUEUE', - - 'ACTIONRUNNER_WORK_QUEUE', - 'ACTIONRUNNER_CANCEL_QUEUE', - 'ACTIONRUNNER_PAUSE_QUEUE', - 'ACTIONRUNNER_RESUME_QUEUE', - - 'EXPORTER_WORK_QUEUE', - - 'NOTIFIER_ACTIONUPDATE_WORK_QUEUE', - - 'RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE', - - 'RULESENGINE_WORK_QUEUE', - - 'STREAM_ANNOUNCEMENT_WORK_QUEUE', - 'STREAM_EXECUTION_ALL_WORK_QUEUE', - 'STREAM_EXECUTION_UPDATE_WORK_QUEUE', - 'STREAM_LIVEACTION_WORK_QUEUE', - - 'WORKFLOW_EXECUTION_WORK_QUEUE', - 'WORKFLOW_EXECUTION_RESUME_QUEUE' + "ACTIONSCHEDULER_REQUEST_QUEUE", + "ACTIONRUNNER_WORK_QUEUE", + "ACTIONRUNNER_CANCEL_QUEUE", + "ACTIONRUNNER_PAUSE_QUEUE", + "ACTIONRUNNER_RESUME_QUEUE", + "EXPORTER_WORK_QUEUE", + "NOTIFIER_ACTIONUPDATE_WORK_QUEUE", + "RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE", + "RULESENGINE_WORK_QUEUE", + "STREAM_ANNOUNCEMENT_WORK_QUEUE", + "STREAM_EXECUTION_ALL_WORK_QUEUE", + "STREAM_EXECUTION_UPDATE_WORK_QUEUE", + "STREAM_LIVEACTION_WORK_QUEUE", + "WORKFLOW_EXECUTION_WORK_QUEUE", + "WORKFLOW_EXECUTION_RESUME_QUEUE", ] # Used by the action scheduler service ACTIONSCHEDULER_REQUEST_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.req', - routing_key=action_constants.LIVEACTION_STATUS_REQUESTED) + "st2.actionrunner.req", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED +) # Used by the action runner service ACTIONRUNNER_WORK_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.work', - routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED) + "st2.actionrunner.work", routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED +) ACTIONRUNNER_CANCEL_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.cancel', - routing_key=action_constants.LIVEACTION_STATUS_CANCELING) + "st2.actionrunner.cancel", routing_key=action_constants.LIVEACTION_STATUS_CANCELING +) ACTIONRUNNER_PAUSE_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.pause', - routing_key=action_constants.LIVEACTION_STATUS_PAUSING) + "st2.actionrunner.pause", routing_key=action_constants.LIVEACTION_STATUS_PAUSING +) ACTIONRUNNER_RESUME_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.resume', - routing_key=action_constants.LIVEACTION_STATUS_RESUMING) + "st2.actionrunner.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING +) # Used by the exporter service EXPORTER_WORK_QUEUE = execution.get_queue( - 'st2.exporter.work', - routing_key=publishers.UPDATE_RK) + "st2.exporter.work", routing_key=publishers.UPDATE_RK +) # Used by the notifier service NOTIFIER_ACTIONUPDATE_WORK_QUEUE = execution.get_queue( - 'st2.notifiers.execution.work', - routing_key=publishers.UPDATE_RK) + "st2.notifiers.execution.work", routing_key=publishers.UPDATE_RK +) # Used by the results tracker service RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE = actionexecutionstate.get_queue( - 'st2.resultstracker.work', - routing_key=publishers.CREATE_RK) + "st2.resultstracker.work", routing_key=publishers.CREATE_RK +) # Used by the rules engine service RULESENGINE_WORK_QUEUE = reactor.get_trigger_instances_queue( - name='st2.trigger_instances_dispatch.rules_engine', - routing_key='#') + name="st2.trigger_instances_dispatch.rules_engine", routing_key="#" +) # Used by the stream service STREAM_ANNOUNCEMENT_WORK_QUEUE = announcement.get_queue( - routing_key=publishers.ANY_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True +) STREAM_EXECUTION_ALL_WORK_QUEUE = execution.get_queue( - routing_key=publishers.ANY_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True +) STREAM_EXECUTION_UPDATE_WORK_QUEUE = execution.get_queue( - routing_key=publishers.UPDATE_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.UPDATE_RK, exclusive=True, auto_delete=True +) STREAM_LIVEACTION_WORK_QUEUE = Queue( None, liveaction.LIVEACTION_XCHG, routing_key=publishers.ANY_RK, exclusive=True, - auto_delete=True) + auto_delete=True, +) # TODO: Perhaps we should use pack.action name as routing key # so we can do more efficient filtering later, if needed STREAM_EXECUTION_OUTPUT_QUEUE = execution.get_output_queue( - name=None, - routing_key=publishers.CREATE_RK, - exclusive=True, - auto_delete=True) + name=None, routing_key=publishers.CREATE_RK, exclusive=True, auto_delete=True +) # Used by the workflow engine service WORKFLOW_EXECUTION_WORK_QUEUE = workflow.get_status_management_queue( - name='st2.workflow.work', - routing_key=action_constants.LIVEACTION_STATUS_REQUESTED) + name="st2.workflow.work", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED +) WORKFLOW_EXECUTION_RESUME_QUEUE = workflow.get_status_management_queue( - name='st2.workflow.resume', - routing_key=action_constants.LIVEACTION_STATUS_RESUMING) + name="st2.workflow.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING +) WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE = execution.get_queue( - 'st2.workflow.action.update', - routing_key=publishers.UPDATE_RK) + "st2.workflow.action.update", routing_key=publishers.UPDATE_RK +) diff --git a/st2common/st2common/transport/reactor.py b/st2common/st2common/transport/reactor.py index 613a1d08ed..c9dc84725c 100644 --- a/st2common/st2common/transport/reactor.py +++ b/st2common/st2common/transport/reactor.py @@ -22,26 +22,24 @@ from st2common.transport import publishers __all__ = [ - 'TriggerCUDPublisher', - 'TriggerInstancePublisher', - - 'TriggerDispatcher', - - 'get_sensor_cud_queue', - 'get_trigger_cud_queue', - 'get_trigger_instances_queue' + "TriggerCUDPublisher", + "TriggerInstancePublisher", + "TriggerDispatcher", + "get_sensor_cud_queue", + "get_trigger_cud_queue", + "get_trigger_instances_queue", ] LOG = logging.getLogger(__name__) # Exchange for Trigger CUD events -TRIGGER_CUD_XCHG = Exchange('st2.trigger', type='topic') +TRIGGER_CUD_XCHG = Exchange("st2.trigger", type="topic") # Exchange for TriggerInstance events -TRIGGER_INSTANCE_XCHG = Exchange('st2.trigger_instances_dispatch', type='topic') +TRIGGER_INSTANCE_XCHG = Exchange("st2.trigger_instances_dispatch", type="topic") # Exchane for Sensor CUD events -SENSOR_CUD_XCHG = Exchange('st2.sensor', type='topic') +SENSOR_CUD_XCHG = Exchange("st2.sensor", type="topic") class SensorCUDPublisher(publishers.CUDPublisher): @@ -96,14 +94,12 @@ def dispatch(self, trigger, payload=None, trace_context=None): assert isinstance(payload, (type(None), dict)) assert isinstance(trace_context, (type(None), TraceContext)) - payload = { - 'trigger': trigger, - 'payload': payload, - TRACE_CONTEXT: trace_context - } - routing_key = 'trigger_instance' + payload = {"trigger": trigger, "payload": payload, TRACE_CONTEXT: trace_context} + routing_key = "trigger_instance" - self._logger.debug('Dispatching trigger (trigger=%s,payload=%s)', trigger, payload) + self._logger.debug( + "Dispatching trigger (trigger=%s,payload=%s)", trigger, payload + ) self._publisher.publish_trigger(payload=payload, routing_key=routing_key) diff --git a/st2common/st2common/transport/utils.py b/st2common/st2common/transport/utils.py index bea2df1e57..e479713ddc 100644 --- a/st2common/st2common/transport/utils.py +++ b/st2common/st2common/transport/utils.py @@ -22,22 +22,18 @@ from st2common import log as logging -__all__ = [ - 'get_connection', - - 'get_messaging_urls' -] +__all__ = ["get_connection", "get_messaging_urls"] LOG = logging.getLogger(__name__) def get_messaging_urls(): - ''' + """ Determines the right messaging urls to supply. In case the `cluster_urls` config is specified then that is used. Else the single `url` property is used. :rtype: ``list`` - ''' + """ if cfg.CONF.messaging.cluster_urls: return cfg.CONF.messaging.cluster_urls return [cfg.CONF.messaging.url] @@ -57,33 +53,41 @@ def get_connection(urls=None, connection_kwargs=None): kwargs = {} - ssl_kwargs = _get_ssl_kwargs(ssl=cfg.CONF.messaging.ssl, - ssl_keyfile=cfg.CONF.messaging.ssl_keyfile, - ssl_certfile=cfg.CONF.messaging.ssl_certfile, - ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs, - login_method=cfg.CONF.messaging.login_method) + ssl_kwargs = _get_ssl_kwargs( + ssl=cfg.CONF.messaging.ssl, + ssl_keyfile=cfg.CONF.messaging.ssl_keyfile, + ssl_certfile=cfg.CONF.messaging.ssl_certfile, + ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs, + login_method=cfg.CONF.messaging.login_method, + ) # NOTE: "connection_kwargs" argument passed to this function has precedence over config values - if len(ssl_kwargs) == 1 and ssl_kwargs['ssl'] is True: - kwargs.update({'ssl': True}) + if len(ssl_kwargs) == 1 and ssl_kwargs["ssl"] is True: + kwargs.update({"ssl": True}) elif len(ssl_kwargs) >= 2: - ssl_kwargs.pop('ssl') - kwargs.update({'ssl': ssl_kwargs}) + ssl_kwargs.pop("ssl") + kwargs.update({"ssl": ssl_kwargs}) - kwargs['login_method'] = cfg.CONF.messaging.login_method + kwargs["login_method"] = cfg.CONF.messaging.login_method kwargs.update(connection_kwargs) # NOTE: This line contains no secret values so it's OK to log it - LOG.debug('Using SSL context for RabbitMQ connection: %s' % (ssl_kwargs)) + LOG.debug("Using SSL context for RabbitMQ connection: %s" % (ssl_kwargs)) connection = Connection(urls, **kwargs) return connection -def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, login_method=None): +def _get_ssl_kwargs( + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + login_method=None, +): """ Return SSL keyword arguments to be used with the kombu.Connection class. """ @@ -93,27 +97,27 @@ def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_req # because user could still specify to use SSL by including "?ssl=true" query param at the # end of the connection URL string if ssl is True: - ssl_kwargs['ssl'] = True + ssl_kwargs["ssl"] = True if ssl_keyfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['keyfile'] = ssl_keyfile + ssl_kwargs["ssl"] = True + ssl_kwargs["keyfile"] = ssl_keyfile if ssl_certfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['certfile'] = ssl_certfile + ssl_kwargs["ssl"] = True + ssl_kwargs["certfile"] = ssl_certfile if ssl_cert_reqs: - if ssl_cert_reqs == 'none': + if ssl_cert_reqs == "none": ssl_cert_reqs = ssl_lib.CERT_NONE - elif ssl_cert_reqs == 'optional': + elif ssl_cert_reqs == "optional": ssl_cert_reqs = ssl_lib.CERT_OPTIONAL - elif ssl_cert_reqs == 'required': + elif ssl_cert_reqs == "required": ssl_cert_reqs = ssl_lib.CERT_REQUIRED - ssl_kwargs['cert_reqs'] = ssl_cert_reqs + ssl_kwargs["cert_reqs"] = ssl_cert_reqs if ssl_ca_certs: - ssl_kwargs['ssl'] = True - ssl_kwargs['ca_certs'] = ssl_ca_certs + ssl_kwargs["ssl"] = True + ssl_kwargs["ca_certs"] = ssl_ca_certs return ssl_kwargs diff --git a/st2common/st2common/transport/workflow.py b/st2common/st2common/transport/workflow.py index 2b9815fcb7..0302611a36 100644 --- a/st2common/st2common/transport/workflow.py +++ b/st2common/st2common/transport/workflow.py @@ -21,22 +21,22 @@ from st2common.transport import publishers -__all__ = [ - 'WorkflowExecutionPublisher', +__all__ = ["WorkflowExecutionPublisher", "get_queue", "get_status_management_queue"] - 'get_queue', - 'get_status_management_queue' -] +WORKFLOW_EXECUTION_XCHG = kombu.Exchange("st2.workflow", type="topic") +WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange( + "st2.workflow.status", type="topic" +) -WORKFLOW_EXECUTION_XCHG = kombu.Exchange('st2.workflow', type='topic') -WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange('st2.workflow.status', type='topic') - - -class WorkflowExecutionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin): +class WorkflowExecutionPublisher( + publishers.CUDPublisher, publishers.StatePublisherMixin +): def __init__(self): publishers.CUDPublisher.__init__(self, exchange=WORKFLOW_EXECUTION_XCHG) - publishers.StatePublisherMixin.__init__(self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG) + publishers.StatePublisherMixin.__init__( + self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG + ) def get_queue(name, routing_key): @@ -44,4 +44,6 @@ def get_queue(name, routing_key): def get_status_management_queue(name, routing_key): - return kombu.Queue(name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key) + return kombu.Queue( + name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key + ) diff --git a/st2common/st2common/triggers.py b/st2common/st2common/triggers.py index a18dadedb9..ec0dba378e 100644 --- a/st2common/st2common/triggers.py +++ b/st2common/st2common/triggers.py @@ -22,52 +22,63 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.triggers import (INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER) +from st2common.constants.triggers import INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER from st2common.exceptions.db import StackStormDBObjectConflictError from st2common.services.triggers import create_trigger_type_db from st2common.services.triggers import create_shadow_trigger from st2common.services.triggers import get_trigger_type_db from st2common.models.system.common import ResourceReference -__all__ = [ - 'register_internal_trigger_types' -] +__all__ = ["register_internal_trigger_types"] LOG = logging.getLogger(__name__) def _register_internal_trigger_type(trigger_definition): try: - trigger_type_db = create_trigger_type_db(trigger_type=trigger_definition, - log_not_unique_error_as_debug=True) + trigger_type_db = create_trigger_type_db( + trigger_type=trigger_definition, log_not_unique_error_as_debug=True + ) except (NotUniqueError, StackStormDBObjectConflictError): # We ignore conflict error since this operation is idempotent and race is not an issue - LOG.debug('Internal trigger type "%s" already exists, ignoring error...' % - (trigger_definition['name'])) - - ref = ResourceReference.to_string_reference(name=trigger_definition['name'], - pack=trigger_definition['pack']) + LOG.debug( + 'Internal trigger type "%s" already exists, ignoring error...' + % (trigger_definition["name"]) + ) + + ref = ResourceReference.to_string_reference( + name=trigger_definition["name"], pack=trigger_definition["pack"] + ) trigger_type_db = get_trigger_type_db(ref) if trigger_type_db: - LOG.debug('Registered internal trigger: %s.', trigger_definition['name']) + LOG.debug("Registered internal trigger: %s.", trigger_definition["name"]) # trigger types with parameters do no require a shadow trigger. if trigger_type_db and not trigger_type_db.parameters_schema: try: - trigger_db = create_shadow_trigger(trigger_type_db, - log_not_unique_error_as_debug=True) - - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger created for parameter-less internal TriggerType. Trigger.id=%s' % - (trigger_db.id), extra=extra) + trigger_db = create_shadow_trigger( + trigger_type_db, log_not_unique_error_as_debug=True + ) + + extra = {"trigger_db": trigger_db} + LOG.audit( + "Trigger created for parameter-less internal TriggerType. Trigger.id=%s" + % (trigger_db.id), + extra=extra, + ) except (NotUniqueError, StackStormDBObjectConflictError): - LOG.debug('Shadow trigger "%s" already exists. Ignoring.', - trigger_type_db.get_reference().ref, exc_info=True) + LOG.debug( + 'Shadow trigger "%s" already exists. Ignoring.', + trigger_type_db.get_reference().ref, + exc_info=True, + ) except (ValidationError, ValueError): - LOG.exception('Validation failed in shadow trigger. TriggerType=%s.', - trigger_type_db.get_reference().ref) + LOG.exception( + "Validation failed in shadow trigger. TriggerType=%s.", + trigger_type_db.get_reference().ref, + ) raise return trigger_type_db @@ -89,16 +100,21 @@ def register_internal_trigger_types(): for _, trigger_definitions in six.iteritems(INTERNAL_TRIGGER_TYPES): for trigger_definition in trigger_definitions: - LOG.debug('Registering internal trigger: %s', trigger_definition['name']) + LOG.debug("Registering internal trigger: %s", trigger_definition["name"]) - is_action_trigger = trigger_definition['name'] == ACTION_SENSOR_TRIGGER['name'] + is_action_trigger = ( + trigger_definition["name"] == ACTION_SENSOR_TRIGGER["name"] + ) if is_action_trigger and not action_sensor_enabled: continue try: trigger_type_db = _register_internal_trigger_type( - trigger_definition=trigger_definition) + trigger_definition=trigger_definition + ) except Exception: - LOG.exception('Failed registering internal trigger: %s.', trigger_definition) + LOG.exception( + "Failed registering internal trigger: %s.", trigger_definition + ) raise else: registered_trigger_types_db.append(trigger_type_db) diff --git a/st2common/st2common/util/action_db.py b/st2common/st2common/util/action_db.py index 610b698c18..4880693348 100644 --- a/st2common/st2common/util/action_db.py +++ b/st2common/st2common/util/action_db.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -42,15 +43,15 @@ __all__ = [ - 'get_action_parameters_specs', - 'get_runnertype_by_id', - 'get_runnertype_by_name', - 'get_action_by_id', - 'get_action_by_ref', - 'get_liveaction_by_id', - 'update_liveaction_status', - 'serialize_positional_argument', - 'get_args' + "get_action_parameters_specs", + "get_runnertype_by_id", + "get_runnertype_by_name", + "get_action_by_id", + "get_action_by_ref", + "get_liveaction_by_id", + "update_liveaction_status", + "serialize_positional_argument", + "get_args", ] @@ -71,11 +72,11 @@ def get_action_parameters_specs(action_ref): if not action_db: return parameters - runner_type_name = action_db.runner_type['name'] + runner_type_name = action_db.runner_type["name"] runner_type_db = get_runnertype_by_name(runnertype_name=runner_type_name) # Runner type parameters should be added first before the action parameters. - parameters.update(runner_type_db['runner_parameters']) + parameters.update(runner_type_db["runner_parameters"]) parameters.update(action_db.parameters) return parameters @@ -83,60 +84,76 @@ def get_action_parameters_specs(action_ref): def get_runnertype_by_id(runnertype_id): """ - Get RunnerType by id. + Get RunnerType by id. - On error, raise StackStormDBObjectNotFoundError + On error, raise StackStormDBObjectNotFoundError """ try: runnertype = RunnerType.get_by_id(runnertype_id) except (ValueError, ValidationError) as e: - LOG.warning('Database lookup for runnertype with id="%s" resulted in ' - 'exception: %s', runnertype_id, e) - raise StackStormDBObjectNotFoundError('Unable to find runnertype with ' - 'id="%s"' % runnertype_id) + LOG.warning( + 'Database lookup for runnertype with id="%s" resulted in ' "exception: %s", + runnertype_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find runnertype with " 'id="%s"' % runnertype_id + ) return runnertype def get_runnertype_by_name(runnertype_name): """ - Get an runnertype by name. - On error, raise ST2ObjectNotFoundError. + Get an runnertype by name. + On error, raise ST2ObjectNotFoundError. """ try: runnertypes = RunnerType.query(name=runnertype_name) except (ValueError, ValidationError) as e: - LOG.error('Database lookup for name="%s" resulted in exception: %s', - runnertype_name, e) - raise StackStormDBObjectNotFoundError('Unable to find runnertype with name="%s"' - % runnertype_name) + LOG.error( + 'Database lookup for name="%s" resulted in exception: %s', + runnertype_name, + e, + ) + raise StackStormDBObjectNotFoundError( + 'Unable to find runnertype with name="%s"' % runnertype_name + ) if not runnertypes: - raise StackStormDBObjectNotFoundError('Unable to find RunnerType with name="%s"' - % runnertype_name) + raise StackStormDBObjectNotFoundError( + 'Unable to find RunnerType with name="%s"' % runnertype_name + ) if len(runnertypes) > 1: - LOG.warning('More than one RunnerType returned from DB lookup by name. ' - 'Result list is: %s', runnertypes) + LOG.warning( + "More than one RunnerType returned from DB lookup by name. " + "Result list is: %s", + runnertypes, + ) return runnertypes[0] def get_action_by_id(action_id): """ - Get Action by id. + Get Action by id. - On error, raise StackStormDBObjectNotFoundError + On error, raise StackStormDBObjectNotFoundError """ action = None try: action = Action.get_by_id(action_id) except (ValueError, ValidationError) as e: - LOG.warning('Database lookup for action with id="%s" resulted in ' - 'exception: %s', action_id, e) - raise StackStormDBObjectNotFoundError('Unable to find action with ' - 'id="%s"' % action_id) + LOG.warning( + 'Database lookup for action with id="%s" resulted in ' "exception: %s", + action_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find action with " 'id="%s"' % action_id + ) return action @@ -153,56 +170,78 @@ def get_action_by_ref(ref): try: return Action.get_by_ref(ref) except ValueError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None def get_liveaction_by_id(liveaction_id): """ - Get LiveAction by id. + Get LiveAction by id. - On error, raise ST2DBObjectNotFoundError. + On error, raise ST2DBObjectNotFoundError. """ liveaction = None try: liveaction = LiveAction.get_by_id(liveaction_id) except (ValidationError, ValueError) as e: - LOG.error('Database lookup for LiveAction with id="%s" resulted in ' - 'exception: %s', liveaction_id, e) - raise StackStormDBObjectNotFoundError('Unable to find LiveAction with ' - 'id="%s"' % liveaction_id) + LOG.error( + 'Database lookup for LiveAction with id="%s" resulted in ' "exception: %s", + liveaction_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find LiveAction with " 'id="%s"' % liveaction_id + ) return liveaction -def update_liveaction_status(status=None, result=None, context=None, end_timestamp=None, - liveaction_id=None, runner_info=None, liveaction_db=None, - publish=True): +def update_liveaction_status( + status=None, + result=None, + context=None, + end_timestamp=None, + liveaction_id=None, + runner_info=None, + liveaction_db=None, + publish=True, +): """ - Update the status of the specified LiveAction to the value provided in - new_status. + Update the status of the specified LiveAction to the value provided in + new_status. - The LiveAction may be specified using either liveaction_id, or as an - liveaction_db instance. + The LiveAction may be specified using either liveaction_id, or as an + liveaction_db instance. """ if (liveaction_id is None) and (liveaction_db is None): - raise ValueError('Must specify an liveaction_id or an liveaction_db when ' - 'calling update_LiveAction_status') + raise ValueError( + "Must specify an liveaction_id or an liveaction_db when " + "calling update_LiveAction_status" + ) if liveaction_db is None: liveaction_db = get_liveaction_by_id(liveaction_id) if status not in LIVEACTION_STATUSES: - raise ValueError('Attempting to set status for LiveAction "%s" ' - 'to unknown status string. Unknown status is "%s"' - % (liveaction_db, status)) + raise ValueError( + 'Attempting to set status for LiveAction "%s" ' + 'to unknown status string. Unknown status is "%s"' % (liveaction_db, status) + ) - if result and cfg.CONF.system.validate_output_schema and status == LIVEACTION_STATUS_SUCCEEDED: + if ( + result + and cfg.CONF.system.validate_output_schema + and status == LIVEACTION_STATUS_SUCCEEDED + ): action_db = get_action_by_ref(liveaction_db.action) - runner_db = get_runnertype_by_name(action_db.runner_type['name']) + runner_db = get_runnertype_by_name(action_db.runner_type["name"]) result, status = output_schema.validate_output( runner_db.output_schema, action_db.output_schema, @@ -214,21 +253,33 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta # If liveaction_db status is set then we need to decrement the counter # because it is transitioning to a new state if liveaction_db.status: - get_driver().dec_counter('action.executions.%s' % (liveaction_db.status)) + get_driver().dec_counter("action.executions.%s" % (liveaction_db.status)) # If status is provided then we need to increment the timer because the action # is transitioning into this new state if status: - get_driver().inc_counter('action.executions.%s' % (status)) + get_driver().inc_counter("action.executions.%s" % (status)) - extra = {'liveaction_db': liveaction_db} - LOG.debug('Updating ActionExection: "%s" with status="%s"', liveaction_db.id, status, - extra=extra) + extra = {"liveaction_db": liveaction_db} + LOG.debug( + 'Updating ActionExection: "%s" with status="%s"', + liveaction_db.id, + status, + extra=extra, + ) # If liveaction is already canceled, then do not allow status to be updated. - if liveaction_db.status == LIVEACTION_STATUS_CANCELED and status != LIVEACTION_STATUS_CANCELED: - LOG.info('Unable to update ActionExecution "%s" with status="%s". ' - 'ActionExecution is already canceled.', liveaction_db.id, status, extra=extra) + if ( + liveaction_db.status == LIVEACTION_STATUS_CANCELED + and status != LIVEACTION_STATUS_CANCELED + ): + LOG.info( + 'Unable to update ActionExecution "%s" with status="%s". ' + "ActionExecution is already canceled.", + liveaction_db.id, + status, + extra=extra, + ) return liveaction_db old_status = liveaction_db.status @@ -250,11 +301,11 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta # manipulated fields liveaction_db = LiveAction.add_or_update(liveaction_db) - LOG.debug('Updated status for LiveAction object.', extra=extra) + LOG.debug("Updated status for LiveAction object.", extra=extra) if publish and status != old_status: LiveAction.publish_status(liveaction_db) - LOG.debug('Published status for LiveAction object.', extra=extra) + LOG.debug("Published status for LiveAction object.", extra=extra) return liveaction_db @@ -267,9 +318,9 @@ def serialize_positional_argument(argument_type, argument_value): sense for shell script actions (only the outter / top level value is serialized). """ - if argument_type in ['string', 'number', 'float']: + if argument_type in ["string", "number", "float"]: if argument_value is None: - argument_value = six.text_type('') + argument_value = six.text_type("") return argument_value if isinstance(argument_value, (int, float)): @@ -277,25 +328,25 @@ def serialize_positional_argument(argument_type, argument_value): if not isinstance(argument_value, six.text_type): # cast string non-unicode values to unicode - argument_value = argument_value.decode('utf-8') - elif argument_type == 'boolean': + argument_value = argument_value.decode("utf-8") + elif argument_type == "boolean": # Booleans are serialized as string "1" and "0" if argument_value is not None: - argument_value = '1' if bool(argument_value) else '0' + argument_value = "1" if bool(argument_value) else "0" else: - argument_value = '' - elif argument_type in ['array', 'list']: + argument_value = "" + elif argument_type in ["array", "list"]: # Lists are serialized a comma delimited string (foo,bar,baz) - argument_value = ','.join(map(str, argument_value)) if argument_value else '' - elif argument_type == 'object': + argument_value = ",".join(map(str, argument_value)) if argument_value else "" + elif argument_type == "object": # Objects are serialized as JSON - argument_value = json.dumps(argument_value) if argument_value else '' - elif argument_type == 'null': + argument_value = json.dumps(argument_value) if argument_value else "" + elif argument_type == "null": # None / null is serialized as en empty string - argument_value = '' + argument_value = "" else: # Other values are simply cast to unicode string - argument_value = six.text_type(argument_value) if argument_value else '' + argument_value = six.text_type(argument_value) if argument_value else "" return argument_value @@ -315,12 +366,13 @@ def get_args(action_parameters, action_db): positional_args = [] positional_args_keys = set() for _, arg in six.iteritems(position_args_dict): - arg_type = action_db_parameters.get(arg, {}).get('type', None) + arg_type = action_db_parameters.get(arg, {}).get("type", None) # Perform serialization for positional arguments arg_value = action_parameters.get(arg, None) - arg_value = serialize_positional_argument(argument_type=arg_type, - argument_value=arg_value) + arg_value = serialize_positional_argument( + argument_type=arg_type, argument_value=arg_value + ) positional_args.append(arg_value) positional_args_keys.add(arg) @@ -340,7 +392,7 @@ def _get_position_arg_dict(action_parameters, action_db): for param in action_db_params: param_meta = action_db_params.get(param, None) if param_meta is not None: - pos = param_meta.get('position') + pos = param_meta.get("position") if pos is not None: args_dict[pos] = param args_dict = OrderedDict(sorted(args_dict.items())) diff --git a/st2common/st2common/util/actionalias_helpstring.py b/st2common/st2common/util/actionalias_helpstring.py index ddee088c8c..109328f926 100644 --- a/st2common/st2common/util/actionalias_helpstring.py +++ b/st2common/st2common/util/actionalias_helpstring.py @@ -18,9 +18,7 @@ from st2common.util.actionalias_matching import normalise_alias_format_string -__all__ = [ - 'generate_helpstring_result' -] +__all__ = ["generate_helpstring_result"] def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=0): @@ -44,7 +42,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= matches = [] count = 0 if not (isinstance(limit, int) and isinstance(offset, int)): - raise TypeError('limit or offset argument is not an integer') + raise TypeError("limit or offset argument is not an integer") for alias in aliases: # Skip disable aliases. if not alias.enabled: @@ -56,7 +54,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= display, _, _ = normalise_alias_format_string(format_) if display: # Skip help strings not containing keyword. - if not re.search(filter or '', display, flags=re.IGNORECASE): + if not re.search(filter or "", display, flags=re.IGNORECASE): continue # Skip over help strings not within the requested offset/limit range. if (offset == 0 and limit > 0) and count >= limit: @@ -65,14 +63,18 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= elif (offset > 0 and limit == 0) and count < offset: count += 1 continue - elif (offset > 0 and limit > 0) and (count < offset or count >= offset + limit): + elif (offset > 0 and limit > 0) and ( + count < offset or count >= offset + limit + ): count += 1 continue - matches.append({ - "pack": alias.pack, - "display": display, - "description": alias.description - }) + matches.append( + { + "pack": alias.pack, + "display": display, + "description": alias.description, + } + ) count += 1 return {"available": count, "helpstrings": matches} diff --git a/st2common/st2common/util/actionalias_matching.py b/st2common/st2common/util/actionalias_matching.py index 3827b12d93..1b20fad414 100644 --- a/st2common/st2common/util/actionalias_matching.py +++ b/st2common/st2common/util/actionalias_matching.py @@ -24,15 +24,15 @@ from st2common.models.utils.action_alias_utils import extract_parameters __all__ = [ - 'list_format_strings_from_aliases', - 'normalise_alias_format_string', - 'match_command_to_alias', - 'get_matching_alias', + "list_format_strings_from_aliases", + "normalise_alias_format_string", + "match_command_to_alias", + "get_matching_alias", ] def list_format_strings_from_aliases(aliases, match_multiple=False): - ''' + """ List patterns from a collection of alias objects :param aliases: The list of aliases @@ -40,34 +40,40 @@ def list_format_strings_from_aliases(aliases, match_multiple=False): :return: A description of potential execution patterns in a list of aliases. :rtype: ``list`` of ``list`` - ''' + """ patterns = [] for alias in aliases: for format_ in alias.formats: - display, representations, _match_multiple = normalise_alias_format_string(format_) + display, representations, _match_multiple = normalise_alias_format_string( + format_ + ) if display and len(representations) == 0: - patterns.append({ - 'alias': alias, - 'format': format_, - 'display': display, - 'representation': '', - }) - else: - patterns.extend([ + patterns.append( { - 'alias': alias, - 'format': format_, - 'display': display, - 'representation': representation, - 'match_multiple': _match_multiple, + "alias": alias, + "format": format_, + "display": display, + "representation": "", } - for representation in representations - ]) + ) + else: + patterns.extend( + [ + { + "alias": alias, + "format": format_, + "display": display, + "representation": representation, + "match_multiple": _match_multiple, + } + for representation in representations + ] + ) return patterns def normalise_alias_format_string(alias_format): - ''' + """ StackStorm action aliases come in two forms; 1. A string holding the format, which is also used as the help string. 2. A dictionary containing "display" and/or "representation" keys. @@ -80,7 +86,7 @@ def normalise_alias_format_string(alias_format): :return: The representation of the alias :rtype: ``tuple`` of (``str``, ``str``) - ''' + """ display = None representation = [] match_multiple = False @@ -89,14 +95,16 @@ def normalise_alias_format_string(alias_format): display = alias_format representation.append(alias_format) elif isinstance(alias_format, dict): - display = alias_format.get('display') - representation = alias_format.get('representation') or [] + display = alias_format.get("display") + representation = alias_format.get("representation") or [] if isinstance(representation, six.string_types): representation = [representation] - match_multiple = alias_format.get('match_multiple', match_multiple) + match_multiple = alias_format.get("match_multiple", match_multiple) else: - raise TypeError("alias_format '%s' is neither a dictionary or string type." - % repr(alias_format)) + raise TypeError( + "alias_format '%s' is neither a dictionary or string type." + % repr(alias_format) + ) return (display, representation, match_multiple) @@ -110,8 +118,9 @@ def match_command_to_alias(command, aliases, match_multiple=False): formats = list_format_strings_from_aliases([alias], match_multiple) for format_ in formats: try: - extract_parameters(format_str=format_['representation'], - param_stream=command) + extract_parameters( + format_str=format_["representation"], param_stream=command + ) except ParseException: continue @@ -125,35 +134,41 @@ def get_matching_alias(command): """ # 1. Get aliases action_alias_dbs = ActionAlias.query( - Q(formats__match_multiple=None) | Q(formats__match_multiple=False), - enabled=True) + Q(formats__match_multiple=None) | Q(formats__match_multiple=False), enabled=True + ) # 2. Match alias(es) to command matches = match_command_to_alias(command=command, aliases=action_alias_dbs) if len(matches) > 1: - raise ActionAliasAmbiguityException("Command '%s' matched more than 1 pattern" % - command, - matches=matches, - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched more than 1 pattern" % command, + matches=matches, + command=command, + ) elif len(matches) == 0: match_multiple_action_alias_dbs = ActionAlias.query( - formats__match_multiple=True, - enabled=True) + formats__match_multiple=True, enabled=True + ) - matches = match_command_to_alias(command=command, aliases=match_multiple_action_alias_dbs, - match_multiple=True) + matches = match_command_to_alias( + command=command, + aliases=match_multiple_action_alias_dbs, + match_multiple=True, + ) if len(matches) > 1: - raise ActionAliasAmbiguityException("Command '%s' matched more than 1 (multi) pattern" % - command, - matches=matches, - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched more than 1 (multi) pattern" % command, + matches=matches, + command=command, + ) if len(matches) == 0: - raise ActionAliasAmbiguityException("Command '%s' matched no patterns" % - command, - matches=[], - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched no patterns" % command, + matches=[], + command=command, + ) return matches[0] diff --git a/st2common/st2common/util/api.py b/st2common/st2common/util/api.py index 4e0e3f4938..2c378ad726 100644 --- a/st2common/st2common/util/api.py +++ b/st2common/st2common/util/api.py @@ -21,8 +21,8 @@ from st2common.util.url import get_url_without_trailing_slash __all__ = [ - 'get_base_public_api_url', - 'get_full_public_api_url', + "get_base_public_api_url", + "get_full_public_api_url", ] LOG = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def get_base_public_api_url(): api_url = get_url_without_trailing_slash(cfg.CONF.auth.api_url) else: LOG.warn('"auth.api_url" configuration option is not configured') - api_url = 'http://%s:%s' % (cfg.CONF.api.host, cfg.CONF.api.port) + api_url = "http://%s:%s" % (cfg.CONF.api.host, cfg.CONF.api.port) return api_url @@ -52,5 +52,5 @@ def get_full_public_api_url(api_version=DEFAULT_API_VERSION): :rtype: ``str`` """ api_url = get_base_public_api_url() - api_url = '%s/%s' % (api_url, api_version) + api_url = "%s/%s" % (api_url, api_version) return api_url diff --git a/st2common/st2common/util/argument_parser.py b/st2common/st2common/util/argument_parser.py index 28645ad15f..757f171661 100644 --- a/st2common/st2common/util/argument_parser.py +++ b/st2common/st2common/util/argument_parser.py @@ -16,9 +16,7 @@ from __future__ import absolute_import import argparse -__all__ = [ - 'generate_argument_parser_for_metadata' -] +__all__ = ["generate_argument_parser_for_metadata"] def generate_argument_parser_for_metadata(metadata): @@ -32,37 +30,37 @@ def generate_argument_parser_for_metadata(metadata): :return: Generated argument parser instance. :rtype: :class:`argparse.ArgumentParser` """ - parameters = metadata['parameters'] + parameters = metadata["parameters"] - parser = argparse.ArgumentParser(description=metadata['description']) + parser = argparse.ArgumentParser(description=metadata["description"]) for parameter_name, parameter_options in parameters.items(): - name = parameter_name.replace('_', '-') - description = parameter_options['description'] - _type = parameter_options['type'] - required = parameter_options.get('required', False) - default_value = parameter_options.get('default', None) - immutable = parameter_options.get('immutable', False) + name = parameter_name.replace("_", "-") + description = parameter_options["description"] + _type = parameter_options["type"] + required = parameter_options.get("required", False) + default_value = parameter_options.get("default", None) + immutable = parameter_options.get("immutable", False) # Immutable arguments can't be controlled by the user if immutable: continue - args = ['--%s' % (name)] - kwargs = {'help': description, 'required': required} + args = ["--%s" % (name)] + kwargs = {"help": description, "required": required} if default_value is not None: - kwargs['default'] = default_value + kwargs["default"] = default_value - if _type == 'string': - kwargs['type'] = str - elif _type == 'integer': - kwargs['type'] = int - elif _type == 'boolean': + if _type == "string": + kwargs["type"] = str + elif _type == "integer": + kwargs["type"] = int + elif _type == "boolean": if default_value is False: - kwargs['action'] = 'store_false' + kwargs["action"] = "store_false" else: - kwargs['action'] = 'store_true' + kwargs["action"] = "store_true" parser.add_argument(*args, **kwargs) diff --git a/st2common/st2common/util/auth.py b/st2common/st2common/util/auth.py index 38294c92a7..90e81d938e 100644 --- a/st2common/st2common/util/auth.py +++ b/st2common/st2common/util/auth.py @@ -28,11 +28,11 @@ from st2common.util import hash as hash_utils __all__ = [ - 'validate_token', - 'validate_token_and_source', - 'generate_api_key', - 'validate_api_key', - 'validate_api_key_and_source' + "validate_token", + "validate_token_and_source", + "generate_api_key", + "validate_api_key", + "validate_api_key_and_source", ] LOG = logging.getLogger(__name__) @@ -53,7 +53,7 @@ def validate_token(token_string): if token.expiry <= date_utils.get_datetime_utc_now(): # TODO: purge expired tokens LOG.audit('Token with id "%s" has expired.' % (token.id)) - raise exceptions.TokenExpiredError('Token has expired.') + raise exceptions.TokenExpiredError("Token has expired.") LOG.audit('Token with id "%s" is validated.' % (token.id)) @@ -74,14 +74,14 @@ def validate_token_and_source(token_in_headers, token_in_query_params): :rtype: :class:`.TokenDB` """ if not token_in_headers and not token_in_query_params: - LOG.audit('Token is not found in header or query parameters.') - raise exceptions.TokenNotProvidedError('Token is not provided.') + LOG.audit("Token is not found in header or query parameters.") + raise exceptions.TokenNotProvidedError("Token is not provided.") if token_in_headers: - LOG.audit('Token provided in headers') + LOG.audit("Token provided in headers") if token_in_query_params: - LOG.audit('Token provided in query parameters') + LOG.audit("Token provided in query parameters") return validate_token(token_in_headers or token_in_query_params) @@ -103,7 +103,8 @@ def generate_api_key(): base64_encoded = base64.b64encode( six.b(hashed_seed), - six.b(random.choice(['rA', 'aZ', 'gQ', 'hH', 'hG', 'aR', 'DD']))).rstrip(b'==') + six.b(random.choice(["rA", "aZ", "gQ", "hH", "hG", "aR", "DD"])), + ).rstrip(b"==") base64_encoded = base64_encoded.decode() return base64_encoded @@ -127,7 +128,7 @@ def validate_api_key(api_key): api_key_db = ApiKey.get(api_key) if not api_key_db.enabled: - raise exceptions.ApiKeyDisabledError('API key is disabled.') + raise exceptions.ApiKeyDisabledError("API key is disabled.") LOG.audit('API key with id "%s" is validated.' % (api_key_db.id)) @@ -148,13 +149,13 @@ def validate_api_key_and_source(api_key_in_headers, api_key_query_params): :rtype: :class:`.ApiKeyDB` """ if not api_key_in_headers and not api_key_query_params: - LOG.audit('API key is not found in header or query parameters.') - raise exceptions.ApiKeyNotProvidedError('API key is not provided.') + LOG.audit("API key is not found in header or query parameters.") + raise exceptions.ApiKeyNotProvidedError("API key is not provided.") if api_key_in_headers: - LOG.audit('API key provided in headers') + LOG.audit("API key provided in headers") if api_key_query_params: - LOG.audit('API key provided in query parameters') + LOG.audit("API key provided in query parameters") return validate_api_key(api_key_in_headers or api_key_query_params) diff --git a/st2common/st2common/util/casts.py b/st2common/st2common/util/casts.py index fa94272e47..aadad8a4a1 100644 --- a/st2common/st2common/util/casts.py +++ b/st2common/st2common/util/casts.py @@ -89,12 +89,12 @@ def _cast_none(x): # These types as they appear in json schema. CASTS = { - 'array': _cast_object, - 'boolean': _cast_boolean, - 'integer': _cast_integer, - 'number': _cast_number, - 'object': _cast_object, - 'string': _cast_string + "array": _cast_object, + "boolean": _cast_boolean, + "integer": _cast_integer, + "number": _cast_number, + "object": _cast_object, + "string": _cast_string, } diff --git a/st2common/st2common/util/compat.py b/st2common/st2common/util/compat.py index 9288f5f3a0..1926f97dba 100644 --- a/st2common/st2common/util/compat.py +++ b/st2common/st2common/util/compat.py @@ -24,16 +24,15 @@ __all__ = [ - 'mock_open_name', - - 'to_unicode', - 'to_ascii', + "mock_open_name", + "to_unicode", + "to_ascii", ] if six.PY3: - mock_open_name = 'builtins.open' + mock_open_name = "builtins.open" else: - mock_open_name = '__builtin__.open' + mock_open_name = "__builtin__.open" def to_unicode(value): @@ -63,4 +62,4 @@ def to_ascii(value): if six.PY3: value = value.encode() - return value.decode('ascii', errors='ignore') + return value.decode("ascii", errors="ignore") diff --git a/st2common/st2common/util/concurrency.py b/st2common/st2common/util/concurrency.py index 50312fa78f..239407ade0 100644 --- a/st2common/st2common/util/concurrency.py +++ b/st2common/st2common/util/concurrency.py @@ -31,34 +31,30 @@ except ImportError: gevent = None -CONCURRENCY_LIBRARY = 'eventlet' +CONCURRENCY_LIBRARY = "eventlet" __all__ = [ - 'set_concurrency_library', - 'get_concurrency_library', - - 'get_subprocess_module', - 'subprocess_popen', - - 'spawn', - 'wait', - 'cancel', - 'kill', - 'sleep', - - 'get_greenlet_exit_exception_class', - - 'get_green_pool_class', - 'is_green_pool_free', - 'green_pool_wait_all' + "set_concurrency_library", + "get_concurrency_library", + "get_subprocess_module", + "subprocess_popen", + "spawn", + "wait", + "cancel", + "kill", + "sleep", + "get_greenlet_exit_exception_class", + "get_green_pool_class", + "is_green_pool_free", + "green_pool_wait_all", ] def set_concurrency_library(library): global CONCURRENCY_LIBRARY - if library not in ['eventlet', 'gevent']: - raise ValueError('Unsupported concurrency library: %s' % (library)) + if library not in ["eventlet", "gevent"]: + raise ValueError("Unsupported concurrency library: %s" % (library)) CONCURRENCY_LIBRARY = library @@ -69,107 +65,111 @@ def get_concurrency_library(): def get_subprocess_module(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": from eventlet.green import subprocess # pylint: disable=import-error + return subprocess - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": from gevent import subprocess # pylint: disable=import-error + return subprocess def subprocess_popen(*args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": from eventlet.green import subprocess # pylint: disable=import-error + return subprocess.Popen(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": from gevent import subprocess # pylint: disable=import-error + return subprocess.Popen(*args, **kwargs) def spawn(func, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.spawn(func, *args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.spawn(func, *args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def wait(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.wait(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.join(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def cancel(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.cancel(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.kill(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def kill(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.kill(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.kill(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def sleep(*args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.sleep(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.sleep(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def get_greenlet_exit_exception_class(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.support.greenlets.GreenletExit - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.GreenletExit else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def get_green_pool_class(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.GreenPool - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.pool.Pool else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def is_green_pool_free(pool): """ Return True if the provided green pool is free, False otherwise. """ - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return pool.free() - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return not pool.full() else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def green_pool_wait_all(pool): """ Wait for all the green threads in the pool to finish. """ - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return pool.waitall() - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": # NOTE: This mimicks eventlet.waitall() functionallity better than # pool.join() return all(gl.ready() for gl in pool.greenlets) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") diff --git a/st2common/st2common/util/config_loader.py b/st2common/st2common/util/config_loader.py index 620707e643..30db039bdc 100644 --- a/st2common/st2common/util/config_loader.py +++ b/st2common/st2common/util/config_loader.py @@ -30,9 +30,7 @@ from st2common.util.config_parser import ContentPackConfigParser from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ContentPackConfigLoader' -] +__all__ = ["ContentPackConfigLoader"] LOG = logging.getLogger(__name__) @@ -79,15 +77,16 @@ def get_config(self): # 2. Retrieve values from "global" pack config file (if available) and resolve them if # necessary - config = self._get_values_for_config(config_schema_db=config_schema_db, - config_db=config_db) + config = self._get_values_for_config( + config_schema_db=config_schema_db, config_db=config_db + ) result.update(config) return result def _get_values_for_config(self, config_schema_db, config_db): - schema_values = getattr(config_schema_db, 'attributes', {}) - config_values = getattr(config_db, 'values', {}) + schema_values = getattr(config_schema_db, "attributes", {}) + config_values = getattr(config_db, "values", {}) config = copy.deepcopy(config_values or {}) @@ -131,24 +130,34 @@ def _assign_dynamic_config_values(self, schema, config, parent_keys=None): # Inspect nested object properties if is_dictionary: parent_keys += [str(config_item_key)] - self._assign_dynamic_config_values(schema=schema_item.get('properties', {}), - config=config[config_item_key], - parent_keys=parent_keys) + self._assign_dynamic_config_values( + schema=schema_item.get("properties", {}), + config=config[config_item_key], + parent_keys=parent_keys, + ) # Inspect nested list items elif is_list: parent_keys += [str(config_item_key)] - self._assign_dynamic_config_values(schema=schema_item.get('items', {}), - config=config[config_item_key], - parent_keys=parent_keys) + self._assign_dynamic_config_values( + schema=schema_item.get("items", {}), + config=config[config_item_key], + parent_keys=parent_keys, + ) else: - is_jinja_expression = jinja_utils.is_jinja_expression(value=config_item_value) + is_jinja_expression = jinja_utils.is_jinja_expression( + value=config_item_value + ) if is_jinja_expression: # Resolve / render the Jinja template expression - full_config_item_key = '.'.join(parent_keys + [str(config_item_key)]) - value = self._get_datastore_value_for_expression(key=full_config_item_key, + full_config_item_key = ".".join( + parent_keys + [str(config_item_key)] + ) + value = self._get_datastore_value_for_expression( + key=full_config_item_key, value=config_item_value, - config_schema_item=schema_item) + config_schema_item=schema_item, + ) config[config_item_key] = value else: @@ -167,12 +176,12 @@ def _assign_default_values(self, schema, config): :rtype: ``dict`` """ for schema_item_key, schema_item in six.iteritems(schema): - has_default_value = 'default' in schema_item + has_default_value = "default" in schema_item has_config_value = schema_item_key in config - default_value = schema_item.get('default', None) - is_object = schema_item.get('type', None) == 'object' - has_properties = schema_item.get('properties', None) + default_value = schema_item.get("default", None) + is_object = schema_item.get("type", None) == "object" + has_properties = schema_item.get("properties", None) if has_default_value and not has_config_value: # Config value is not provided, but default value is, use a default value @@ -183,8 +192,9 @@ def _assign_default_values(self, schema, config): if not config.get(schema_item_key, None): config[schema_item_key] = {} - self._assign_default_values(schema=schema_item['properties'], - config=config[schema_item_key]) + self._assign_default_values( + schema=schema_item["properties"], config=config[schema_item_key] + ) return config @@ -198,18 +208,21 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non from st2common.services.config import deserialize_key_value config_schema_item = config_schema_item or {} - secret = config_schema_item.get('secret', False) + secret = config_schema_item.get("secret", False) try: - value = render_template_with_system_and_user_context(value=value, - user=self.user) + value = render_template_with_system_and_user_context( + value=value, user=self.user + ) except Exception as e: # Throw a more user-friendly exception on failed render exc_class = type(e) original_msg = six.text_type(e) - msg = ('Failed to render dynamic configuration value for key "%s" with value ' - '"%s" for pack "%s" config: %s %s ' % (key, value, self.pack_name, - exc_class, original_msg)) + msg = ( + 'Failed to render dynamic configuration value for key "%s" with value ' + '"%s" for pack "%s" config: %s %s ' + % (key, value, self.pack_name, exc_class, original_msg) + ) raise RuntimeError(msg) if value: @@ -222,21 +235,17 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non def get_config(pack, user): - """Returns config for given pack and user. - """ + """Returns config for given pack and user.""" LOG.debug('Attempting to get config for pack "%s" and user "%s"' % (pack, user)) if pack and user: - LOG.debug('Pack and user found. Loading config.') - config_loader = ContentPackConfigLoader( - pack_name=pack, - user=user - ) + LOG.debug("Pack and user found. Loading config.") + config_loader = ContentPackConfigLoader(pack_name=pack, user=user) config = config_loader.get_config() else: config = {} - LOG.debug('Config: %s', config) + LOG.debug("Config: %s", config) return config diff --git a/st2common/st2common/util/config_parser.py b/st2common/st2common/util/config_parser.py index 247dca88fa..40c9e30313 100644 --- a/st2common/st2common/util/config_parser.py +++ b/st2common/st2common/util/config_parser.py @@ -21,10 +21,7 @@ from st2common.content import utils -__all__ = [ - 'ContentPackConfigParser', - 'ContentPackConfig' -] +__all__ = ["ContentPackConfigParser", "ContentPackConfig"] class ContentPackConfigParser(object): @@ -32,8 +29,8 @@ class ContentPackConfigParser(object): Class responsible for obtaining and parsing content pack configs. """ - GLOBAL_CONFIG_NAME = 'config.yaml' - LOCAL_CONFIG_SUFFIX = '_config.yaml' + GLOBAL_CONFIG_NAME = "config.yaml" + LOCAL_CONFIG_SUFFIX = "_config.yaml" def __init__(self, pack_name): self.pack_name = pack_name @@ -85,8 +82,7 @@ def get_global_config_path(self): if not self.pack_path: return None - global_config_path = os.path.join(self.pack_path, - self.GLOBAL_CONFIG_NAME) + global_config_path = os.path.join(self.pack_path, self.GLOBAL_CONFIG_NAME) return global_config_path @classmethod @@ -95,7 +91,7 @@ def get_and_parse_config(cls, config_path): return None if os.path.exists(config_path) and os.path.isfile(config_path): - with io.open(config_path, 'r', encoding='utf8') as fp: + with io.open(config_path, "r", encoding="utf8") as fp: config = yaml.safe_load(fp.read()) return ContentPackConfig(file_path=config_path, config=config) diff --git a/st2common/st2common/util/crypto.py b/st2common/st2common/util/crypto.py index 230c4ada8e..d01e20557b 100644 --- a/st2common/st2common/util/crypto.py +++ b/st2common/st2common/util/crypto.py @@ -51,23 +51,18 @@ from cryptography.hazmat.backends import default_backend __all__ = [ - 'KEYCZAR_HEADER_SIZE', - 'KEYCZAR_AES_BLOCK_SIZE', - 'KEYCZAR_HLEN', - - 'read_crypto_key', - - 'symmetric_encrypt', - 'symmetric_decrypt', - - 'cryptography_symmetric_encrypt', - 'cryptography_symmetric_decrypt', - + "KEYCZAR_HEADER_SIZE", + "KEYCZAR_AES_BLOCK_SIZE", + "KEYCZAR_HLEN", + "read_crypto_key", + "symmetric_encrypt", + "symmetric_decrypt", + "cryptography_symmetric_encrypt", + "cryptography_symmetric_decrypt", # NOTE: Keyczar functions are here for testing reasons - they are only used by tests - 'keyczar_symmetric_encrypt', - 'keyczar_symmetric_decrypt', - - 'AESKey' + "keyczar_symmetric_encrypt", + "keyczar_symmetric_decrypt", + "AESKey", ] # Keyczar related constants @@ -94,13 +89,19 @@ class AESKey(object): mode = None size = None - def __init__(self, aes_key_string, hmac_key_string, hmac_key_size, mode='CBC', - size=DEFAULT_AES_KEY_SIZE): - if mode not in ['CBC']: - raise ValueError('Unsupported mode: %s' % (mode)) + def __init__( + self, + aes_key_string, + hmac_key_string, + hmac_key_size, + mode="CBC", + size=DEFAULT_AES_KEY_SIZE, + ): + if mode not in ["CBC"]: + raise ValueError("Unsupported mode: %s" % (mode)) if size < MINIMUM_AES_KEY_SIZE: - raise ValueError('Unsafe key size: %s' % (size)) + raise ValueError("Unsafe key size: %s" % (size)) self.aes_key_string = aes_key_string self.hmac_key_string = hmac_key_string @@ -121,7 +122,7 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE): :rtype: :class:`AESKey` """ if key_size < MINIMUM_AES_KEY_SIZE: - raise ValueError('Unsafe key size: %s' % (key_size)) + raise ValueError("Unsafe key size: %s" % (key_size)) aes_key_bytes = os.urandom(int(key_size / 8)) aes_key_string = Base64WSEncode(aes_key_bytes) @@ -129,8 +130,13 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE): hmac_key_bytes = os.urandom(int(key_size / 8)) hmac_key_string = Base64WSEncode(hmac_key_bytes) - return AESKey(aes_key_string=aes_key_string, hmac_key_string=hmac_key_string, - hmac_key_size=key_size, mode='CBC', size=key_size) + return AESKey( + aes_key_string=aes_key_string, + hmac_key_string=hmac_key_string, + hmac_key_size=key_size, + mode="CBC", + size=key_size, + ) def to_json(self): """ @@ -140,19 +146,22 @@ def to_json(self): :rtype: ``str`` """ data = { - 'hmacKey': { - 'hmacKeyString': self.hmac_key_string, - 'size': self.hmac_key_size + "hmacKey": { + "hmacKeyString": self.hmac_key_string, + "size": self.hmac_key_size, }, - 'aesKeyString': self.aes_key_string, - 'mode': self.mode.upper(), - 'size': int(self.size) + "aesKeyString": self.aes_key_string, + "mode": self.mode.upper(), + "size": int(self.size), } return json.dumps(data) def __repr__(self): - return ('' % (self.hmac_key_size, self.mode, - self.size)) + return "" % ( + self.hmac_key_size, + self.mode, + self.size, + ) def read_crypto_key(key_path): @@ -164,17 +173,19 @@ def read_crypto_key(key_path): :rtype: :class:`AESKey` """ - with open(key_path, 'r') as fp: + with open(key_path, "r") as fp: content = fp.read() content = json.loads(content) try: - aes_key = AESKey(aes_key_string=content['aesKeyString'], - hmac_key_string=content['hmacKey']['hmacKeyString'], - hmac_key_size=content['hmacKey']['size'], - mode=content['mode'].upper(), - size=content['size']) + aes_key = AESKey( + aes_key_string=content["aesKeyString"], + hmac_key_string=content["hmacKey"]["hmacKeyString"], + hmac_key_size=content["hmacKey"]["size"], + mode=content["mode"].upper(), + size=content["size"], + ) except KeyError as e: msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e)) raise KeyError(msg) @@ -187,7 +198,9 @@ def symmetric_encrypt(encrypt_key, plaintext): def symmetric_decrypt(decrypt_key, ciphertext): - return cryptography_symmetric_decrypt(decrypt_key=decrypt_key, ciphertext=ciphertext) + return cryptography_symmetric_decrypt( + decrypt_key=decrypt_key, ciphertext=ciphertext + ) def cryptography_symmetric_encrypt(encrypt_key, plaintext): @@ -206,9 +219,12 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): NOTE: Header itself is unused, but it's added so the format is compatible with keyczar format. """ - assert isinstance(encrypt_key, AESKey), 'encrypt_key needs to be AESKey class instance' - assert isinstance(plaintext, (six.text_type, six.string_types, six.binary_type)), \ - 'plaintext needs to either be a string/unicode or bytes' + assert isinstance( + encrypt_key, AESKey + ), "encrypt_key needs to be AESKey class instance" + assert isinstance( + plaintext, (six.text_type, six.string_types, six.binary_type) + ), "plaintext needs to either be a string/unicode or bytes" aes_key_bytes = encrypt_key.aes_key_bytes hmac_key_bytes = encrypt_key.hmac_key_bytes @@ -218,7 +234,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): if isinstance(plaintext, (six.text_type, six.string_types)): # Convert data to bytes - data = plaintext.encode('utf-8') + data = plaintext.encode("utf-8") else: data = plaintext @@ -234,7 +250,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): # NOTE: We don't care about actual Keyczar header value, we only care about the length (5 # bytes) so we simply add 5 0's - header_bytes = b'00000' + header_bytes = b"00000" ciphertext_bytes = encryptor.update(data) + encryptor.finalize() msg_bytes = header_bytes + iv_bytes + ciphertext_bytes @@ -263,9 +279,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): NOTE 2: This function is loosely based on keyczar AESKey.Decrypt() (Apache 2.0 license). """ - assert isinstance(decrypt_key, AESKey), 'decrypt_key needs to be AESKey class instance' - assert isinstance(ciphertext, (six.text_type, six.string_types, six.binary_type)), \ - 'ciphertext needs to either be a string/unicode or bytes' + assert isinstance( + decrypt_key, AESKey + ), "decrypt_key needs to be AESKey class instance" + assert isinstance( + ciphertext, (six.text_type, six.string_types, six.binary_type) + ), "ciphertext needs to either be a string/unicode or bytes" aes_key_bytes = decrypt_key.aes_key_bytes hmac_key_bytes = decrypt_key.hmac_key_bytes @@ -280,10 +299,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): # Verify ciphertext contains IV + HMAC signature if len(data_bytes) < (KEYCZAR_AES_BLOCK_SIZE + KEYCZAR_HLEN): - raise ValueError('Invalid or malformed ciphertext (too short)') + raise ValueError("Invalid or malformed ciphertext (too short)") iv_bytes = data_bytes[:KEYCZAR_AES_BLOCK_SIZE] # first block is IV - ciphertext_bytes = data_bytes[KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN] # strip IV and signature + ciphertext_bytes = data_bytes[ + KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN + ] # strip IV and signature signature_bytes = data_bytes[-KEYCZAR_HLEN:] # last 20 bytes are signature # Verify HMAC signature @@ -302,6 +323,7 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): decrypted = pkcs5_unpad(decrypted) return decrypted + ### # NOTE: Those methods below are deprecated and only used for testing purposes ## @@ -329,11 +351,12 @@ def keyczar_symmetric_encrypt(encrypt_key, plaintext): from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error from keyczar.keyinfo import GetMode # pylint: disable=import-error - encrypt_key = KeyczarAesKey(encrypt_key.aes_key_string, - KeyczarHmacKey(encrypt_key.hmac_key_string, - encrypt_key.hmac_key_size), - encrypt_key.size, - GetMode(encrypt_key.mode)) + encrypt_key = KeyczarAesKey( + encrypt_key.aes_key_string, + KeyczarHmacKey(encrypt_key.hmac_key_string, encrypt_key.hmac_key_size), + encrypt_key.size, + GetMode(encrypt_key.mode), + ) return binascii.hexlify(encrypt_key.Encrypt(plaintext)).upper() @@ -356,11 +379,12 @@ def keyczar_symmetric_decrypt(decrypt_key, ciphertext): from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error from keyczar.keyinfo import GetMode # pylint: disable=import-error - decrypt_key = KeyczarAesKey(decrypt_key.aes_key_string, - KeyczarHmacKey(decrypt_key.hmac_key_string, - decrypt_key.hmac_key_size), - decrypt_key.size, - GetMode(decrypt_key.mode)) + decrypt_key = KeyczarAesKey( + decrypt_key.aes_key_string, + KeyczarHmacKey(decrypt_key.hmac_key_string, decrypt_key.hmac_key_size), + decrypt_key.size, + GetMode(decrypt_key.mode), + ) return decrypt_key.Decrypt(binascii.unhexlify(ciphertext)) @@ -370,7 +394,7 @@ def pkcs5_pad(data): Pad data using PKCS5 """ pad = KEYCZAR_AES_BLOCK_SIZE - len(data) % KEYCZAR_AES_BLOCK_SIZE - data = data + pad * chr(pad).encode('utf-8') + data = data + pad * chr(pad).encode("utf-8") return data @@ -380,7 +404,7 @@ def pkcs5_unpad(data): """ if isinstance(data, six.binary_type): # Make sure we are operating with a string type - data = data.decode('utf-8') + data = data.decode("utf-8") pad = ord(data[-1]) data = data[:-pad] @@ -404,9 +428,9 @@ def Base64WSEncode(s): """ if isinstance(s, six.text_type): # Make sure input string is always converted to bytes (if not already) - s = s.encode('utf-8') + s = s.encode("utf-8") - return base64.urlsafe_b64encode(s).decode('utf-8').replace("=", "") + return base64.urlsafe_b64encode(s).decode("utf-8").replace("=", "") def Base64WSDecode(s): @@ -427,12 +451,12 @@ def Base64WSDecode(s): NOTE: Taken from keyczar (Apache 2.0 license) """ - s = ''.join(s.splitlines()) + s = "".join(s.splitlines()) s = str(s.replace(" ", "")) # kill whitespace, make string (not unicode) d = len(s) % 4 if d == 1: - raise ValueError('Base64 decoding errors') + raise ValueError("Base64 decoding errors") elif d == 2: s += "==" elif d == 3: @@ -442,4 +466,4 @@ def Base64WSDecode(s): return base64.urlsafe_b64decode(s) except TypeError as e: # Decoding raises TypeError if s contains invalid characters. - raise ValueError('Base64 decoding error: %s' % (six.text_type(e))) + raise ValueError("Base64 decoding error: %s" % (six.text_type(e))) diff --git a/st2common/st2common/util/date.py b/st2common/st2common/util/date.py index 979c3e8eb3..8df0e4659f 100644 --- a/st2common/st2common/util/date.py +++ b/st2common/st2common/util/date.py @@ -24,12 +24,7 @@ import dateutil.parser -__all__ = [ - 'get_datetime_utc_now', - 'add_utc_tz', - 'convert_to_utc', - 'parse' -] +__all__ = ["get_datetime_utc_now", "add_utc_tz", "convert_to_utc", "parse"] def get_datetime_utc_now(): @@ -45,14 +40,14 @@ def get_datetime_utc_now(): def append_milliseconds_to_time(date, millis): """ - Return time UTC datetime object offset by provided milliseconds. + Return time UTC datetime object offset by provided milliseconds. """ return convert_to_utc(date + datetime.timedelta(milliseconds=millis)) def add_utc_tz(dt): if dt.tzinfo and dt.tzinfo.utcoffset(dt) != datetime.timedelta(0): - raise ValueError('datetime already contains a non UTC timezone') + raise ValueError("datetime already contains a non UTC timezone") return dt.replace(tzinfo=dateutil.tz.tzutc()) diff --git a/st2common/st2common/util/debugging.py b/st2common/st2common/util/debugging.py index dd5d74d2a2..66abbbe1ad 100644 --- a/st2common/st2common/util/debugging.py +++ b/st2common/st2common/util/debugging.py @@ -25,11 +25,7 @@ from st2common.logging.misc import set_log_level_for_all_loggers -__all__ = [ - 'enable_debugging', - 'disable_debugging', - 'is_enabled' -] +__all__ = ["enable_debugging", "disable_debugging", "is_enabled"] ENABLE_DEBUGGING = False diff --git a/st2common/st2common/util/deprecation.py b/st2common/st2common/util/deprecation.py index 160423a5e2..a178a9473d 100644 --- a/st2common/st2common/util/deprecation.py +++ b/st2common/st2common/util/deprecation.py @@ -23,10 +23,14 @@ def deprecated(func): as deprecated. It will result in a warning being emitted when the function is used. """ + def new_func(*args, **kwargs): - warnings.warn("Call to deprecated function {}.".format(func.__name__), - category=DeprecationWarning) + warnings.warn( + "Call to deprecated function {}.".format(func.__name__), + category=DeprecationWarning, + ) return func(*args, **kwargs) + new_func.__name__ = func.__name__ new_func.__doc__ = func.__doc__ new_func.__dict__.update(func.__dict__) diff --git a/st2common/st2common/util/driver_loader.py b/st2common/st2common/util/driver_loader.py index 285c22ed79..50f5044c41 100644 --- a/st2common/st2common/util/driver_loader.py +++ b/st2common/st2common/util/driver_loader.py @@ -21,15 +21,11 @@ from st2common import log as logging -__all__ = [ - 'get_available_backends', - 'get_backend_driver', - 'get_backend_instance' -] +__all__ = ["get_available_backends", "get_backend_driver", "get_backend_instance"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2common.rbac.backend' +BACKENDS_NAMESPACE = "st2common.rbac.backend" def get_available_backends(namespace, invoke_on_load=False): @@ -62,8 +58,9 @@ def get_backend_driver(namespace, name, invoke_on_load=False): LOG.debug('Retrieving driver for backend "%s"' % (name)) try: - manager = DriverManager(namespace=namespace, name=name, - invoke_on_load=invoke_on_load) + manager = DriverManager( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) except RuntimeError: message = 'Invalid "%s" backend specified: %s' % (namespace, name) LOG.exception(message) @@ -79,7 +76,9 @@ def get_backend_instance(namespace, name, invoke_on_load=False): :param name: Backend name. :type name: ``str`` """ - cls = get_backend_driver(namespace=namespace, name=name, invoke_on_load=invoke_on_load) + cls = get_backend_driver( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) cls_instance = cls() return cls_instance diff --git a/st2common/st2common/util/enum.py b/st2common/st2common/util/enum.py index ddcc138ea5..84a6e968f5 100644 --- a/st2common/st2common/util/enum.py +++ b/st2common/st2common/util/enum.py @@ -16,15 +16,16 @@ from __future__ import absolute_import import inspect -__all__ = [ - 'Enum' -] +__all__ = ["Enum"] class Enum(object): @classmethod def get_valid_values(cls): keys = list(cls.__dict__.keys()) - values = [getattr(cls, key) for key in keys if (not key.startswith('_') and - not inspect.ismethod(getattr(cls, key)))] + values = [ + getattr(cls, key) + for key in keys + if (not key.startswith("_") and not inspect.ismethod(getattr(cls, key))) + ] return values diff --git a/st2common/st2common/util/file_system.py b/st2common/st2common/util/file_system.py index d6d2458aec..e26adaedfd 100644 --- a/st2common/st2common/util/file_system.py +++ b/st2common/st2common/util/file_system.py @@ -26,10 +26,7 @@ import six -__all__ = [ - 'get_file_list', - 'recursive_chown' -] +__all__ = ["get_file_list", "recursive_chown"] def get_file_list(directory, exclude_patterns=None): @@ -48,9 +45,9 @@ def get_file_list(directory, exclude_patterns=None): :rtype: ``list`` """ result = [] - if not directory.endswith('/'): + if not directory.endswith("/"): # Make sure trailing slash is present - directory = directory + '/' + directory = directory + "/" def include_file(file_path): if not exclude_patterns: @@ -63,7 +60,7 @@ def include_file(file_path): return True for (dirpath, dirnames, filenames) in os.walk(directory): - base_path = dirpath.replace(directory, '') + base_path = dirpath.replace(directory, "") for filename in filenames: if base_path: diff --git a/st2common/st2common/util/green/shell.py b/st2common/st2common/util/green/shell.py index 4fd71ef7cf..4b6d79935b 100644 --- a/st2common/st2common/util/green/shell.py +++ b/st2common/st2common/util/green/shell.py @@ -27,20 +27,31 @@ from st2common import log as logging from st2common.util import concurrency -__all__ = [ - 'run_command' -] +__all__ = ["run_command"] TIMEOUT_EXIT_CODE = -9 LOG = logging.getLogger(__name__) -def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, - cwd=None, env=None, timeout=60, preexec_func=None, kill_func=None, - read_stdout_func=None, read_stderr_func=None, - read_stdout_buffer=None, read_stderr_buffer=None, stdin_value=None, - bufsize=0): +def run_command( + cmd, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + cwd=None, + env=None, + timeout=60, + preexec_func=None, + kill_func=None, + read_stdout_func=None, + read_stderr_func=None, + read_stdout_buffer=None, + read_stderr_buffer=None, + stdin_value=None, + bufsize=0, +): """ Run the provided command in a subprocess and wait until it completes. @@ -89,59 +100,77 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, :rtype: ``tuple`` (exit_code, stdout, stderr, timed_out) """ - LOG.debug('Entering st2common.util.green.run_command.') + LOG.debug("Entering st2common.util.green.run_command.") assert isinstance(cmd, (list, tuple) + six.string_types) - if (read_stdout_func and not read_stderr_func) or (read_stderr_func and not read_stdout_func): - raise ValueError('Both read_stdout_func and read_stderr_func arguments need ' - 'to be provided.') + if (read_stdout_func and not read_stderr_func) or ( + read_stderr_func and not read_stdout_func + ): + raise ValueError( + "Both read_stdout_func and read_stderr_func arguments need " + "to be provided." + ) if read_stdout_func and not (read_stdout_buffer or read_stderr_buffer): - raise ValueError('read_stdout_buffer and read_stderr_buffer arguments need to be provided ' - 'when read_stdout_func is provided') + raise ValueError( + "read_stdout_buffer and read_stderr_buffer arguments need to be provided " + "when read_stdout_func is provided" + ) if not env: - LOG.debug('env argument not provided. using process env (os.environ).') + LOG.debug("env argument not provided. using process env (os.environ).") env = os.environ.copy() subprocess = concurrency.get_subprocess_module() # Note: We are using eventlet / gevent friendly implementation of subprocess which uses # GreenPipe so it doesn't block - LOG.debug('Creating subprocess.') - process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr, - env=env, cwd=cwd, shell=shell, preexec_fn=preexec_func, - bufsize=bufsize) + LOG.debug("Creating subprocess.") + process = concurrency.subprocess_popen( + args=cmd, + stdin=stdin, + stdout=stdout, + stderr=stderr, + env=env, + cwd=cwd, + shell=shell, + preexec_fn=preexec_func, + bufsize=bufsize, + ) if read_stdout_func: - LOG.debug('Spawning read_stdout_func function') - read_stdout_thread = concurrency.spawn(read_stdout_func, process.stdout, read_stdout_buffer) + LOG.debug("Spawning read_stdout_func function") + read_stdout_thread = concurrency.spawn( + read_stdout_func, process.stdout, read_stdout_buffer + ) if read_stderr_func: - LOG.debug('Spawning read_stderr_func function') - read_stderr_thread = concurrency.spawn(read_stderr_func, process.stderr, read_stderr_buffer) + LOG.debug("Spawning read_stderr_func function") + read_stderr_thread = concurrency.spawn( + read_stderr_func, process.stderr, read_stderr_buffer + ) def on_timeout_expired(timeout): global timed_out try: - LOG.debug('Starting process wait inside timeout handler.') + LOG.debug("Starting process wait inside timeout handler.") process.wait(timeout=timeout) except subprocess.TimeoutExpired: # Command has timed out, kill the process and propagate the error. # Note: We explicitly set the returncode to indicate the timeout. - LOG.debug('Command execution timeout reached.') + LOG.debug("Command execution timeout reached.") # NOTE: It's important we set returncode twice - here and below to avoid race in this # function because "kill_func()" is async and "process.kill()" is not. process.returncode = TIMEOUT_EXIT_CODE if kill_func: - LOG.debug('Calling kill_func.') + LOG.debug("Calling kill_func.") kill_func(process=process) else: - LOG.debug('Killing process.') + LOG.debug("Killing process.") process.kill() # NOTE: It's imporant to set returncode here as well, since call to process.kill() sets @@ -149,25 +178,27 @@ def on_timeout_expired(timeout): process.returncode = TIMEOUT_EXIT_CODE if read_stdout_func and read_stderr_func: - LOG.debug('Killing read_stdout_thread and read_stderr_thread') + LOG.debug("Killing read_stdout_thread and read_stderr_thread") concurrency.kill(read_stdout_thread) concurrency.kill(read_stderr_thread) - LOG.debug('Spawning timeout handler thread.') + LOG.debug("Spawning timeout handler thread.") timeout_thread = concurrency.spawn(on_timeout_expired, timeout) - LOG.debug('Attaching to process.') + LOG.debug("Attaching to process.") if stdin_value: if six.PY3: - stdin_value = stdin_value.encode('utf-8') + stdin_value = stdin_value.encode("utf-8") process.stdin.write(stdin_value) if read_stdout_func and read_stderr_func: - LOG.debug('Using real-time stdout and stderr read mode, calling process.wait()') + LOG.debug("Using real-time stdout and stderr read mode, calling process.wait()") process.wait() else: - LOG.debug('Using delayed stdout and stderr read mode, calling process.communicate()') + LOG.debug( + "Using delayed stdout and stderr read mode, calling process.communicate()" + ) stdout, stderr = process.communicate() concurrency.cancel(timeout_thread) @@ -182,11 +213,11 @@ def on_timeout_expired(timeout): stderr = read_stderr_buffer.getvalue() if exit_code == TIMEOUT_EXIT_CODE: - LOG.debug('Timeout.') + LOG.debug("Timeout.") timed_out = True else: - LOG.debug('No timeout.') + LOG.debug("No timeout.") timed_out = False - LOG.debug('Returning.') + LOG.debug("Returning.") return (exit_code, stdout, stderr, timed_out) diff --git a/st2common/st2common/util/greenpooldispatch.py b/st2common/st2common/util/greenpooldispatch.py index d85ebfbf5d..156d530116 100644 --- a/st2common/st2common/util/greenpooldispatch.py +++ b/st2common/st2common/util/greenpooldispatch.py @@ -21,9 +21,7 @@ from st2common import log as logging -__all__ = [ - 'BufferedDispatcher' -] +__all__ = ["BufferedDispatcher"] # If the thread pool has been occupied with no empty threads for more than this number of seconds # a message will be logged @@ -38,14 +36,20 @@ class BufferedDispatcher(object): - - def __init__(self, dispatch_pool_size=50, monitor_thread_empty_q_sleep_time=5, - monitor_thread_no_workers_sleep_time=1, name=None): + def __init__( + self, + dispatch_pool_size=50, + monitor_thread_empty_q_sleep_time=5, + monitor_thread_no_workers_sleep_time=1, + name=None, + ): self._pool_limit = dispatch_pool_size self._dispatcher_pool = eventlet.GreenPool(dispatch_pool_size) self._dispatch_monitor_thread = eventlet.greenthread.spawn(self._flush) self._monitor_thread_empty_q_sleep_time = monitor_thread_empty_q_sleep_time - self._monitor_thread_no_workers_sleep_time = monitor_thread_no_workers_sleep_time + self._monitor_thread_no_workers_sleep_time = ( + monitor_thread_no_workers_sleep_time + ) self._name = name self._work_buffer = six.moves.queue.Queue() @@ -77,7 +81,9 @@ def _flush_now(self): now = time.time() if (now - self._pool_last_free_ts) >= POOL_BUSY_THRESHOLD_SECONDS: - LOG.info(POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS)) + LOG.info( + POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS) + ) return @@ -90,8 +96,15 @@ def _flush_now(self): def __repr__(self): free_count = self._dispatcher_pool.free() - values = (self.name, self._pool_limit, free_count, self._monitor_thread_empty_q_sleep_time, - self._monitor_thread_no_workers_sleep_time) - return ('' % - values) + values = ( + self.name, + self._pool_limit, + free_count, + self._monitor_thread_empty_q_sleep_time, + self._monitor_thread_no_workers_sleep_time, + ) + return ( + "" + % values + ) diff --git a/st2common/st2common/util/gunicorn_workers.py b/st2common/st2common/util/gunicorn_workers.py index 61eebe84e4..69942ac309 100644 --- a/st2common/st2common/util/gunicorn_workers.py +++ b/st2common/st2common/util/gunicorn_workers.py @@ -20,9 +20,7 @@ import six from gunicorn.workers.sync import SyncWorker -__all__ = [ - 'EventletSyncWorker' -] +__all__ = ["EventletSyncWorker"] class EventletSyncWorker(SyncWorker): @@ -44,7 +42,7 @@ def handle_quit(self, sig, frame): except AssertionError as e: msg = six.text_type(e) - if 'do not call blocking functions from the mainloop' in msg: + if "do not call blocking functions from the mainloop" in msg: # Workaround for "do not call blocking functions from the mainloop" issue sys.exit(0) diff --git a/st2common/st2common/util/hash.py b/st2common/st2common/util/hash.py index f0a5596379..3d7c83328c 100644 --- a/st2common/st2common/util/hash.py +++ b/st2common/st2common/util/hash.py @@ -19,12 +19,10 @@ import hashlib -__all__ = [ - 'hash' -] +__all__ = ["hash"] -FIXED_SALT = 'saltnpepper' +FIXED_SALT = "saltnpepper" def hash(value, salt=FIXED_SALT): diff --git a/st2common/st2common/util/http.py b/st2common/st2common/util/http.py index e11a277be6..26aa6d445d 100644 --- a/st2common/st2common/util/http.py +++ b/st2common/st2common/util/http.py @@ -18,17 +18,20 @@ http_client = six.moves.http_client -__all__ = [ - 'HTTP_SUCCESS', - 'parse_content_type_header' +__all__ = ["HTTP_SUCCESS", "parse_content_type_header"] + +HTTP_SUCCESS = [ + http_client.OK, + http_client.CREATED, + http_client.ACCEPTED, + http_client.NON_AUTHORITATIVE_INFORMATION, + http_client.NO_CONTENT, + http_client.RESET_CONTENT, + http_client.PARTIAL_CONTENT, + http_client.MULTI_STATUS, + http_client.IM_USED, ] -HTTP_SUCCESS = [http_client.OK, http_client.CREATED, http_client.ACCEPTED, - http_client.NON_AUTHORITATIVE_INFORMATION, http_client.NO_CONTENT, - http_client.RESET_CONTENT, http_client.PARTIAL_CONTENT, - http_client.MULTI_STATUS, http_client.IM_USED, - ] - def parse_content_type_header(content_type): """ @@ -37,13 +40,13 @@ def parse_content_type_header(content_type): :rype: ``tuple`` """ - if ';' in content_type: - split = content_type.split(';') + if ";" in content_type: + split = content_type.split(";") media = split[0] options = {} for pair in split[1:]: - split_pair = pair.split('=', 1) + split_pair = pair.split("=", 1) if len(split_pair) != 2: continue diff --git a/st2common/st2common/util/ip_utils.py b/st2common/st2common/util/ip_utils.py index 4e2a00357a..53253432d8 100644 --- a/st2common/st2common/util/ip_utils.py +++ b/st2common/st2common/util/ip_utils.py @@ -21,11 +21,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'is_ipv4', - 'is_ipv6', - 'split_host_port' -] +__all__ = ["is_ipv4", "is_ipv6", "split_host_port"] BRACKET_PATTERN = r"^\[.*\]" # IPv6 bracket pattern to specify port COMPILED_BRACKET_PATTERN = re.compile(BRACKET_PATTERN) @@ -91,30 +87,32 @@ def split_host_port(host_str): # Check if it's square bracket style. match = COMPILED_BRACKET_PATTERN.match(host_str) if match: - LOG.debug('Square bracket style.') + LOG.debug("Square bracket style.") # Check if square bracket style no port. match = COMPILED_HOST_ONLY_IN_BRACKET_PATTERN.match(host_str) if match: - hostname = match.group().strip('[]') + hostname = match.group().strip("[]") return (hostname, port) - hostname, separator, port = hostname.rpartition(':') + hostname, separator, port = hostname.rpartition(":") try: - LOG.debug('host_str: %s, hostname: %s port: %s' % (host_str, hostname, port)) + LOG.debug( + "host_str: %s, hostname: %s port: %s" % (host_str, hostname, port) + ) port = int(port) - hostname = hostname.strip('[]') + hostname = hostname.strip("[]") return (hostname, port) except: - raise Exception('Invalid port %s specified.' % port) + raise Exception("Invalid port %s specified." % port) else: - LOG.debug('Non-bracket address. host_str: %s' % host_str) - if ':' in host_str: - LOG.debug('Non-bracket with port.') - hostname, separator, port = hostname.rpartition(':') + LOG.debug("Non-bracket address. host_str: %s" % host_str) + if ":" in host_str: + LOG.debug("Non-bracket with port.") + hostname, separator, port = hostname.rpartition(":") try: port = int(port) return (hostname, port) except: - raise Exception('Invalid port %s specified.' % port) + raise Exception("Invalid port %s specified." % port) return (hostname, port) diff --git a/st2common/st2common/util/isotime.py b/st2common/st2common/util/isotime.py index 0830393bf8..0c6ca1c4d4 100644 --- a/st2common/st2common/util/isotime.py +++ b/st2common/st2common/util/isotime.py @@ -25,17 +25,14 @@ from st2common.util import date as date_utils import six -__all__ = [ - 'format', - 'validate', - 'parse' -] +__all__ = ["format", "validate", "parse"] -ISO8601_FORMAT = '%Y-%m-%dT%H:%M:%S' -ISO8601_FORMAT_MICROSECOND = '%Y-%m-%dT%H:%M:%S.%f' -ISO8601_UTC_REGEX = \ - r'^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$' +ISO8601_FORMAT = "%Y-%m-%dT%H:%M:%S" +ISO8601_FORMAT_MICROSECOND = "%Y-%m-%dT%H:%M:%S.%f" +ISO8601_UTC_REGEX = ( + r"^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$" +) def format(dt, usec=True, offset=True): @@ -53,20 +50,21 @@ def format(dt, usec=True, offset=True): fmt = ISO8601_FORMAT_MICROSECOND if usec else ISO8601_FORMAT if offset: - ost = dt.strftime('%z') - ost = (ost[:3] + ':' + ost[3:]) if ost else '+00:00' + ost = dt.strftime("%z") + ost = (ost[:3] + ":" + ost[3:]) if ost else "+00:00" else: - tz = dt.tzinfo.tzname(dt) if dt.tzinfo else 'UTC' - ost = 'Z' if tz == 'UTC' else tz + tz = dt.tzinfo.tzname(dt) if dt.tzinfo else "UTC" + ost = "Z" if tz == "UTC" else tz return dt.strftime(fmt) + ost def validate(value, raise_exception=True): - if (isinstance(value, datetime.datetime) or - (type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value))): + if isinstance(value, datetime.datetime) or ( + type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value) + ): return True if raise_exception: - raise ValueError('Datetime value does not match expected format.') + raise ValueError("Datetime value does not match expected format.") return False diff --git a/st2common/st2common/util/jinja.py b/st2common/st2common/util/jinja.py index 9234986f9f..4472246908 100644 --- a/st2common/st2common/util/jinja.py +++ b/st2common/st2common/util/jinja.py @@ -22,21 +22,14 @@ from st2common.util.compat import to_unicode -__all__ = [ - 'get_jinja_environment', - 'render_values', - 'is_jinja_expression' -] +__all__ = ["get_jinja_environment", "render_values", "is_jinja_expression"] -JINJA_EXPRESSIONS_START_MARKERS = [ - '{{', - '{%' -] +JINJA_EXPRESSIONS_START_MARKERS = ["{{", "{%"] -JINJA_REGEX = '({{(.*)}})' +JINJA_REGEX = "({{(.*)}})" JINJA_REGEX_PTRN = re.compile(JINJA_REGEX) -JINJA_BLOCK_REGEX = '({%(.*)%})' +JINJA_BLOCK_REGEX = "({%(.*)%})" JINJA_BLOCK_REGEX_PTRN = re.compile(JINJA_BLOCK_REGEX) @@ -53,59 +46,52 @@ def get_filters(): from st2common.expressions.functions import path return { - 'decrypt_kv': datastore.decrypt_kv, - - 'from_json_string': data.from_json_string, - 'from_yaml_string': data.from_yaml_string, - 'json_escape': data.json_escape, - 'jsonpath_query': data.jsonpath_query, - 'to_complex': data.to_complex, - 'to_json_string': data.to_json_string, - 'to_yaml_string': data.to_yaml_string, - - 'regex_match': regex.regex_match, - 'regex_replace': regex.regex_replace, - 'regex_search': regex.regex_search, - 'regex_substring': regex.regex_substring, - - 'to_human_time_from_seconds': time.to_human_time_from_seconds, - - 'version_compare': version.version_compare, - 'version_more_than': version.version_more_than, - 'version_less_than': version.version_less_than, - 'version_equal': version.version_equal, - 'version_match': version.version_match, - 'version_bump_major': version.version_bump_major, - 'version_bump_minor': version.version_bump_minor, - 'version_bump_patch': version.version_bump_patch, - 'version_strip_patch': version.version_strip_patch, - 'use_none': data.use_none, - - 'basename': path.basename, - 'dirname': path.dirname + "decrypt_kv": datastore.decrypt_kv, + "from_json_string": data.from_json_string, + "from_yaml_string": data.from_yaml_string, + "json_escape": data.json_escape, + "jsonpath_query": data.jsonpath_query, + "to_complex": data.to_complex, + "to_json_string": data.to_json_string, + "to_yaml_string": data.to_yaml_string, + "regex_match": regex.regex_match, + "regex_replace": regex.regex_replace, + "regex_search": regex.regex_search, + "regex_substring": regex.regex_substring, + "to_human_time_from_seconds": time.to_human_time_from_seconds, + "version_compare": version.version_compare, + "version_more_than": version.version_more_than, + "version_less_than": version.version_less_than, + "version_equal": version.version_equal, + "version_match": version.version_match, + "version_bump_major": version.version_bump_major, + "version_bump_minor": version.version_bump_minor, + "version_bump_patch": version.version_bump_patch, + "version_strip_patch": version.version_strip_patch, + "use_none": data.use_none, + "basename": path.basename, + "dirname": path.dirname, } def get_jinja_environment(allow_undefined=False, trim_blocks=True, lstrip_blocks=True): - ''' + """ jinja2.Environment object that is setup with right behaviors and custom filters. :param strict_undefined: If should allow undefined variables in templates :type strict_undefined: ``bool`` - ''' + """ # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used import jinja2 undefined = jinja2.Undefined if allow_undefined else jinja2.StrictUndefined env = jinja2.Environment( # nosec - undefined=undefined, - trim_blocks=trim_blocks, - lstrip_blocks=lstrip_blocks + undefined=undefined, trim_blocks=trim_blocks, lstrip_blocks=lstrip_blocks ) env.filters.update(get_filters()) - env.tests['in'] = lambda item, list: item in list + env.tests["in"] = lambda item, list: item in list return env @@ -130,7 +116,7 @@ def render_values(mapping=None, context=None, allow_undefined=False): # This mean __context is a reserve key word although backwards compat is preserved by making # sure that real context is updated later and therefore will override the __context value. super_context = {} - super_context['__context'] = context + super_context["__context"] = context super_context.update(context) env = get_jinja_environment(allow_undefined=allow_undefined) @@ -150,7 +136,7 @@ def render_values(mapping=None, context=None, allow_undefined=False): v = str(v) try: - LOG.info('Rendering string %s. Super context=%s', v, super_context) + LOG.info("Rendering string %s. Super context=%s", v, super_context) rendered_v = env.from_string(v).render(super_context) except Exception as e: # Attach key and value which failed the rendering @@ -166,7 +152,12 @@ def render_values(mapping=None, context=None, allow_undefined=False): if reverse_json_dumps: rendered_v = json.loads(rendered_v) rendered_mapping[k] = rendered_v - LOG.info('Mapping: %s, rendered_mapping: %s, context: %s', mapping, rendered_mapping, context) + LOG.info( + "Mapping: %s, rendered_mapping: %s, context: %s", + mapping, + rendered_mapping, + context, + ) return rendered_mapping @@ -194,6 +185,6 @@ def convert_jinja_to_raw_block(value): if isinstance(value, six.string_types): if JINJA_REGEX_PTRN.findall(value) or JINJA_BLOCK_REGEX_PTRN.findall(value): - return '{% raw %}' + value + '{% endraw %}' + return "{% raw %}" + value + "{% endraw %}" return value diff --git a/st2common/st2common/util/jsonify.py b/st2common/st2common/util/jsonify.py index 1f47cec1b0..16a95dde99 100644 --- a/st2common/st2common/util/jsonify.py +++ b/st2common/st2common/util/jsonify.py @@ -25,18 +25,12 @@ import six -__all__ = [ - 'json_encode', - 'json_loads', - 'try_loads', - - 'get_json_type_for_python_value' -] +__all__ = ["json_encode", "json_loads", "try_loads", "get_json_type_for_python_value"] class GenericJSON(JSONEncoder): def default(self, obj): # pylint: disable=method-hidden - if hasattr(obj, '__json__') and six.callable(obj.__json__): + if hasattr(obj, "__json__") and six.callable(obj.__json__): return obj.__json__() else: return JSONEncoder.default(self, obj) @@ -47,7 +41,7 @@ def json_encode(obj, indent=4): def load_file(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return json.load(fd) @@ -92,16 +86,16 @@ def get_json_type_for_python_value(value): :rtype: ``str`` """ if isinstance(value, six.text_type): - return 'string' + return "string" elif isinstance(value, (int, float)): - return 'number' + return "number" elif isinstance(value, dict): - return 'object' + return "object" elif isinstance(value, (list, tuple)): - return 'array' + return "array" elif isinstance(value, bool): - return 'boolean' + return "boolean" elif value is None: - return 'null' + return "null" else: - return 'unknown' + return "unknown" diff --git a/st2common/st2common/util/keyvalue.py b/st2common/st2common/util/keyvalue.py index 05246d2d32..cad32250a8 100644 --- a/st2common/st2common/util/keyvalue.py +++ b/st2common/st2common/util/keyvalue.py @@ -24,22 +24,23 @@ from st2common.rbac.backends import get_rbac_backend from st2common.persistence.keyvalue import KeyValuePair from st2common.services.config import deserialize_key_value -from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE, - ALLOWED_SCOPES) +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + USER_SCOPE, + ALLOWED_SCOPES, +) from st2common.models.db.auth import UserDB from st2common.exceptions.rbac import AccessDeniedError -__all__ = [ - 'get_datastore_full_scope', - 'get_key' -] +__all__ = ["get_datastore_full_scope", "get_key"] LOG = logging.getLogger(__name__) def _validate_scope(scope): if scope not in ALLOWED_SCOPES: - msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES) + msg = "Scope %s is not in allowed scopes list: %s." % (scope, ALLOWED_SCOPES) raise ValueError(msg) @@ -48,9 +49,9 @@ def _validate_decrypt_query_parameter(decrypt, scope, is_admin, user_db): Validate that the provider user is either admin or requesting to decrypt value for themselves. """ - is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE) + is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE if decrypt and (not is_user_scope and not is_admin): - msg = 'Decrypt option requires administrator access' + msg = "Decrypt option requires administrator access" raise AccessDeniedError(message=msg, user_db=user_db) @@ -61,7 +62,7 @@ def get_datastore_full_scope(scope): if DATASTORE_PARENT_SCOPE in scope: return scope - return '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope) + return "%s%s%s" % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope) def _derive_scope_and_key(key, user, scope=None): @@ -75,10 +76,10 @@ def _derive_scope_and_key(key, user, scope=None): if scope is not None: return scope, key - if key.startswith('system.'): - return FULL_SYSTEM_SCOPE, key[key.index('.') + 1:] + if key.startswith("system."): + return FULL_SYSTEM_SCOPE, key[key.index(".") + 1 :] - return FULL_USER_SCOPE, '%s:%s' % (user, key) + return FULL_USER_SCOPE, "%s:%s" % (user, key) def get_key(key=None, user_db=None, scope=None, decrypt=False): @@ -86,10 +87,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): Retrieve key from KVP store """ if not isinstance(key, six.string_types): - raise TypeError('Given key is not typeof string.') + raise TypeError("Given key is not typeof string.") if not isinstance(decrypt, bool): - raise TypeError('Decrypt parameter is not typeof bool.') + raise TypeError("Decrypt parameter is not typeof bool.") if not user_db: # Use system user @@ -98,9 +99,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): scope, key_id = _derive_scope_and_key(key=key, user=user_db.name, scope=scope) scope = get_datastore_full_scope(scope) - LOG.debug('get_key key_id: %s, scope: %s, user: %s, decrypt: %s' % (key_id, scope, - str(user_db.name), - decrypt)) + LOG.debug( + "get_key key_id: %s, scope: %s, user: %s, decrypt: %s" + % (key_id, scope, str(user_db.name), decrypt) + ) _validate_scope(scope=scope) @@ -108,8 +110,9 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): is_admin = rbac_utils.user_is_admin(user_db=user_db) # User needs to be either admin or requesting item for itself - _validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, is_admin=is_admin, - user_db=user_db) + _validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, is_admin=is_admin, user_db=user_db + ) # Get the key value pair by scope and name. kvp = KeyValuePair.get_by_scope_and_name(scope, key_id) diff --git a/st2common/st2common/util/loader.py b/st2common/st2common/util/loader.py index 0e27a0da32..1c5a5a4b54 100644 --- a/st2common/st2common/util/loader.py +++ b/st2common/st2common/util/loader.py @@ -28,19 +28,14 @@ from st2common.exceptions.plugins import IncompatiblePluginException from st2common import log as logging -__all__ = [ - 'register_plugin', - 'register_plugin_class', - - 'load_meta_file' -] +__all__ = ["register_plugin", "register_plugin_class", "load_meta_file"] LOG = logging.getLogger(__name__) -PYTHON_EXTENSION = '.py' -ALLOWED_EXTS = ['.json', '.yaml', '.yml'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +PYTHON_EXTENSION = ".py" +ALLOWED_EXTS = [".json", ".yaml", ".yml"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} # Cache for dynamically loaded runner modules RUNNER_MODULES_CACHE = defaultdict(dict) @@ -48,7 +43,9 @@ def _register_plugin_path(plugin_dir_abs_path): if not os.path.isdir(plugin_dir_abs_path): - raise Exception('Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path)) + raise Exception( + 'Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path) + ) for x in sys.path: if plugin_dir_abs_path in (x, x + os.sep): @@ -59,15 +56,21 @@ def _register_plugin_path(plugin_dir_abs_path): def _get_plugin_module(plugin_file_path): plugin_module = os.path.basename(plugin_file_path) if plugin_module.endswith(PYTHON_EXTENSION): - plugin_module = plugin_module[:plugin_module.rfind('.py')] + plugin_module = plugin_module[: plugin_module.rfind(".py")] else: plugin_module = None return plugin_module def _get_classes_in_module(module): - return [kls for name, kls in inspect.getmembers(module, - lambda member: inspect.isclass(member) and member.__module__ == module.__name__)] + return [ + kls + for name, kls in inspect.getmembers( + module, + lambda member: inspect.isclass(member) + and member.__module__ == module.__name__, + ) + ] def _get_plugin_classes(module_name): @@ -92,7 +95,7 @@ def _get_plugin_methods(plugin_klass): method_names = [] for name, method in methods: method_properties = method.__dict__ - is_abstract = method_properties.get('__isabstractmethod__', False) + is_abstract = method_properties.get("__isabstractmethod__", False) if is_abstract: continue @@ -102,16 +105,18 @@ def _get_plugin_methods(plugin_klass): def _validate_methods(plugin_base_class, plugin_klass): - ''' + """ XXX: This is hacky but we'd like to validate the methods in plugin_impl at least has all the *abstract* methods in plugin_base_class. - ''' + """ expected_methods = plugin_base_class.__abstractmethods__ plugin_methods = _get_plugin_methods(plugin_klass) for method in expected_methods: if method not in plugin_methods: - message = 'Class "%s" doesn\'t implement required "%s" method from the base class' + message = ( + 'Class "%s" doesn\'t implement required "%s" method from the base class' + ) raise IncompatiblePluginException(message % (plugin_klass.__name__, method)) @@ -147,8 +152,10 @@ def register_plugin_class(base_class, file_path, class_name): klass = getattr(module, class_name, None) if not klass: - raise Exception('Plugin file "%s" doesn\'t expose class named "%s"' % - (file_path, class_name)) + raise Exception( + 'Plugin file "%s" doesn\'t expose class named "%s"' + % (file_path, class_name) + ) _register_plugin(base_class, klass) return klass @@ -173,12 +180,14 @@ def register_plugin(plugin_base_class, plugin_abs_file_path): registered_plugins.append(klass) except Exception as e: LOG.exception(e) - LOG.debug('Skipping class %s as it doesn\'t match specs.', klass) + LOG.debug("Skipping class %s as it doesn't match specs.", klass) continue if len(registered_plugins) == 0: - raise Exception('Found no classes in plugin file "%s" matching requirements.' % - (plugin_abs_file_path)) + raise Exception( + 'Found no classes in plugin file "%s" matching requirements.' + % (plugin_abs_file_path) + ) return registered_plugins @@ -189,16 +198,17 @@ def load_meta_file(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return PARSER_FUNCS[file_ext](f) def get_available_plugins(namespace): - """Return names of the available / installed plugins for a given namespace. - """ + """Return names of the available / installed plugins for a given namespace.""" from stevedore.extension import ExtensionManager manager = ExtensionManager(namespace=namespace, invoke_on_load=False) @@ -206,9 +216,10 @@ def get_available_plugins(namespace): def get_plugin_instance(namespace, name, invoke_on_load=True): - """Return class instance for the provided plugin name and namespace. - """ + """Return class instance for the provided plugin name and namespace.""" from stevedore.driver import DriverManager - manager = DriverManager(namespace=namespace, name=name, invoke_on_load=invoke_on_load) + manager = DriverManager( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) return manager.driver diff --git a/st2common/st2common/util/misc.py b/st2common/st2common/util/misc.py index 28773abedb..6a1027e9fe 100644 --- a/st2common/st2common/util/misc.py +++ b/st2common/st2common/util/misc.py @@ -26,18 +26,17 @@ import six __all__ = [ - 'prefix_dict_keys', - 'compare_path_file_name', - 'get_field_name_from_mongoengine_error', - - 'sanitize_output', - 'strip_shell_chars', - 'rstrip_last_char', - 'lowercase_value' + "prefix_dict_keys", + "compare_path_file_name", + "get_field_name_from_mongoengine_error", + "sanitize_output", + "strip_shell_chars", + "rstrip_last_char", + "lowercase_value", ] -def prefix_dict_keys(dictionary, prefix='_'): +def prefix_dict_keys(dictionary, prefix="_"): """ Prefix dictionary keys with a provided prefix. @@ -52,7 +51,7 @@ def prefix_dict_keys(dictionary, prefix='_'): result = {} for key, value in six.iteritems(dictionary): - result['%s%s' % (prefix, key)] = value + result["%s%s" % (prefix, key)] = value return result @@ -89,7 +88,7 @@ def sanitize_output(input_str, uses_pty=False): output = strip_shell_chars(input_str) if uses_pty: - output = output.replace('\r\n', '\n') + output = output.replace("\r\n", "\n") return output @@ -105,8 +104,8 @@ def strip_shell_chars(input_str): :rtype: ``str`` """ - stripped_str = rstrip_last_char(input_str, '\n') - stripped_str = rstrip_last_char(stripped_str, '\r') + stripped_str = rstrip_last_char(input_str, "\n") + stripped_str = rstrip_last_char(stripped_str, "\r") return stripped_str @@ -127,7 +126,7 @@ def rstrip_last_char(input_str, char_to_strip): return input_str if input_str.endswith(char_to_strip): - return input_str[:-len(char_to_strip)] + return input_str[: -len(char_to_strip)] return input_str @@ -153,10 +152,10 @@ def get_normalized_file_path(file_path): :rtype: ``str`` """ - if hasattr(sys, 'frozen'): # support for py2exe - file_path = 'logging%s__init__%s' % (os.sep, file_path[-4:]) - elif file_path[-4:].lower() in ['.pyc', '.pyo']: - file_path = file_path[:-4] + '.py' + if hasattr(sys, "frozen"): # support for py2exe + file_path = "logging%s__init__%s" % (os.sep, file_path[-4:]) + elif file_path[-4:].lower() in [".pyc", ".pyo"]: + file_path = file_path[:-4] + ".py" else: file_path = file_path @@ -193,7 +192,7 @@ def get_field_name_from_mongoengine_error(exc): """ msg = str(exc) - match = re.match("Cannot resolve field \"(.+?)\"", msg) + match = re.match('Cannot resolve field "(.+?)"', msg) if match: return match.groups()[0] @@ -201,7 +200,9 @@ def get_field_name_from_mongoengine_error(exc): return msg -def ignore_and_log_exception(exc_classes=(Exception,), logger=None, level=logging.WARNING): +def ignore_and_log_exception( + exc_classes=(Exception,), logger=None, level=logging.WARNING +): """ Decorator which catches the provided exception classes and logs them instead of letting them bubble all the way up. @@ -214,13 +215,14 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except exc_classes as e: - if len(args) >= 1 and getattr(args[0], '__class__', None): - func_name = '%s.%s' % (args[0].__class__.__name__, func.__name__) + if len(args) >= 1 and getattr(args[0], "__class__", None): + func_name = "%s.%s" % (args[0].__class__.__name__, func.__name__) else: func_name = func.__name__ - message = ('Exception in fuction "%s": %s' % (func_name, str(e))) + message = 'Exception in fuction "%s": %s' % (func_name, str(e)) logger.log(level, message) return wrapper + return decorator diff --git a/st2common/st2common/util/mongoescape.py b/st2common/st2common/util/mongoescape.py index 6d42b4972c..d75d9502f4 100644 --- a/st2common/st2common/util/mongoescape.py +++ b/st2common/st2common/util/mongoescape.py @@ -21,17 +21,22 @@ from st2common.util.ujson import fast_deepcopy # Note: Because of old rule escaping code, two different characters can be translated back to dot -RULE_CRITERIA_UNESCAPED = ['.'] -RULE_CRITERIA_ESCAPED = [u'\u2024'] -RULE_CRITERIA_ESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED))) -RULE_CRITERIA_UNESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED))) +RULE_CRITERIA_UNESCAPED = ["."] +RULE_CRITERIA_ESCAPED = ["\u2024"] +RULE_CRITERIA_ESCAPE_TRANSLATION = dict( + list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED)) +) +RULE_CRITERIA_UNESCAPE_TRANSLATION = dict( + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) +) # http://docs.mongodb.org/manual/faq/developers/#faq-dollar-sign-escaping -UNESCAPED = ['.', '$'] -ESCAPED = [u'\uFF0E', u'\uFF04'] +UNESCAPED = [".", "$"] +ESCAPED = ["\uFF0E", "\uFF04"] ESCAPE_TRANSLATION = dict(list(zip(UNESCAPED, ESCAPED))) UNESCAPE_TRANSLATION = dict( - list(zip(ESCAPED, UNESCAPED)) + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) + list(zip(ESCAPED, UNESCAPED)) + + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) ) diff --git a/st2common/st2common/util/monkey_patch.py b/st2common/st2common/util/monkey_patch.py index 5a042fd656..76b4a191de 100644 --- a/st2common/st2common/util/monkey_patch.py +++ b/st2common/st2common/util/monkey_patch.py @@ -22,13 +22,13 @@ import sys __all__ = [ - 'monkey_patch', - 'use_select_poll_workaround', - 'is_use_debugger_flag_provided' + "monkey_patch", + "use_select_poll_workaround", + "is_use_debugger_flag_provided", ] -USE_DEBUGGER_FLAG = '--use-debugger' -PARENT_ARGS_FLAG = '--parent-args=' +USE_DEBUGGER_FLAG = "--use-debugger" +PARENT_ARGS_FLAG = "--parent-args=" def monkey_patch(patch_thread=None): @@ -48,7 +48,9 @@ def monkey_patch(patch_thread=None): if patch_thread is None: patch_thread = not is_use_debugger_flag_provided() - eventlet.monkey_patch(os=True, select=True, socket=True, thread=patch_thread, time=True) + eventlet.monkey_patch( + os=True, select=True, socket=True, thread=patch_thread, time=True + ) def use_select_poll_workaround(nose_only=True): @@ -80,20 +82,20 @@ def use_select_poll_workaround(nose_only=True): import eventlet # Work around to get tests to pass with eventlet >= 0.20.0 - if not nose_only or (nose_only and 'nose' in sys.modules.keys()): + if not nose_only or (nose_only and "nose" in sys.modules.keys()): # Add back blocking poll() to eventlet monkeypatched select - original_poll = eventlet.patcher.original('select').poll + original_poll = eventlet.patcher.original("select").poll select.poll = original_poll - sys.modules['select'] = select + sys.modules["select"] = select subprocess.select = select if sys.version_info >= (3, 6, 5): # If we also don't patch selectors.select, it will fail with Python >= 3.6.5 import selectors # pylint: disable=import-error - sys.modules['selectors'] = selectors - selectors.select = sys.modules['select'] + sys.modules["selectors"] = selectors + selectors.select = sys.modules["select"] def is_use_debugger_flag_provided(): diff --git a/st2common/st2common/util/output_schema.py b/st2common/st2common/util/output_schema.py index 607f1af0bb..2bde19c3c0 100644 --- a/st2common/st2common/util/output_schema.py +++ b/st2common/st2common/util/output_schema.py @@ -26,37 +26,36 @@ def _validate_runner(runner_schema, result): - LOG.debug('Validating runner output: %s', runner_schema) + LOG.debug("Validating runner output: %s", runner_schema) runner_schema = { "type": "object", "properties": runner_schema, - "additionalProperties": False + "additionalProperties": False, } - schema.validate(result, runner_schema, cls=schema.get_validator('custom')) + schema.validate(result, runner_schema, cls=schema.get_validator("custom")) def _validate_action(action_schema, result, output_key): - LOG.debug('Validating action output: %s', action_schema) + LOG.debug("Validating action output: %s", action_schema) final_result = result[output_key] action_schema = { "type": "object", "properties": action_schema, - "additionalProperties": False + "additionalProperties": False, } - schema.validate(final_result, action_schema, cls=schema.get_validator('custom')) + schema.validate(final_result, action_schema, cls=schema.get_validator("custom")) def validate_output(runner_schema, action_schema, result, status, output_key): - """ Validate output of action with runner and action schema. - """ + """Validate output of action with runner and action schema.""" try: - LOG.debug('Validating action output: %s', result) - LOG.debug('Output Key: %s', output_key) + LOG.debug("Validating action output: %s", result) + LOG.debug("Output Key: %s", output_key) if runner_schema: _validate_runner(runner_schema, result) @@ -64,26 +63,26 @@ def validate_output(runner_schema, action_schema, result, status, output_key): _validate_action(action_schema, result, output_key) except jsonschema.ValidationError: - LOG.exception('Failed to validate output.') + LOG.exception("Failed to validate output.") _, ex, _ = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. result = { - 'error': str(ex), - 'message': 'Error validating output. See error output for more details.', + "error": str(ex), + "message": "Error validating output. See error output for more details.", } return (result, status) except: - LOG.exception('Failed to validate output.') + LOG.exception("Failed to validate output.") _, ex, tb = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. result = { - 'traceback': ''.join(traceback.format_tb(tb, 20)), - 'error': str(ex), - 'message': 'Error validating output. See error output for more details.', + "traceback": "".join(traceback.format_tb(tb, 20)), + "error": str(ex), + "message": "Error validating output. See error output for more details.", } return (result, status) diff --git a/st2common/st2common/util/pack.py b/st2common/st2common/util/pack.py index 6ac4e4fc48..43dde60051 100644 --- a/st2common/st2common/util/pack.py +++ b/st2common/st2common/util/pack.py @@ -30,27 +30,28 @@ from st2common.util import jinja as jinja_utils __all__ = [ - 'get_pack_ref_from_metadata', - 'get_pack_metadata', - 'get_pack_warnings', - - 'get_pack_common_libs_path_for_pack_ref', - 'get_pack_common_libs_path_for_pack_db', - - 'validate_config_against_schema', - - 'normalize_pack_version' + "get_pack_ref_from_metadata", + "get_pack_metadata", + "get_pack_warnings", + "get_pack_common_libs_path_for_pack_ref", + "get_pack_common_libs_path_for_pack_db", + "validate_config_against_schema", + "normalize_pack_version", ] # Common format for python 2.7 warning if six.PY2: - PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \ - "Python 2 support will be dropped in future releases. " \ - "Please consider updating your packs to work with Python 3.x" + PACK_PYTHON2_WARNING = ( + "DEPRECATION WARNING: Pack %s only supports Python 2.x. " + "Python 2 support will be dropped in future releases. " + "Please consider updating your packs to work with Python 3.x" + ) else: - PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \ - "Python 2 support has been removed since st2 v3.4.0. " \ - "Please update your packs to work with Python 3.x" + PACK_PYTHON2_WARNING = ( + "DEPRECATION WARNING: Pack %s only supports Python 2.x. " + "Python 2 support has been removed since st2 v3.4.0. " + "Please update your packs to work with Python 3.x" + ) def get_pack_ref_from_metadata(metadata, pack_directory_name=None): @@ -69,19 +70,23 @@ def get_pack_ref_from_metadata(metadata, pack_directory_name=None): # which are in sub-directories) # 2. If attribute is not available, but pack name is and pack name meets the valid name # criteria, we use that - if metadata.get('ref', None): - pack_ref = metadata['ref'] - elif pack_directory_name and re.match(PACK_REF_WHITELIST_REGEX, pack_directory_name): + if metadata.get("ref", None): + pack_ref = metadata["ref"] + elif pack_directory_name and re.match( + PACK_REF_WHITELIST_REGEX, pack_directory_name + ): pack_ref = pack_directory_name else: - if re.match(PACK_REF_WHITELIST_REGEX, metadata['name']): - pack_ref = metadata['name'] + if re.match(PACK_REF_WHITELIST_REGEX, metadata["name"]): + pack_ref = metadata["name"] else: - msg = ('Pack name "%s" contains invalid characters and "ref" attribute is not ' - 'available. You either need to add "ref" attribute which contains only word ' - 'characters to the pack metadata file or update name attribute to contain only' - 'word characters.') - raise ValueError(msg % (metadata['name'])) + msg = ( + 'Pack name "%s" contains invalid characters and "ref" attribute is not ' + 'available. You either need to add "ref" attribute which contains only word ' + "characters to the pack metadata file or update name attribute to contain only" + "word characters." + ) + raise ValueError(msg % (metadata["name"])) return pack_ref @@ -95,7 +100,9 @@ def get_pack_metadata(pack_dir): manifest_path = os.path.join(pack_dir, MANIFEST_FILE_NAME) if not os.path.isfile(manifest_path): - raise ValueError('Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME)) + raise ValueError( + 'Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME) + ) meta_loader = MetaLoader() content = meta_loader.load(manifest_path) @@ -112,15 +119,16 @@ def get_pack_warnings(pack_metadata): :rtype: ``str`` """ warning = None - versions = pack_metadata.get('python_versions', None) - pack_name = pack_metadata.get('name', None) - if versions and set(versions) == set(['2']): + versions = pack_metadata.get("python_versions", None) + pack_name = pack_metadata.get("name", None) + if versions and set(versions) == set(["2"]): warning = PACK_PYTHON2_WARNING % pack_name return warning -def validate_config_against_schema(config_schema, config_object, config_path, - pack_name=None): +def validate_config_against_schema( + config_schema, config_object, config_path, pack_name=None +): """ Validate provided config dictionary against the provided config schema dictionary. @@ -128,35 +136,49 @@ def validate_config_against_schema(config_schema, config_object, config_path, # NOTE: Lazy improt to avoid performance overhead of importing this module when it's not used import jsonschema - pack_name = pack_name or 'unknown' + pack_name = pack_name or "unknown" - schema = util_schema.get_schema_for_resource_parameters(parameters_schema=config_schema, - allow_additional_properties=True) + schema = util_schema.get_schema_for_resource_parameters( + parameters_schema=config_schema, allow_additional_properties=True + ) instance = config_object try: - cleaned = util_schema.validate(instance=instance, schema=schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=instance, + schema=schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) for key in cleaned: - if (jinja_utils.is_jinja_expression(value=cleaned.get(key)) and - "decrypt_kv" in cleaned.get(key) and config_schema.get(key).get('secret')): - raise ValueValidationException('Values specified as "secret: True" in config ' - 'schema are automatically decrypted by default. Use ' - 'of "decrypt_kv" jinja filter is not allowed for ' - 'such values. Please check the specified values in ' - 'the config or the default values in the schema.') + if ( + jinja_utils.is_jinja_expression(value=cleaned.get(key)) + and "decrypt_kv" in cleaned.get(key) + and config_schema.get(key).get("secret") + ): + raise ValueValidationException( + 'Values specified as "secret: True" in config ' + "schema are automatically decrypted by default. Use " + 'of "decrypt_kv" jinja filter is not allowed for ' + "such values. Please check the specified values in " + "the config or the default values in the schema." + ) except jsonschema.ValidationError as e: - attribute = getattr(e, 'path', []) + attribute = getattr(e, "path", []) if isinstance(attribute, (tuple, list, collections.Iterable)): attribute = [str(item) for item in attribute] - attribute = '.'.join(attribute) + attribute = ".".join(attribute) else: attribute = str(attribute) - msg = ('Failed validating attribute "%s" in config for pack "%s" (%s): %s' % - (attribute, pack_name, config_path, six.text_type(e))) + msg = 'Failed validating attribute "%s" in config for pack "%s" (%s): %s' % ( + attribute, + pack_name, + config_path, + six.text_type(e), + ) raise jsonschema.ValidationError(msg) return cleaned @@ -183,12 +205,12 @@ def get_pack_common_libs_path_for_pack_db(pack_db): :rtype: ``str`` """ - pack_dir = getattr(pack_db, 'path', None) + pack_dir = getattr(pack_db, "path", None) if not pack_dir: return None - libs_path = os.path.join(pack_dir, 'lib') + libs_path = os.path.join(pack_dir, "lib") return libs_path @@ -202,8 +224,8 @@ def normalize_pack_version(version): """ version = str(version) - version_seperator_count = version.count('.') + version_seperator_count = version.count(".") if version_seperator_count == 1: - version = version + '.0' + version = version + ".0" return version diff --git a/st2common/st2common/util/pack_management.py b/st2common/st2common/util/pack_management.py index 48b9457203..0fde5b1d86 100644 --- a/st2common/st2common/util/pack_management.py +++ b/st2common/st2common/util/pack_management.py @@ -48,29 +48,33 @@ from st2common.util.versioning import get_python_version __all__ = [ - 'download_pack', - - 'get_repo_url', - 'eval_repo_url', - - 'apply_pack_owner_group', - 'apply_pack_permissions', - - 'get_and_set_proxy_config' + "download_pack", + "get_repo_url", + "eval_repo_url", + "apply_pack_owner_group", + "apply_pack_permissions", + "get_and_set_proxy_config", ] LOG = logging.getLogger(__name__) -CONFIG_FILE = 'config.yaml' +CONFIG_FILE = "config.yaml" CURRENT_STACKSTORM_VERSION = get_stackstorm_version() CURRENT_PYTHON_VERSION = get_python_version() -SUDO_BINARY = find_executable('sudo') +SUDO_BINARY = find_executable("sudo") -def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, force=False, - proxy_config=None, force_owner_group=True, force_permissions=True, - logger=LOG): +def download_pack( + pack, + abs_repo_base="/opt/stackstorm/packs", + verify_ssl=True, + force=False, + proxy_config=None, + force_owner_group=True, + force_permissions=True, + logger=LOG, +): """ Download the pack and move it to /opt/stackstorm/packs. @@ -105,11 +109,11 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, result = [pack_url, None, None] temp_dir_name = hashlib.md5(pack_url.encode()).hexdigest() - lock_file = LockFile('/tmp/%s' % (temp_dir_name)) + lock_file = LockFile("/tmp/%s" % (temp_dir_name)) lock_file_path = lock_file.lock_file if force: - logger.debug('Force mode is enabled, deleting lock file...') + logger.debug("Force mode is enabled, deleting lock file...") try: os.unlink(lock_file_path) @@ -119,31 +123,42 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, with lock_file: try: - user_home = os.path.expanduser('~') + user_home = os.path.expanduser("~") abs_local_path = os.path.join(user_home, temp_dir_name) - if pack_url.startswith('file://'): + if pack_url.startswith("file://"): # Local pack - local_pack_directory = os.path.abspath(os.path.join(pack_url.split('file://')[1])) + local_pack_directory = os.path.abspath( + os.path.join(pack_url.split("file://")[1]) + ) else: local_pack_directory = None # If it's a local pack which is not a git repository, just copy the directory content # over if local_pack_directory and not os.path.isdir( - os.path.join(local_pack_directory, '.git')): + os.path.join(local_pack_directory, ".git") + ): if not os.path.isdir(local_pack_directory): - raise ValueError('Local pack directory "%s" doesn\'t exist' % - (local_pack_directory)) + raise ValueError( + 'Local pack directory "%s" doesn\'t exist' + % (local_pack_directory) + ) - logger.debug('Detected local pack directory which is not a git repository, just ' - 'copying files over...') + logger.debug( + "Detected local pack directory which is not a git repository, just " + "copying files over..." + ) shutil.copytree(local_pack_directory, abs_local_path) else: # 1. Clone / download the repo - clone_repo(temp_dir=abs_local_path, repo_url=pack_url, verify_ssl=verify_ssl, - ref=pack_version) + clone_repo( + temp_dir=abs_local_path, + repo_url=pack_url, + verify_ssl=verify_ssl, + ref=pack_version, + ) pack_metadata = get_pack_metadata(pack_dir=abs_local_path) pack_ref = get_pack_ref(pack_dir=abs_local_path) @@ -154,12 +169,15 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, verify_pack_version(pack_metadata=pack_metadata) # 3. Move pack to the final location - move_result = move_pack(abs_repo_base=abs_repo_base, pack_name=pack_ref, - abs_local_path=abs_local_path, - pack_metadata=pack_metadata, - force_owner_group=force_owner_group, - force_permissions=force_permissions, - logger=logger) + move_result = move_pack( + abs_repo_base=abs_repo_base, + pack_name=pack_ref, + abs_local_path=abs_local_path, + pack_metadata=pack_metadata, + force_owner_group=force_owner_group, + force_permissions=force_permissions, + logger=logger, + ) result[2] = move_result finally: cleanup_repo(abs_local_path=abs_local_path) @@ -167,21 +185,21 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, return tuple(result) -def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): +def clone_repo(temp_dir, repo_url, verify_ssl=True, ref="master"): # Switch to non-interactive mode - os.environ['GIT_TERMINAL_PROMPT'] = '0' - os.environ['GIT_ASKPASS'] = '/bin/echo' + os.environ["GIT_TERMINAL_PROMPT"] = "0" + os.environ["GIT_ASKPASS"] = "/bin/echo" # Disable SSL cert checking if explictly asked if not verify_ssl: - os.environ['GIT_SSL_NO_VERIFY'] = 'true' + os.environ["GIT_SSL_NO_VERIFY"] = "true" # Clone the repo from git; we don't use shallow copying # because we want the user to work with the repo in the # future. repo = Repo.clone_from(repo_url, temp_dir) - is_local_repo = repo_url.startswith('file://') + is_local_repo = repo_url.startswith("file://") try: active_branch = repo.active_branch @@ -194,18 +212,20 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): # Special case for local git repos - we allow users to install from repos which are checked out # at a specific commit (aka detached HEAD) if is_local_repo and not active_branch and not ref: - LOG.debug('Installing pack from git repo on disk, skipping branch checkout') + LOG.debug("Installing pack from git repo on disk, skipping branch checkout") return temp_dir use_branch = False # Special case when a default repo branch is not "master" # No ref provided so we just use a default active branch - if (not ref or ref == active_branch.name) and repo.active_branch.object == repo.head.commit: + if ( + not ref or ref == active_branch.name + ) and repo.active_branch.object == repo.head.commit: gitref = repo.active_branch.object else: # Try to match the reference to a branch name (i.e. "master") - gitref = get_gitref(repo, 'origin/%s' % ref) + gitref = get_gitref(repo, "origin/%s" % ref) if gitref: use_branch = True @@ -215,7 +235,7 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): # Try to match the reference to a "vX.Y.Z" tag if not gitref and re.match(PACK_VERSION_REGEX, ref): - gitref = get_gitref(repo, 'v%s' % ref) + gitref = get_gitref(repo, "v%s" % ref) # Giving up ¯\_(ツ)_/¯ if not gitref: @@ -224,43 +244,52 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): valid_versions = get_valid_versions_for_repo(repo=repo) if len(valid_versions) >= 1: - valid_versions_string = ', '.join(valid_versions) + valid_versions_string = ", ".join(valid_versions) - msg += ' Available versions are: %s.' + msg += " Available versions are: %s." format_values.append(valid_versions_string) raise ValueError(msg % tuple(format_values)) # We're trying to figure out which branch the ref is actually on, # since there's no direct way to check for this in git-python. - branches = repo.git.branch('-a', '--contains', gitref.hexsha) # pylint: disable=no-member + branches = repo.git.branch( + "-a", "--contains", gitref.hexsha + ) # pylint: disable=no-member # Git tags aren't necessarily on a branch. # If this is the case, gitref will be the tag name, but branches will be # empty. # We also need to checkout tags slightly differently than branches. if branches: - branches = branches.replace('*', '').split() + branches = branches.replace("*", "").split() if active_branch.name not in branches or use_branch: - branch = 'origin/%s' % ref if use_branch else branches[0] - short_branch = ref if use_branch else branches[0].split('/')[-1] - repo.git.checkout('-b', short_branch, branch) + branch = "origin/%s" % ref if use_branch else branches[0] + short_branch = ref if use_branch else branches[0].split("/")[-1] + repo.git.checkout("-b", short_branch, branch) branch = repo.head.reference else: branch = repo.active_branch.name repo.git.checkout(gitref.hexsha) # pylint: disable=no-member - repo.git.branch('-f', branch, gitref.hexsha) # pylint: disable=no-member + repo.git.branch("-f", branch, gitref.hexsha) # pylint: disable=no-member repo.git.checkout(branch) else: - repo.git.checkout('v%s' % ref) # pylint: disable=no-member + repo.git.checkout("v%s" % ref) # pylint: disable=no-member return temp_dir -def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_owner_group=True, - force_permissions=True, logger=LOG): +def move_pack( + abs_repo_base, + pack_name, + abs_local_path, + pack_metadata, + force_owner_group=True, + force_permissions=True, + logger=LOG, +): """ Move pack directory into the final location. """ @@ -270,8 +299,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own to = abs_repo_base dest_pack_path = os.path.join(abs_repo_base, pack_name) if os.path.exists(dest_pack_path): - logger.debug('Removing existing pack %s in %s to replace.', pack_name, - dest_pack_path) + logger.debug( + "Removing existing pack %s in %s to replace.", pack_name, dest_pack_path + ) # Ensure to preserve any existing configuration old_config_file = os.path.join(dest_pack_path, CONFIG_FILE) @@ -282,7 +312,7 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own shutil.rmtree(dest_pack_path) - logger.debug('Moving pack from %s to %s.', abs_local_path, to) + logger.debug("Moving pack from %s to %s.", abs_local_path, to) shutil.move(abs_local_path, dest_pack_path) # post move fix all permissions @@ -299,9 +329,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own if warning: logger.warning(warning) - message = 'Success.' + message = "Success." elif message: - message = 'Failure : %s' % message + message = "Failure : %s" % message return (desired, message) @@ -316,20 +346,25 @@ def apply_pack_owner_group(pack_path): pack_group = utils.get_pack_group() if pack_group: - LOG.debug('Changing owner group of "{}" directory to {}'.format(pack_path, pack_group)) + LOG.debug( + 'Changing owner group of "{}" directory to {}'.format(pack_path, pack_group) + ) if SUDO_BINARY: - args = ['sudo', 'chgrp', '-R', pack_group, pack_path] + args = ["sudo", "chgrp", "-R", pack_group, pack_path] else: # Environments where sudo is not available (e.g. docker) - args = ['chgrp', '-R', pack_group, pack_path] + args = ["chgrp", "-R", pack_group, pack_path] exit_code, _, stderr, _ = shell.run_command(args) if exit_code != 0: # Non fatal, but we still log it - LOG.debug('Failed to change owner group on directory "{}" to "{}": {}' - .format(pack_path, pack_group, stderr)) + LOG.debug( + 'Failed to change owner group on directory "{}" to "{}": {}'.format( + pack_path, pack_group, stderr + ) + ) return True @@ -370,13 +405,13 @@ def get_repo_url(pack, proxy_config=None): name_or_url = pack_and_version[0] version = pack_and_version[1] if len(pack_and_version) > 1 else None - if len(name_or_url.split('/')) == 1: + if len(name_or_url.split("/")) == 1: pack = get_pack_from_index(name_or_url, proxy_config=proxy_config) if not pack: raise Exception('No record of the "%s" pack in the index.' % (name_or_url)) - return (pack['repo_url'], version or pack['version']) + return (pack["repo_url"], version or pack["version"]) else: return (eval_repo_url(name_or_url), version) @@ -386,12 +421,12 @@ def eval_repo_url(repo_url): Allow passing short GitHub or GitLab SSH style URLs. """ if not repo_url: - raise Exception('No valid repo_url provided or could be inferred.') + raise Exception("No valid repo_url provided or could be inferred.") if repo_url.startswith("gitlab@") or repo_url.startswith("file://"): return repo_url else: - if len(repo_url.split('/')) == 2 and 'git@' not in repo_url: - url = 'https://github.com/{}'.format(repo_url) + if len(repo_url.split("/")) == 2 and "git@" not in repo_url: + url = "https://github.com/{}".format(repo_url) else: url = repo_url return url @@ -400,50 +435,65 @@ def eval_repo_url(repo_url): def is_desired_pack(abs_pack_path, pack_name): # path has to exist. if not os.path.exists(abs_pack_path): - return (False, 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' % - (pack_name)) + return ( + False, + 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' % (pack_name), + ) # should not include reserved characters for character in PACK_RESERVED_CHARACTERS: if character in pack_name: - return (False, 'Pack name "%s" contains reserved character "%s"' % - (pack_name, character)) + return ( + False, + 'Pack name "%s" contains reserved character "%s"' + % (pack_name, character), + ) # must contain a manifest file. Empty file is ok for now. if not os.path.isfile(os.path.join(abs_pack_path, MANIFEST_FILE_NAME)): - return (False, 'Pack is missing a manifest file (%s).' % (MANIFEST_FILE_NAME)) + return (False, "Pack is missing a manifest file (%s)." % (MANIFEST_FILE_NAME)) - return (True, '') + return (True, "") def verify_pack_version(pack_metadata): """ Verify that the pack works with the currently running StackStorm version. """ - pack_name = pack_metadata.get('name', None) - required_stackstorm_version = pack_metadata.get('stackstorm_version', None) - supported_python_versions = pack_metadata.get('python_versions', None) + pack_name = pack_metadata.get("name", None) + required_stackstorm_version = pack_metadata.get("stackstorm_version", None) + supported_python_versions = pack_metadata.get("python_versions", None) # If stackstorm_version attribute is specified, verify that the pack works with currently # running version of StackStorm if required_stackstorm_version: - if not complex_semver_match(CURRENT_STACKSTORM_VERSION, required_stackstorm_version): - msg = ('Pack "%s" requires StackStorm "%s", but current version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % - (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION)) + if not complex_semver_match( + CURRENT_STACKSTORM_VERSION, required_stackstorm_version + ): + msg = ( + 'Pack "%s" requires StackStorm "%s", but current version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION) + ) raise ValueError(msg) if supported_python_versions: - if set(supported_python_versions) == set(['2']) and (not six.PY2): - msg = ('Pack "%s" requires Python 2.x, but current Python version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION)) + if set(supported_python_versions) == set(["2"]) and (not six.PY2): + msg = ( + 'Pack "%s" requires Python 2.x, but current Python version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, CURRENT_PYTHON_VERSION) + ) raise ValueError(msg) - elif set(supported_python_versions) == set(['3']) and (not six.PY3): - msg = ('Pack "%s" requires Python 3.x, but current Python version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION)) + elif set(supported_python_versions) == set(["3"]) and (not six.PY3): + msg = ( + 'Pack "%s" requires Python 3.x, but current Python version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, CURRENT_PYTHON_VERSION) + ) raise ValueError(msg) else: # Pack support Python 2.x and 3.x so no check is needed, or @@ -474,7 +524,7 @@ def get_valid_versions_for_repo(repo): valid_versions = [] for tag in repo.tags: - if tag.name.startswith('v') and re.match(PACK_VERSION_REGEX, tag.name[1:]): + if tag.name.startswith("v") and re.match(PACK_VERSION_REGEX, tag.name[1:]): # Note: We strip leading "v" from the version number valid_versions.append(tag.name[1:]) @@ -486,39 +536,38 @@ def get_pack_ref(pack_dir): Read pack reference from the metadata file and sanitize it. """ metadata = get_pack_metadata(pack_dir=pack_dir) - pack_ref = get_pack_ref_from_metadata(metadata=metadata, - pack_directory_name=None) + pack_ref = get_pack_ref_from_metadata(metadata=metadata, pack_directory_name=None) return pack_ref def get_and_set_proxy_config(): - https_proxy = os.environ.get('https_proxy', None) - http_proxy = os.environ.get('http_proxy', None) - proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None) - no_proxy = os.environ.get('no_proxy', None) + https_proxy = os.environ.get("https_proxy", None) + http_proxy = os.environ.get("http_proxy", None) + proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None) + no_proxy = os.environ.get("no_proxy", None) proxy_config = {} if http_proxy or https_proxy: - LOG.debug('Using proxy %s', http_proxy if http_proxy else https_proxy) + LOG.debug("Using proxy %s", http_proxy if http_proxy else https_proxy) proxy_config = { - 'https_proxy': https_proxy, - 'http_proxy': http_proxy, - 'proxy_ca_bundle_path': proxy_ca_bundle_path, - 'no_proxy': no_proxy + "https_proxy": https_proxy, + "http_proxy": http_proxy, + "proxy_ca_bundle_path": proxy_ca_bundle_path, + "no_proxy": no_proxy, } - if https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = https_proxy + if https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = https_proxy - if http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = http_proxy + if http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = http_proxy - if no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = no_proxy + if no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = no_proxy - if proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = no_proxy + if proxy_ca_bundle_path and not os.environ.get("proxy_ca_bundle_path", None): + os.environ["no_proxy"] = no_proxy return proxy_config diff --git a/st2common/st2common/util/param.py b/st2common/st2common/util/param.py index 93507fcd87..270c90e424 100644 --- a/st2common/st2common/util/param.py +++ b/st2common/st2common/util/param.py @@ -26,7 +26,11 @@ from st2common.util.jinja import is_jinja_expression from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX from st2common.constants.pack import PACK_CONFIG_CONTEXT_KV_PREFIX -from st2common.constants.keyvalue import DATASTORE_PARENT_SCOPE, SYSTEM_SCOPE, FULL_SYSTEM_SCOPE +from st2common.constants.keyvalue import ( + DATASTORE_PARENT_SCOPE, + SYSTEM_SCOPE, + FULL_SYSTEM_SCOPE, +) from st2common.constants.keyvalue import USER_SCOPE, FULL_USER_SCOPE from st2common.exceptions.param import ParamException from st2common.services.keyvalues import KeyValueLookup, UserKeyValueLookup @@ -39,23 +43,27 @@ ENV = jinja_utils.get_jinja_environment() __all__ = [ - 'render_live_params', - 'render_final_params', + "render_live_params", + "render_final_params", ] def _split_params(runner_parameters, action_parameters, mixed_params): def pf(params, skips): - result = {k: v for k, v in six.iteritems(mixed_params) - if k in params and k not in skips} + result = { + k: v + for k, v in six.iteritems(mixed_params) + if k in params and k not in skips + } return result + return (pf(runner_parameters, {}), pf(action_parameters, runner_parameters)) def _cast_params(rendered, parameter_schemas): - ''' + """ It's just here to make tests happy - ''' + """ casted_params = {} for k, v in six.iteritems(rendered): casted_params[k] = _cast(v, parameter_schemas[k] or {}) @@ -66,7 +74,7 @@ def _cast(v, parameter_schema): if v is None or not parameter_schema: return v - parameter_type = parameter_schema.get('type', None) + parameter_type = parameter_schema.get("type", None) if not parameter_type: return v @@ -78,23 +86,27 @@ def _cast(v, parameter_schema): def _create_graph(action_context, config): - ''' + """ Creates a generic directed graph for depencency tree and fills it with basic context variables - ''' + """ G = nx.DiGraph() system_keyvalue_context = {SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)} # If both 'user' and 'api_user' are specified, this prioritize 'api_user' - user = action_context['user'] if 'user' in action_context else None - user = action_context['api_user'] if 'api_user' in action_context else user + user = action_context["user"] if "user" in action_context else None + user = action_context["api_user"] if "api_user" in action_context else user if not user: # When no user is not specified, this selects system-user's scope by default. user = cfg.CONF.system_user.user - LOG.info('Unable to retrieve user / api_user value from action_context. Falling back ' - 'to and using system_user (%s).' % (user)) + LOG.info( + "Unable to retrieve user / api_user value from action_context. Falling back " + "to and using system_user (%s)." % (user) + ) - system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup(scope=FULL_USER_SCOPE, user=user) + system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup( + scope=FULL_USER_SCOPE, user=user + ) G.add_node(DATASTORE_PARENT_SCOPE, value=system_keyvalue_context) G.add_node(ACTION_CONTEXT_KV_PREFIX, value=action_context) G.add_node(PACK_CONFIG_CONTEXT_KV_PREFIX, value=config) @@ -102,9 +114,9 @@ def _create_graph(action_context, config): def _process(G, name, value): - ''' + """ Determines whether parameter is a template or a value. Adds graph nodes and edges accordingly. - ''' + """ # Jinja defaults to ascii parser in python 2.x unless you set utf-8 support on per module level # Instead we're just assuming every string to be a unicode string if isinstance(value, str): @@ -114,23 +126,21 @@ def _process(G, name, value): if isinstance(value, list) or isinstance(value, dict): complex_value_str = str(value) - is_jinja_expr = ( - jinja_utils.is_jinja_expression(value) or jinja_utils.is_jinja_expression( - complex_value_str - ) - ) + is_jinja_expr = jinja_utils.is_jinja_expression( + value + ) or jinja_utils.is_jinja_expression(complex_value_str) if is_jinja_expr: G.add_node(name, template=value) template_ast = ENV.parse(value) - LOG.debug('Template ast: %s', template_ast) + LOG.debug("Template ast: %s", template_ast) # Dependencies of the node represent jinja variables used in the template # We're connecting nodes with an edge for every depencency to traverse them # in the right order and also make sure that we don't have missing or cyclic # dependencies upfront. dependencies = meta.find_undeclared_variables(template_ast) - LOG.debug('Dependencies: %s', dependencies) + LOG.debug("Dependencies: %s", dependencies) if dependencies: for dependency in dependencies: G.add_edge(dependency, name) @@ -139,24 +149,24 @@ def _process(G, name, value): def _process_defaults(G, schemas): - ''' + """ Process dependencies for parameters default values in the order schemas are defined. - ''' + """ for schema in schemas: for name, value in six.iteritems(schema): absent = name not in G.node - is_none = G.node.get(name, {}).get('value') is None - immutable = value.get('immutable', False) + is_none = G.node.get(name, {}).get("value") is None + immutable = value.get("immutable", False) if absent or is_none or immutable: - _process(G, name, value.get('default')) + _process(G, name, value.get("default")) def _validate(G): - ''' + """ Validates dependency graph to ensure it has no missing or cyclic dependencies - ''' + """ for name in G.nodes(): - if 'value' not in G.node[name] and 'template' not in G.node[name]: + if "value" not in G.node[name] and "template" not in G.node[name]: msg = 'Dependency unsatisfied in variable "%s"' % name raise ParamException(msg) @@ -172,51 +182,52 @@ def _validate(G): variable_names.append(variable_name) - variable_names = ', '.join(sorted(variable_names)) - msg = ('Cyclic dependency found in the following variables: %s. Likely the variable is ' - 'referencing itself' % (variable_names)) + variable_names = ", ".join(sorted(variable_names)) + msg = ( + "Cyclic dependency found in the following variables: %s. Likely the variable is " + "referencing itself" % (variable_names) + ) raise ParamException(msg) def _render(node, render_context): - ''' + """ Render the node depending on its type - ''' - if 'template' in node: + """ + if "template" in node: complex_type = False - if isinstance(node['template'], list) or isinstance(node['template'], dict): - node['template'] = json.dumps(node['template']) + if isinstance(node["template"], list) or isinstance(node["template"], dict): + node["template"] = json.dumps(node["template"]) # Finds occurrences of "{{variable}}" and adds `to_complex` filter # so types are honored. If it doesn't follow that syntax then it's # rendered as a string. - node['template'] = re.sub( - r'"{{([A-z0-9_-]+)}}"', r'{{\1 | to_complex}}', - node['template'] + node["template"] = re.sub( + r'"{{([A-z0-9_-]+)}}"', r"{{\1 | to_complex}}", node["template"] ) - LOG.debug('Rendering complex type: %s', node['template']) + LOG.debug("Rendering complex type: %s", node["template"]) complex_type = True - LOG.debug('Rendering node: %s with context: %s', node, render_context) + LOG.debug("Rendering node: %s with context: %s", node, render_context) - result = ENV.from_string(str(node['template'])).render(render_context) + result = ENV.from_string(str(node["template"])).render(render_context) - LOG.debug('Render complete: %s', result) + LOG.debug("Render complete: %s", result) if complex_type: result = json.loads(result) - LOG.debug('Complex Type Rendered: %s', result) + LOG.debug("Complex Type Rendered: %s", result) return result - if 'value' in node: - return node['value'] + if "value" in node: + return node["value"] def _resolve_dependencies(G): - ''' + """ Traverse the dependency graph starting from resolved nodes - ''' + """ context = {} for name in nx.topological_sort(G): node = G.node[name] @@ -224,7 +235,7 @@ def _resolve_dependencies(G): context[name] = _render(node, context) except Exception as e: - LOG.debug('Failed to render %s: %s', name, e, exc_info=True) + LOG.debug("Failed to render %s: %s", name, e, exc_info=True) msg = 'Failed to render parameter "%s": %s' % (name, six.text_type(e)) raise ParamException(msg) @@ -232,9 +243,9 @@ def _resolve_dependencies(G): def _cast_params_from(params, context, schemas): - ''' + """ Pick a list of parameters from context and cast each of them according to the schemas provided - ''' + """ result = {} # First, cast only explicitly provided live parameters @@ -258,17 +269,19 @@ def _cast_params_from(params, context, schemas): for param_name, param_details in schema.items(): # Skip if the parameter have immutable set to true in schema - if param_details.get('immutable'): + if param_details.get("immutable"): continue # Skip if the parameter doesn't have a default, or if the # value in the context is identical to the default - if 'default' not in param_details or \ - param_details.get('default') == context[param_name]: + if ( + "default" not in param_details + or param_details.get("default") == context[param_name] + ): continue # Skip if the default value isn't a Jinja expression - if not is_jinja_expression(param_details.get('default')): + if not is_jinja_expression(param_details.get("default")): continue # Skip if the parameter is being overridden @@ -280,22 +293,29 @@ def _cast_params_from(params, context, schemas): return result -def render_live_params(runner_parameters, action_parameters, params, action_context, - additional_contexts=None): - ''' +def render_live_params( + runner_parameters, + action_parameters, + params, + action_context, + additional_contexts=None, +): + """ Renders list of parameters. Ensures that there's no cyclic or missing dependencies. Returns a dict of plain rendered parameters. - ''' + """ additional_contexts = additional_contexts or {} - pack = action_context.get('pack') - user = action_context.get('user') + pack = action_context.get("pack") + user = action_context.get("user") try: config = get_config(pack, user) except Exception as e: - LOG.info('Failed to retrieve config for pack %s and user %s: %s' % (pack, user, - six.text_type(e))) + LOG.info( + "Failed to retrieve config for pack %s and user %s: %s" + % (pack, user, six.text_type(e)) + ) config = {} G = _create_graph(action_context, config) @@ -310,18 +330,20 @@ def render_live_params(runner_parameters, action_parameters, params, action_cont _validate(G) context = _resolve_dependencies(G) - live_params = _cast_params_from(params, context, [action_parameters, runner_parameters]) + live_params = _cast_params_from( + params, context, [action_parameters, runner_parameters] + ) return live_params def render_final_params(runner_parameters, action_parameters, params, action_context): - ''' + """ Renders missing parameters required for action to execute. Treats parameters from the dict as plain values instead of trying to render them again. Returns dicts for action and runner parameters. - ''' - config = get_config(action_context.get('pack'), action_context.get('user')) + """ + config = get_config(action_context.get("pack"), action_context.get("user")) G = _create_graph(action_context, config) @@ -331,18 +353,29 @@ def render_final_params(runner_parameters, action_parameters, params, action_con _validate(G) context = _resolve_dependencies(G) - context = _cast_params_from(context, context, [action_parameters, runner_parameters]) + context = _cast_params_from( + context, context, [action_parameters, runner_parameters] + ) return _split_params(runner_parameters, action_parameters, context) -def get_finalized_params(runnertype_parameter_info, action_parameter_info, liveaction_parameters, - action_context): - ''' +def get_finalized_params( + runnertype_parameter_info, + action_parameter_info, + liveaction_parameters, + action_context, +): + """ Left here to keep tests running. Later we would need to split tests so they start testing each function separately. - ''' - params = render_live_params(runnertype_parameter_info, action_parameter_info, - liveaction_parameters, action_context) - return render_final_params(runnertype_parameter_info, action_parameter_info, params, - action_context) + """ + params = render_live_params( + runnertype_parameter_info, + action_parameter_info, + liveaction_parameters, + action_context, + ) + return render_final_params( + runnertype_parameter_info, action_parameter_info, params, action_context + ) diff --git a/st2common/st2common/util/payload.py b/st2common/st2common/util/payload.py index 92b36d55c0..b2dc2a74af 100644 --- a/st2common/st2common/util/payload.py +++ b/st2common/st2common/util/payload.py @@ -22,11 +22,8 @@ class PayloadLookup(object): - def __init__(self, payload, prefix=TRIGGER_PAYLOAD_PREFIX): - self.context = { - prefix: payload - } + self.context = {prefix: payload} for system_scope in SYSTEM_SCOPES: self.context[system_scope] = KeyValueLookup(scope=system_scope) diff --git a/st2common/st2common/util/queues.py b/st2common/st2common/util/queues.py index 526692155f..9fce3b20a7 100644 --- a/st2common/st2common/util/queues.py +++ b/st2common/st2common/util/queues.py @@ -36,7 +36,7 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix :rtype: ``str`` """ if not queue_name_base: - raise ValueError('Queue name base cannot be empty.') + raise ValueError("Queue name base cannot be empty.") if not queue_name_suffix: return queue_name_base @@ -46,8 +46,8 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix # Pick last 10 digits of uuid. Arbitrary but unique enough. Long queue names # might cause issues in RabbitMQ. u_hex = uuid.uuid4().hex - uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10:] - queue_suffix = '%s-%s' % (queue_name_suffix, uuid_suffix) + uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10 :] + queue_suffix = "%s-%s" % (queue_name_suffix, uuid_suffix) - queue_name = '%s.%s' % (queue_name_base, queue_suffix) + queue_name = "%s.%s" % (queue_name_base, queue_suffix) return queue_name diff --git a/st2common/st2common/util/reference.py b/st2common/st2common/util/reference.py index 3262eb603f..137a014d73 100644 --- a/st2common/st2common/util/reference.py +++ b/st2common/st2common/util/reference.py @@ -20,24 +20,25 @@ def get_ref_from_model(model): if model is None: - raise ValueError('Model has None value.') - model_id = getattr(model, 'id', None) + raise ValueError("Model has None value.") + model_id = getattr(model, "id", None) if model_id is None: - raise db.StackStormDBObjectMalformedError('model %s must contain id.' % str(model)) - reference = {'id': str(model_id), - 'name': getattr(model, 'name', None)} + raise db.StackStormDBObjectMalformedError( + "model %s must contain id." % str(model) + ) + reference = {"id": str(model_id), "name": getattr(model, "name", None)} return reference def get_model_from_ref(db_api, reference): if reference is None: - raise db.StackStormDBObjectNotFoundError('No reference supplied.') - model_id = reference.get('id', None) + raise db.StackStormDBObjectNotFoundError("No reference supplied.") + model_id = reference.get("id", None) if model_id is not None: return db_api.get_by_id(model_id) - model_name = reference.get('name', None) + model_name = reference.get("name", None) if model_name is None: - raise db.StackStormDBObjectNotFoundError('Both name and id are None.') + raise db.StackStormDBObjectNotFoundError("Both name and id are None.") return db_api.get_by_name(model_name) @@ -71,8 +72,10 @@ def get_resource_ref_from_model(model): name = model.name pack = model.pack except AttributeError: - raise Exception('Cannot build ResourceReference for model: %s. Name or pack missing.' - % model) + raise Exception( + "Cannot build ResourceReference for model: %s. Name or pack missing." + % model + ) return ResourceReference(name=name, pack=pack) diff --git a/st2common/st2common/util/sandboxing.py b/st2common/st2common/util/sandboxing.py index 02871f7472..9801f7d112 100644 --- a/st2common/st2common/util/sandboxing.py +++ b/st2common/st2common/util/sandboxing.py @@ -31,11 +31,11 @@ from st2common.content.utils import get_pack_base_path __all__ = [ - 'get_sandbox_python_binary_path', - 'get_sandbox_python_path', - 'get_sandbox_python_path_for_python_action', - 'get_sandbox_path', - 'get_sandbox_virtualenv_path', + "get_sandbox_python_binary_path", + "get_sandbox_python_path", + "get_sandbox_python_path_for_python_action", + "get_sandbox_path", + "get_sandbox_virtualenv_path", ] @@ -47,13 +47,13 @@ def get_sandbox_python_binary_path(pack=None): :type pack: ``str`` """ system_base_path = cfg.CONF.system.base_path - virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack) + virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack) if pack in SYSTEM_PACK_NAMES: # Use system python for "packs" and "core" actions python_path = sys.executable else: - python_path = os.path.join(virtualenv_path, 'bin/python') + python_path = os.path.join(virtualenv_path, "bin/python") return python_path @@ -70,19 +70,19 @@ def get_sandbox_path(virtualenv_path): """ sandbox_path = [] - parent_path = os.environ.get('PATH', '') + parent_path = os.environ.get("PATH", "") if not virtualenv_path: return parent_path - parent_path = parent_path.split(':') + parent_path = parent_path.split(":") parent_path = [path for path in parent_path if path] # Add virtualenv bin directory - virtualenv_bin_path = os.path.join(virtualenv_path, 'bin/') + virtualenv_bin_path = os.path.join(virtualenv_path, "bin/") sandbox_path.append(virtualenv_bin_path) sandbox_path.extend(parent_path) - sandbox_path = ':'.join(sandbox_path) + sandbox_path = ":".join(sandbox_path) return sandbox_path @@ -104,9 +104,9 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv= :type inherit_parent_virtualenv: ``str`` """ sandbox_python_path = [] - parent_python_path = os.environ.get('PYTHONPATH', '') + parent_python_path = os.environ.get("PYTHONPATH", "") - parent_python_path = parent_python_path.split(':') + parent_python_path = parent_python_path.split(":") parent_python_path = [path for path in parent_python_path if path] if inherit_from_parent: @@ -121,13 +121,14 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv= sandbox_python_path.append(site_packages_dir) - sandbox_python_path = ':'.join(sandbox_python_path) - sandbox_python_path = ':' + sandbox_python_path + sandbox_python_path = ":".join(sandbox_python_path) + sandbox_python_path = ":" + sandbox_python_path return sandbox_python_path -def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, - inherit_parent_virtualenv=True): +def get_sandbox_python_path_for_python_action( + pack, inherit_from_parent=True, inherit_parent_virtualenv=True +): """ Return sandbox PYTHONPATH for a particular Python runner action. @@ -136,30 +137,36 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, """ sandbox_python_path = get_sandbox_python_path( inherit_from_parent=inherit_from_parent, - inherit_parent_virtualenv=inherit_parent_virtualenv) + inherit_parent_virtualenv=inherit_parent_virtualenv, + ) pack_base_path = get_pack_base_path(pack_name=pack) virtualenv_path = get_sandbox_virtualenv_path(pack=pack) if virtualenv_path and os.path.isdir(virtualenv_path): - pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib') + pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib") virtualenv_directories = os.listdir(pack_virtualenv_lib_path) - virtualenv_directories = [dir_name for dir_name in virtualenv_directories if - fnmatch.fnmatch(dir_name, 'python*')] + virtualenv_directories = [ + dir_name + for dir_name in virtualenv_directories + if fnmatch.fnmatch(dir_name, "python*") + ] # Add the pack's lib directory (lib/python3.x) in front of the PYTHONPATH. - pack_actions_lib_paths = os.path.join(pack_base_path, 'actions', 'lib') - pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib') - pack_venv_lib_directory = os.path.join(pack_virtualenv_lib_path, virtualenv_directories[0]) + pack_actions_lib_paths = os.path.join(pack_base_path, "actions", "lib") + pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib") + pack_venv_lib_directory = os.path.join( + pack_virtualenv_lib_path, virtualenv_directories[0] + ) # Add the pack's site-packages directory (lib/python3.x/site-packages) # in front of the Python system site-packages This is important because # we want Python 3 compatible libraries to be used from the pack virtual # environment and not system ones. - pack_venv_site_packages_directory = os.path.join(pack_virtualenv_lib_path, - virtualenv_directories[0], - 'site-packages') + pack_venv_site_packages_directory = os.path.join( + pack_virtualenv_lib_path, virtualenv_directories[0], "site-packages" + ) full_sandbox_python_path = [ # NOTE: Order here is very important for imports to function correctly @@ -169,7 +176,7 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, sandbox_python_path, ] - sandbox_python_path = ':'.join(full_sandbox_python_path) + sandbox_python_path = ":".join(full_sandbox_python_path) return sandbox_python_path @@ -183,7 +190,7 @@ def get_sandbox_virtualenv_path(pack): virtualenv_path = None else: system_base_path = cfg.CONF.system.base_path - virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack) + virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack) return virtualenv_path @@ -195,8 +202,9 @@ def is_in_virtualenv(): """ # sys.real_prefix is for virtualenv # sys.base_prefix != sys.prefix is for venv - return (hasattr(sys, 'real_prefix') or - (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix)) + return hasattr(sys, "real_prefix") or ( + hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix + ) def get_virtualenv_prefix(): @@ -205,10 +213,10 @@ def get_virtualenv_prefix(): where we retrieved the virtualenv prefix from. The second element is the virtualenv prefix. """ - if hasattr(sys, 'real_prefix'): - return ('sys.real_prefix', sys.real_prefix) - elif hasattr(sys, 'base_prefix'): - return ('sys.base_prefix', sys.base_prefix) + if hasattr(sys, "real_prefix"): + return ("sys.real_prefix", sys.real_prefix) + elif hasattr(sys, "base_prefix"): + return ("sys.base_prefix", sys.base_prefix) return (None, None) @@ -216,9 +224,9 @@ def set_virtualenv_prefix(prefix_tuple): """ :return: Sets the virtualenv prefix given a tuple returned from get_virtualenv_prefix() """ - if prefix_tuple[0] == 'sys.real_prefix' and hasattr(sys, 'real_prefix'): + if prefix_tuple[0] == "sys.real_prefix" and hasattr(sys, "real_prefix"): sys.real_prefix = prefix_tuple[1] - elif prefix_tuple[0] == 'sys.base_prefix' and hasattr(sys, 'base_prefix'): + elif prefix_tuple[0] == "sys.base_prefix" and hasattr(sys, "base_prefix"): sys.base_prefix = prefix_tuple[1] @@ -226,7 +234,7 @@ def clear_virtualenv_prefix(): """ :return: Unsets / removes / resets the virtualenv prefix """ - if hasattr(sys, 'real_prefix'): + if hasattr(sys, "real_prefix"): del sys.real_prefix - if hasattr(sys, 'base_prefix'): + if hasattr(sys, "base_prefix"): sys.base_prefix = sys.prefix diff --git a/st2common/st2common/util/schema/__init__.py b/st2common/st2common/util/schema/__init__.py index a49f733a5a..8e18509cd1 100644 --- a/st2common/st2common/util/schema/__init__.py +++ b/st2common/st2common/util/schema/__init__.py @@ -27,19 +27,19 @@ from st2common.util.misc import deep_update __all__ = [ - 'get_validator', - 'get_draft_schema', - 'get_action_parameters_schema', - 'get_schema_for_action_parameters', - 'get_schema_for_resource_parameters', - 'is_property_type_single', - 'is_property_type_list', - 'is_property_type_anyof', - 'is_property_type_oneof', - 'is_property_nullable', - 'is_attribute_type_array', - 'is_attribute_type_object', - 'validate' + "get_validator", + "get_draft_schema", + "get_action_parameters_schema", + "get_schema_for_action_parameters", + "get_schema_for_resource_parameters", + "is_property_type_single", + "is_property_type_list", + "is_property_type_anyof", + "is_property_type_oneof", + "is_property_nullable", + "is_attribute_type_array", + "is_attribute_type_object", + "validate", ] # https://github.com/json-schema/json-schema/blob/master/draft-04/schema @@ -49,12 +49,13 @@ # and draft 3 version of required. PATH = os.path.join(os.path.dirname(os.path.realpath(__file__))) SCHEMAS = { - 'draft4': jsonify.load_file(os.path.join(PATH, 'draft4.json')), - 'custom': jsonify.load_file(os.path.join(PATH, 'custom.json')), - + "draft4": jsonify.load_file(os.path.join(PATH, "draft4.json")), + "custom": jsonify.load_file(os.path.join(PATH, "custom.json")), # Custom schema for action params which doesn't allow parameter "type" attribute to be array - 'action_params': jsonify.load_file(os.path.join(PATH, 'action_params.json')), - 'action_output_schema': jsonify.load_file(os.path.join(PATH, 'action_output_schema.json')) + "action_params": jsonify.load_file(os.path.join(PATH, "action_params.json")), + "action_output_schema": jsonify.load_file( + os.path.join(PATH, "action_output_schema.json") + ), } SCHEMA_ANY_TYPE = { @@ -64,23 +65,23 @@ {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } RUNNER_PARAM_OVERRIDABLE_ATTRS = [ - 'default', - 'description', - 'enum', - 'immutable', - 'required' + "default", + "description", + "enum", + "immutable", + "required", ] -def get_draft_schema(version='custom', additional_properties=False): +def get_draft_schema(version="custom", additional_properties=False): schema = copy.deepcopy(SCHEMAS[version]) - if additional_properties and 'additionalProperties' in schema: - del schema['additionalProperties'] + if additional_properties and "additionalProperties" in schema: + del schema["additionalProperties"] return schema @@ -89,8 +90,7 @@ def get_action_output_schema(additional_properties=True): Return a generic schema which is used for validating action output. """ return get_draft_schema( - version='action_output_schema', - additional_properties=additional_properties + version="action_output_schema", additional_properties=additional_properties ) @@ -98,81 +98,100 @@ def get_action_parameters_schema(additional_properties=False): """ Return a generic schema which is used for validating action parameters definition. """ - return get_draft_schema(version='action_params', additional_properties=additional_properties) + return get_draft_schema( + version="action_params", additional_properties=additional_properties + ) CustomValidator = create( - meta_schema=get_draft_schema(version='custom', additional_properties=True), + meta_schema=get_draft_schema(version="custom", additional_properties=True), validators={ - u"$ref": _validators.ref, - u"additionalItems": _validators.additionalItems, - u"additionalProperties": _validators.additionalProperties, - u"allOf": _validators.allOf_draft4, - u"anyOf": _validators.anyOf_draft4, - u"dependencies": _validators.dependencies, - u"enum": _validators.enum, - u"format": _validators.format, - u"items": _validators.items, - u"maxItems": _validators.maxItems, - u"maxLength": _validators.maxLength, - u"maxProperties": _validators.maxProperties_draft4, - u"maximum": _validators.maximum, - u"minItems": _validators.minItems, - u"minLength": _validators.minLength, - u"minProperties": _validators.minProperties_draft4, - u"minimum": _validators.minimum, - u"multipleOf": _validators.multipleOf, - u"not": _validators.not_draft4, - u"oneOf": _validators.oneOf_draft4, - u"pattern": _validators.pattern, - u"patternProperties": _validators.patternProperties, - u"properties": _validators.properties_draft3, - u"type": _validators.type_draft4, - u"uniqueItems": _validators.uniqueItems, + "$ref": _validators.ref, + "additionalItems": _validators.additionalItems, + "additionalProperties": _validators.additionalProperties, + "allOf": _validators.allOf_draft4, + "anyOf": _validators.anyOf_draft4, + "dependencies": _validators.dependencies, + "enum": _validators.enum, + "format": _validators.format, + "items": _validators.items, + "maxItems": _validators.maxItems, + "maxLength": _validators.maxLength, + "maxProperties": _validators.maxProperties_draft4, + "maximum": _validators.maximum, + "minItems": _validators.minItems, + "minLength": _validators.minLength, + "minProperties": _validators.minProperties_draft4, + "minimum": _validators.minimum, + "multipleOf": _validators.multipleOf, + "not": _validators.not_draft4, + "oneOf": _validators.oneOf_draft4, + "pattern": _validators.pattern, + "patternProperties": _validators.patternProperties, + "properties": _validators.properties_draft3, + "type": _validators.type_draft4, + "uniqueItems": _validators.uniqueItems, }, version="custom_validator", ) def is_property_type_single(property_schema): - return (isinstance(property_schema, dict) and - 'anyOf' not in list(property_schema.keys()) and - 'oneOf' not in list(property_schema.keys()) and - not isinstance(property_schema.get('type', 'string'), list)) + return ( + isinstance(property_schema, dict) + and "anyOf" not in list(property_schema.keys()) + and "oneOf" not in list(property_schema.keys()) + and not isinstance(property_schema.get("type", "string"), list) + ) def is_property_type_list(property_schema): - return (isinstance(property_schema, dict) and - isinstance(property_schema.get('type', 'string'), list)) + return isinstance(property_schema, dict) and isinstance( + property_schema.get("type", "string"), list + ) def is_property_type_anyof(property_schema): - return isinstance(property_schema, dict) and 'anyOf' in list(property_schema.keys()) + return isinstance(property_schema, dict) and "anyOf" in list(property_schema.keys()) def is_property_type_oneof(property_schema): - return isinstance(property_schema, dict) and 'oneOf' in list(property_schema.keys()) + return isinstance(property_schema, dict) and "oneOf" in list(property_schema.keys()) def is_property_nullable(property_type_schema): # For anyOf and oneOf, the property_schema is a list of types. if isinstance(property_type_schema, list): - return len([t for t in property_type_schema - if ((isinstance(t, six.string_types) and t == 'null') or - (isinstance(t, dict) and t.get('type', 'string') == 'null'))]) > 0 - - return (isinstance(property_type_schema, dict) and - property_type_schema.get('type', 'string') == 'null') + return ( + len( + [ + t + for t in property_type_schema + if ( + (isinstance(t, six.string_types) and t == "null") + or (isinstance(t, dict) and t.get("type", "string") == "null") + ) + ] + ) + > 0 + ) + + return ( + isinstance(property_type_schema, dict) + and property_type_schema.get("type", "string") == "null" + ) def is_attribute_type_array(attribute_type): - return (attribute_type == 'array' or - (isinstance(attribute_type, list) and 'array' in attribute_type)) + return attribute_type == "array" or ( + isinstance(attribute_type, list) and "array" in attribute_type + ) def is_attribute_type_object(attribute_type): - return (attribute_type == 'object' or - (isinstance(attribute_type, list) and 'object' in attribute_type)) + return attribute_type == "object" or ( + isinstance(attribute_type, list) and "object" in attribute_type + ) def assign_default_values(instance, schema): @@ -186,11 +205,11 @@ def assign_default_values(instance, schema): if not instance_is_dict and not instance_is_array: return instance - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - has_default_value = 'default' in property_data - default_value = property_data.get('default', None) + has_default_value = "default" in property_data + default_value = property_data.get("default", None) # Assign default value on the instance so the validation doesn't fail if requires is true # but the value is not provided @@ -203,29 +222,36 @@ def assign_default_values(instance, schema): instance[index][property_name] = default_value # Support for nested properties (array and object) - attribute_type = property_data.get('type', None) - schema_items = property_data.get('items', {}) + attribute_type = property_data.get("type", None) + schema_items = property_data.get("items", {}) # Array - if (is_attribute_type_array(attribute_type) and - schema_items and schema_items.get('properties', {})): + if ( + is_attribute_type_array(attribute_type) + and schema_items + and schema_items.get("properties", {}) + ): array_instance = instance.get(property_name, None) - array_schema = schema['properties'][property_name]['items'] + array_schema = schema["properties"][property_name]["items"] if array_instance is not None: # Note: We don't perform subschema assignment if no value is provided - instance[property_name] = assign_default_values(instance=array_instance, - schema=array_schema) + instance[property_name] = assign_default_values( + instance=array_instance, schema=array_schema + ) # Object - if is_attribute_type_object(attribute_type) and property_data.get('properties', {}): + if is_attribute_type_object(attribute_type) and property_data.get( + "properties", {} + ): object_instance = instance.get(property_name, None) - object_schema = schema['properties'][property_name] + object_schema = schema["properties"][property_name] if object_instance is not None: # Note: We don't perform subschema assignment if no value is provided - instance[property_name] = assign_default_values(instance=object_instance, - schema=object_schema) + instance[property_name] = assign_default_values( + instance=object_instance, schema=object_schema + ) return instance @@ -236,51 +262,70 @@ def modify_schema_allow_default_none(schema): defines a default value of None. """ schema = copy.deepcopy(schema) - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - is_optional = not property_data.get('required', False) - has_default_value = 'default' in property_data - default_value = property_data.get('default', None) - property_schema = schema['properties'][property_name] + is_optional = not property_data.get("required", False) + has_default_value = "default" in property_data + default_value = property_data.get("default", None) + property_schema = schema["properties"][property_name] if (has_default_value or is_optional) and default_value is None: # If property is anyOf and oneOf then it has to be process differently. - if (is_property_type_anyof(property_schema) and - not is_property_nullable(property_schema['anyOf'])): - property_schema['anyOf'].append({'type': 'null'}) - elif (is_property_type_oneof(property_schema) and - not is_property_nullable(property_schema['oneOf'])): - property_schema['oneOf'].append({'type': 'null'}) - elif (is_property_type_list(property_schema) and - not is_property_nullable(property_schema.get('type'))): - property_schema['type'].append('null') - elif (is_property_type_single(property_schema) and - not is_property_nullable(property_schema.get('type'))): - property_schema['type'] = [property_schema.get('type', 'string'), 'null'] + if is_property_type_anyof(property_schema) and not is_property_nullable( + property_schema["anyOf"] + ): + property_schema["anyOf"].append({"type": "null"}) + elif is_property_type_oneof(property_schema) and not is_property_nullable( + property_schema["oneOf"] + ): + property_schema["oneOf"].append({"type": "null"}) + elif is_property_type_list(property_schema) and not is_property_nullable( + property_schema.get("type") + ): + property_schema["type"].append("null") + elif is_property_type_single(property_schema) and not is_property_nullable( + property_schema.get("type") + ): + property_schema["type"] = [ + property_schema.get("type", "string"), + "null", + ] # Support for nested properties (array and object) - attribute_type = property_data.get('type', None) - schema_items = property_data.get('items', {}) + attribute_type = property_data.get("type", None) + schema_items = property_data.get("items", {}) # Array - if (is_attribute_type_array(attribute_type) and - schema_items and schema_items.get('properties', {})): + if ( + is_attribute_type_array(attribute_type) + and schema_items + and schema_items.get("properties", {}) + ): array_schema = schema_items array_schema = modify_schema_allow_default_none(schema=array_schema) - schema['properties'][property_name]['items'] = array_schema + schema["properties"][property_name]["items"] = array_schema # Object - if is_attribute_type_object(attribute_type) and property_data.get('properties', {}): + if is_attribute_type_object(attribute_type) and property_data.get( + "properties", {} + ): object_schema = property_data object_schema = modify_schema_allow_default_none(schema=object_schema) - schema['properties'][property_name] = object_schema + schema["properties"][property_name] = object_schema return schema -def validate(instance, schema, cls=None, use_default=True, allow_default_none=False, *args, - **kwargs): +def validate( + instance, + schema, + cls=None, + use_default=True, + allow_default_none=False, + *args, + **kwargs, +): """ Custom validate function which supports default arguments combined with the "required" property. @@ -292,13 +337,13 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa """ instance = copy.deepcopy(instance) - schema_type = schema.get('type', None) + schema_type = schema.get("type", None) instance_is_dict = isinstance(instance, dict) if use_default and allow_default_none: schema = modify_schema_allow_default_none(schema=schema) - if use_default and schema_type == 'object' and instance_is_dict: + if use_default and schema_type == "object" and instance_is_dict: instance = assign_default_values(instance=instance, schema=schema) # pylint: disable=assignment-from-no-return @@ -307,28 +352,30 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa return instance -VALIDATORS = { - 'draft4': jsonschema.Draft4Validator, - 'custom': CustomValidator -} +VALIDATORS = {"draft4": jsonschema.Draft4Validator, "custom": CustomValidator} -def get_validator(version='custom'): +def get_validator(version="custom"): validator = VALIDATORS[version] return validator -def validate_runner_parameter_attribute_override(action_ref, param_name, attr_name, - runner_param_attr_value, action_param_attr_value): +def validate_runner_parameter_attribute_override( + action_ref, param_name, attr_name, runner_param_attr_value, action_param_attr_value +): """ Validate that the provided parameter from the action schema can override the runner parameter. """ param_values_are_the_same = action_param_attr_value == runner_param_attr_value - if (attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS and not param_values_are_the_same): + if ( + attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS + and not param_values_are_the_same + ): raise InvalidActionParameterException( 'The attribute "%s" for the runner parameter "%s" in action "%s" ' - 'cannot be overridden.' % (attr_name, param_name, action_ref)) + "cannot be overridden." % (attr_name, param_name, action_ref) + ) return True @@ -341,7 +388,8 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None): """ if not runnertype_db: from st2common.util.action_db import get_runnertype_by_name - runnertype_db = get_runnertype_by_name(action_db.runner_type['name']) + + runnertype_db = get_runnertype_by_name(action_db.runner_type["name"]) # Note: We need to perform a deep merge because user can only specify a single parameter # attribute when overriding it in an action metadata. @@ -359,26 +407,31 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None): for attribute, value in six.iteritems(schema): runner_param_value = runnertype_db.runner_parameters[name].get(attribute) - validate_runner_parameter_attribute_override(action_ref=action_db.ref, - param_name=name, - attr_name=attribute, - runner_param_attr_value=runner_param_value, - action_param_attr_value=value) + validate_runner_parameter_attribute_override( + action_ref=action_db.ref, + param_name=name, + attr_name=attribute, + runner_param_attr_value=runner_param_value, + action_param_attr_value=value, + ) schema = get_schema_for_resource_parameters(parameters_schema=parameters_schema) if parameters_schema: - schema['title'] = action_db.name + schema["title"] = action_db.name if action_db.description: - schema['description'] = action_db.description + schema["description"] = action_db.description return schema -def get_schema_for_resource_parameters(parameters_schema, allow_additional_properties=False): +def get_schema_for_resource_parameters( + parameters_schema, allow_additional_properties=False +): """ Dynamically construct JSON schema for the provided resource from the parameters metadata. """ + def normalize(x): return {k: v if v else SCHEMA_ANY_TYPE for k, v in six.iteritems(x)} @@ -386,8 +439,8 @@ def normalize(x): properties = {} properties.update(normalize(parameters_schema)) if properties: - schema['type'] = 'object' - schema['properties'] = properties - schema['additionalProperties'] = allow_additional_properties + schema["type"] = "object" + schema["properties"] = properties + schema["additionalProperties"] = allow_additional_properties return schema diff --git a/st2common/st2common/util/secrets.py b/st2common/st2common/util/secrets.py index 2945ef0594..b863a93a61 100644 --- a/st2common/st2common/util/secrets.py +++ b/st2common/st2common/util/secrets.py @@ -65,7 +65,7 @@ def get_secret_parameters(parameters): """ secret_parameters = {} - parameters_type = parameters.get('type') + parameters_type = parameters.get("type") # If the parameter itself is secret, then skip all processing below it # and return the type of this parameter. # @@ -74,22 +74,22 @@ def get_secret_parameters(parameters): # **Important** that we do this check first, so in case this parameter # is an `object` or `array`, and the user wants the full thing # to be secret, that it is marked as secret. - if parameters.get('secret', False): + if parameters.get("secret", False): return parameters_type iterator = None - if parameters_type == 'object': + if parameters_type == "object": # if this is an object, then iterate over the properties within # the object # result = dict - iterator = six.iteritems(parameters.get('properties', {})) - elif parameters_type == 'array': + iterator = six.iteritems(parameters.get("properties", {})) + elif parameters_type == "array": # if this is an array, then iterate over the items definition as a single # property # result = list - iterator = enumerate([parameters.get('items', {})]) + iterator = enumerate([parameters.get("items", {})]) secret_parameters = [] - elif parameters_type in ['integer', 'number', 'boolean', 'null', 'string']: + elif parameters_type in ["integer", "number", "boolean", "null", "string"]: # if this a "plain old datatype", then iterate over the properties set # of the data type # result = string (property type) @@ -105,8 +105,8 @@ def get_secret_parameters(parameters): if not isinstance(options, dict): continue - parameter_type = options.get('type') - if options.get('secret', False): + parameter_type = options.get("type") + if options.get("secret", False): # If this parameter is secret, then add it our secret parameters # # **This causes the _full_ object / array tree to be secret @@ -121,7 +121,7 @@ def get_secret_parameters(parameters): secret_parameters[parameter] = parameter_type else: return parameter_type - elif parameter_type in ['object', 'array']: + elif parameter_type in ["object", "array"]: # otherwise recursively dive into the `object`/`array` and # find individual parameters marked as secret sub_params = get_secret_parameters(options) @@ -176,15 +176,17 @@ def mask_secret_parameters(parameters, secret_parameters, result=None): for secret_param, secret_sub_params in iterator: if is_dict: if secret_param in result: - result[secret_param] = mask_secret_parameters(parameters[secret_param], - secret_sub_params, - result=result[secret_param]) + result[secret_param] = mask_secret_parameters( + parameters[secret_param], + secret_sub_params, + result=result[secret_param], + ) elif is_list: # we're assuming lists contain the same data type for every element for idx, value in enumerate(result): - result[idx] = mask_secret_parameters(parameters[idx], - secret_sub_params, - result=result[idx]) + result[idx] = mask_secret_parameters( + parameters[idx], secret_sub_params, result=result[idx] + ) else: result[secret_param] = MASKED_ATTRIBUTE_VALUE @@ -204,8 +206,8 @@ def mask_inquiry_response(response, schema): """ result = fast_deepcopy(response) - for prop_name, prop_attrs in schema['properties'].items(): - if prop_attrs.get('secret') is True: + for prop_name, prop_attrs in schema["properties"].items(): + if prop_attrs.get("secret") is True: if prop_name in response: result[prop_name] = MASKED_ATTRIBUTE_VALUE diff --git a/st2common/st2common/util/service.py b/st2common/st2common/util/service.py index 6691e50268..e3c2dcb9f9 100644 --- a/st2common/st2common/util/service.py +++ b/st2common/st2common/util/service.py @@ -24,13 +24,13 @@ def retry_on_exceptions(exc): - LOG.warning('Evaluating retry on exception %s. %s', type(exc), str(exc)) + LOG.warning("Evaluating retry on exception %s. %s", type(exc), str(exc)) is_mongo_connection_error = isinstance(exc, pymongo.errors.ConnectionFailure) retrying = is_mongo_connection_error if retrying: - LOG.warning('Retrying on exception %s.', type(exc)) + LOG.warning("Retrying on exception %s.", type(exc)) return retrying diff --git a/st2common/st2common/util/shell.py b/st2common/st2common/util/shell.py index 5c4217594a..945ec39a5a 100644 --- a/st2common/st2common/util/shell.py +++ b/st2common/st2common/util/shell.py @@ -30,13 +30,7 @@ # subprocess functionality and run_command subprocess = concurrency.get_subprocess_module() -__all__ = [ - 'run_command', - 'kill_process', - - 'quote_unix', - 'quote_windows' -] +__all__ = ["run_command", "kill_process", "quote_unix", "quote_windows"] LOG = logging.getLogger(__name__) @@ -45,8 +39,15 @@ # pylint: disable=too-many-function-args -def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, - cwd=None, env=None): +def run_command( + cmd, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + cwd=None, + env=None, +): """ Run the provided command in a subprocess and wait until it completes. @@ -79,8 +80,15 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, if not env: env = os.environ.copy() - process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr, - env=env, cwd=cwd, shell=shell) + process = concurrency.subprocess_popen( + args=cmd, + stdin=stdin, + stdout=stdout, + stderr=stderr, + env=env, + cwd=cwd, + shell=shell, + ) stdout, stderr = process.communicate() exit_code = process.returncode @@ -100,15 +108,17 @@ def kill_process(process): :param process: Process object as returned by subprocess.Popen. :type process: ``object`` """ - kill_command = shlex.split('sudo pkill -TERM -s %s' % (process.pid)) + kill_command = shlex.split("sudo pkill -TERM -s %s" % (process.pid)) try: if six.PY3: - status = subprocess.call(kill_command, timeout=100) # pylint: disable=not-callable + status = subprocess.call( + kill_command, timeout=100 + ) # pylint: disable=not-callable else: status = subprocess.call(kill_command) # pylint: disable=not-callable except Exception: - LOG.exception('Unable to pkill process.') + LOG.exception("Unable to pkill process.") return status @@ -151,11 +161,12 @@ def on_parent_exit(signame): Based on https://gist.github.com/evansd/2346614 """ + def noop(): pass try: - libc = cdll['libc.so.6'] + libc = cdll["libc.so.6"] except OSError: # libc, can't be found (e.g. running on non-Unix system), we cant ensure signal will be # triggered @@ -173,5 +184,6 @@ def set_parent_exit_signal(): # http://linux.die.net/man/2/prctl result = prctl(PR_SET_PDEATHSIG, signum) if result != 0: - raise Exception('prctl failed with error code %s' % result) + raise Exception("prctl failed with error code %s" % result) + return set_parent_exit_signal diff --git a/st2common/st2common/util/spec_loader.py b/st2common/st2common/util/spec_loader.py index 8ab926330f..07889fa2d2 100644 --- a/st2common/st2common/util/spec_loader.py +++ b/st2common/st2common/util/spec_loader.py @@ -33,16 +33,13 @@ from st2common.rbac.types import PermissionType from st2common.util import isotime -__all__ = [ - 'load_spec', - 'generate_spec' -] +__all__ = ["load_spec", "generate_spec"] ARGUMENTS = { - 'DEFAULT_PACK_NAME': st2common.constants.pack.DEFAULT_PACK_NAME, - 'LIVEACTION_STATUSES': st2common.constants.action.LIVEACTION_STATUSES, - 'PERMISSION_TYPE': PermissionType, - 'ISO8601_UTC_REGEX': isotime.ISO8601_UTC_REGEX + "DEFAULT_PACK_NAME": st2common.constants.pack.DEFAULT_PACK_NAME, + "LIVEACTION_STATUSES": st2common.constants.action.LIVEACTION_STATUSES, + "PERMISSION_TYPE": PermissionType, + "ISO8601_UTC_REGEX": isotime.ISO8601_UTC_REGEX, } @@ -50,23 +47,35 @@ class UniqueKeyLoader(Loader): """ YAML loader which throws on a duplicate key. """ + def construct_mapping(self, node, deep=False): if not isinstance(node, MappingNode): - raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, - node.start_mark) + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) try: hash(key) except TypeError as exc: - raise ConstructorError("while constructing a mapping", node.start_mark, - "found unacceptable key (%s)" % exc, key_node.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found unacceptable key (%s)" % exc, + key_node.start_mark, + ) # check for duplicate keys if key in mapping: - raise ConstructorError("while constructing a mapping", node.start_mark, - "found duplicate key", key_node.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found duplicate key", + key_node.start_mark, + ) value = self.construct_object(value_node, deep=deep) mapping[key] = value return mapping diff --git a/st2common/st2common/util/system_info.py b/st2common/st2common/util/system_info.py index a83bf5169f..b81d205907 100644 --- a/st2common/st2common/util/system_info.py +++ b/st2common/st2common/util/system_info.py @@ -17,22 +17,14 @@ import os import socket -__all__ = [ - 'get_host_info', - 'get_process_info' -] +__all__ = ["get_host_info", "get_process_info"] def get_host_info(): - host_info = { - 'hostname': socket.gethostname() - } + host_info = {"hostname": socket.gethostname()} return host_info def get_process_info(): - process_info = { - 'hostname': socket.gethostname(), - 'pid': os.getpid() - } + process_info = {"hostname": socket.gethostname(), "pid": os.getpid()} return process_info diff --git a/st2common/st2common/util/templating.py b/st2common/st2common/util/templating.py index 9dc25d917c..82e8e1c246 100644 --- a/st2common/st2common/util/templating.py +++ b/st2common/st2common/util/templating.py @@ -24,9 +24,9 @@ from st2common.services.keyvalues import UserKeyValueLookup __all__ = [ - 'render_template', - 'render_template_with_system_context', - 'render_template_with_system_and_user_context' + "render_template", + "render_template_with_system_context", + "render_template_with_system_and_user_context", ] @@ -74,7 +74,9 @@ def render_template_with_system_context(value, context=None, prefix=None): return rendered -def render_template_with_system_and_user_context(value, user, context=None, prefix=None): +def render_template_with_system_and_user_context( + value, user, context=None, prefix=None +): """ Render provided template with a default system context and user context for the provided user. @@ -95,7 +97,7 @@ def render_template_with_system_and_user_context(value, user, context=None, pref context = context or {} context[DATASTORE_PARENT_SCOPE] = { SYSTEM_SCOPE: KeyValueLookup(prefix=prefix, scope=FULL_SYSTEM_SCOPE), - USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE) + USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE), } rendered = render_template(value=value, context=context) diff --git a/st2common/st2common/util/types.py b/st2common/st2common/util/types.py index 5c25990a6e..ad70f078b9 100644 --- a/st2common/st2common/util/types.py +++ b/st2common/st2common/util/types.py @@ -20,17 +20,14 @@ from __future__ import absolute_import import collections -__all__ = [ - 'OrderedSet' -] +__all__ = ["OrderedSet"] class OrderedSet(collections.MutableSet): - def __init__(self, iterable=None): self.end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.map = {} # key --> [key, prev, next] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] if iterable is not None: self |= iterable @@ -68,15 +65,15 @@ def __reversed__(self): def pop(self, last=True): if not self: - raise KeyError('set is empty') + raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] self.discard(key) return key def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): diff --git a/st2common/st2common/util/uid.py b/st2common/st2common/util/uid.py index 07d04d7511..289184d59e 100644 --- a/st2common/st2common/util/uid.py +++ b/st2common/st2common/util/uid.py @@ -20,9 +20,7 @@ from __future__ import absolute_import from st2common.models.db.stormbase import UIDFieldMixin -__all__ = [ - 'parse_uid' -] +__all__ = ["parse_uid"] def parse_uid(uid): @@ -33,12 +31,12 @@ def parse_uid(uid): :rtype: ``tuple`` """ if UIDFieldMixin.UID_SEPARATOR not in uid: - raise ValueError('Invalid uid: %s' % (uid)) + raise ValueError("Invalid uid: %s" % (uid)) parsed = uid.split(UIDFieldMixin.UID_SEPARATOR) if len(parsed) < 2: - raise ValueError('Invalid or malformed uid: %s' % (uid)) + raise ValueError("Invalid or malformed uid: %s" % (uid)) resource_type = parsed[0] uid_remainder = parsed[1:] diff --git a/st2common/st2common/util/ujson.py b/st2common/st2common/util/ujson.py index cace243448..6c533fb30a 100644 --- a/st2common/st2common/util/ujson.py +++ b/st2common/st2common/util/ujson.py @@ -19,9 +19,7 @@ import ujson -__all__ = [ - 'fast_deepcopy' -] +__all__ = ["fast_deepcopy"] def fast_deepcopy(value, fall_back_to_deepcopy=True): diff --git a/st2common/st2common/util/url.py b/st2common/st2common/util/url.py index 9c3196f835..b4dd8fc137 100644 --- a/st2common/st2common/util/url.py +++ b/st2common/st2common/util/url.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'get_url_without_trailing_slash' -] +__all__ = ["get_url_without_trailing_slash"] def get_url_without_trailing_slash(value): @@ -27,5 +25,5 @@ def get_url_without_trailing_slash(value): :rtype: ``str`` """ - result = value[:-1] if value.endswith('/') else value + result = value[:-1] if value.endswith("/") else value return result diff --git a/st2common/st2common/util/versioning.py b/st2common/st2common/util/versioning.py index 121a93312a..89da24f174 100644 --- a/st2common/st2common/util/versioning.py +++ b/st2common/st2common/util/versioning.py @@ -25,12 +25,7 @@ from st2common import __version__ as stackstorm_version -__all__ = [ - 'get_stackstorm_version', - 'get_python_version', - - 'complex_semver_match' -] +__all__ = ["get_stackstorm_version", "get_python_version", "complex_semver_match"] def get_stackstorm_version(): @@ -38,8 +33,8 @@ def get_stackstorm_version(): Return a valid semver version string for the currently running StackStorm version. """ # Special handling for dev versions which are not valid semver identifiers - if 'dev' in stackstorm_version and stackstorm_version.count('.') == 1: - version = stackstorm_version.replace('dev', '.0') + if "dev" in stackstorm_version and stackstorm_version.count(".") == 1: + version = stackstorm_version.replace("dev", ".0") return version return stackstorm_version @@ -50,7 +45,7 @@ def get_python_version(): Return Python version used by this installation. """ version_info = sys.version_info - return '%s.%s.%s' % (version_info.major, version_info.minor, version_info.micro) + return "%s.%s.%s" % (version_info.major, version_info.minor, version_info.micro) def complex_semver_match(version, version_specifier): @@ -63,10 +58,10 @@ def complex_semver_match(version, version_specifier): :rtype: ``bool`` """ - if version_specifier == 'all': + if version_specifier == "all": return True - split_version_specifier = version_specifier.split(',') + split_version_specifier = version_specifier.split(",") if len(split_version_specifier) == 1: # No comma, we can do a simple comparision diff --git a/st2common/st2common/util/virtualenvs.py b/st2common/st2common/util/virtualenvs.py index db56e6fb20..7f408c9da3 100644 --- a/st2common/st2common/util/virtualenvs.py +++ b/st2common/st2common/util/virtualenvs.py @@ -36,16 +36,22 @@ from st2common.content.utils import get_packs_base_paths from st2common.content.utils import get_pack_directory -__all__ = [ - 'setup_pack_virtualenv' -] +__all__ = ["setup_pack_virtualenv"] LOG = logging.getLogger(__name__) -def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True, - include_setuptools=True, include_wheel=True, proxy_config=None, - no_download=True, force_owner_group=True): +def setup_pack_virtualenv( + pack_name, + update=False, + logger=None, + include_pip=True, + include_setuptools=True, + include_wheel=True, + proxy_config=None, + no_download=True, + force_owner_group=True, +): """ Setup virtual environment for the provided pack. @@ -68,7 +74,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True if not re.match(PACK_REF_WHITELIST_REGEX, pack_name): raise ValueError('Invalid pack name "%s"' % (pack_name)) - base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/') + base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, "virtualenvs/") virtualenv_path = os.path.join(base_virtualenvs_path, quote_unix(pack_name)) # Ensure pack directory exists in one of the search paths @@ -78,7 +84,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True if not pack_path: packs_base_paths = get_packs_base_paths() - search_paths = ', '.join(packs_base_paths) + search_paths = ", ".join(packs_base_paths) msg = 'Pack "%s" is not installed. Looked in: %s' % (pack_name, search_paths) raise Exception(msg) @@ -88,42 +94,64 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True remove_virtualenv(virtualenv_path=virtualenv_path, logger=logger) # 1. Create virtual environment - logger.debug('Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path)) - create_virtualenv(virtualenv_path=virtualenv_path, logger=logger, include_pip=include_pip, - include_setuptools=include_setuptools, include_wheel=include_wheel, - no_download=no_download) + logger.debug( + 'Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path) + ) + create_virtualenv( + virtualenv_path=virtualenv_path, + logger=logger, + include_pip=include_pip, + include_setuptools=include_setuptools, + include_wheel=include_wheel, + no_download=no_download, + ) # 2. Install base requirements which are common to all the packs - logger.debug('Installing base requirements') + logger.debug("Installing base requirements") for requirement in BASE_PACK_REQUIREMENTS: - install_requirement(virtualenv_path=virtualenv_path, requirement=requirement, - proxy_config=proxy_config, logger=logger) + install_requirement( + virtualenv_path=virtualenv_path, + requirement=requirement, + proxy_config=proxy_config, + logger=logger, + ) # 3. Install pack-specific requirements - requirements_file_path = os.path.join(pack_path, 'requirements.txt') + requirements_file_path = os.path.join(pack_path, "requirements.txt") has_requirements = os.path.isfile(requirements_file_path) if has_requirements: - logger.debug('Installing pack specific requirements from "%s"' % - (requirements_file_path)) - install_requirements(virtualenv_path=virtualenv_path, - requirements_file_path=requirements_file_path, - proxy_config=proxy_config, - logger=logger) + logger.debug( + 'Installing pack specific requirements from "%s"' % (requirements_file_path) + ) + install_requirements( + virtualenv_path=virtualenv_path, + requirements_file_path=requirements_file_path, + proxy_config=proxy_config, + logger=logger, + ) else: - logger.debug('No pack specific requirements found') + logger.debug("No pack specific requirements found") # 4. Set the owner group if force_owner_group: apply_pack_owner_group(pack_path=virtualenv_path) - action = 'updated' if update else 'created' - logger.debug('Virtualenv for pack "%s" successfully %s in "%s"' % - (pack_name, action, virtualenv_path)) - - -def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_setuptools=True, - include_wheel=True, no_download=True): + action = "updated" if update else "created" + logger.debug( + 'Virtualenv for pack "%s" successfully %s in "%s"' + % (pack_name, action, virtualenv_path) + ) + + +def create_virtualenv( + virtualenv_path, + logger=None, + include_pip=True, + include_setuptools=True, + include_wheel=True, + no_download=True, +): """ :param include_pip: Include pip binary and package in the newely created virtual environment. :type include_pip: ``bool`` @@ -145,7 +173,7 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se python_binary = cfg.CONF.actionrunner.python_binary virtualenv_binary = cfg.CONF.actionrunner.virtualenv_binary virtualenv_opts = cfg.CONF.actionrunner.virtualenv_opts or [] - virtualenv_opts += ['--verbose'] + virtualenv_opts += ["--verbose"] if not os.path.isfile(python_binary): raise Exception('Python binary "%s" doesn\'t exist' % (python_binary)) @@ -153,39 +181,44 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se if not os.path.isfile(virtualenv_binary): raise Exception('Virtualenv binary "%s" doesn\'t exist.' % (virtualenv_binary)) - logger.debug('Creating virtualenv in "%s" using Python binary "%s"' % - (virtualenv_path, python_binary)) + logger.debug( + 'Creating virtualenv in "%s" using Python binary "%s"' + % (virtualenv_path, python_binary) + ) cmd = [virtualenv_binary] - cmd.extend(['-p', python_binary]) + cmd.extend(["-p", python_binary]) cmd.extend(virtualenv_opts) if not include_pip: - cmd.append('--no-pip') + cmd.append("--no-pip") if not include_setuptools: - cmd.append('--no-setuptools') + cmd.append("--no-setuptools") if not include_wheel: - cmd.append('--no-wheel') + cmd.append("--no-wheel") if no_download: - cmd.append('--no-download') + cmd.append("--no-download") cmd.extend([virtualenv_path]) - logger.debug('Running command "%s" to create virtualenv.', ' '.join(cmd)) + logger.debug('Running command "%s" to create virtualenv.', " ".join(cmd)) try: exit_code, stdout, stderr = run_command(cmd=cmd) except OSError as e: - raise Exception('Error executing command %s. %s.' % (' '.join(cmd), - six.text_type(e))) + raise Exception( + "Error executing command %s. %s." % (" ".join(cmd), six.text_type(e)) + ) if exit_code != 0: - raise Exception('Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s' % - (virtualenv_path, stdout, stderr)) + raise Exception( + 'Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s' + % (virtualenv_path, stdout, stderr) + ) return True @@ -204,51 +237,60 @@ def remove_virtualenv(virtualenv_path, logger=None): try: shutil.rmtree(virtualenv_path) except Exception as e: - logger.error('Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e)) + logger.error( + 'Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e) + ) raise e return True -def install_requirements(virtualenv_path, requirements_file_path, proxy_config=None, logger=None): +def install_requirements( + virtualenv_path, requirements_file_path, proxy_config=None, logger=None +): """ Install requirements from a file. """ logger = logger or LOG - pip_path = os.path.join(virtualenv_path, 'bin/pip') + pip_path = os.path.join(virtualenv_path, "bin/pip") pip_opts = cfg.CONF.actionrunner.pip_opts or [] cmd = [pip_path] if proxy_config: - cert = proxy_config.get('proxy_ca_bundle_path', None) - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) + cert = proxy_config.get("proxy_ca_bundle_path", None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) if http_proxy: - cmd.extend(['--proxy', http_proxy]) + cmd.extend(["--proxy", http_proxy]) if https_proxy: - cmd.extend(['--proxy', https_proxy]) + cmd.extend(["--proxy", https_proxy]) if cert: - cmd.extend(['--cert', cert]) + cmd.extend(["--cert", cert]) - cmd.append('install') + cmd.append("install") cmd.extend(pip_opts) - cmd.extend(['-U', '-r', requirements_file_path]) + cmd.extend(["-U", "-r", requirements_file_path]) env = get_env_for_subprocess_command() - logger.debug('Installing requirements from file %s with command %s.', - requirements_file_path, ' '.join(cmd)) + logger.debug( + "Installing requirements from file %s with command %s.", + requirements_file_path, + " ".join(cmd), + ) exit_code, stdout, stderr = run_command(cmd=cmd, env=env) if exit_code != 0: stdout = to_ascii(stdout) stderr = to_ascii(stderr) - raise Exception('Failed to install requirements from "%s": %s (stderr: %s)' % - (requirements_file_path, stdout, stderr)) + raise Exception( + 'Failed to install requirements from "%s": %s (stderr: %s)' + % (requirements_file_path, stdout, stderr) + ) return True @@ -260,35 +302,37 @@ def install_requirement(virtualenv_path, requirement, proxy_config=None, logger= :param requirement: Requirement specifier. """ logger = logger or LOG - pip_path = os.path.join(virtualenv_path, 'bin/pip') + pip_path = os.path.join(virtualenv_path, "bin/pip") pip_opts = cfg.CONF.actionrunner.pip_opts or [] cmd = [pip_path] if proxy_config: - cert = proxy_config.get('proxy_ca_bundle_path', None) - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) + cert = proxy_config.get("proxy_ca_bundle_path", None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) if http_proxy: - cmd.extend(['--proxy', http_proxy]) + cmd.extend(["--proxy", http_proxy]) if https_proxy: - cmd.extend(['--proxy', https_proxy]) + cmd.extend(["--proxy", https_proxy]) if cert: - cmd.extend(['--cert', cert]) + cmd.extend(["--cert", cert]) - cmd.append('install') + cmd.append("install") cmd.extend(pip_opts) cmd.extend([requirement]) env = get_env_for_subprocess_command() - logger.debug('Installing requirement %s with command %s.', - requirement, ' '.join(cmd)) + logger.debug( + "Installing requirement %s with command %s.", requirement, " ".join(cmd) + ) exit_code, stdout, stderr = run_command(cmd=cmd, env=env) if exit_code != 0: - raise Exception('Failed to install requirement "%s": %s' % - (requirement, stdout)) + raise Exception( + 'Failed to install requirement "%s": %s' % (requirement, stdout) + ) return True @@ -302,7 +346,7 @@ def get_env_for_subprocess_command(): """ env = os.environ.copy() - if 'PYTHONPATH' in env: - del env['PYTHONPATH'] + if "PYTHONPATH" in env: + del env["PYTHONPATH"] return env diff --git a/st2common/st2common/util/wsgi.py b/st2common/st2common/util/wsgi.py index a3441e4bda..63ec6c6253 100644 --- a/st2common/st2common/util/wsgi.py +++ b/st2common/st2common/util/wsgi.py @@ -24,9 +24,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'shutdown_server_kill_pending_requests' -] +__all__ = ["shutdown_server_kill_pending_requests"] def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): @@ -46,7 +44,7 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): sock.close() active_requests = worker_pool.running() - LOG.info('Shutting down. Requests left: %s', active_requests) + LOG.info("Shutting down. Requests left: %s", active_requests) # Give active requests some time to finish if active_requests > 0: @@ -57,5 +55,5 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): for coro in running_corutines: eventlet.greenthread.kill(coro) - LOG.info('Exiting...') + LOG.info("Exiting...") raise SystemExit() diff --git a/st2common/st2common/validators/api/action.py b/st2common/st2common/validators/api/action.py index 1eb5dbfeb9..973e999fa6 100644 --- a/st2common/st2common/validators/api/action.py +++ b/st2common/st2common/validators/api/action.py @@ -26,10 +26,7 @@ from st2common.models.system.common import ResourceReference from six.moves import range -__all__ = [ - 'validate_action', - 'get_runner_model' -] +__all__ = ["validate_action", "get_runner_model"] LOG = logging.getLogger(__name__) @@ -49,14 +46,17 @@ def validate_action(action_api, runner_type_db=None): # Check if pack is valid. if not _is_valid_pack(action_api.pack): packs_base_paths = get_packs_base_paths() - packs_base_paths = ','.join(packs_base_paths) - msg = ('Content pack "%s" is not found or doesn\'t contain actions directory. ' - 'Searched in: %s' % - (action_api.pack, packs_base_paths)) + packs_base_paths = ",".join(packs_base_paths) + msg = ( + 'Content pack "%s" is not found or doesn\'t contain actions directory. ' + "Searched in: %s" % (action_api.pack, packs_base_paths) + ) raise ValueValidationException(msg) # Check if parameters defined are valid. - action_ref = ResourceReference.to_string_reference(pack=action_api.pack, name=action_api.name) + action_ref = ResourceReference.to_string_reference( + pack=action_api.pack, name=action_api.name + ) _validate_parameters(action_ref, action_api.parameters, runner_db.runner_parameters) @@ -66,15 +66,18 @@ def get_runner_model(action_api): try: runner_db = get_runnertype_by_name(action_api.runner_type) except StackStormDBObjectNotFoundError: - msg = ('RunnerType %s is not found. If you are using old and deprecated runner name, you ' - 'need to switch to a new one. For more information, please see ' - 'https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9' % (action_api.runner_type)) + msg = ( + "RunnerType %s is not found. If you are using old and deprecated runner name, you " + "need to switch to a new one. For more information, please see " + "https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9" + % (action_api.runner_type) + ) raise ValueValidationException(msg) return runner_db def _is_valid_pack(pack): - return check_pack_content_directory_exists(pack=pack, content_type='actions') + return check_pack_content_directory_exists(pack=pack, content_type="actions") def _validate_parameters(action_ref, action_params=None, runner_params=None): @@ -84,32 +87,44 @@ def _validate_parameters(action_ref, action_params=None, runner_params=None): if action_param in runner_params: for action_param_attr, value in six.iteritems(action_param_meta): util_schema.validate_runner_parameter_attribute_override( - action_ref, action_param, action_param_attr, - value, runner_params[action_param].get(action_param_attr)) - - if 'position' in action_param_meta: - pos = action_param_meta['position'] + action_ref, + action_param, + action_param_attr, + value, + runner_params[action_param].get(action_param_attr), + ) + + if "position" in action_param_meta: + pos = action_param_meta["position"] param = position_params.get(pos, None) if param: - msg = ('Parameters %s and %s have same position %d.' % (action_param, param, pos) + - ' Position values have to be unique.') + msg = ( + "Parameters %s and %s have same position %d." + % (action_param, param, pos) + + " Position values have to be unique." + ) raise ValueValidationException(msg) else: position_params[pos] = action_param - if 'immutable' in action_param_meta: + if "immutable" in action_param_meta: if action_param in runner_params: runner_param_meta = runner_params[action_param] - if 'immutable' in runner_param_meta: - msg = 'Param %s is declared immutable in runner. ' % action_param + \ - 'Cannot override in action.' + if "immutable" in runner_param_meta: + msg = ( + "Param %s is declared immutable in runner. " % action_param + + "Cannot override in action." + ) raise ValueValidationException(msg) - if 'default' not in action_param_meta and 'default' not in runner_param_meta: - msg = 'Immutable param %s requires a default value.' % action_param + if ( + "default" not in action_param_meta + and "default" not in runner_param_meta + ): + msg = "Immutable param %s requires a default value." % action_param raise ValueValidationException(msg) else: - if 'default' not in action_param_meta: - msg = 'Immutable param %s requires a default value.' % action_param + if "default" not in action_param_meta: + msg = "Immutable param %s requires a default value." % action_param raise ValueValidationException(msg) return _validate_position_values_contiguous(position_params) @@ -120,10 +135,10 @@ def _validate_position_values_contiguous(position_params): return True positions = sorted(position_params.keys()) - contiguous = (positions == list(range(min(positions), max(positions) + 1))) + contiguous = positions == list(range(min(positions), max(positions) + 1)) if not contiguous: - msg = 'Positions supplied %s for parameters are not contiguous.' % positions + msg = "Positions supplied %s for parameters are not contiguous." % positions raise ValueValidationException(msg) return True diff --git a/st2common/st2common/validators/api/misc.py b/st2common/st2common/validators/api/misc.py index b18ff05d21..215afc5501 100644 --- a/st2common/st2common/validators/api/misc.py +++ b/st2common/st2common/validators/api/misc.py @@ -17,9 +17,7 @@ from st2common.constants.pack import SYSTEM_PACK_NAME from st2common.exceptions.apivalidation import ValueValidationException -__all__ = [ - 'validate_not_part_of_system_pack' -] +__all__ = ["validate_not_part_of_system_pack"] def validate_not_part_of_system_pack(resource_db): @@ -32,10 +30,10 @@ def validate_not_part_of_system_pack(resource_db): :param resource_db: Resource database object to check. :type resource_db: ``object`` """ - pack = getattr(resource_db, 'pack', None) + pack = getattr(resource_db, "pack", None) if pack == SYSTEM_PACK_NAME: - msg = 'Resources belonging to system level packs can\'t be manipulated' + msg = "Resources belonging to system level packs can't be manipulated" raise ValueValidationException(msg) return resource_db diff --git a/st2common/st2common/validators/api/reactor.py b/st2common/st2common/validators/api/reactor.py index eb2cf1c814..0d84a66a99 100644 --- a/st2common/st2common/validators/api/reactor.py +++ b/st2common/st2common/validators/api/reactor.py @@ -29,10 +29,9 @@ from st2common.services import triggers __all__ = [ - 'validate_criteria', - - 'validate_trigger_parameters', - 'validate_trigger_payload' + "validate_criteria", + "validate_trigger_parameters", + "validate_trigger_payload", ] @@ -43,20 +42,30 @@ def validate_criteria(criteria): if not isinstance(criteria, dict): - raise ValueValidationException('Criteria should be a dict.') + raise ValueValidationException("Criteria should be a dict.") for key, value in six.iteritems(criteria): - operator = value.get('type', None) + operator = value.get("type", None) if operator is None: - raise ValueValidationException('Operator not specified for field: ' + key) + raise ValueValidationException("Operator not specified for field: " + key) if operator not in allowed_operators: - raise ValueValidationException('For field: ' + key + ', operator ' + operator + - ' not in list of allowed operators: ' + - str(list(allowed_operators.keys()))) - pattern = value.get('pattern', None) + raise ValueValidationException( + "For field: " + + key + + ", operator " + + operator + + " not in list of allowed operators: " + + str(list(allowed_operators.keys())) + ) + pattern = value.get("pattern", None) if pattern is None: - raise ValueValidationException('For field: ' + key + ', no pattern specified ' + - 'for operator ' + operator) + raise ValueValidationException( + "For field: " + + key + + ", no pattern specified " + + "for operator " + + operator + ) def validate_trigger_parameters(trigger_type_ref, parameters): @@ -77,27 +86,33 @@ def validate_trigger_parameters(trigger_type_ref, parameters): is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES if is_system_trigger: # System trigger - parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['parameters_schema'] + parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["parameters_schema"] else: trigger_type_db = triggers.get_trigger_type_db(trigger_type_ref) if not trigger_type_db: # Trigger doesn't exist in the database return None - parameters_schema = getattr(trigger_type_db, 'parameters_schema', {}) + parameters_schema = getattr(trigger_type_db, "parameters_schema", {}) if not parameters_schema: # Parameters schema not defined for the this trigger return None # We only validate non-system triggers if config option is set (enabled) if not is_system_trigger and not cfg.CONF.system.validate_trigger_parameters: - LOG.debug('Got non-system trigger "%s", but trigger parameter validation for non-system' - 'triggers is disabled, skipping validation.' % (trigger_type_ref)) + LOG.debug( + 'Got non-system trigger "%s", but trigger parameter validation for non-system' + "triggers is disabled, skipping validation." % (trigger_type_ref) + ) return None - cleaned = util_schema.validate(instance=parameters, schema=parameters_schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=parameters, + schema=parameters_schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) # Additional validation for CronTimer trigger # TODO: If we need to add more checks like this we should consider abstracting this out. @@ -110,7 +125,9 @@ def validate_trigger_parameters(trigger_type_ref, parameters): return cleaned -def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trigger=False): +def validate_trigger_payload( + trigger_type_ref, payload, throw_on_inexistent_trigger=False +): """ This function validates trigger payload parameters for system and user-defined triggers. @@ -128,8 +145,8 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig # NOTE: Due to the awful code in some other places we also need to support a scenario where # this variable is a dictionary and contains various TriggerDB object attributes. if isinstance(trigger_type_ref, dict): - if trigger_type_ref.get('type', None): - trigger_type_ref = trigger_type_ref['type'] + if trigger_type_ref.get("type", None): + trigger_type_ref = trigger_type_ref["type"] else: trigger_db = triggers.get_trigger_db_by_ref_or_dict(trigger_type_ref) @@ -143,16 +160,16 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES if is_system_trigger: # System trigger - payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['payload_schema'] + payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["payload_schema"] else: # We assume Trigger ref and not TriggerType ref is passed in if second # part (trigger name) is a valid UUID version 4 try: - trigger_uuid = uuid.UUID(trigger_type_ref.split('.')[-1]) + trigger_uuid = uuid.UUID(trigger_type_ref.split(".")[-1]) except ValueError: is_trigger_db = False else: - is_trigger_db = (trigger_uuid.version == 4) + is_trigger_db = trigger_uuid.version == 4 if is_trigger_db: trigger_db = triggers.get_trigger_db_by_ref(trigger_type_ref) @@ -165,25 +182,33 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig if not trigger_type_db: # Trigger doesn't exist in the database if throw_on_inexistent_trigger: - msg = ('Trigger type with reference "%s" doesn\'t exist in the database' % - (trigger_type_ref)) + msg = ( + 'Trigger type with reference "%s" doesn\'t exist in the database' + % (trigger_type_ref) + ) raise ValueError(msg) return None - payload_schema = getattr(trigger_type_db, 'payload_schema', {}) + payload_schema = getattr(trigger_type_db, "payload_schema", {}) if not payload_schema: # Payload schema not defined for the this trigger return None # We only validate non-system triggers if config option is set (enabled) if not is_system_trigger and not cfg.CONF.system.validate_trigger_payload: - LOG.debug('Got non-system trigger "%s", but trigger payload validation for non-system' - 'triggers is disabled, skipping validation.' % (trigger_type_ref)) + LOG.debug( + 'Got non-system trigger "%s", but trigger payload validation for non-system' + "triggers is disabled, skipping validation." % (trigger_type_ref) + ) return None - cleaned = util_schema.validate(instance=payload, schema=payload_schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=payload, + schema=payload_schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) return cleaned diff --git a/st2common/st2common/validators/workflow/base.py b/st2common/st2common/validators/workflow/base.py index 226a4668fb..3bf8e9fbd5 100644 --- a/st2common/st2common/validators/workflow/base.py +++ b/st2common/st2common/validators/workflow/base.py @@ -20,7 +20,6 @@ @six.add_metaclass(abc.ABCMeta) class WorkflowValidator(object): - @abc.abstractmethod def validate(self, definition): raise NotImplementedError diff --git a/st2common/tests/fixtures/mock_runner/mock_runner.py b/st2common/tests/fixtures/mock_runner/mock_runner.py index 9110e740f4..66295e8421 100644 --- a/st2common/tests/fixtures/mock_runner/mock_runner.py +++ b/st2common/tests/fixtures/mock_runner/mock_runner.py @@ -23,9 +23,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'get_runner' -] +__all__ = ["get_runner"] def get_runner(): @@ -36,6 +34,7 @@ class MockRunner(ActionRunner): """ Runner which does absolutely nothing. """ + KEYS_TO_TRANSFORM = [] def __init__(self, runner_id): @@ -47,9 +46,9 @@ def pre_run(self): def run(self, action_parameters): result = { - 'failed': False, - 'succeeded': True, - 'return_code': 0, + "failed": False, + "succeeded": True, + "return_code": 0, } status = LIVEACTION_STATUS_SUCCEEDED diff --git a/st2common/tests/fixtures/version_file.py b/st2common/tests/fixtures/version_file.py index 882f420538..b52f01d75c 100644 --- a/st2common/tests/fixtures/version_file.py +++ b/st2common/tests/fixtures/version_file.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '1.2.3' +__version__ = "1.2.3" diff --git a/st2common/tests/integration/test_rabbitmq_ssl_listener.py b/st2common/tests/integration/test_rabbitmq_ssl_listener.py index 9c1ddeef06..e64a22995d 100644 --- a/st2common/tests/integration/test_rabbitmq_ssl_listener.py +++ b/st2common/tests/integration/test_rabbitmq_ssl_listener.py @@ -27,12 +27,10 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'RabbitMQTLSListenerTestCase' -] +__all__ = ["RabbitMQTLSListenerTestCase"] -CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'ssl_certs/') -ST2_CI = (os.environ.get('ST2_CI', 'false').lower() == 'true') +CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "ssl_certs/") +ST2_CI = os.environ.get("ST2_CI", "false").lower() == "true" NON_SSL_LISTENER_PORT = 5672 SSL_LISTENER_PORT = 5671 @@ -40,42 +38,49 @@ # NOTE: We only run those tests on the CI provider because at the moment, local # vagrant dev VM doesn't expose RabbitMQ SSL listener by default -@unittest2.skipIf(not ST2_CI, - 'Skipping tests because ST2_CI environment variable is not set to "true"') +@unittest2.skipIf( + not ST2_CI, + 'Skipping tests because ST2_CI environment variable is not set to "true"', +) class RabbitMQTLSListenerTestCase(unittest2.TestCase): - def setUp(self): # Set default values - cfg.CONF.set_override(name='ssl', override=False, group='messaging') - cfg.CONF.set_override(name='ssl_keyfile', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override=None, group='messaging') + cfg.CONF.set_override(name="ssl", override=False, group="messaging") + cfg.CONF.set_override(name="ssl_keyfile", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_certfile", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_ca_certs", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_cert_reqs", override=None, group="messaging") def test_non_ssl_connection_on_ssl_listener_port_failure(self): - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg_1 = '[Errno 104]' # followed by: ' Connection reset by peer' or ' ECONNRESET' - expected_msg_2 = 'Socket closed' - expected_msg_3 = 'Server unexpectedly closed connection' + expected_msg_1 = ( + "[Errno 104]" # followed by: ' Connection reset by peer' or ' ECONNRESET' + ) + expected_msg_2 = "Socket closed" + expected_msg_3 = "Server unexpectedly closed connection" try: connection.connect() except Exception as e: self.assertFalse(connection.connected) self.assertIsInstance(e, (IOError, socket.error)) - self.assertTrue(expected_msg_1 in six.text_type(e) or - expected_msg_2 in six.text_type(e) or - expected_msg_3 in six.text_type(e)) + self.assertTrue( + expected_msg_1 in six.text_type(e) + or expected_msg_2 in six.text_type(e) + or expected_msg_3 in six.text_type(e) + ) else: - self.fail('Exception was not thrown') + self.fail("Exception was not thrown") if connection: connection.release() def test_ssl_connection_on_ssl_listener_success(self): # Using query param notation - urls = 'amqp://guest:guest@127.0.0.1:5671/?ssl=true' + urls = "amqp://guest:guest@127.0.0.1:5671/?ssl=true" connection = transport_utils.get_connection(urls=urls) try: @@ -86,9 +91,11 @@ def test_ssl_connection_on_ssl_listener_success(self): connection.release() # Using messaging.ssl config option - cfg.CONF.set_override(name='ssl', override=True, group='messaging') + cfg.CONF.set_override(name="ssl", override=True, group="messaging") - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -98,15 +105,21 @@ def test_ssl_connection_on_ssl_listener_success(self): connection.release() def test_ssl_connection_ca_certs_provided(self): - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") - cfg.CONF.set_override(name='ssl', override=True, group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override(name="ssl", override=True, group="messaging") + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) # 1. Validate server cert against a valid CA bundle (success) - cert required - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -117,35 +130,51 @@ def test_ssl_connection_ca_certs_provided(self): # 2. Validate server cert against other CA bundle (failure) # CA bundle which was not used to sign the server cert - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed' + expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) # 3. Validate server cert against other CA bundle (failure) - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='optional', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="optional", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed' + expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) # 4. Validate server cert against other CA bundle (failure) # We use invalid bundle but cert_reqs is none - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='none', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override(name="ssl_cert_reqs", override="none", group="messaging") + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -156,16 +185,28 @@ def test_ssl_connection_ca_certs_provided(self): def test_ssl_connect_client_side_cert_authentication(self): # 1. Success, valid client side cert provided - ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem') - ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'client/client_certificate.pem') - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') - - cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') - - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem") + ssl_certfile = os.path.join( + CERTS_FIXTURES_PATH, "client/client_certificate.pem" + ) + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") + + cfg.CONF.set_override( + name="ssl_keyfile", override=ssl_keyfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_certfile", override=ssl_certfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) + + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -175,16 +216,28 @@ def test_ssl_connect_client_side_cert_authentication(self): connection.release() # 2. Invalid client side cert provided - failure - ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem') - ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'server/server_certificate.pem') - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') - - cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') - - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') - - expected_msg = r'\[X509: KEY_VALUES_MISMATCH\] key values mismatch' + ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem") + ssl_certfile = os.path.join( + CERTS_FIXTURES_PATH, "server/server_certificate.pem" + ) + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") + + cfg.CONF.set_override( + name="ssl_keyfile", override=ssl_keyfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_certfile", override=ssl_certfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) + + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) + + expected_msg = r"\[X509: KEY_VALUES_MISMATCH\] key values mismatch" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) diff --git a/st2common/tests/integration/test_register_content_script.py b/st2common/tests/integration/test_register_content_script.py index 1d7ca955f9..8082a85371 100644 --- a/st2common/tests/integration/test_register_content_script.py +++ b/st2common/tests/integration/test_register_content_script.py @@ -26,15 +26,15 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -SCRIPT_PATH = os.path.join(BASE_DIR, '../../bin/st2-register-content') +SCRIPT_PATH = os.path.join(BASE_DIR, "../../bin/st2-register-content") SCRIPT_PATH = os.path.abspath(SCRIPT_PATH) -BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests.conf', '-v'] -BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ['--register-actions'] +BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, "--config-file=conf/st2.tests.conf", "-v"] +BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ["--register-actions"] PACKS_PATH = get_fixtures_packs_base_path() -PACKS_COUNT = len(glob.glob('%s/*/pack.yaml' % (PACKS_PATH))) -assert(PACKS_COUNT >= 2) +PACKS_COUNT = len(glob.glob("%s/*/pack.yaml" % (PACKS_PATH))) +assert PACKS_COUNT >= 2 class ContentRegisterScriptTestCase(IntegrationTestCase): @@ -43,27 +43,27 @@ def setUp(self): test_config.parse_args() def test_register_from_pack_success(self): - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 1 actions.', stderr) + self.assertIn("Registered 1 actions.", stderr) self.assertEqual(exit_code, 0) def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): # No fail on failure flag, should succeed - pack_dir = 'doesntexistblah' - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = "doesntexistblah" + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), - '--register-no-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), + "--register-no-fail-on-failure", ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, _ = run_command(cmd=cmd) @@ -71,9 +71,9 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): # Fail on failure, should fail opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), - '--register-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), + "--register-fail-on-failure", ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) @@ -82,30 +82,30 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): def test_register_from_pack_action_metadata_fails_validation(self): # No fail on failure flag, should succeed - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4') - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4") + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-no-fail-on-failure', - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-no-fail-on-failure", + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 actions.', stderr) + self.assertIn("Registered 0 actions.", stderr) self.assertEqual(exit_code, 0) # Fail on failure, should fail - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-fail-on-failure', - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-fail-on-failure", + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('object has no attribute \'get\'', stderr) + self.assertIn("object has no attribute 'get'", stderr) self.assertEqual(exit_code, 1) def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self): @@ -114,44 +114,58 @@ def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self): # Note: We want to use a different config which sets fixtures/packs_1/ # dir as packs_base_paths - cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v', - '--register-sensors'] + cmd = [ + sys.executable, + SCRIPT_PATH, + "--config-file=conf/st2.tests1.conf", + "-v", + "--register-sensors", + ] exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 sensors.', stderr, 'Actual stderr: %s' % (stderr)) + self.assertIn("Registered 0 sensors.", stderr, "Actual stderr: %s" % (stderr)) self.assertEqual(exit_code, 0) - cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v', - '--register-all', '--register-no-fail-on-failure'] + cmd = [ + sys.executable, + SCRIPT_PATH, + "--config-file=conf/st2.tests1.conf", + "-v", + "--register-all", + "--register-no-fail-on-failure", + ] exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 actions.', stderr) - self.assertIn('Registered 0 sensors.', stderr) - self.assertIn('Registered 0 rules.', stderr) + self.assertIn("Registered 0 actions.", stderr) + self.assertIn("Registered 0 sensors.", stderr) + self.assertIn("Registered 0 rules.", stderr) self.assertEqual(exit_code, 0) def test_register_all_and_register_setup_virtualenvs(self): # Verify that --register-all works in combinations with --register-setup-virtualenvs # Single pack - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") cmd = BASE_CMD_ARGS + [ - '--register-pack=%s' % (pack_dir), - '--register-all', - '--register-setup-virtualenvs', - '--register-no-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-all", + "--register-setup-virtualenvs", + "--register-no-fail-on-failure", ] exit_code, stdout, stderr = run_command(cmd=cmd) - self.assertIn('Registering actions', stderr, 'Actual stderr: %s' % (stderr)) - self.assertIn('Registering rules', stderr) - self.assertIn('Setup virtualenv for %s pack(s)' % ('1'), stderr) + self.assertIn("Registering actions", stderr, "Actual stderr: %s" % (stderr)) + self.assertIn("Registering rules", stderr) + self.assertIn("Setup virtualenv for %s pack(s)" % ("1"), stderr) self.assertEqual(exit_code, 0) def test_register_setup_virtualenvs(self): # Single pack - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") - cmd = BASE_CMD_ARGS + ['--register-pack=%s' % (pack_dir), '--register-setup-virtualenvs', - '--register-no-fail-on-failure'] + cmd = BASE_CMD_ARGS + [ + "--register-pack=%s" % (pack_dir), + "--register-setup-virtualenvs", + "--register-no-fail-on-failure", + ] exit_code, stdout, stderr = run_command(cmd=cmd) self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr) - self.assertIn('Setup virtualenv for 1 pack(s)', stderr) + self.assertIn("Setup virtualenv for 1 pack(s)", stderr) self.assertEqual(exit_code, 0) diff --git a/st2common/tests/integration/test_service_setup_log_level_filtering.py b/st2common/tests/integration/test_service_setup_log_level_filtering.py index ac3f90deaf..a03e90688a 100644 --- a/st2common/tests/integration/test_service_setup_log_level_filtering.py +++ b/st2common/tests/integration/test_service_setup_log_level_filtering.py @@ -25,36 +25,42 @@ from st2tests.base import IntegrationTestCase from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'ServiceSetupLogLevelFilteringTestCase' -] +__all__ = ["ServiceSetupLogLevelFilteringTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) FIXTURES_DIR = get_fixtures_base_path() -ST2_CONFIG_INFO_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.info_log_level.conf') +ST2_CONFIG_INFO_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.info_log_level.conf" +) ST2_CONFIG_INFO_LL_PATH = os.path.abspath(ST2_CONFIG_INFO_LL_PATH) -ST2_CONFIG_DEBUG_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.debug_log_level.conf') +ST2_CONFIG_DEBUG_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.debug_log_level.conf" +) ST2_CONFIG_DEBUG_LL_PATH = os.path.abspath(ST2_CONFIG_DEBUG_LL_PATH) -ST2_CONFIG_AUDIT_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.audit_log_level.conf') +ST2_CONFIG_AUDIT_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.audit_log_level.conf" +) ST2_CONFIG_AUDIT_LL_PATH = os.path.abspath(ST2_CONFIG_AUDIT_LL_PATH) -ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join(FIXTURES_DIR, - 'conf/st2.tests.api.system_debug_true.conf') +ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.system_debug_true.conf" +) ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.abspath(ST2_CONFIG_SYSTEM_DEBUG_PATH) -ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join(FIXTURES_DIR, - 'conf/st2.tests.api.system_debug_true_logging_debug.conf') +ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.system_debug_true_logging_debug.conf" +) PYTHON_BINARY = sys.executable -ST2API_BINARY = os.path.join(BASE_DIR, '../../../st2api/bin/st2api') +ST2API_BINARY = os.path.join(BASE_DIR, "../../../st2api/bin/st2api") ST2API_BINARY = os.path.abspath(ST2API_BINARY) -CMD = [PYTHON_BINARY, ST2API_BINARY, '--config-file'] +CMD = [PYTHON_BINARY, ST2API_BINARY, "--config-file"] class ServiceSetupLogLevelFilteringTestCase(IntegrationTestCase): @@ -68,11 +74,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertNotIn('DEBUG [-]', stdout) - self.assertNotIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertNotIn("DEBUG [-]", stdout) + self.assertNotIn("AUDIT [-]", stdout) # 2. DEBUG log level - audit messages should be included process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH) @@ -83,11 +89,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) # 3. AUDIT log level - audit messages should be included process = self._start_process(config_path=ST2_CONFIG_AUDIT_LL_PATH) @@ -98,11 +104,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertNotIn('INFO [-]', stdout) - self.assertNotIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertNotIn("INFO [-]", stdout) + self.assertNotIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) # 2. INFO log level but system.debug set to True process = self._start_process(config_path=ST2_CONFIG_SYSTEM_DEBUG_PATH) @@ -113,11 +119,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) def test_kombu_heartbeat_tick_log_messages_are_excluded(self): # 1. system.debug = True config option is set, verify heartbeat_tick message is not logged @@ -128,8 +134,8 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self): eventlet.sleep(5) process.send_signal(signal.SIGKILL) - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')) - self.assertNotIn('heartbeat_tick', stdout) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")) + self.assertNotIn("heartbeat_tick", stdout) # 2. system.debug = False, log level is set to debug process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH) @@ -139,14 +145,19 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self): eventlet.sleep(5) process.send_signal(signal.SIGKILL) - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')) - self.assertNotIn('heartbeat_tick', stdout) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")) + self.assertNotIn("heartbeat_tick", stdout) def _start_process(self, config_path): cmd = CMD + [config_path] - cwd = os.path.abspath(os.path.join(BASE_DIR, '../../../')) + cwd = os.path.abspath(os.path.join(BASE_DIR, "../../../")) cwd = os.path.abspath(cwd) - process = subprocess.Popen(cmd, cwd=cwd, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) return process diff --git a/st2common/tests/unit/base.py b/st2common/tests/unit/base.py index 6a22b139db..65948d1d11 100644 --- a/st2common/tests/unit/base.py +++ b/st2common/tests/unit/base.py @@ -24,13 +24,11 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'BaseDBModelCRUDTestCase', - - 'FakeModel', - 'FakeModelDB', - - 'ChangeRevFakeModel', - 'ChangeRevFakeModelDB' + "BaseDBModelCRUDTestCase", + "FakeModel", + "FakeModelDB", + "ChangeRevFakeModel", + "ChangeRevFakeModelDB", ] @@ -57,19 +55,26 @@ def test_crud_operations(self): self.assertEqual(getattr(retrieved_db, attribute_name), attribute_value) # 2. Test update - updated_attribute_value = 'updated-%s' % (str(time.time())) + updated_attribute_value = "updated-%s" % (str(time.time())) setattr(model_db, self.update_attribute_name, updated_attribute_value) saved_db = self.persistance_class.add_or_update(model_db) - self.assertEqual(getattr(saved_db, self.update_attribute_name), updated_attribute_value) + self.assertEqual( + getattr(saved_db, self.update_attribute_name), updated_attribute_value + ) retrieved_db = self.persistance_class.get_by_id(saved_db.id) self.assertEqual(saved_db.id, retrieved_db.id) - self.assertEqual(getattr(retrieved_db, self.update_attribute_name), updated_attribute_value) + self.assertEqual( + getattr(retrieved_db, self.update_attribute_name), updated_attribute_value + ) # 3. Test delete self.persistance_class.delete(model_db) - self.assertRaises(StackStormDBObjectNotFoundError, self.persistance_class.get_by_id, - model_db.id) + self.assertRaises( + StackStormDBObjectNotFoundError, + self.persistance_class.get_by_id, + model_db.id, + ) class FakeModelDB(stormbase.StormBaseDB): @@ -79,11 +84,11 @@ class FakeModelDB(stormbase.StormBaseDB): timestamp = mongoengine.DateTimeField() meta = { - 'indexes': [ - {'fields': ['index']}, - {'fields': ['category']}, - {'fields': ['timestamp']}, - {'fields': ['context.user']}, + "indexes": [ + {"fields": ["index"]}, + {"fields": ["category"]}, + {"fields": ["timestamp"]}, + {"fields": ["context.user"]}, ] } diff --git a/st2common/tests/unit/services/test_access.py b/st2common/tests/unit/services/test_access.py index 79e680b30d..4f7d8169b4 100644 --- a/st2common/tests/unit/services/test_access.py +++ b/st2common/tests/unit/services/test_access.py @@ -28,11 +28,10 @@ import st2tests.config as tests_config -USERNAME = 'manas' +USERNAME = "manas" class AccessServiceTest(DbTestCase): - @classmethod def setUpClass(cls): super(AccessServiceTest, cls).setUpClass() @@ -47,7 +46,7 @@ def test_create_token(self): def test_create_token_fail(self): try: access.create_token(None) - self.assertTrue(False, 'Create succeeded was expected to fail.') + self.assertTrue(False, "Create succeeded was expected to fail.") except ValueError: self.assertTrue(True) @@ -56,7 +55,7 @@ def test_delete_token(self): access.delete_token(token.token) try: token = Token.get(token.token) - self.assertTrue(False, 'Delete failed was expected to pass.') + self.assertTrue(False, "Delete failed was expected to pass.") except TokenNotFoundError: self.assertTrue(True) @@ -71,13 +70,17 @@ def test_create_token_ttl_ok(self): self.assertIsNotNone(token) self.assertIsNotNone(token.token) self.assertEqual(token.user, USERNAME) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertLess(isotime.parse(token.expiry), expected_expiry) def test_create_token_ttl_capped(self): ttl = cfg.CONF.auth.token_ttl + 10 - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) token = access.create_token(USERNAME, 10) self.assertIsNotNone(token) @@ -86,11 +89,13 @@ def test_create_token_ttl_capped(self): self.assertLess(isotime.parse(token.expiry), expected_expiry) def test_create_token_service_token_can_use_arbitrary_ttl(self): - ttl = (10000 * 24 * 24) + ttl = 10000 * 24 * 24 # Service token should support arbitrary TTL token = access.create_token(USERNAME, ttl=ttl, service=True) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertIsNotNone(token) @@ -98,5 +103,6 @@ def test_create_token_service_token_can_use_arbitrary_ttl(self): self.assertLess(isotime.parse(token.expiry), expected_expiry) # Non service token should throw on TTL which is too large - self.assertRaises(TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, - service=False) + self.assertRaises( + TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, service=False + ) diff --git a/st2common/tests/unit/services/test_action.py b/st2common/tests/unit/services/test_action.py index 7bda929cc0..ab8db72329 100644 --- a/st2common/tests/unit/services/test_action.py +++ b/st2common/tests/unit/services/test_action.py @@ -39,145 +39,126 @@ RUNNER = { - 'name': 'local-shell-script', - 'description': 'A runner to execute local command.', - 'enabled': True, - 'runner_parameters': { - 'hosts': {'type': 'string'}, - 'cmd': {'type': 'string'}, - 'sudo': {'type': 'boolean', 'default': False} + "name": "local-shell-script", + "description": "A runner to execute local command.", + "enabled": True, + "runner_parameters": { + "hosts": {"type": "string"}, + "cmd": {"type": "string"}, + "sudo": {"type": "boolean", "default": False}, }, - 'runner_module': 'remoterunner' + "runner_module": "remoterunner", } RUNNER_ACTION_CHAIN = { - 'name': 'action-chain', - 'description': 'AC runner.', - 'enabled': True, - 'runner_parameters': { - }, - 'runner_module': 'remoterunner' + "name": "action-chain", + "description": "AC runner.", + "enabled": True, + "runner_parameters": {}, + "runner_module": "remoterunner", } ACTION = { - 'name': 'my.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'arg_default_value': { - 'type': 'string', - 'default': 'abc' - }, - 'arg_default_type': { - } + "name": "my.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": { + "arg_default_value": {"type": "string", "default": "abc"}, + "arg_default_type": {}, }, - 'notify': { - 'on-complete': { - 'message': 'My awesome action is complete. Party time!!!', - 'routes': ['notify.slack'] + "notify": { + "on-complete": { + "message": "My awesome action is complete. Party time!!!", + "routes": ["notify.slack"], } - } + }, } ACTION_WORKFLOW = { - 'name': 'my.wf_action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'action-chain' + "name": "my.wf_action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "action-chain", } ACTION_OVR_PARAM = { - 'name': 'my.sudo.default.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'default': True - } - } + "name": "my.sudo.default.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"default": True}}, } ACTION_OVR_PARAM_MUTABLE = { - 'name': 'my.sudo.mutable.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'immutable': False - } - } + "name": "my.sudo.mutable.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"immutable": False}}, } ACTION_OVR_PARAM_IMMUTABLE = { - 'name': 'my.sudo.immutable.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'immutable': True - } - } + "name": "my.sudo.immutable.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"immutable": True}}, } ACTION_OVR_PARAM_BAD_ATTR = { - 'name': 'my.sudo.invalid.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'type': 'number' - } - } + "name": "my.sudo.invalid.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"type": "number"}}, } ACTION_OVR_PARAM_BAD_ATTR_NOOP = { - 'name': 'my.sudo.invalid.noop.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'type': 'boolean' - } - } + "name": "my.sudo.invalid.noop.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"type": "boolean"}}, } -PACK = 'default' -ACTION_REF = ResourceReference(name='my.action', pack=PACK).ref -ACTION_WORKFLOW_REF = ResourceReference(name='my.wf_action', pack=PACK).ref -ACTION_OVR_PARAM_REF = ResourceReference(name='my.sudo.default.action', pack=PACK).ref -ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference(name='my.sudo.mutable.action', pack=PACK).ref -ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference(name='my.sudo.immutable.action', pack=PACK).ref -ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference(name='my.sudo.invalid.action', pack=PACK).ref +PACK = "default" +ACTION_REF = ResourceReference(name="my.action", pack=PACK).ref +ACTION_WORKFLOW_REF = ResourceReference(name="my.wf_action", pack=PACK).ref +ACTION_OVR_PARAM_REF = ResourceReference(name="my.sudo.default.action", pack=PACK).ref +ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference( + name="my.sudo.mutable.action", pack=PACK +).ref +ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference( + name="my.sudo.immutable.action", pack=PACK +).ref +ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference( + name="my.sudo.invalid.action", pack=PACK +).ref ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF = ResourceReference( - name='my.sudo.invalid.noop.action', pack=PACK).ref + name="my.sudo.invalid.noop.action", pack=PACK +).ref -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class TestActionExecutionService(DbTestCase): - @classmethod def setUpClass(cls): super(TestActionExecutionService, cls).setUpClass() @@ -188,17 +169,21 @@ def setUpClass(cls): RunnerType.add_or_update(RunnerTypeAPI.to_model(runner_api)) cls.actions = { - ACTION['name']: ActionAPI(**ACTION), - ACTION_WORKFLOW['name']: ActionAPI(**ACTION_WORKFLOW), - ACTION_OVR_PARAM['name']: ActionAPI(**ACTION_OVR_PARAM), - ACTION_OVR_PARAM_MUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_MUTABLE), - ACTION_OVR_PARAM_IMMUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE), - ACTION_OVR_PARAM_BAD_ATTR['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR), - ACTION_OVR_PARAM_BAD_ATTR_NOOP['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR_NOOP) + ACTION["name"]: ActionAPI(**ACTION), + ACTION_WORKFLOW["name"]: ActionAPI(**ACTION_WORKFLOW), + ACTION_OVR_PARAM["name"]: ActionAPI(**ACTION_OVR_PARAM), + ACTION_OVR_PARAM_MUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_MUTABLE), + ACTION_OVR_PARAM_IMMUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE), + ACTION_OVR_PARAM_BAD_ATTR["name"]: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR), + ACTION_OVR_PARAM_BAD_ATTR_NOOP["name"]: ActionAPI( + **ACTION_OVR_PARAM_BAD_ATTR_NOOP + ), } - cls.actiondbs = {name: Action.add_or_update(ActionAPI.to_model(action)) - for name, action in six.iteritems(cls.actions)} + cls.actiondbs = { + name: Action.add_or_update(ActionAPI.to_model(action)) + for name, action in six.iteritems(cls.actions) + } cls.container = RunnerContainer() @@ -212,8 +197,8 @@ def tearDownClass(cls): super(TestActionExecutionService, cls).tearDownClass() def _submit_request(self, action_ref=ACTION_REF): - context = {'user': USERNAME} - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + context = {"user": USERNAME} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=action_ref, context=context, parameters=parameters) req, _ = action_service.request(req) ex = action_db.get_liveaction_by_id(str(req.id)) @@ -249,7 +234,7 @@ def _create_nested_executions(self, depth=2): root_liveaction_db = LiveAction.add_or_update(root_liveaction_db) root_ex = executions.create_execution_object(root_liveaction_db) - last_id = root_ex['id'] + last_id = root_ex["id"] # Create children to the specified depth for i in range(depth): @@ -264,11 +249,7 @@ def _create_nested_executions(self, depth=2): child_liveaction_db = LiveActionDB() child_liveaction_db.status = action_constants.LIVEACTION_STATUS_PAUSED child_liveaction_db.action = action - child_liveaction_db.context = { - "parent": { - "execution_id": last_id - } - } + child_liveaction_db.context = {"parent": {"execution_id": last_id}} child_liveaction_db = LiveAction.add_or_update(child_liveaction_db) parent_ex = executions.create_execution_object(child_liveaction_db) last_id = parent_ex.id @@ -277,104 +258,116 @@ def _create_nested_executions(self, depth=2): return (child_liveaction_db, root_liveaction_db) def test_req_non_workflow_action(self): - actiondb = self.actiondbs[ACTION['name']] + actiondb = self.actiondbs[ACTION["name"]] req, ex = self._submit_request(action_ref=ACTION_REF) self.assertIsNotNone(ex) self.assertEqual(ex.action_is_workflow, False) self.assertEqual(ex.id, req.id) - self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name])) - self.assertEqual(ex.context['user'], req.context['user']) + self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name])) + self.assertEqual(ex.context["user"], req.context["user"]) self.assertDictEqual(ex.parameters, req.parameters) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) self.assertIsNotNone(ex.notify) # mongoengine DateTimeField stores datetime only up to milliseconds - self.assertEqual(isotime.format(ex.start_timestamp, usec=False), - isotime.format(req.start_timestamp, usec=False)) + self.assertEqual( + isotime.format(ex.start_timestamp, usec=False), + isotime.format(req.start_timestamp, usec=False), + ) def test_req_workflow_action(self): - actiondb = self.actiondbs[ACTION_WORKFLOW['name']] + actiondb = self.actiondbs[ACTION_WORKFLOW["name"]] req, ex = self._submit_request(action_ref=ACTION_WORKFLOW_REF) self.assertIsNotNone(ex) self.assertEqual(ex.action_is_workflow, True) self.assertEqual(ex.id, req.id) - self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name])) - self.assertEqual(ex.context['user'], req.context['user']) + self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name])) + self.assertEqual(ex.context["user"], req.context["user"]) self.assertDictEqual(ex.parameters, req.parameters) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) def test_req_invalid_parameters(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': 123} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_value": 123} liveaction = LiveActionDB(action=ACTION_REF, parameters=parameters) - self.assertRaises(jsonschema.ValidationError, action_service.request, liveaction) + self.assertRaises( + jsonschema.ValidationError, action_service.request, liveaction + ) def test_req_optional_parameter_none_value(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': None} + parameters = { + "hosts": "127.0.0.1", + "cmd": "uname -a", + "arg_default_value": None, + } req = LiveActionDB(action=ACTION_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_optional_parameter_none_value_no_default(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_type': None} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_type": None} req = LiveActionDB(action=ACTION_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': False} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": False} req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter_type_attribute_value_changed(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_REF, parameters=parameters) with self.assertRaises(action_exc.InvalidActionParameterException) as ex_ctx: req, _ = action_service.request(req) - expected = ('The attribute "type" for the runner parameter "sudo" in ' - 'action "default.my.sudo.invalid.action" cannot be overridden.') + expected = ( + 'The attribute "type" for the runner parameter "sudo" in ' + 'action "default.my.sudo.invalid.action" cannot be overridden.' + ) self.assertEqual(str(ex_ctx.exception), expected) def test_req_override_runner_parameter_type_attribute_no_value_changed(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} - req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters) + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} + req = LiveActionDB( + action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters + ) req, _ = action_service.request(req) def test_req_override_runner_parameter_mutable(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True} req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter_immutable(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True} req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters) self.assertRaises(ValueError, action_service.request, req) def test_req_nonexistent_action(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} - action_ref = ResourceReference(name='i.action', pack='default').ref + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} + action_ref = ResourceReference(name="i.action", pack="default").ref ex = LiveActionDB(action=action_ref, parameters=parameters) self.assertRaises(ValueError, action_service.request, ex) def test_req_disabled_action(self): - actiondb = self.actiondbs[ACTION['name']] + actiondb = self.actiondbs[ACTION["name"]] actiondb.enabled = False Action.add_or_update(actiondb) try: - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} ex = LiveActionDB(action=ACTION_REF, parameters=parameters) self.assertRaises(ValueError, action_service.request, ex) except Exception as e: @@ -390,7 +383,9 @@ def test_req_cancellation(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -405,7 +400,9 @@ def test_req_cancellation_uncancelable_state(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to FAILED. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_FAILED, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_FAILED, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_FAILED) @@ -429,20 +426,20 @@ def test_req_pause_unsupported(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request pause. self.assertRaises( - runner_exc.InvalidActionRunnerOperationError, - self._submit_pause, - ex + runner_exc.InvalidActionRunnerOperationError, self._submit_pause, ex ) def test_req_pause(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -451,7 +448,9 @@ def test_req_pause(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -459,11 +458,11 @@ def test_req_pause(self): ex = self._submit_pause(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_pause_not_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -473,16 +472,14 @@ def test_req_pause_not_running(self): # Request pause. self.assertRaises( - runner_exc.UnexpectedActionExecutionStatusError, - self._submit_pause, - ex + runner_exc.UnexpectedActionExecutionStatusError, self._submit_pause, ex ) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_pause_already_pausing(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -491,7 +488,9 @@ def test_req_pause_already_pausing(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -500,12 +499,14 @@ def test_req_pause_already_pausing(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) # Request pause again. - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: ex = self._submit_pause(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) mocked.assert_not_called() finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_unsupported(self): req, ex = self._submit_request() @@ -514,20 +515,20 @@ def test_req_resume_unsupported(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request resume. self.assertRaises( - runner_exc.InvalidActionRunnerOperationError, - self._submit_resume, - ex + runner_exc.InvalidActionRunnerOperationError, self._submit_resume, ex ) def test_req_resume(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -536,7 +537,9 @@ def test_req_resume(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -545,7 +548,9 @@ def test_req_resume(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) # Update ex status to PAUSED. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_PAUSED, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_PAUSED, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSED) @@ -553,11 +558,11 @@ def test_req_resume(self): ex = self._submit_resume(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RESUMING) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_not_paused(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -566,7 +571,9 @@ def test_req_resume_not_paused(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -576,16 +583,14 @@ def test_req_resume_not_paused(self): # Request resume. self.assertRaises( - runner_exc.UnexpectedActionExecutionStatusError, - self._submit_resume, - ex + runner_exc.UnexpectedActionExecutionStatusError, self._submit_resume, ex ) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_already_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -594,25 +599,28 @@ def test_req_resume_already_running(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request resume. - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: ex = self._submit_resume(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) mocked.assert_not_called() finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_root_liveaction(self): - """Test that get_root_liveaction correctly retrieves the root liveaction - """ + """Test that get_root_liveaction correctly retrieves the root liveaction""" # Test a variety of depths for i in range(1, 7): child, expected_root = self._create_nested_executions(depth=i) actual_root = action_service.get_root_liveaction(child) - self.assertEqual(expected_root['id'], actual_root['id']) + self.assertEqual(expected_root["id"], actual_root["id"]) diff --git a/st2common/tests/unit/services/test_keyvalue.py b/st2common/tests/unit/services/test_keyvalue.py index a11a3bb11b..bd080719bb 100644 --- a/st2common/tests/unit/services/test_keyvalue.py +++ b/st2common/tests/unit/services/test_keyvalue.py @@ -22,17 +22,22 @@ class KeyValueServicesTest(unittest2.TestCase): - def test_get_key_reference_system_scope(self): - ref = get_key_reference(scope=SYSTEM_SCOPE, name='foo') - self.assertEqual(ref, 'foo') + ref = get_key_reference(scope=SYSTEM_SCOPE, name="foo") + self.assertEqual(ref, "foo") def test_get_key_reference_user_scope(self): - ref = get_key_reference(scope=USER_SCOPE, name='foo', user='stanley') - self.assertEqual(ref, 'stanley:foo') - self.assertRaises(InvalidUserException, get_key_reference, - scope=USER_SCOPE, name='foo', user='') + ref = get_key_reference(scope=USER_SCOPE, name="foo", user="stanley") + self.assertEqual(ref, "stanley:foo") + self.assertRaises( + InvalidUserException, + get_key_reference, + scope=USER_SCOPE, + name="foo", + user="", + ) def test_get_key_reference_invalid_scope_raises_exception(self): - self.assertRaises(InvalidScopeException, get_key_reference, - scope='sketchy', name='foo') + self.assertRaises( + InvalidScopeException, get_key_reference, scope="sketchy", name="foo" + ) diff --git a/st2common/tests/unit/services/test_policy.py b/st2common/tests/unit/services/test_policy.py index 69fb0624e6..128ce1defe 100644 --- a/st2common/tests/unit/services/test_policy.py +++ b/st2common/tests/unit/services/test_policy.py @@ -16,6 +16,7 @@ from __future__ import absolute_import import st2tests.config as tests_config + tests_config.parse_args() import st2common @@ -32,23 +33,22 @@ from st2tests import fixturesloader as fixtures -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', # wolfpack.action-1 - 'action2.yaml', # wolfpack.action-2 - 'local.yaml' # core.local + "actions": [ + "action1.yaml", # wolfpack.action-1 + "action2.yaml", # wolfpack.action-2 + "local.yaml", # core.local + ], + "policies": [ + "policy_2.yaml", # mock policy on wolfpack.action-1 + "policy_5.yaml", # concurrency policy on wolfpack.action-2 ], - 'policies': [ - 'policy_2.yaml', # mock policy on wolfpack.action-1 - 'policy_5.yaml' # concurrency policy on wolfpack.action-2 - ] } class PolicyServiceTestCase(st2tests.DbTestCase): - @classmethod def setUpClass(cls): super(PolicyServiceTestCase, cls).setUpClass() @@ -60,28 +60,39 @@ def setUpClass(cls): policies_registrar.register_policy_types(st2common) loader = fixtures.FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def setUp(self): super(PolicyServiceTestCase, self).setUp() - params = {'action': 'wolfpack.action-1', 'parameters': {'actionstr': 'foo-last'}} + params = { + "action": "wolfpack.action-1", + "parameters": {"actionstr": "foo-last"}, + } self.lv_ac_db_1 = action_db_models.LiveActionDB(**params) self.lv_ac_db_1, _ = action_service.request(self.lv_ac_db_1) - params = {'action': 'wolfpack.action-2', 'parameters': {'actionstr': 'foo-last'}} + params = { + "action": "wolfpack.action-2", + "parameters": {"actionstr": "foo-last"}, + } self.lv_ac_db_2 = action_db_models.LiveActionDB(**params) self.lv_ac_db_2, _ = action_service.request(self.lv_ac_db_2) - params = {'action': 'core.local', 'parameters': {'cmd': 'date'}} + params = {"action": "core.local", "parameters": {"cmd": "date"}} self.lv_ac_db_3 = action_db_models.LiveActionDB(**params) self.lv_ac_db_3, _ = action_service.request(self.lv_ac_db_3) def tearDown(self): - action_service.update_status(self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED) - action_service.update_status(self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED) - action_service.update_status(self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED + ) + action_service.update_status( + self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED + ) + action_service.update_status( + self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED + ) def test_action_has_policies(self): self.assertTrue(policy_service.has_policies(self.lv_ac_db_1)) @@ -93,7 +104,7 @@ def test_action_has_specific_policies(self): self.assertTrue( policy_service.has_policies( self.lv_ac_db_2, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK, ) ) @@ -101,6 +112,6 @@ def test_action_does_not_have_specific_policies(self): self.assertFalse( policy_service.has_policies( self.lv_ac_db_1, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK, ) ) diff --git a/st2common/tests/unit/services/test_synchronization.py b/st2common/tests/unit/services/test_synchronization.py index 86cf36042f..991e6b9036 100644 --- a/st2common/tests/unit/services/test_synchronization.py +++ b/st2common/tests/unit/services/test_synchronization.py @@ -39,13 +39,15 @@ def tearDownClass(cls): super(SynchronizationTest, cls).tearDownClass() def test_service_configured(self): - cfg.CONF.set_override(name='url', override='kazoo://127.0.0.1:2181', group='coordination') + cfg.CONF.set_override( + name="url", override="kazoo://127.0.0.1:2181", group="coordination" + ) self.assertTrue(coordination.configured()) - cfg.CONF.set_override(name='url', override='file:///tmp', group='coordination') + cfg.CONF.set_override(name="url", override="file:///tmp", group="coordination") self.assertFalse(coordination.configured()) - cfg.CONF.set_override(name='url', override='zake://', group='coordination') + cfg.CONF.set_override(name="url", override="zake://", group="coordination") self.assertFalse(coordination.configured()) def test_lock(self): diff --git a/st2common/tests/unit/services/test_trace.py b/st2common/tests/unit/services/test_trace.py index 06c9260586..807dc4251d 100644 --- a/st2common/tests/unit/services/test_trace.py +++ b/st2common/tests/unit/services/test_trace.py @@ -30,33 +30,37 @@ from st2tests import DbTestCase -FIXTURES_PACK = 'traces' - -TEST_MODELS = OrderedDict(( - ('executions', [ - 'traceable_execution.yaml', - 'rule_fired_execution.yaml', - 'execution_with_parent.yaml' - ]), - ('liveactions', [ - 'traceable_liveaction.yaml', - 'liveaction_with_parent.yaml' - ]), - ('traces', [ - 'trace_empty.yaml', - 'trace_multiple_components.yaml', - 'trace_one_each.yaml', - 'trace_one_each_dup.yaml', - 'trace_execution.yaml' - ]), - ('triggers', ['trigger1.yaml']), - ('triggerinstances', [ - 'action_trigger.yaml', - 'notify_trigger.yaml', - 'non_internal_trigger.yaml' - ]), - ('rules', ['rule1.yaml']), -)) +FIXTURES_PACK = "traces" + +TEST_MODELS = OrderedDict( + ( + ( + "executions", + [ + "traceable_execution.yaml", + "rule_fired_execution.yaml", + "execution_with_parent.yaml", + ], + ), + ("liveactions", ["traceable_liveaction.yaml", "liveaction_with_parent.yaml"]), + ( + "traces", + [ + "trace_empty.yaml", + "trace_multiple_components.yaml", + "trace_one_each.yaml", + "trace_one_each_dup.yaml", + "trace_execution.yaml", + ], + ), + ("triggers", ["trigger1.yaml"]), + ( + "triggerinstances", + ["action_trigger.yaml", "notify_trigger.yaml", "non_internal_trigger.yaml"], + ), + ("rules", ["rule1.yaml"]), + ) +) class DummyComponent(object): @@ -78,139 +82,184 @@ class TestTraceService(DbTestCase): @classmethod def setUpClass(cls): super(TestTraceService, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.trace1 = cls.models['traces']['trace_multiple_components.yaml'] - cls.trace2 = cls.models['traces']['trace_one_each.yaml'] - cls.trace3 = cls.models['traces']['trace_one_each_dup.yaml'] - cls.trace_empty = cls.models['traces']['trace_empty.yaml'] - cls.trace_execution = cls.models['traces']['trace_execution.yaml'] - - cls.action_trigger = cls.models['triggerinstances']['action_trigger.yaml'] - cls.notify_trigger = cls.models['triggerinstances']['notify_trigger.yaml'] - cls.non_internal_trigger = cls.models['triggerinstances']['non_internal_trigger.yaml'] - - cls.rule1 = cls.models['rules']['rule1.yaml'] - - cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml'] - cls.liveaction_with_parent = cls.models['liveactions']['liveaction_with_parent.yaml'] - cls.traceable_execution = cls.models['executions']['traceable_execution.yaml'] - cls.rule_fired_execution = cls.models['executions']['rule_fired_execution.yaml'] - cls.execution_with_parent = cls.models['executions']['execution_with_parent.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.trace1 = cls.models["traces"]["trace_multiple_components.yaml"] + cls.trace2 = cls.models["traces"]["trace_one_each.yaml"] + cls.trace3 = cls.models["traces"]["trace_one_each_dup.yaml"] + cls.trace_empty = cls.models["traces"]["trace_empty.yaml"] + cls.trace_execution = cls.models["traces"]["trace_execution.yaml"] + + cls.action_trigger = cls.models["triggerinstances"]["action_trigger.yaml"] + cls.notify_trigger = cls.models["triggerinstances"]["notify_trigger.yaml"] + cls.non_internal_trigger = cls.models["triggerinstances"][ + "non_internal_trigger.yaml" + ] + + cls.rule1 = cls.models["rules"]["rule1.yaml"] + + cls.traceable_liveaction = cls.models["liveactions"][ + "traceable_liveaction.yaml" + ] + cls.liveaction_with_parent = cls.models["liveactions"][ + "liveaction_with_parent.yaml" + ] + cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"] + cls.rule_fired_execution = cls.models["executions"]["rule_fired_execution.yaml"] + cls.execution_with_parent = cls.models["executions"][ + "execution_with_parent.yaml" + ] def test_get_trace_db_by_action_execution(self): - action_execution = DummyComponent(id_=self.trace1.action_executions[0].object_id) - trace_db = trace_service.get_trace_db_by_action_execution(action_execution=action_execution) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + action_execution = DummyComponent( + id_=self.trace1.action_executions[0].object_id + ) + trace_db = trace_service.get_trace_db_by_action_execution( + action_execution=action_execution + ) + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_db_by_action_execution_fail(self): - action_execution = DummyComponent(id_=self.trace2.action_executions[0].object_id) - self.assertRaises(UniqueTraceNotFoundException, - trace_service.get_trace_db_by_action_execution, - **{'action_execution': action_execution}) + action_execution = DummyComponent( + id_=self.trace2.action_executions[0].object_id + ) + self.assertRaises( + UniqueTraceNotFoundException, + trace_service.get_trace_db_by_action_execution, + **{"action_execution": action_execution}, + ) def test_get_trace_db_by_rule(self): rule = DummyComponent(id_=self.trace1.rules[0].object_id) trace_dbs = trace_service.get_trace_db_by_rule(rule=rule) - self.assertEqual(len(trace_dbs), 1, 'Expected 1 trace_db.') - self.assertEqual(trace_dbs[0].id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(len(trace_dbs), 1, "Expected 1 trace_db.") + self.assertEqual( + trace_dbs[0].id, self.trace1.id, "Incorrect trace_db returned." + ) def test_get_multiple_trace_db_by_rule(self): rule = DummyComponent(id_=self.trace2.rules[0].object_id) trace_dbs = trace_service.get_trace_db_by_rule(rule=rule) - self.assertEqual(len(trace_dbs), 2, 'Expected 2 trace_db.') + self.assertEqual(len(trace_dbs), 2, "Expected 2 trace_db.") result = [trace_db.id for trace_db in trace_dbs] - self.assertEqual(result, [self.trace2.id, self.trace3.id], 'Incorrect trace_dbs returned.') + self.assertEqual( + result, [self.trace2.id, self.trace3.id], "Incorrect trace_dbs returned." + ) def test_get_trace_db_by_trigger_instance(self): - trigger_instance = DummyComponent(id_=self.trace1.trigger_instances[0].object_id) - trace_db = trace_service.get_trace_db_by_trigger_instance(trigger_instance=trigger_instance) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + trigger_instance = DummyComponent( + id_=self.trace1.trigger_instances[0].object_id + ) + trace_db = trace_service.get_trace_db_by_trigger_instance( + trigger_instance=trigger_instance + ) + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_db_by_trigger_instance_fail(self): - trigger_instance = DummyComponent(id_=self.trace2.trigger_instances[0].object_id) - self.assertRaises(UniqueTraceNotFoundException, - trace_service.get_trace_db_by_trigger_instance, - **{'trigger_instance': trigger_instance}) + trigger_instance = DummyComponent( + id_=self.trace2.trigger_instances[0].object_id + ) + self.assertRaises( + UniqueTraceNotFoundException, + trace_service.get_trace_db_by_trigger_instance, + **{"trigger_instance": trigger_instance}, + ) def test_get_trace_by_dict(self): - trace_context = {'id_': str(self.trace1.id)} + trace_context = {"id_": str(self.trace1.id)} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = {'id_': str(bson.ObjectId())} - self.assertRaises(StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context) + trace_context = {"id_": str(bson.ObjectId())} + self.assertRaises( + StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context + ) - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_by_trace_context(self): - trace_context = TraceContext(**{'id_': str(self.trace1.id)}) + trace_context = TraceContext(**{"id_": str(self.trace1.id)}) trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = TraceContext(**{'trace_tag': self.trace1.trace_tag}) + trace_context = TraceContext(**{"trace_tag": self.trace1.trace_tag}) trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_ignore_trace_tag(self): - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context, ignore_trace_tag=True) - self.assertEqual(trace_db, None, 'Should be None.') + self.assertEqual(trace_db, None, "Should be None.") def test_get_trace_fail_empty_context(self): trace_context = {} self.assertRaises(ValueError, trace_service.get_trace, trace_context) def test_get_trace_fail_multi_match(self): - trace_context = {'trace_tag': self.trace2.trace_tag} - self.assertRaises(UniqueTraceNotFoundException, trace_service.get_trace, trace_context) + trace_context = {"trace_tag": self.trace2.trace_tag} + self.assertRaises( + UniqueTraceNotFoundException, trace_service.get_trace, trace_context + ) def test_get_trace_db_by_live_action_valid_id_context(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['trace_context'] = {'id_': str(self.trace_execution.id)} - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + traceable_liveaction.context["trace_context"] = { + "id_": str(self.trace_execution.id) + } + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace_execution.id) def test_get_trace_db_by_live_action_trace_tag_context(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['trace_context'] = { - 'trace_tag': str(self.trace_execution.trace_tag) + traceable_liveaction.context["trace_context"] = { + "trace_tag": str(self.trace_execution.trace_tag) } - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertTrue(created) - self.assertEqual(trace_db.id, None, 'Expected to be None') + self.assertEqual(trace_db.id, None, "Expected to be None") self.assertEqual(trace_db.trace_tag, str(self.trace_execution.trace_tag)) def test_get_trace_db_by_live_action_parent(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['parent'] = { - 'execution_id': str(self.trace1.action_executions[0].object_id) + traceable_liveaction.context["parent"] = { + "execution_id": str(self.trace1.action_executions[0].object_id) } - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace1.id) def test_get_trace_db_by_live_action_parent_fail(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['parent'] = { - 'execution_id': str(bson.ObjectId()) - } - self.assertRaises(StackStormDBObjectNotFoundError, - trace_service.get_trace_db_by_live_action, - traceable_liveaction) + traceable_liveaction.context["parent"] = {"execution_id": str(bson.ObjectId())} + self.assertRaises( + StackStormDBObjectNotFoundError, + trace_service.get_trace_db_by_live_action, + traceable_liveaction, + ) def test_get_trace_db_by_live_action_from_execution(self): traceable_liveaction = copy.copy(self.traceable_liveaction) # fixtures id value in liveaction is not persisted in DB. - traceable_liveaction.id = bson.ObjectId(self.traceable_execution.liveaction['id']) - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + traceable_liveaction.id = bson.ObjectId( + self.traceable_execution.liveaction["id"] + ) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace_execution.id) @@ -218,76 +267,119 @@ def test_get_trace_db_by_live_action_new_trace(self): traceable_liveaction = copy.copy(self.traceable_liveaction) # a liveaction without any associated ActionExecution traceable_liveaction.id = bson.ObjectId() - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertTrue(created) - self.assertEqual(trace_db.id, None, 'Should be None.') + self.assertEqual(trace_db.id, None, "Should be None.") def test_add_or_update_given_trace_context(self): - trace_context = {'id_': str(self.trace_empty.id)} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"id_": str(self.trace_empty.id)} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" trace_service.add_or_update_given_trace_context( trace_context, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) retrieved_trace_db = Trace.get_by_id(self.trace_empty.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) Trace.add_or_update(self.trace_empty) def test_add_or_update_given_trace_db(self): - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" to_save = copy.copy(self.trace_empty) to_save.id = None saved = trace_service.add_or_update_given_trace_db( to_save, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) retrieved_trace_db = Trace.get_by_id(saved.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) # Now add more TraceComponents and validated that they are added properly. saved = trace_service.add_or_update_given_trace_db( retrieved_trace_db, action_executions=[str(bson.ObjectId()), str(bson.ObjectId())], rules=[str(bson.ObjectId())], - trigger_instances=[str(bson.ObjectId()), str(bson.ObjectId()), str(bson.ObjectId())]) + trigger_instances=[ + str(bson.ObjectId()), + str(bson.ObjectId()), + str(bson.ObjectId()), + ], + ) retrieved_trace_db = Trace.get_by_id(saved.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 3, - 'Expected updated action_executions.') - self.assertEqual(len(retrieved_trace_db.rules), 2, 'Expected updated rules.') - self.assertEqual(len(retrieved_trace_db.trigger_instances), 4, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 3, + "Expected updated action_executions.", + ) + self.assertEqual(len(retrieved_trace_db.rules), 2, "Expected updated rules.") + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 4, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) @@ -295,179 +387,238 @@ def test_add_or_update_given_trace_db_fail(self): self.assertRaises(ValueError, trace_service.add_or_update_given_trace_db, None) def test_add_or_update_given_trace_context_new(self): - trace_context = {'trace_tag': 'awesome_test_trace'} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"trace_tag": "awesome_test_trace"} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" pre_add_or_update_traces = len(Trace.get_all()) trace_db = trace_service.add_or_update_given_trace_context( trace_context, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) post_add_or_update_traces = len(Trace.get_all()) - self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces, - 'Expected new Trace to be created.') + self.assertTrue( + post_add_or_update_traces > pre_add_or_update_traces, + "Expected new Trace to be created.", + ) retrieved_trace_db = Trace.get_by_id(trace_db.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) def test_add_or_update_given_trace_context_new_with_causals(self): - trace_context = {'trace_tag': 'causal_test_trace'} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"trace_tag": "causal_test_trace"} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" pre_add_or_update_traces = len(Trace.get_all()) trace_db = trace_service.add_or_update_given_trace_context( trace_context, - action_executions=[{'id': action_execution_id, - 'caused_by': {'id': '%s:%s' % (rule_id, trigger_instance_id), - 'type': 'rule'}}], - rules=[{'id': rule_id, - 'caused_by': {'id': trigger_instance_id, 'type': 'trigger-instance'}}], - trigger_instances=[trigger_instance_id]) + action_executions=[ + { + "id": action_execution_id, + "caused_by": { + "id": "%s:%s" % (rule_id, trigger_instance_id), + "type": "rule", + }, + } + ], + rules=[ + { + "id": rule_id, + "caused_by": { + "id": trigger_instance_id, + "type": "trigger-instance", + }, + } + ], + trigger_instances=[trigger_instance_id], + ) post_add_or_update_traces = len(Trace.get_all()) - self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces, - 'Expected new Trace to be created.') + self.assertTrue( + post_add_or_update_traces > pre_add_or_update_traces, + "Expected new Trace to be created.", + ) retrieved_trace_db = Trace.get_by_id(trace_db.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].caused_by, - {'id': '%s:%s' % (rule_id, trigger_instance_id), - 'type': 'rule'}, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].caused_by, - {'id': trigger_instance_id, 'type': 'trigger-instance'}, - 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].caused_by, {}, - 'Expected updated rules.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].caused_by, + {"id": "%s:%s" % (rule_id, trigger_instance_id), "type": "rule"}, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + self.assertEqual( + retrieved_trace_db.rules[0].caused_by, + {"id": trigger_instance_id, "type": "trigger-instance"}, + "Expected updated rules.", + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].caused_by, + {}, + "Expected updated rules.", + ) Trace.delete(retrieved_trace_db) def test_trace_component_for_trigger_instance(self): # action_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.action_trigger) + self.action_trigger + ) expected = { - 'id': str(self.action_trigger.id), - 'ref': self.action_trigger.trigger, - 'caused_by': { - 'type': 'action_execution', - 'id': self.action_trigger.payload['execution_id'] - } + "id": str(self.action_trigger.id), + "ref": self.action_trigger.trigger, + "caused_by": { + "type": "action_execution", + "id": self.action_trigger.payload["execution_id"], + }, } self.assertEqual(trace_component, expected) # notify_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.notify_trigger) + self.notify_trigger + ) expected = { - 'id': str(self.notify_trigger.id), - 'ref': self.notify_trigger.trigger, - 'caused_by': { - 'type': 'action_execution', - 'id': self.notify_trigger.payload['execution_id'] - } + "id": str(self.notify_trigger.id), + "ref": self.notify_trigger.trigger, + "caused_by": { + "type": "action_execution", + "id": self.notify_trigger.payload["execution_id"], + }, } self.assertEqual(trace_component, expected) # non_internal_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.non_internal_trigger) + self.non_internal_trigger + ) expected = { - 'id': str(self.non_internal_trigger.id), - 'ref': self.non_internal_trigger.trigger, - 'caused_by': {} + "id": str(self.non_internal_trigger.id), + "ref": self.non_internal_trigger.trigger, + "caused_by": {}, } self.assertEqual(trace_component, expected) def test_trace_component_for_rule(self): - trace_component = trace_service.get_trace_component_for_rule(self.rule1, - self.non_internal_trigger) + trace_component = trace_service.get_trace_component_for_rule( + self.rule1, self.non_internal_trigger + ) expected = { - 'id': str(self.rule1.id), - 'ref': self.rule1.ref, - 'caused_by': { - 'type': 'trigger_instance', - 'id': str(self.non_internal_trigger.id) - } + "id": str(self.rule1.id), + "ref": self.rule1.ref, + "caused_by": { + "type": "trigger_instance", + "id": str(self.non_internal_trigger.id), + }, } self.assertEqual(trace_component, expected) def test_trace_component_for_action_execution(self): # no cause trace_component = trace_service.get_trace_component_for_action_execution( - self.traceable_execution, - self.traceable_liveaction) + self.traceable_execution, self.traceable_liveaction + ) expected = { - 'id': str(self.traceable_execution.id), - 'ref': self.traceable_execution.action['ref'], - 'caused_by': {} + "id": str(self.traceable_execution.id), + "ref": self.traceable_execution.action["ref"], + "caused_by": {}, } self.assertEqual(trace_component, expected) # rule_fired_execution trace_component = trace_service.get_trace_component_for_action_execution( - self.rule_fired_execution, - self.traceable_liveaction) + self.rule_fired_execution, self.traceable_liveaction + ) expected = { - 'id': str(self.rule_fired_execution.id), - 'ref': self.rule_fired_execution.action['ref'], - 'caused_by': { - 'type': 'rule', - 'id': '%s:%s' % (self.rule_fired_execution.rule['id'], - self.rule_fired_execution.trigger_instance['id']) - } + "id": str(self.rule_fired_execution.id), + "ref": self.rule_fired_execution.action["ref"], + "caused_by": { + "type": "rule", + "id": "%s:%s" + % ( + self.rule_fired_execution.rule["id"], + self.rule_fired_execution.trigger_instance["id"], + ), + }, } self.assertEqual(trace_component, expected) # execution_with_parent trace_component = trace_service.get_trace_component_for_action_execution( - self.execution_with_parent, - self.liveaction_with_parent) + self.execution_with_parent, self.liveaction_with_parent + ) expected = { - 'id': str(self.execution_with_parent.id), - 'ref': self.execution_with_parent.action['ref'], - 'caused_by': { - 'type': 'action_execution', - 'id': self.liveaction_with_parent.context['parent']['execution_id'] - } + "id": str(self.execution_with_parent.id), + "ref": self.execution_with_parent.action["ref"], + "caused_by": { + "type": "action_execution", + "id": self.liveaction_with_parent.context["parent"]["execution_id"], + }, } self.assertEqual(trace_component, expected) class TestTraceContext(TestCase): - def test_str_method(self): - trace_context = TraceContext(id_='id', trace_tag='tag') + trace_context = TraceContext(id_="id", trace_tag="tag") self.assertTrue(str(trace_context)) - trace_context = TraceContext(trace_tag='tag') + trace_context = TraceContext(trace_tag="tag") self.assertTrue(str(trace_context)) - trace_context = TraceContext(id_='id') + trace_context = TraceContext(id_="id") self.assertTrue(str(trace_context)) diff --git a/st2common/tests/unit/services/test_trace_injection_action_services.py b/st2common/tests/unit/services/test_trace_injection_action_services.py index 8f9570d0e2..4b4fe0d177 100644 --- a/st2common/tests/unit/services/test_trace_injection_action_services.py +++ b/st2common/tests/unit/services/test_trace_injection_action_services.py @@ -21,13 +21,13 @@ from st2tests.fixturesloader import FixturesLoader from st2tests import DbTestCase -FIXTURES_PACK = 'traces' +FIXTURES_PACK = "traces" TEST_MODELS = { - 'executions': ['traceable_execution.yaml'], - 'liveactions': ['traceable_liveaction.yaml'], - 'actions': ['chain1.yaml'], - 'runners': ['actionchain.yaml'] + "executions": ["traceable_execution.yaml"], + "liveactions": ["traceable_liveaction.yaml"], + "actions": ["chain1.yaml"], + "runners": ["actionchain.yaml"], } @@ -41,44 +41,52 @@ class TraceInjectionTests(DbTestCase): @classmethod def setUpClass(cls): super(TraceInjectionTests, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) - cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml'] - cls.traceable_execution = cls.models['executions']['traceable_execution.yaml'] - cls.action = cls.models['actions']['chain1.yaml'] + cls.traceable_liveaction = cls.models["liveactions"][ + "traceable_liveaction.yaml" + ] + cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"] + cls.action = cls.models["actions"]["chain1.yaml"] def test_trace_provided(self): - self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'OohLaLaLa'} + self.traceable_liveaction["context"]["trace_context"] = { + "trace_tag": "OohLaLaLa" + } action_services.request(self.traceable_liveaction) traces = Trace.get_all() self.assertEqual(len(traces), 1) - self.assertEqual(len(traces[0]['action_executions']), 1) + self.assertEqual(len(traces[0]["action_executions"]), 1) # Let's use existing trace id in trace context. # We shouldn't create new trace object. trace_id = str(traces[0].id) - self.traceable_liveaction['context']['trace_context'] = {'id_': trace_id} + self.traceable_liveaction["context"]["trace_context"] = {"id_": trace_id} action_services.request(self.traceable_liveaction) traces = Trace.get_all() self.assertEqual(len(traces), 1) - self.assertEqual(len(traces[0]['action_executions']), 2) + self.assertEqual(len(traces[0]["action_executions"]), 2) def test_trace_tag_resuse(self): - self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'blank space'} + self.traceable_liveaction["context"]["trace_context"] = { + "trace_tag": "blank space" + } action_services.request(self.traceable_liveaction) # Let's use same trace tag again and we should see two trace objects in db. action_services.request(self.traceable_liveaction) - traces = Trace.query(**{'trace_tag': 'blank space'}) + traces = Trace.query(**{"trace_tag": "blank space"}) self.assertEqual(len(traces), 2) def test_invalid_trace_id_provided(self): liveactions = LiveAction.get_all() self.assertEqual(len(liveactions), 1) # fixtures loads it. - self.traceable_liveaction['context']['trace_context'] = {'id_': 'balleilaka'} + self.traceable_liveaction["context"]["trace_context"] = {"id_": "balleilaka"} - self.assertRaises(TraceNotFoundException, action_services.request, - self.traceable_liveaction) + self.assertRaises( + TraceNotFoundException, action_services.request, self.traceable_liveaction + ) # Make sure no liveactions are left behind liveactions = LiveAction.get_all() diff --git a/st2common/tests/unit/services/test_workflow.py b/st2common/tests/unit/services/test_workflow.py index 71cae679ba..23bd4aca60 100644 --- a/st2common/tests/unit/services/test_workflow.py +++ b/st2common/tests/unit/services/test_workflow.py @@ -25,6 +25,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -43,33 +44,35 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) -PACK_7 = 'dummy_pack_7' -PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + PACK_7 +PACK_7 = "dummy_pack_7" +PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + PACK_7 PACKS = [ TEST_PACK_PATH, PACK_7_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionServiceTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionServiceTest, cls).setUpClass() @@ -79,18 +82,17 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_request(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. @@ -99,7 +101,9 @@ def test_request(self): wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) # Check workflow execution is saved to the database. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) # Check required attributes. @@ -110,10 +114,12 @@ def test_request(self): self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED) def test_request_with_input(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters={'who': 'stan'}) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters={"who": "stan"} + ) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. @@ -122,7 +128,9 @@ def test_request_with_input(self): wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) # Check workflow execution is saved to the database. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) # Check required attributes. @@ -133,18 +141,16 @@ def test_request_with_input(self): self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED) # Check input and context. - expected_input = { - 'who': 'stan' - } + expected_input = {"who": "stan"} self.assertDictEqual(wf_ex_db.input, expected_input) def test_request_bad_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the action execution object with the bad action. ac_ex_db = ex_db_models.ActionExecutionDB( - action={'ref': 'mock.foobar'}, runner={'name': 'foobar'} + action={"ref": "mock.foobar"}, runner={"name": "foobar"} ) # Request the workflow execution. @@ -153,14 +159,16 @@ def test_request_bad_action(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_bad_action_ref(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-ref.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-inspection-action-ref.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -169,14 +177,16 @@ def test_request_wf_def_with_bad_action_ref(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_unregistered_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-db.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-inspection-action-db.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -185,15 +195,15 @@ def test_request_wf_def_with_unregistered_action(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_missing_required_action_param(self): - wf_name = 'fail-inspection-missing-required-action-param' - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') + wf_name = "fail-inspection-missing-required-action-param" + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -202,15 +212,15 @@ def test_request_wf_def_with_missing_required_action_param(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_unexpected_action_param(self): - wf_name = 'fail-inspection-unexpected-action-param' - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') + wf_name = "fail-inspection-unexpected-action-param" + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -219,44 +229,46 @@ def test_request_wf_def_with_unexpected_action_param(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_task_execution(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} - st2_ctx = {'execution_id': wf_ex_db.action_execution} + task_ctx = {"foo": "bar"} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'core.echo', 'input': {'message': 'Veni, vidi, vici.'}} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "core.echo", "input": {"message": "Veni, vidi, vici."}} + ], } workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Check task execution is saved to the database. - task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + task_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(task_ex_dbs), 1) # Check required attributes. @@ -267,42 +279,46 @@ def test_request_task_execution(self): self.assertEqual(task_ex_db.status, wf_statuses.RUNNING) # Check action execution for the task query with task execution ID. - ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(task_ex_db.id)) + ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(task_ex_db.id) + ) self.assertEqual(len(ac_ex_dbs), 1) # Check action execution for the task query with workflow execution ID. - ac_ex_dbs = ex_db_access.ActionExecution.query(workflow_execution=str(wf_ex_db.id)) + ac_ex_dbs = ex_db_access.ActionExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(ac_ex_dbs), 1) def test_request_task_execution_bad_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} - st2_ctx = {'execution_id': wf_ex_db.action_execution} + task_ctx = {"foo": "bar"} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'mock.echo', 'input': {'message': 'Veni, vidi, vici.'}} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "mock.echo", "input": {"message": "Veni, vidi, vici."}} + ], } self.assertRaises( @@ -310,14 +326,14 @@ def test_request_task_execution_bad_action(self): workflow_service.request_task_execution, wf_ex_db, st2_ctx, - task_ex_req + task_ex_req, ) def test_handle_action_execution_completion(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request and pre-process the workflow execution. @@ -327,111 +343,124 @@ def test_handle_action_execution_completion(self): wf_ex_db = self.prep_wf_ex(wf_ex_db) # Manually request task execution. - self.run_workflow_step(wf_ex_db, 'task1', 0, ctx={'foo': 'bar'}) + self.run_workflow_step(wf_ex_db, "task1", 0, ctx={"foo": "bar"}) # Check that a new task is executed. - self.assert_task_running('task2', 0) + self.assert_task_running("task2", 0) def test_evaluate_action_execution_delay(self): - base_task_ex_req = {'task_id': 'task1', 'task_name': 'task1', 'route': 0} + base_task_ex_req = {"task_id": "task1", "task_name": "task1", "route": 0} # No task delay. task_ex_req = copy.deepcopy(base_task_ex_req) - ac_ex_req = {'action': 'core.noop', 'input': None} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req) + ac_ex_req = {"action": "core.noop", "input": None} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req + ) self.assertIsNone(actual_delay) # Simple task delay. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - ac_ex_req = {'action': 'core.noop', 'input': None} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req) + task_ex_req["delay"] = 180 + ac_ex_req = {"action": "core.noop", "input": None} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req + ) self.assertEqual(actual_delay, 180) # Task delay for with items task and with no concurrency. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = None - ac_ex_req = {'action': 'core.noop', 'input': None, 'items_id': 0} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = None + ac_ex_req = {"action": "core.noop", "input": None, "items_id": 0} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertEqual(actual_delay, 180) # Task delay for with items task, with concurrency, and evaluate first item. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = 1 - ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 0} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = 1 + ac_ex_req = {"action": "core.noop", "input": None, "item_id": 0} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertEqual(actual_delay, 180) # Task delay for with items task, with concurrency, and evaluate later items. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = 1 - ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 1} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = 1 + ac_ex_req = {"action": "core.noop", "input": None, "item_id": 1} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertIsNone(actual_delay) def test_request_action_execution_render(self): # Manually create ConfigDB - output = 'Testing' - value = { - "config_item_one": output - } + output = "Testing" + value = {"config_item_one": output} config_db = pk_db_models.ConfigDB(pack=PACK_7, values=value) config = pk_db_access.Config.add_or_update(config_db) self.assertEqual(len(config), 3) - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'render_config_context.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "render_config_context.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Pass down appropriate st2 context to the task and action execution(s). - root_st2_ctx = wf_ex_db.context.get('st2', {}) + root_st2_ctx = wf_ex_db.context.get("st2", {}) st2_ctx = { - 'execution_id': wf_ex_db.action_execution, - 'user': root_st2_ctx.get('user'), - 'pack': root_st2_ctx.get('pack') + "execution_id": wf_ex_db.action_execution, + "user": root_st2_ctx.get("user"), + "pack": root_st2_ctx.get("pack"), } # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} + task_ctx = {"foo": "bar"} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'dummy_pack_7.render_config_context', 'input': None} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "dummy_pack_7.render_config_context", "input": None} + ], } workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Check task execution is saved to the database. - task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + task_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(task_ex_dbs), 1) workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Manually request action execution task_ex_db = task_ex_dbs[0] - action_ex_db = workflow_service.request_action_execution(wf_ex_db, task_ex_db, st2_ctx, - task_ex_req['actions'][0]) + action_ex_db = workflow_service.request_action_execution( + wf_ex_db, task_ex_db, st2_ctx, task_ex_req["actions"][0] + ) # Check required attributes. self.assertIsNotNone(str(action_ex_db.id)) self.assertEqual(task_ex_db.workflow_execution, str(wf_ex_db.id)) - expected_parameters = {'value1': output} + expected_parameters = {"value1": output} self.assertEqual(expected_parameters, action_ex_db.parameters) diff --git a/st2common/tests/unit/services/test_workflow_cancellation.py b/st2common/tests/unit/services/test_workflow_cancellation.py index 26455971f0..22694924a3 100644 --- a/st2common/tests/unit/services/test_workflow_cancellation.py +++ b/st2common/tests/unit/services/test_workflow_cancellation.py @@ -22,6 +22,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -35,39 +36,35 @@ TEST_FIXTURES = { - 'workflows': [ - 'sequential.yaml', - 'join.yaml' - ], - 'actions': [ - 'sequential.yaml', - 'join.yaml' - ] + "workflows": ["sequential.yaml", "join.yaml"], + "actions": ["sequential.yaml", "join.yaml"], } -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionCancellationTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionCancellationTest, cls).setUpClass() @@ -77,8 +74,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -86,8 +82,10 @@ def setUpClass(cls): def test_cancellation(self): # Manually create the liveaction and action execution objects without publishing. - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, TEST_FIXTURES['workflows'][0]) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, TEST_FIXTURES["workflows"][0] + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.create_request(lv_ac_db) # Request and pre-process the workflow execution. @@ -98,8 +96,8 @@ def test_cancellation(self): # Manually request task executions. task_route = 0 - self.run_workflow_step(wf_ex_db, 'task1', task_route) - self.assert_task_running('task2', task_route) + self.run_workflow_step(wf_ex_db, "task1", task_route) + self.assert_task_running("task2", task_route) # Cancel the workflow when there is still active task(s). wf_ex_db = wf_svc.request_cancellation(ac_ex_db) @@ -108,8 +106,8 @@ def test_cancellation(self): self.assertEqual(wf_ex_db.status, wf_statuses.CANCELING) # Manually complete the task and ensure workflow is canceled. - self.run_workflow_step(wf_ex_db, 'task2', task_route) - self.assert_task_not_started('task3', task_route) + self.run_workflow_step(wf_ex_db, "task2", task_route) + self.assert_task_not_started("task3", task_route) conductor, wf_ex_db = wf_svc.refresh_conductor(str(wf_ex_db.id)) self.assertEqual(conductor.get_workflow_status(), wf_statuses.CANCELED) self.assertEqual(wf_ex_db.status, wf_statuses.CANCELED) diff --git a/st2common/tests/unit/services/test_workflow_identify_orphans.py b/st2common/tests/unit/services/test_workflow_identify_orphans.py index d45ba1527f..306e22badd 100644 --- a/st2common/tests/unit/services/test_workflow_identify_orphans.py +++ b/st2common/tests/unit/services/test_workflow_identify_orphans.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -48,42 +49,51 @@ LOG = logging.getLogger(__name__) -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class WorkflowServiceIdentifyOrphansTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ ex_db_models.ActionExecutionDB, lv_db_models.LiveActionDB, wf_db_models.WorkflowExecutionDB, - wf_db_models.TaskExecutionDB + wf_db_models.TaskExecutionDB, ] @classmethod @@ -95,8 +105,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -119,8 +128,9 @@ def tearDown(self): def mock_workflow_records(self, completed=False, expired=True, log=True): status = ( - ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else - ac_const.LIVEACTION_STATUS_RUNNING + ac_const.LIVEACTION_STATUS_SUCCEEDED + if completed + else ac_const.LIVEACTION_STATUS_RUNNING ) # Identify start and end timestamp @@ -131,18 +141,24 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): end_timestamp = utc_now_dt if completed else None # Assign metadata. - action_ref = 'orquesta_tests.sequential' - runner = 'orquesta' - user = 'stanley' + action_ref = "orquesta_tests.sequential" + runner = "orquesta" + user = "stanley" # Create the WorkflowExecutionDB record first since the ID needs to be # included in the LiveActionDB and ActionExecutionDB records. - st2_ctx = {'st2': {'action_execution_id': '123', 'action': 'foobar', 'runner': 'orquesta'}} + st2_ctx = { + "st2": { + "action_execution_id": "123", + "action": "foobar", + "runner": "orquesta", + } + } wf_ex_db = wf_db_models.WorkflowExecutionDB( context=st2_ctx, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False) @@ -152,13 +168,10 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): workflow_execution=str(wf_ex_db.id), action=action_ref, action_is_workflow=True, - context={ - 'user': user, - 'workflow_execution': str(wf_ex_db.id) - }, + context={"user": user, "workflow_execution": str(wf_ex_db.id)}, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False) @@ -166,30 +179,20 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): # Create the ActionExecutionDB record. ac_ex_db = ex_db_models.ActionExecutionDB( workflow_execution=str(wf_ex_db.id), - action={ - 'runner_type': runner, - 'ref': action_ref - }, - runner={ - 'name': runner - }, - liveaction={ - 'id': str(lv_ac_db.id) - }, - context={ - 'user': user, - 'workflow_execution': str(wf_ex_db.id) - }, + action={"runner_type": runner, "ref": action_ref}, + runner={"name": runner}, + liveaction={"id": str(lv_ac_db.id)}, + context={"user": user, "workflow_execution": str(wf_ex_db.id)}, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) if log: - ac_ex_db.log = [{'status': 'running', 'timestamp': start_timestamp}] + ac_ex_db.log = [{"status": "running", "timestamp": start_timestamp}] if log and status in ac_const.LIVEACTION_COMPLETED_STATES: - ac_ex_db.log.append({'status': status, 'timestamp': end_timestamp}) + ac_ex_db.log.append({"status": status, "timestamp": end_timestamp}) ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False) @@ -199,14 +202,16 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): return wf_ex_db, lv_ac_db, ac_ex_db - def mock_task_records(self, parent, task_id, task_route=0, - completed=True, expired=False, log=True): + def mock_task_records( + self, parent, task_id, task_route=0, completed=True, expired=False, log=True + ): if not completed and expired: - raise ValueError('Task must be set completed=True if expired=True.') + raise ValueError("Task must be set completed=True if expired=True.") status = ( - ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else - ac_const.LIVEACTION_STATUS_RUNNING + ac_const.LIVEACTION_STATUS_SUCCEEDED + if completed + else ac_const.LIVEACTION_STATUS_RUNNING ) parent_wf_ex_db, parent_ac_ex_db = parent[0], parent[2] @@ -218,9 +223,9 @@ def mock_task_records(self, parent, task_id, task_route=0, end_timestamp = expiry_dt if expired else utc_now_dt # Assign metadata. - action_ref = 'core.local' - runner = 'local-shell-cmd' - user = 'stanley' + action_ref = "core.local" + runner = "local-shell-cmd" + user = "stanley" # Create the TaskExecutionDB record first since the ID needs to be # included in the LiveActionDB and ActionExecutionDB records. @@ -229,7 +234,7 @@ def mock_task_records(self, parent, task_id, task_route=0, task_id=task_id, task_route=0, status=status, - start_timestamp=parent_wf_ex_db.start_timestamp + start_timestamp=parent_wf_ex_db.start_timestamp, ) if status in ac_const.LIVEACTION_COMPLETED_STATES: @@ -239,18 +244,15 @@ def mock_task_records(self, parent, task_id, task_route=0, # Build context for LiveActionDB and ActionExecutionDB. context = { - 'user': user, - 'orquesta': { - 'task_id': tk_ex_db.task_id, - 'task_name': tk_ex_db.task_id, - 'workflow_execution_id': str(parent_wf_ex_db.id), - 'task_execution_id': str(tk_ex_db.id), - 'task_route': tk_ex_db.task_route + "user": user, + "orquesta": { + "task_id": tk_ex_db.task_id, + "task_name": tk_ex_db.task_id, + "workflow_execution_id": str(parent_wf_ex_db.id), + "task_execution_id": str(tk_ex_db.id), + "task_route": tk_ex_db.task_route, }, - 'parent': { - 'user': user, - 'execution_id': str(parent_ac_ex_db.id) - } + "parent": {"user": user, "execution_id": str(parent_ac_ex_db.id)}, } # Create the LiveActionDB record. @@ -262,7 +264,7 @@ def mock_task_records(self, parent, task_id, task_route=0, context=context, status=status, start_timestamp=tk_ex_db.start_timestamp, - end_timestamp=tk_ex_db.end_timestamp + end_timestamp=tk_ex_db.end_timestamp, ) lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False) @@ -271,27 +273,22 @@ def mock_task_records(self, parent, task_id, task_route=0, ac_ex_db = ex_db_models.ActionExecutionDB( workflow_execution=str(parent_wf_ex_db.id), task_execution=str(tk_ex_db.id), - action={ - 'runner_type': runner, - 'ref': action_ref - }, - runner={ - 'name': runner - }, - liveaction={ - 'id': str(lv_ac_db.id) - }, + action={"runner_type": runner, "ref": action_ref}, + runner={"name": runner}, + liveaction={"id": str(lv_ac_db.id)}, context=context, status=status, start_timestamp=tk_ex_db.start_timestamp, - end_timestamp=tk_ex_db.end_timestamp + end_timestamp=tk_ex_db.end_timestamp, ) if log: - ac_ex_db.log = [{'status': 'running', 'timestamp': tk_ex_db.start_timestamp}] + ac_ex_db.log = [ + {"status": "running", "timestamp": tk_ex_db.start_timestamp} + ] if log and status in ac_const.LIVEACTION_COMPLETED_STATES: - ac_ex_db.log.append({'status': status, 'timestamp': tk_ex_db.end_timestamp}) + ac_ex_db.log.append({"status": status, "timestamp": tk_ex_db.end_timestamp}) ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False) @@ -303,18 +300,18 @@ def test_no_orphans(self): # Workflow that is still running with task completed and not expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False) # Workflow that is still running with task running and not expired. wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False) + self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False) # Workflow that is completed and not expired. self.mock_workflow_records(completed=True, expired=False) # Workflow that is completed with task completed and not expired. wf_ex_set_5 = self.mock_workflow_records(completed=True, expired=False) - self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=False) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 0) @@ -339,33 +336,33 @@ def test_identify_orphans_with_no_task_executions(self): def test_identify_orphans_with_task_executions(self): # Workflow that is still running with task completed and expired. wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True) - self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True) + self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True) # Workflow that is still running with task completed and not expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False) # Workflow that is still running with task running and not expired. wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False) + self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False) # Workflow that is still running with multiple tasks and not expired. # One of the task completed passed expiry date but another task is still running. wf_ex_set_4 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_4, 'task1', completed=True, expired=True) - self.mock_task_records(wf_ex_set_4, 'task2', completed=False, expired=False) + self.mock_task_records(wf_ex_set_4, "task1", completed=True, expired=True) + self.mock_task_records(wf_ex_set_4, "task2", completed=False, expired=False) # Workflow that is still running with multiple tasks and not expired. # Both of the tasks are completed with one completed only recently. wf_ex_set_5 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=True) - self.mock_task_records(wf_ex_set_5, 'task2', completed=True, expired=False) + self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=True) + self.mock_task_records(wf_ex_set_5, "task2", completed=True, expired=False) # Workflow that is still running with multiple tasks and not expired. # One of the task completed recently and another task is still running. wf_ex_set_6 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_6, 'task1', completed=True, expired=False) - self.mock_task_records(wf_ex_set_6, 'task2', completed=False, expired=False) + self.mock_task_records(wf_ex_set_6, "task1", completed=True, expired=False) + self.mock_task_records(wf_ex_set_6, "task2", completed=False, expired=False) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 1) @@ -373,8 +370,10 @@ def test_identify_orphans_with_task_executions(self): def test_action_execution_with_missing_log_entries(self): # Workflow that is still running and expired. However the state change logs are missing. - wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True, log=False) - self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True) + wf_ex_set_1 = self.mock_workflow_records( + completed=False, expired=True, log=False + ) + self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 0) @@ -385,7 +384,7 @@ def test_garbage_collection(self): # Workflow that is still running with task completed and expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=True) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=True) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=True) # Ensure these workflows are identified as orphans. orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() diff --git a/st2common/tests/unit/services/test_workflow_rerun.py b/st2common/tests/unit/services/test_workflow_rerun.py index 6808991595..f5ff2bc487 100644 --- a/st2common/tests/unit/services/test_workflow_rerun.py +++ b/st2common/tests/unit/services/test_workflow_rerun.py @@ -24,6 +24,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from local_runner import local_shell_command_runner @@ -42,32 +43,38 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {}) -RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {}) +RUNNER_RESULT_SUCCEEDED = ( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + {"stdout": "foobar"}, + {}, +) @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionRerunTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionRerunTest, cls).setUpClass() @@ -77,18 +84,17 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def prep_wf_ex_for_rerun(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1) # Request the workflow execution. @@ -99,9 +105,12 @@ def prep_wf_ex_for_rerun(self): # Fail workflow execution. self.run_workflow_step( - wf_ex_db, 'task1', 0, + wf_ex_db, + "task1", + 0, expected_ac_ex_db_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_tk_ex_db_status=wf_statuses.FAILED) + expected_tk_ex_db_status=wf_statuses.FAILED, + ) # Check workflow status. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -115,20 +124,22 @@ def prep_wf_ex_for_rerun(self): return wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_request_rerun(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options) wf_ex_db = self.prep_wf_ex(wf_ex_db) @@ -138,7 +149,7 @@ def test_request_rerun(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete task1. - self.run_workflow_step(wf_ex_db, 'task1', 0) + self.run_workflow_step(wf_ex_db, "task1", 0) # Check workflow status and make sure it is still running. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -150,10 +161,10 @@ def test_request_rerun(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) def test_request_rerun_while_original_is_still_running(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1) # Request the workflow execution. @@ -168,16 +179,16 @@ def test_request_rerun_while_original_is_still_running(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -186,24 +197,26 @@ def test_request_rerun_while_original_is_still_running(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_request_rerun_again_while_prev_rerun_is_still_running(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options) wf_ex_db = self.prep_wf_ex(wf_ex_db) @@ -213,7 +226,7 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete task1. - self.run_workflow_step(wf_ex_db, 'task1', 0) + self.run_workflow_step(wf_ex_db, "task1", 0) # Check workflow status and make sure it is still running. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -225,16 +238,16 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db3, ac_ex_db3 = action_service.create_request(lv_ac_db3) # Request workflow execution rerun again. st2_ctx = self.mock_st2_context(ac_ex_db3, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -243,26 +256,28 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): workflow_service.request_rerun, ac_ex_db3, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_missing_workflow_execution_id(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun without workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - 'Unable to rerun workflow execution because ' - 'workflow_execution_id is not provided.' + "Unable to rerun workflow execution because " + "workflow_execution_id is not provided." ) self.assertRaisesRegexp( @@ -271,27 +286,28 @@ def test_request_rerun_with_missing_workflow_execution_id(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_nonexistent_workflow_execution(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = uuid.uuid4().hex[0:24] - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = uuid.uuid4().hex[0:24] + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it does not exist.$' + '^Unable to rerun workflow execution ".*" ' "because it does not exist.$" ) self.assertRaisesRegexp( @@ -300,12 +316,14 @@ def test_request_rerun_with_nonexistent_workflow_execution(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_workflow_execution_not_abended(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() @@ -315,16 +333,16 @@ def test_request_rerun_with_workflow_execution_not_abended(self): wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -333,29 +351,33 @@ def test_request_rerun_with_workflow_execution_not_abended(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_conductor_status_not_abended(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually set workflow conductor state to paused. - wf_ex_db.state['status'] = wf_statuses.PAUSED + wf_ex_db.state["status"] = wf_statuses.PAUSED wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} - expected_error = 'Unable to rerun workflow because it is not in a completed state.' + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} + expected_error = ( + "Unable to rerun workflow because it is not in a completed state." + ) self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, @@ -363,25 +385,29 @@ def test_request_rerun_with_conductor_status_not_abended(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_bad_task_name(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task5354']} - expected_error = '^Unable to rerun workflow because one or more tasks is not found: .*$' + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task5354"]} + expected_error = ( + "^Unable to rerun workflow because one or more tasks is not found: .*$" + ) self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, @@ -389,36 +415,40 @@ def test_request_rerun_with_bad_task_name(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_conductor_status_not_resuming(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'due to an unknown cause.' + '^Unable to rerun workflow execution ".*" ' "due to an unknown cause." ) - with mock.patch.object(conducting.WorkflowConductor, 'get_workflow_status', - mock.MagicMock(return_value=wf_statuses.FAILED)): + with mock.patch.object( + conducting.WorkflowConductor, + "get_workflow_status", + mock.MagicMock(return_value=wf_statuses.FAILED), + ): self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, expected_error, workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) diff --git a/st2common/tests/unit/services/test_workflow_service_retries.py b/st2common/tests/unit/services/test_workflow_service_retries.py index 35fafc1213..baa79c6954 100644 --- a/st2common/tests/unit/services/test_workflow_service_retries.py +++ b/st2common/tests/unit/services/test_workflow_service_retries.py @@ -27,6 +27,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -49,12 +50,14 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @@ -63,11 +66,11 @@ def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, **kwargs): - seq_len = len(wf_ex_db.state['sequence']) + seq_len = len(wf_ex_db.state["sequence"]) if seq_len > 0: - current_task_id = wf_ex_db.state['sequence'][seq_len - 1:][0]['id'] - temp_file_path = TEMP_DIR_PATH + '/' + current_task_id + current_task_id = wf_ex_db.state["sequence"][seq_len - 1 :][0]["id"] + temp_file_path = TEMP_DIR_PATH + "/" + current_task_id if os.path.exists(temp_file_path): os.remove(temp_file_path) @@ -77,31 +80,38 @@ def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, ** @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaServiceRetryTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ wf_db_models.WorkflowExecutionDB, wf_db_models.TaskExecutionDB, - ex_q_db_models.ActionExecutionSchedulingQueueItemDB + ex_q_db_models.ActionExecutionSchedulingQueueItemDB, ] @classmethod @@ -113,30 +123,38 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) @mock.patch.object( - coord_svc.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=[ - coordination.ToozConnectionError('foobar'), - coordination.ToozConnectionError('fubar'), - coord_svc.NoOpLock(name='noop')])) + coord_svc.NoOpDriver, + "get_lock", + mock.MagicMock( + side_effect=[ + coordination.ToozConnectionError("foobar"), + coordination.ToozConnectionError("fubar"), + coord_svc.NoOpLock(name="noop"), + ] + ), + ) def test_recover_from_coordinator_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 and expect acquiring lock returns a few connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) @@ -145,45 +163,60 @@ def test_recover_from_coordinator_connection_error(self): self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - coord_svc.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coord_svc.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) def test_retries_exhausted_from_coordinator_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 but retries exhaused with connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # The connection error should raise if retries are exhaused. self.assertRaises( coordination.ToozConnectionError, wf_svc.handle_action_execution_completion, - tk1_ac_ex_db + tk1_ac_ex_db, ) @mock.patch.object( - wf_svc, 'update_task_state', - mock.MagicMock(side_effect=[ - mongoengine.connection.MongoEngineConnectionError(), - mongoengine.connection.MongoEngineConnectionError(), - None])) + wf_svc, + "update_task_state", + mock.MagicMock( + side_effect=[ + mongoengine.connection.MongoEngineConnectionError(), + mongoengine.connection.MongoEngineConnectionError(), + None, + ] + ), + ) def test_recover_from_database_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 and expect acquiring lock returns a few connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) @@ -192,61 +225,71 @@ def test_recover_from_database_connection_error(self): self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - wf_svc, 'update_task_state', - mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError())) + wf_svc, + "update_task_state", + mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError()), + ) def test_retries_exhausted_from_database_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 but retries exhaused with connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # The connection error should raise if retries are exhaused. self.assertRaises( mongoengine.connection.MongoEngineConnectionError, wf_svc.handle_action_execution_completion, - tk1_ac_ex_db + tk1_ac_ex_db, ) @mock.patch.object( - wf_db_access.WorkflowExecution, 'update', - mock.MagicMock(side_effect=mock_wf_db_update_conflict)) + wf_db_access.WorkflowExecution, + "update", + mock.MagicMock(side_effect=mock_wf_db_update_conflict), + ) def test_recover_from_database_write_conflicts(self): # Create a temporary file which will be used to signal # which task(s) to mock the DB write conflict. - temp_file_path = TEMP_DIR_PATH + '/task4' + temp_file_path = TEMP_DIR_PATH + "/task4" if not os.path.exists(temp_file_path): - with open(temp_file_path, 'w'): + with open(temp_file_path, "w"): pass # Manually create the liveaction and action execution objects without publishing. - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'join.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "join.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Manually request task executions. task_route = 0 - self.run_workflow_step(wf_ex_db, 'task1', task_route) - self.assert_task_running('task2', task_route) - self.assert_task_running('task4', task_route) - self.run_workflow_step(wf_ex_db, 'task2', task_route) - self.assert_task_running('task3', task_route) - self.run_workflow_step(wf_ex_db, 'task4', task_route) - self.assert_task_running('task5', task_route) - self.run_workflow_step(wf_ex_db, 'task3', task_route) - self.assert_task_not_started('task6', task_route) - self.run_workflow_step(wf_ex_db, 'task5', task_route) - self.assert_task_running('task6', task_route) - self.run_workflow_step(wf_ex_db, 'task6', task_route) - self.assert_task_running('task7', task_route) - self.run_workflow_step(wf_ex_db, 'task7', task_route) + self.run_workflow_step(wf_ex_db, "task1", task_route) + self.assert_task_running("task2", task_route) + self.assert_task_running("task4", task_route) + self.run_workflow_step(wf_ex_db, "task2", task_route) + self.assert_task_running("task3", task_route) + self.run_workflow_step(wf_ex_db, "task4", task_route) + self.assert_task_running("task5", task_route) + self.run_workflow_step(wf_ex_db, "task3", task_route) + self.assert_task_not_started("task6", task_route) + self.run_workflow_step(wf_ex_db, "task5", task_route) + self.assert_task_running("task6", task_route) + self.run_workflow_step(wf_ex_db, "task6", task_route) + self.assert_task_running("task7", task_route) + self.run_workflow_step(wf_ex_db, "task7", task_route) self.assert_workflow_completed(str(wf_ex_db.id), status=wf_statuses.SUCCEEDED) # Ensure retry happened. diff --git a/st2common/tests/unit/test_action_alias_utils.py b/st2common/tests/unit/test_action_alias_utils.py index 33b78981a5..daad0fbe1e 100644 --- a/st2common/tests/unit/test_action_alias_utils.py +++ b/st2common/tests/unit/test_action_alias_utils.py @@ -14,281 +14,312 @@ # limitations under the License. from __future__ import absolute_import -from sre_parse import (parse, AT, AT_BEGINNING, AT_BEGINNING_STRING, AT_END, AT_END_STRING) +from sre_parse import ( + parse, + AT, + AT_BEGINNING, + AT_BEGINNING_STRING, + AT_END, + AT_END_STRING, +) from mock import Mock from unittest2 import TestCase from st2common.exceptions.content import ParseException from st2common.models.utils.action_alias_utils import ( - ActionAliasFormatParser, search_regex_tokens, - inject_immutable_parameters + ActionAliasFormatParser, + search_regex_tokens, + inject_immutable_parameters, ) class TestActionAliasParser(TestCase): def test_empty_string(self): - alias_format = '' - param_stream = '' + alias_format = "" + param_stream = "" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() self.assertEqual(extracted_values, {}) def test_arbitrary_pairs(self): # single-word param - alias_format = '' - param_stream = 'a=foobar1' + alias_format = "" + param_stream = "a=foobar1" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar1'}) + self.assertEqual(extracted_values, {"a": "foobar1"}) # multi-word double-quoted param - alias_format = 'foo' + alias_format = "foo" param_stream = 'foo a="foobar2 poonies bar"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'}) + self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"}) # multi-word single-quoted param - alias_format = 'foo' - param_stream = 'foo a=\'foobar2 poonies bar\'' + alias_format = "foo" + param_stream = "foo a='foobar2 poonies bar'" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'}) + self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"}) # JSON param - alias_format = 'foo' + alias_format = "foo" param_stream = 'foo a={"foobar2": "poonies"}' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"foobar2": "poonies"}'}) + self.assertEqual(extracted_values, {"a": '{"foobar2": "poonies"}'}) # Multiple mixed params - alias_format = '' + alias_format = "" param_stream = 'a=foobar1 b="boobar2 3 4" c=\'coobar3 4\' d={"a": "b"}' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar1', - 'b': 'boobar2 3 4', - 'c': 'coobar3 4', - 'd': '{"a": "b"}'}) + self.assertEqual( + extracted_values, + {"a": "foobar1", "b": "boobar2 3 4", "c": "coobar3 4", "d": '{"a": "b"}'}, + ) # Params along with a "normal" alias format - alias_format = '{{ captain }} is my captain' + alias_format = "{{ captain }} is my captain" param_stream = 'Malcolm Reynolds is my captain weirdo="River Tam"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'captain': 'Malcolm Reynolds', - 'weirdo': 'River Tam'}) + self.assertEqual( + extracted_values, {"captain": "Malcolm Reynolds", "weirdo": "River Tam"} + ) def test_simple_parsing(self): - alias_format = 'skip {{a}} more skip {{b}} and skip more.' - param_stream = 'skip a1 more skip b1 and skip more.' + alias_format = "skip {{a}} more skip {{b}} and skip more." + param_stream = "skip a1 more skip b1 and skip more." parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1", "b": "b1"}) def test_end_string_parsing(self): - alias_format = 'skip {{a}} more skip {{b}}' - param_stream = 'skip a1 more skip b1' + alias_format = "skip {{a}} more skip {{b}}" + param_stream = "skip a1 more skip b1" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1", "b": "b1"}) def test_spaced_parsing(self): - alias_format = 'skip {{a}} more skip {{b}} and skip more.' + alias_format = "skip {{a}} more skip {{b}} and skip more." param_stream = 'skip "a1 a2" more skip b1 and skip more.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1 a2", "b": "b1"}) def test_default_values(self): - alias_format = 'acl {{a}} {{b}} {{c}} {{d=1}}' + alias_format = "acl {{a}} {{b}} {{c}} {{d=1}}" param_stream = 'acl "a1 a2" "b1" "c1"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1', - 'c': 'c1', 'd': '1'}) + self.assertEqual( + extracted_values, {"a": "a1 a2", "b": "b1", "c": "c1", "d": "1"} + ) def test_spacing(self): - alias_format = 'acl {{a=test}}' - param_stream = 'acl' + alias_format = "acl {{a=test}}" + param_stream = "acl" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'test'}) + self.assertEqual(extracted_values, {"a": "test"}) def test_json_parsing(self): - alias_format = 'skip {{a}} more skip.' + alias_format = "skip {{a}} more skip." param_stream = 'skip {"a": "b", "c": "d"} more skip.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}'}) + self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}'}) def test_mixed_parsing(self): - alias_format = 'skip {{a}} more skip {{b}}.' + alias_format = "skip {{a}} more skip {{b}}." param_stream = 'skip {"a": "b", "c": "d"} more skip x.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}', - 'b': 'x'}) + self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}', "b": "x"}) def test_param_spaces(self): - alias_format = 's {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}' - param_stream = 's one more two more three more' + alias_format = "s {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}" + param_stream = "s one more two more three more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'one', 'b': 'two', - 'c': 'three', 'd': '99'}) + self.assertEqual( + extracted_values, {"a": "one", "b": "two", "c": "three", "d": "99"} + ) def test_enclosed_defaults(self): - alias_format = 'skip {{ a = value }} more' - param_stream = 'skip one more' + alias_format = "skip {{ a = value }} more" + param_stream = "skip one more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'one'}) + self.assertEqual(extracted_values, {"a": "one"}) - alias_format = 'skip {{ a = value }} more' - param_stream = 'skip more' + alias_format = "skip {{ a = value }} more" + param_stream = "skip more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'value'}) + self.assertEqual(extracted_values, {"a": "value"}) def test_template_defaults(self): - alias_format = 'two by two hands of {{ color = {{ colors.default_color }} }}' - param_stream = 'two by two hands of' + alias_format = "two by two hands of {{ color = {{ colors.default_color }} }}" + param_stream = "two by two hands of" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'color': '{{ colors.default_color }}'}) + self.assertEqual(extracted_values, {"color": "{{ colors.default_color }}"}) def test_key_value_combinations(self): # one-word value, single extra pair - alias_format = 'testing {{ a }}' - param_stream = 'testing value b=value2' + alias_format = "testing {{ a }}" + param_stream = "testing value b=value2" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'value', - 'b': 'value2'}) + self.assertEqual(extracted_values, {"a": "value", "b": "value2"}) # default value, single extra pair with quotes - alias_format = 'testing {{ a=new }}' + alias_format = "testing {{ a=new }}" param_stream = 'testing b="another value"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'new', - 'b': 'another value'}) + self.assertEqual(extracted_values, {"a": "new", "b": "another value"}) # multiple values and multiple extra pairs - alias_format = 'testing {{ b=abc }} {{ c=xyz }}' + alias_format = "testing {{ b=abc }} {{ c=xyz }}" param_stream = 'testing newvalue d={"1": "2"} e="long value"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'b': 'newvalue', - 'c': 'xyz', - 'd': '{"1": "2"}', - 'e': 'long value'}) + self.assertEqual( + extracted_values, + {"b": "newvalue", "c": "xyz", "d": '{"1": "2"}', "e": "long value"}, + ) def test_stream_is_none_with_all_default_values(self): - alias_format = 'skip {{d=test1}} more skip {{e=test1}}.' - param_stream = 'skip more skip' + alias_format = "skip {{d=test1}} more skip {{e=test1}}." + param_stream = "skip more skip" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'d': 'test1', 'e': 'test1'}) + self.assertEqual(extracted_values, {"d": "test1", "e": "test1"}) def test_stream_is_not_none_some_default_values(self): - alias_format = 'skip {{d=test}} more skip {{e=test}}' - param_stream = 'skip ponies more skip' + alias_format = "skip {{d=test}} more skip {{e=test}}" + param_stream = "skip ponies more skip" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'d': 'ponies', 'e': 'test'}) + self.assertEqual(extracted_values, {"d": "ponies", "e": "test"}) def test_stream_is_none_no_default_values(self): - alias_format = 'skip {{d}} more skip {{e}}.' + alias_format = "skip {{d}} more skip {{e}}." param_stream = None parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."' - self.assertRaisesRegexp(ParseException, expected_msg, - parser.get_extracted_param_value) + expected_msg = ( + 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."' + ) + self.assertRaisesRegexp( + ParseException, expected_msg, parser.get_extracted_param_value + ) def test_all_the_things(self): # this is the most insane example I could come up with - alias_format = "{{ p0='http' }} g {{ p1=p }} a " + \ - "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}" - param_stream = "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}" + alias_format = ( + "{{ p0='http' }} g {{ p1=p }} a " + + "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}" + ) + param_stream = ( + "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}" + ) parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'p0': 'http', 'p1': 'p', - 'url': 'http://google.com', - 'p2': '{{ execution.id }}', - 'p3': '{{ e.i }}', - 'p4': 'testing', 'p5': "{'a':'c'}"}) + self.assertEqual( + extracted_values, + { + "p0": "http", + "p1": "p", + "url": "http://google.com", + "p2": "{{ execution.id }}", + "p3": "{{ e.i }}", + "p4": "testing", + "p5": "{'a':'c'}", + }, + ) def test_command_doesnt_match_format_string(self): - alias_format = 'foo bar ponies' - param_stream = 'foo lulz ponies' + alias_format = "foo bar ponies" + param_stream = "foo lulz ponies" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"' - self.assertRaisesRegexp(ParseException, expected_msg, - parser.get_extracted_param_value) + expected_msg = ( + 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"' + ) + self.assertRaisesRegexp( + ParseException, expected_msg, parser.get_extracted_param_value + ) def test_ending_parameters_matching(self): - alias_format = 'foo bar' - param_stream = 'foo bar pony1=foo pony2=bar' + alias_format = "foo bar" + param_stream = "foo bar pony1=foo pony2=bar" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'pony1': 'foo', 'pony2': 'bar'}) + self.assertEqual(extracted_values, {"pony1": "foo", "pony2": "bar"}) def test_regex_beginning_anchors(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)' - param_stream = 'foo ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)" + param_stream = "foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_beginning_anchors_dont_match(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)' - param_stream = 'bar foo ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)" + param_stream = "bar foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "bar foo ASDF-1234" doesn't match format string '''\ - r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"''' + expected_msg = ( + r"""Command "bar foo ASDF-1234" doesn't match format string """ + r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() self.assertEqual(e.msg, expected_msg) def test_regex_ending_anchors(self): - alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'foo ASDF-1234' + alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_ending_anchors_dont_match(self): - alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'foo ASDF-1234 bar' + alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "foo ASDF-1234 bar" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "foo ASDF-1234 bar" doesn't match format string '''\ - r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + expected_msg = ( + r"""Command "foo ASDF-1234 bar" doesn't match format string """ + r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() self.assertEqual(e.msg, expected_msg) def test_regex_beginning_and_ending_anchors(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$' - param_stream = 'foo ASDF-1234 bar' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$" + param_stream = "foo ASDF-1234 bar" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_beginning_and_ending_anchors_dont_match(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'bar ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "bar ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "bar ASDF-1234" doesn't match format string '''\ - r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + expected_msg = ( + r"""Command "bar ASDF-1234" doesn't match format string """ + r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() @@ -332,8 +363,8 @@ def test_immutable_parameters_are_injected(self): exec_params = [{"param1": "value1", "param2": "value2"}] inject_immutable_parameters(action_alias_db, exec_params, {}) self.assertEqual( - exec_params, - [{"param1": "value1", "param2": "value2", "env": "dev"}]) + exec_params, [{"param1": "value1", "param2": "value2", "env": "dev"}] + ) def test_immutable_parameters_with_jinja(self): action_alias_db = Mock() @@ -341,8 +372,8 @@ def test_immutable_parameters_with_jinja(self): exec_params = [{"param1": "value1", "param2": "value2"}] inject_immutable_parameters(action_alias_db, exec_params, {}) self.assertEqual( - exec_params, - [{"param1": "value1", "param2": "value2", "env": "dev1"}]) + exec_params, [{"param1": "value1", "param2": "value2", "env": "dev1"}] + ) def test_override_raises_error(self): action_alias_db = Mock() diff --git a/st2common/tests/unit/test_action_api_validator.py b/st2common/tests/unit/test_action_api_validator.py index 1cf16d3f14..5be1ca13ba 100644 --- a/st2common/tests/unit/test_action_api_validator.py +++ b/st2common/tests/unit/test_action_api_validator.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -29,66 +30,83 @@ from st2tests import DbTestCase from st2tests.fixtures.packs import executions as fixture -__all__ = [ - 'TestActionAPIValidator' -] +__all__ = ["TestActionAPIValidator"] class TestActionAPIValidator(DbTestCase): - @classmethod def setUpClass(cls): super(TestActionAPIValidator, cls).setUpClass() runners_registrar.register_runners() - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_runner_type_happy_case(self): - action_api_dict = fixture.ARTIFACTS['actions']['local'] + action_api_dict = fixture.ARTIFACTS["actions"]["local"] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) except: - self.fail('Exception validating action: %s' % json.dumps(action_api_dict)) + self.fail("Exception validating action: %s" % json.dumps(action_api_dict)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_runner_type_invalid_runner(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-invalid-runner'] + action_api_dict = fixture.ARTIFACTS["actions"]["action-with-invalid-runner"] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_override_immutable_runner_param(self): - action_api_dict = fixture.ARTIFACTS['actions']['remote-override-runner-immutable'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "remote-override-runner-immutable" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('Cannot override in action.', six.text_type(e)) + self.assertIn("Cannot override in action.", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_immutable(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-param-no-default'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-immutable-param-no-default" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('requires a default value.', six.text_type(e)) + self.assertIn("requires a default value.", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_immutable_no_default(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-runner-param-no-default'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-immutable-runner-param-no-default" + ] action_api = ActionAPI(**action_api_dict) # Runner param sudo is decalred immutable in action but no defualt value @@ -97,30 +115,44 @@ def test_validate_action_param_immutable_no_default(self): action_validator.validate_action(action_api) except ValueValidationException as e: print(e) - self.fail('Action validation should have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have passed. %s" % json.dumps(action_api_dict) + ) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_position_values_unique(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-unique-positions'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-with-non-unique-positions" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should have failed ' + - 'because position values are not unique.' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have failed " + + "because position values are not unique." + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('have same position', six.text_type(e)) + self.assertIn("have same position", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_position_values_contiguous(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-contiguous-positions'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-with-non-contiguous-positions" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should have failed ' + - 'because position values are not contiguous.' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have failed " + + "because position values are not contiguous." + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('are not contiguous', six.text_type(e)) + self.assertIn("are not contiguous", six.text_type(e)) diff --git a/st2common/tests/unit/test_action_db_utils.py b/st2common/tests/unit/test_action_db_utils.py index ba2dcef018..f7a114b85b 100644 --- a/st2common/tests/unit/test_action_db_utils.py +++ b/st2common/tests/unit/test_action_db_utils.py @@ -35,7 +35,7 @@ from st2tests.base import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionDBUtilsTestCase(DbTestCase): runnertype_db = None action_db = None @@ -48,26 +48,39 @@ def setUpClass(cls): def test_get_runnertype_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_id, - 'somedummyrunnerid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_runnertype_by_id, + "somedummyrunnerid", + ) # By name. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_name, - 'somedummyrunnername') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_runnertype_by_name, + "somedummyrunnername", + ) def test_get_runnertype_existing(self): # Lookup by id and verify name equals. - runner = action_db_utils.get_runnertype_by_id(ActionDBUtilsTestCase.runnertype_db.id) + runner = action_db_utils.get_runnertype_by_id( + ActionDBUtilsTestCase.runnertype_db.id + ) self.assertEqual(runner.name, ActionDBUtilsTestCase.runnertype_db.name) # Lookup by name and verify id equals. - runner = action_db_utils.get_runnertype_by_name(ActionDBUtilsTestCase.runnertype_db.name) + runner = action_db_utils.get_runnertype_by_name( + ActionDBUtilsTestCase.runnertype_db.name + ) self.assertEqual(runner.id, ActionDBUtilsTestCase.runnertype_db.id) def test_get_action_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_action_by_id, - 'somedummyactionid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_action_by_id, + "somedummyactionid", + ) # By ref. - action = action_db_utils.get_action_by_ref('packaintexist.somedummyactionname') + action = action_db_utils.get_action_by_ref("packaintexist.somedummyactionname") self.assertIsNone(action) def test_get_action_existing(self): @@ -77,50 +90,57 @@ def test_get_action_existing(self): # Lookup by reference as string. action_ref = ResourceReference.to_string_reference( pack=ActionDBUtilsTestCase.action_db.pack, - name=ActionDBUtilsTestCase.action_db.name) + name=ActionDBUtilsTestCase.action_db.name, + ) action = action_db_utils.get_action_by_ref(action_ref) self.assertEqual(action.id, ActionDBUtilsTestCase.action_db.id) def test_get_actionexec_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_liveaction_by_id, - 'somedummyactionexecid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_liveaction_by_id, + "somedummyactionexecid", + ) def test_get_actionexec_existing(self): - liveaction = action_db_utils.get_liveaction_by_id(ActionDBUtilsTestCase.liveaction_db.id) + liveaction = action_db_utils.get_liveaction_by_id( + ActionDBUtilsTestCase.liveaction_db.id + ) self.assertEqual(liveaction, ActionDBUtilsTestCase.liveaction_db) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_with_incorrect_output_schema(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params runner = mock.MagicMock() - runner.output_schema = { - "notaparam": { - "type": "boolean" - } - } + runner.output_schema = {"notaparam": {"type": "boolean"}} liveaction_db.runner = runner liveaction_db = LiveAction.add_or_update(liveaction_db) origliveaction_db = copy.copy(liveaction_db) now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) @@ -128,18 +148,19 @@ def test_update_liveaction_with_incorrect_output_schema(self): self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_status(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -147,24 +168,31 @@ def test_update_liveaction_status(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) # Update status, result, context, and end timestamp. now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) @@ -172,18 +200,19 @@ def test_update_liveaction_status(self): self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_canceled_liveaction(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -191,21 +220,25 @@ def test_update_canceled_liveaction(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) # Cancel liveaction. now = get_datetime_utc_now() - status = 'canceled' + status = "canceled" newliveaction_db = action_db_utils.update_liveaction_status( - status=status, end_timestamp=now, liveaction_id=liveaction_db.id) + status=status, end_timestamp=now, liveaction_id=liveaction_db.id + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) self.assertEqual(newliveaction_db.end_timestamp, now) @@ -213,31 +246,36 @@ def test_update_canceled_liveaction(self): # Since liveaction has already been canceled, check that anymore update of # status, result, context, and end timestamp are not processed. now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'canceled') + self.assertEqual(newliveaction_db.status, "canceled") self.assertNotEqual(newliveaction_db.result, result) self.assertNotEqual(newliveaction_db.context, context) self.assertNotEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_result_with_dotted_key(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -245,66 +283,79 @@ def test_update_liveaction_result_with_dotted_key(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) now = get_datetime_utc_now() - status = 'succeeded' - result = {'a': 1, 'b': True, 'a.b.c': 'abc'} - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = {"a": 1, "b": True, "a.b.c": "abc"} + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) - self.assertIn('a.b.c', list(result.keys())) + self.assertIn("a.b.c", list(result.keys())) self.assertDictEqual(newliveaction_db.result, result) self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_LiveAction_status_invalid(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) # Update by id. - self.assertRaises(ValueError, action_db_utils.update_liveaction_status, - status='mea culpa', liveaction_id=liveaction_db.id) + self.assertRaises( + ValueError, + action_db_utils.update_liveaction_status, + status="mea culpa", + liveaction_id=liveaction_db.id, + ) # Verify that state is not published. self.assertFalse(LiveActionPublisher.publish_state.called) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_same_liveaction_status(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'requested' + liveaction_db.status = "requested" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -312,141 +363,150 @@ def test_update_same_liveaction_status(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='requested', liveaction_id=liveaction_db.id) + status="requested", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'requested') + self.assertEqual(newliveaction_db.status, "requested") # Verify that state is not published. self.assertFalse(LiveActionPublisher.publish_state.called) def test_get_args(self): - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'runnerint': 555 - } - pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, ['20', '', 'foo', '', '', '', '', ''], - 'Positional args not parsed correctly.') - self.assertNotIn('actionint', named_args) - self.assertNotIn('actionstr', named_args) - self.assertEqual(named_args.get('runnerint'), 555) + params = {"actionstr": "foo", "actionint": 20, "runnerint": 555} + pos_args, named_args = action_db_utils.get_args( + params, ActionDBUtilsTestCase.action_db + ) + self.assertListEqual( + pos_args, + ["20", "", "foo", "", "", "", "", ""], + "Positional args not parsed correctly.", + ) + self.assertNotIn("actionint", named_args) + self.assertNotIn("actionstr", named_args) + self.assertEqual(named_args.get("runnerint"), 555) # Test serialization for different positional argument types and values # Test all the values provided params = { - 'actionint': 1, - 'actionfloat': 1.5, - 'actionstr': 'string value', - 'actionbool': True, - 'actionarray': ['foo', 'bar', 'baz', 'qux'], - 'actionlist': ['foo', 'bar', 'baz'], - 'actionobject': {'a': 1, 'b': '2'}, + "actionint": 1, + "actionfloat": 1.5, + "actionstr": "string value", + "actionbool": True, + "actionarray": ["foo", "bar", "baz", "qux"], + "actionlist": ["foo", "bar", "baz"], + "actionobject": {"a": 1, "b": "2"}, } expected_pos_args = [ - '1', - '1.5', - 'string value', - '1', - 'foo,bar,baz,qux', - 'foo,bar,baz', + "1", + "1.5", + "string value", + "1", + "foo,bar,baz,qux", + "foo,bar,baz", '{"a": 1, "b": "2"}', - '' + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) params = { - 'actionint': 1, - 'actionfloat': 1.5, - 'actionstr': 'string value', - 'actionbool': False, - 'actionarray': [], - 'actionlist': [], - 'actionobject': {'a': 1, 'b': '2'}, + "actionint": 1, + "actionfloat": 1.5, + "actionstr": "string value", + "actionbool": False, + "actionarray": [], + "actionlist": [], + "actionobject": {"a": 1, "b": "2"}, } expected_pos_args = [ - '1', - '1.5', - 'string value', - '0', - '', - '', + "1", + "1.5", + "string value", + "0", + "", + "", '{"a": 1, "b": "2"}', - '' + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) # Test none values params = { - 'actionint': None, - 'actionfloat': None, - 'actionstr': None, - 'actionbool': None, - 'actionarray': None, - 'actionlist': None, - 'actionobject': None, + "actionint": None, + "actionfloat": None, + "actionstr": None, + "actionbool": None, + "actionarray": None, + "actionlist": None, + "actionobject": None, } - expected_pos_args = [ - '', - '', - '', - '', - '', - '', - '', - '' - ] + expected_pos_args = ["", "", "", "", "", "", "", ""] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) # Test unicode values params = { - 'actionstr': 'bar č š hello đ č p ž Ž a 💩😁', - 'actionint': 20, - 'runnerint': 555 + "actionstr": "bar č š hello đ č p ž Ž a 💩😁", + "actionint": 20, + "runnerint": 555, } expected_pos_args = [ - '20', - '', - u'bar č š hello đ č p ž Ž a 💩😁', - '', - '', - '', - '', - '' + "20", + "", + "bar č š hello đ č p ž Ž a 💩😁", + "", + "", + "", + "", + "", ] - pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, 'Positional args not parsed correctly.') + pos_args, named_args = action_db_utils.get_args( + params, ActionDBUtilsTestCase.action_db + ) + self.assertListEqual( + pos_args, expected_pos_args, "Positional args not parsed correctly." + ) # Test arrays and lists with values of different types params = { - 'actionarray': [None, False, 1, 4.2e1, '1e3', 'foo'], - 'actionlist': [None, False, 1, 73e-2, '1e2', 'bar'] + "actionarray": [None, False, 1, 4.2e1, "1e3", "foo"], + "actionlist": [None, False, 1, 73e-2, "1e2", "bar"], } expected_pos_args = [ - '', - '', - '', - '', - 'None,False,1,42.0,1e3,foo', - 'None,False,1,0.73,1e2,bar', - '', - '' + "", + "", + "", + "", + "None,False,1,42.0,1e3,foo", + "None,False,1,0.73,1e2,bar", + "", + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) - self.assertNotIn('actionint', named_args) - self.assertNotIn('actionstr', named_args) - self.assertEqual(named_args.get('runnerint'), 555) + self.assertNotIn("actionint", named_args) + self.assertNotIn("actionstr", named_args) + self.assertEqual(named_args.get("runnerint"), 555) @classmethod def _setup_test_models(cls): @@ -456,63 +516,65 @@ def _setup_test_models(cls): @classmethod def setup_runner(cls): test_runner = { - 'name': 'test-runner', - 'description': 'A test runner.', - 'enabled': True, - 'runner_parameters': { - 'runnerstr': { - 'description': 'Foo str param.', - 'type': 'string', - 'default': 'defaultfoo' + "name": "test-runner", + "description": "A test runner.", + "enabled": True, + "runner_parameters": { + "runnerstr": { + "description": "Foo str param.", + "type": "string", + "default": "defaultfoo", }, - 'runnerint': { - 'description': 'Foo int param.', - 'type': 'number' + "runnerint": {"description": "Foo int param.", "type": "number"}, + "runnerdummy": { + "description": "Dummy param.", + "type": "string", + "default": "runnerdummy", }, - 'runnerdummy': { - 'description': 'Dummy param.', - 'type': 'string', - 'default': 'runnerdummy' - } }, - 'runner_module': 'tests.test_runner' + "runner_module": "tests.test_runner", } runnertype_api = RunnerTypeAPI(**test_runner) ActionDBUtilsTestCase.runnertype_db = RunnerType.add_or_update( - RunnerTypeAPI.to_model(runnertype_api)) + RunnerTypeAPI.to_model(runnertype_api) + ) @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def setup_action_models(cls): - pack = 'wolfpack' - name = 'action-1' + pack = "wolfpack" + name = "action-1" parameters = { - 'actionint': {'type': 'number', 'default': 10, 'position': 0}, - 'actionfloat': {'type': 'float', 'required': False, 'position': 1}, - 'actionstr': {'type': 'string', 'required': True, 'position': 2}, - 'actionbool': {'type': 'boolean', 'required': False, 'position': 3}, - 'actionarray': {'type': 'array', 'required': False, 'position': 4}, - 'actionlist': {'type': 'list', 'required': False, 'position': 5}, - 'actionobject': {'type': 'object', 'required': False, 'position': 6}, - 'actionnull': {'type': 'null', 'required': False, 'position': 7}, - - 'runnerdummy': {'type': 'string', 'default': 'actiondummy'} + "actionint": {"type": "number", "default": 10, "position": 0}, + "actionfloat": {"type": "float", "required": False, "position": 1}, + "actionstr": {"type": "string", "required": True, "position": 2}, + "actionbool": {"type": "boolean", "required": False, "position": 3}, + "actionarray": {"type": "array", "required": False, "position": 4}, + "actionlist": {"type": "list", "required": False, "position": 5}, + "actionobject": {"type": "object", "required": False, "position": 6}, + "actionnull": {"type": "null", "required": False, "position": 7}, + "runnerdummy": {"type": "string", "default": "actiondummy"}, } - action_db = ActionDB(pack=pack, name=name, description='awesomeness', - enabled=True, - ref=ResourceReference(name=name, pack=pack).ref, - entry_point='', runner_type={'name': 'test-runner'}, - parameters=parameters) + action_db = ActionDB( + pack=pack, + name=name, + description="awesomeness", + enabled=True, + ref=ResourceReference(name=name, pack=pack).ref, + entry_point="", + runner_type={"name": "test-runner"}, + parameters=parameters, + ) ActionDBUtilsTestCase.action_db = Action.add_or_update(action_db) liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ActionDBUtilsTestCase.action_db.ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params ActionDBUtilsTestCase.liveaction_db = LiveAction.add_or_update(liveaction_db) diff --git a/st2common/tests/unit/test_action_param_utils.py b/st2common/tests/unit/test_action_param_utils.py index 08a6654f21..5eecf018dc 100644 --- a/st2common/tests/unit/test_action_param_utils.py +++ b/st2common/tests/unit/test_action_param_utils.py @@ -28,23 +28,16 @@ TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action3.yaml' - ], - 'runners': [ - 'testrunner1.yaml', - 'testrunner3.yaml' - ] + "actions": ["action1.yaml", "action3.yaml"], + "runners": ["testrunner1.yaml", "testrunner3.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) class ActionParamsUtilsTest(DbTestCase): - @classmethod def setUpClass(cls): super(ActionParamsUtilsTest, cls).setUpClass() @@ -54,86 +47,105 @@ def setUpClass(cls): cls.runnertype_dbs = {} cls.action_dbs = {} - for _, fixture in six.iteritems(FIXTURES['runners']): + for _, fixture in six.iteritems(FIXTURES["runners"]): instance = RunnerTypeAPI(**fixture) runnertype_db = RunnerType.add_or_update(RunnerTypeAPI.to_model(instance)) cls.runnertype_dbs[runnertype_db.name] = runnertype_db - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) action_db = Action.add_or_update(ActionAPI.to_model(instance)) cls.action_dbs[action_db.name] = action_db def test_merge_action_runner_params_meta(self): required, optional, immutable = action_param_utils.get_params_view( - action_db=self.action_dbs['action-1'], - runner_db=self.runnertype_dbs['test-runner-1']) + action_db=self.action_dbs["action-1"], + runner_db=self.runnertype_dbs["test-runner-1"], + ) merged = {} merged.update(required) merged.update(optional) merged.update(immutable) consolidated = action_param_utils.get_params_view( - action_db=self.action_dbs['action-1'], - runner_db=self.runnertype_dbs['test-runner-1'], - merged_only=True) + action_db=self.action_dbs["action-1"], + runner_db=self.runnertype_dbs["test-runner-1"], + merged_only=True, + ) # Validate that merged_only view works. self.assertEqual(merged, consolidated) # Validate required params. - self.assertEqual(len(required), 1, 'Required should contain only one param.') - self.assertIn('actionstr', required, 'actionstr param is a required param.') - self.assertNotIn('actionstr', optional, 'actionstr should not be in optional parameters') - self.assertNotIn('actionstr', immutable, 'actionstr should not be in immutable parameters') - self.assertIn('actionstr', merged, 'actionstr should be in action parameters') + self.assertEqual(len(required), 1, "Required should contain only one param.") + self.assertIn("actionstr", required, "actionstr param is a required param.") + self.assertNotIn( + "actionstr", optional, "actionstr should not be in optional parameters" + ) + self.assertNotIn( + "actionstr", immutable, "actionstr should not be in immutable parameters" + ) + self.assertIn("actionstr", merged, "actionstr should be in action parameters") # Validate immutable params. - self.assertIn('runnerimmutable', immutable, 'runnerimmutable should be in immutable.') - self.assertIn('actionimmutable', immutable, 'actionimmutable should be in immutable.') + self.assertIn( + "runnerimmutable", immutable, "runnerimmutable should be in immutable." + ) + self.assertIn( + "actionimmutable", immutable, "actionimmutable should be in immutable." + ) # Validate optional params. for opt in optional: - self.assertIn(opt, merged, 'Optional %s should be in action parameters' % opt) - self.assertNotIn(opt, required, 'Optional %s should not be in required params' % opt) - self.assertNotIn(opt, immutable, 'Optional %s should not be in immutable params' % opt) + self.assertIn( + opt, merged, "Optional %s should be in action parameters" % opt + ) + self.assertNotIn( + opt, required, "Optional %s should not be in required params" % opt + ) + self.assertNotIn( + opt, immutable, "Optional %s should not be in immutable params" % opt + ) def test_merge_param_meta_values(self): runner_meta = copy.deepcopy( - self.runnertype_dbs['test-runner-1'].runner_parameters['runnerdummy']) - action_meta = copy.deepcopy(self.action_dbs['action-1'].parameters['runnerdummy']) - merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta, - runner_meta=runner_meta) + self.runnertype_dbs["test-runner-1"].runner_parameters["runnerdummy"] + ) + action_meta = copy.deepcopy( + self.action_dbs["action-1"].parameters["runnerdummy"] + ) + merged_meta = action_param_utils._merge_param_meta_values( + action_meta=action_meta, runner_meta=runner_meta + ) # Description is in runner meta but not in action meta. - self.assertEqual(merged_meta['description'], runner_meta['description']) + self.assertEqual(merged_meta["description"], runner_meta["description"]) # Default value is overridden in action. - self.assertEqual(merged_meta['default'], action_meta['default']) + self.assertEqual(merged_meta["default"], action_meta["default"]) # Immutability is set in action. - self.assertEqual(merged_meta['immutable'], action_meta['immutable']) + self.assertEqual(merged_meta["immutable"], action_meta["immutable"]) def test_merge_param_meta_require_override(self): - action_meta = { - 'required': False - } - runner_meta = { - 'required': True - } - merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta, - runner_meta=runner_meta) + action_meta = {"required": False} + runner_meta = {"required": True} + merged_meta = action_param_utils._merge_param_meta_values( + action_meta=action_meta, runner_meta=runner_meta + ) - self.assertEqual(merged_meta['required'], action_meta['required']) + self.assertEqual(merged_meta["required"], action_meta["required"]) def test_validate_action_inputs(self): requires, unexpected = action_param_utils.validate_action_parameters( - self.action_dbs['action-1'].ref, {'foo': 'bar'}) + self.action_dbs["action-1"].ref, {"foo": "bar"} + ) - self.assertListEqual(requires, ['actionstr']) - self.assertListEqual(unexpected, ['foo']) + self.assertListEqual(requires, ["actionstr"]) + self.assertListEqual(unexpected, ["foo"]) def test_validate_overridden_action_inputs(self): requires, unexpected = action_param_utils.validate_action_parameters( - self.action_dbs['action-3'].ref, {'k1': 'foo'}) + self.action_dbs["action-3"].ref, {"k1": "foo"} + ) self.assertListEqual(requires, []) self.assertListEqual(unexpected, []) diff --git a/st2common/tests/unit/test_action_system_models.py b/st2common/tests/unit/test_action_system_models.py index 8098759b56..c8812acf38 100644 --- a/st2common/tests/unit/test_action_system_models.py +++ b/st2common/tests/unit/test_action_system_models.py @@ -19,24 +19,30 @@ from st2common.models.system.action import RemoteAction from st2common.models.system.action import RemoteScriptAction -__all__ = [ - 'RemoteActionTestCase', - 'RemoteScriptActionTestCase' -] +__all__ = ["RemoteActionTestCase", "RemoteScriptActionTestCase"] class RemoteActionTestCase(unittest2.TestCase): def test_instantiation(self): - action = RemoteAction(name='name', action_exec_id='aeid', command='ls -la', - env_vars={'a': 1}, on_behalf_user='onbehalf', user='user', - hosts=['127.0.0.1'], parallel=False, sudo=True, timeout=10) - self.assertEqual(action.name, 'name') - self.assertEqual(action.action_exec_id, 'aeid') - self.assertEqual(action.command, 'ls -la') - self.assertEqual(action.env_vars, {'a': 1}) - self.assertEqual(action.on_behalf_user, 'onbehalf') - self.assertEqual(action.user, 'user') - self.assertEqual(action.hosts, ['127.0.0.1']) + action = RemoteAction( + name="name", + action_exec_id="aeid", + command="ls -la", + env_vars={"a": 1}, + on_behalf_user="onbehalf", + user="user", + hosts=["127.0.0.1"], + parallel=False, + sudo=True, + timeout=10, + ) + self.assertEqual(action.name, "name") + self.assertEqual(action.action_exec_id, "aeid") + self.assertEqual(action.command, "ls -la") + self.assertEqual(action.env_vars, {"a": 1}) + self.assertEqual(action.on_behalf_user, "onbehalf") + self.assertEqual(action.user, "user") + self.assertEqual(action.hosts, ["127.0.0.1"]) self.assertEqual(action.parallel, False) self.assertEqual(action.sudo, True) self.assertEqual(action.timeout, 10) @@ -44,26 +50,35 @@ def test_instantiation(self): class RemoteScriptActionTestCase(unittest2.TestCase): def test_instantiation(self): - action = RemoteScriptAction(name='name', action_exec_id='aeid', - script_local_path_abs='/tmp/sc/ma_script.sh', - script_local_libs_path_abs='/tmp/sc/libs', named_args=None, - positional_args=None, env_vars={'a': 1}, - on_behalf_user='onbehalf', user='user', - remote_dir='/home/mauser', hosts=['127.0.0.1'], - parallel=False, sudo=True, timeout=10) - self.assertEqual(action.name, 'name') - self.assertEqual(action.action_exec_id, 'aeid') - self.assertEqual(action.script_local_libs_path_abs, '/tmp/sc/libs') - self.assertEqual(action.env_vars, {'a': 1}) - self.assertEqual(action.on_behalf_user, 'onbehalf') - self.assertEqual(action.user, 'user') - self.assertEqual(action.remote_dir, '/home/mauser') - self.assertEqual(action.hosts, ['127.0.0.1']) + action = RemoteScriptAction( + name="name", + action_exec_id="aeid", + script_local_path_abs="/tmp/sc/ma_script.sh", + script_local_libs_path_abs="/tmp/sc/libs", + named_args=None, + positional_args=None, + env_vars={"a": 1}, + on_behalf_user="onbehalf", + user="user", + remote_dir="/home/mauser", + hosts=["127.0.0.1"], + parallel=False, + sudo=True, + timeout=10, + ) + self.assertEqual(action.name, "name") + self.assertEqual(action.action_exec_id, "aeid") + self.assertEqual(action.script_local_libs_path_abs, "/tmp/sc/libs") + self.assertEqual(action.env_vars, {"a": 1}) + self.assertEqual(action.on_behalf_user, "onbehalf") + self.assertEqual(action.user, "user") + self.assertEqual(action.remote_dir, "/home/mauser") + self.assertEqual(action.hosts, ["127.0.0.1"]) self.assertEqual(action.parallel, False) self.assertEqual(action.sudo, True) self.assertEqual(action.timeout, 10) - self.assertEqual(action.script_local_dir, '/tmp/sc') - self.assertEqual(action.script_name, 'ma_script.sh') - self.assertEqual(action.remote_script, '/home/mauser/ma_script.sh') - self.assertEqual(action.command, '/home/mauser/ma_script.sh') + self.assertEqual(action.script_local_dir, "/tmp/sc") + self.assertEqual(action.script_name, "ma_script.sh") + self.assertEqual(action.remote_script, "/home/mauser/ma_script.sh") + self.assertEqual(action.command, "/home/mauser/ma_script.sh") diff --git a/st2common/tests/unit/test_actionchain_schema.py b/st2common/tests/unit/test_actionchain_schema.py index 5c968c9a11..e5bba6c0e2 100644 --- a/st2common/tests/unit/test_actionchain_schema.py +++ b/st2common/tests/unit/test_actionchain_schema.py @@ -20,42 +20,48 @@ from st2common.models.system import actionchain from st2tests.fixturesloader import FixturesLoader -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'actionchains': ['chain1.yaml', 'malformedchain.yaml', 'no_default_chain.yaml', - 'chain_with_vars.yaml', 'chain_with_publish.yaml'] + "actionchains": [ + "chain1.yaml", + "malformedchain.yaml", + "no_default_chain.yaml", + "chain_with_vars.yaml", + "chain_with_publish.yaml", + ] } -FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES) -CHAIN_1 = FIXTURES['actionchains']['chain1.yaml'] -MALFORMED_CHAIN = FIXTURES['actionchains']['malformedchain.yaml'] -NO_DEFAULT_CHAIN = FIXTURES['actionchains']['no_default_chain.yaml'] -CHAIN_WITH_VARS = FIXTURES['actionchains']['chain_with_vars.yaml'] -CHAIN_WITH_PUBLISH = FIXTURES['actionchains']['chain_with_publish.yaml'] +FIXTURES = FixturesLoader().load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES +) +CHAIN_1 = FIXTURES["actionchains"]["chain1.yaml"] +MALFORMED_CHAIN = FIXTURES["actionchains"]["malformedchain.yaml"] +NO_DEFAULT_CHAIN = FIXTURES["actionchains"]["no_default_chain.yaml"] +CHAIN_WITH_VARS = FIXTURES["actionchains"]["chain_with_vars.yaml"] +CHAIN_WITH_PUBLISH = FIXTURES["actionchains"]["chain_with_publish.yaml"] class ActionChainSchemaTest(unittest2.TestCase): - def test_actionchain_schema_valid(self): chain = actionchain.ActionChain(**CHAIN_1) - self.assertEqual(len(chain.chain), len(CHAIN_1['chain'])) - self.assertEqual(chain.default, CHAIN_1['default']) + self.assertEqual(len(chain.chain), len(CHAIN_1["chain"])) + self.assertEqual(chain.default, CHAIN_1["default"]) def test_actionchain_no_default(self): chain = actionchain.ActionChain(**NO_DEFAULT_CHAIN) - self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN['chain'])) + self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN["chain"])) self.assertEqual(chain.default, None) def test_actionchain_with_vars(self): chain = actionchain.ActionChain(**CHAIN_WITH_VARS) - self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS['chain'])) - self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS['vars'])) + self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS["chain"])) + self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS["vars"])) def test_actionchain_with_publish(self): chain = actionchain.ActionChain(**CHAIN_WITH_PUBLISH) - self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH['chain'])) - self.assertEqual(len(chain.chain[0].publish), - len(CHAIN_WITH_PUBLISH['chain'][0]['publish'])) + self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH["chain"])) + self.assertEqual( + len(chain.chain[0].publish), len(CHAIN_WITH_PUBLISH["chain"][0]["publish"]) + ) def test_actionchain_schema_invalid(self): with self.assertRaises(ValidationError): diff --git a/st2common/tests/unit/test_aliasesregistrar.py b/st2common/tests/unit/test_aliasesregistrar.py index b827830594..4f17246dcf 100644 --- a/st2common/tests/unit/test_aliasesregistrar.py +++ b/st2common/tests/unit/test_aliasesregistrar.py @@ -22,22 +22,20 @@ from st2tests import DbTestCase from st2tests import fixturesloader -__all__ = [ - 'TestAliasRegistrar' -] +__all__ = ["TestAliasRegistrar"] -ALIASES_FIXTURE_PACK_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), - 'dummy_pack_1') -ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, 'aliases') +ALIASES_FIXTURE_PACK_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) +ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, "aliases") class TestAliasRegistrar(DbTestCase): - def test_alias_registration(self): count = aliasesregistrar.register_aliases(pack_dir=ALIASES_FIXTURE_PACK_PATH) # expect all files to contain be aliases self.assertEqual(count, len(os.listdir(ALIASES_FIXTURE_PATH))) action_alias_dbs = ActionAlias.get_all() - self.assertEqual(action_alias_dbs[0].metadata_file, 'aliases/alias1.yaml') + self.assertEqual(action_alias_dbs[0].metadata_file, "aliases/alias1.yaml") diff --git a/st2common/tests/unit/test_api_model_validation.py b/st2common/tests/unit/test_api_model_validation.py index d5f250482b..20eb98ce6c 100644 --- a/st2common/tests/unit/test_api_model_validation.py +++ b/st2common/tests/unit/test_api_model_validation.py @@ -18,196 +18,197 @@ from st2common.models.api.base import BaseAPI -__all__ = [ - 'APIModelValidationTestCase' -] +__all__ = ["APIModelValidationTestCase"] class MockAPIModel1(BaseAPI): model = None schema = { - 'title': 'MockAPIModel', - 'description': 'Test', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the action runner.', - 'type': ['string', 'null'], - 'default': None + "title": "MockAPIModel", + "description": "Test", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the action runner.", + "type": ["string", "null"], + "default": None, }, - 'name': { - 'description': 'The name of the action runner.', - 'type': 'string', - 'required': True + "name": { + "description": "The name of the action runner.", + "type": "string", + "required": True, }, - 'description': { - 'description': 'The description of the action runner.', - 'type': 'string' + "description": { + "description": "The description of the action runner.", + "type": "string", }, - 'enabled': { - 'type': 'boolean', - 'default': True - }, - 'parameters': { - 'type': 'object' - }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': 'unknown' + "enabled": {"type": "boolean", "default": True}, + "parameters": {"type": "object"}, + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": "unknown", }, - 'enabled': { - 'type': 'boolean', - 'default': True + "enabled": {"type": "boolean", "default": True}, + "description": { + "type": "string", + "description": "Description", + "required": False, }, - 'description': { - 'type': 'string', - 'description': 'Description', - 'required': False - } - } + }, }, - 'default': [] - } + "default": [], + }, }, - 'additionalProperties': False + "additionalProperties": False, } class MockAPIModel2(BaseAPI): model = None schema = { - 'title': 'MockAPIModel2', - 'description': 'Test', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the action runner.', - 'type': 'string', - 'default': None + "title": "MockAPIModel2", + "description": "Test", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the action runner.", + "type": "string", + "default": None, }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': None + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": None, }, - 'description': { - 'type': 'string', - 'required': True - } - } + "description": {"type": "string", "required": True}, + }, }, - 'default': [] + "default": [], }, - 'parameters': { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'name': { - 'type': 'string', - 'required': True - } + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "name": {"type": "string", "required": True}, }, - 'additionalProperties': False, - } + "additionalProperties": False, + }, }, - 'additionalProperties': False + "additionalProperties": False, } class APIModelValidationTestCase(unittest2.TestCase): def test_validate_default_values_are_set(self): # no "permission_grants" attribute - mock_model_api = MockAPIModel1(name='name') - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.name, 'name') - self.assertEqual(getattr(mock_model_api, 'enabled', None), None) - self.assertEqual(getattr(mock_model_api, 'permission_grants', None), None) + mock_model_api = MockAPIModel1(name="name") + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.name, "name") + self.assertEqual(getattr(mock_model_api, "enabled", None), None) + self.assertEqual(getattr(mock_model_api, "permission_grants", None), None) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.name, 'name') - self.assertEqual(getattr(mock_model_api, 'enabled', None), None) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.name, "name") + self.assertEqual(getattr(mock_model_api, "enabled", None), None) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.name, 'name') + self.assertEqual(mock_model_api_validated.name, "name") self.assertEqual(mock_model_api_validated.enabled, True) self.assertEqual(mock_model_api_validated.permission_grants, []) # "permission_grants" attribute present, but child missing - mock_model_api = MockAPIModel1(name='name', enabled=False, - permission_grants=[{}, {'description': 'test'}]) - self.assertEqual(mock_model_api.name, 'name') + mock_model_api = MockAPIModel1( + name="name", enabled=False, permission_grants=[{}, {"description": "test"}] + ) + self.assertEqual(mock_model_api.name, "name") self.assertEqual(mock_model_api.enabled, False) - self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}]) + self.assertEqual( + mock_model_api.permission_grants, [{}, {"description": "test"}] + ) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(mock_model_api.name, 'name') + self.assertEqual(mock_model_api.name, "name") self.assertEqual(mock_model_api.enabled, False) - self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}]) + self.assertEqual( + mock_model_api.permission_grants, [{}, {"description": "test"}] + ) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.name, 'name') + self.assertEqual(mock_model_api_validated.name, "name") self.assertEqual(mock_model_api_validated.enabled, False) - self.assertEqual(mock_model_api_validated.permission_grants, - [{'resource_uid': 'unknown', 'enabled': True}, - {'resource_uid': 'unknown', 'enabled': True, 'description': 'test'}]) + self.assertEqual( + mock_model_api_validated.permission_grants, + [ + {"resource_uid": "unknown", "enabled": True}, + {"resource_uid": "unknown", "enabled": True, "description": "test"}, + ], + ) def test_validate_nested_attribute_with_default_not_provided(self): mock_model_api = MockAPIModel2() - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset') + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual( + getattr(mock_model_api, "permission_grants", "notset"), "notset" + ) + self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset") mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset') + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual( + getattr(mock_model_api, "permission_grants", "notset"), "notset" + ) + self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset") # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) self.assertEqual(mock_model_api_validated.permission_grants, []) - self.assertEqual(getattr(mock_model_api_validated, 'parameters', 'notset'), 'notset') + self.assertEqual( + getattr(mock_model_api_validated, "parameters", "notset"), "notset" + ) def test_validate_allow_default_none_for_any_type(self): - mock_model_api = MockAPIModel2(permission_grants=[{'description': 'test'}], - parameters={'name': 'test'}) - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}]) - self.assertEqual(mock_model_api.parameters, {'name': 'test'}) + mock_model_api = MockAPIModel2( + permission_grants=[{"description": "test"}], parameters={"name": "test"} + ) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}]) + self.assertEqual(mock_model_api.parameters, {"name": "test"}) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}]) - self.assertEqual(mock_model_api.parameters, {'name': 'test'}) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}]) + self.assertEqual(mock_model_api.parameters, {"name": "test"}) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.permission_grants, - [{'description': 'test', 'resource_uid': None}]) - self.assertEqual(mock_model_api_validated.parameters, {'id': None, 'name': 'test'}) + self.assertEqual( + mock_model_api_validated.permission_grants, + [{"description": "test", "resource_uid": None}], + ) + self.assertEqual( + mock_model_api_validated.parameters, {"id": None, "name": "test"} + ) diff --git a/st2common/tests/unit/test_casts.py b/st2common/tests/unit/test_casts.py index 55e95ca781..62bf0ac4e8 100644 --- a/st2common/tests/unit/test_casts.py +++ b/st2common/tests/unit/test_casts.py @@ -23,19 +23,19 @@ class CastsTestCase(unittest2.TestCase): def test_cast_string(self): - cast_func = get_cast('string') + cast_func = get_cast("string") - value = 'test1' + value = "test1" result = cast_func(value) - self.assertEqual(result, 'test1') + self.assertEqual(result, "test1") - value = u'test2' + value = "test2" result = cast_func(value) - self.assertEqual(result, u'test2') + self.assertEqual(result, "test2") - value = '' + value = "" result = cast_func(value) - self.assertEqual(result, '') + self.assertEqual(result, "") # None should be preserved value = None @@ -48,7 +48,7 @@ def test_cast_string(self): self.assertRaisesRegexp(ValueError, expected_msg, cast_func, value) def test_cast_array(self): - cast_func = get_cast('array') + cast_func = get_cast("array") # Python literal value = str([1, 2, 3]) diff --git a/st2common/tests/unit/test_config_loader.py b/st2common/tests/unit/test_config_loader.py index f59e3efe4f..e1849d7868 100644 --- a/st2common/tests/unit/test_config_loader.py +++ b/st2common/tests/unit/test_config_loader.py @@ -24,9 +24,7 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'ContentPackConfigLoaderTestCase' -] +__all__ = ["ContentPackConfigLoaderTestCase"] class ContentPackConfigLoaderTestCase(CleanDbTestCase): @@ -37,7 +35,7 @@ def test_ensure_local_pack_config_feature_removed(self): # Test a scenario where all the values are loaded from pack local # config and pack global config (pack name.yaml) doesn't exist. # Test a scenario where no values are overridden in the datastore - loader = ContentPackConfigLoader(pack_name='dummy_pack_4') + loader = ContentPackConfigLoader(pack_name="dummy_pack_4") config = loader.get_config() expected_config = {} @@ -46,35 +44,39 @@ def test_ensure_local_pack_config_feature_removed(self): def test_get_config_some_values_overriden_in_datastore(self): # Test a scenario where some values are overriden in datastore via pack # flobal config - kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5', - key_name='api_secret', - value='some_api_secret', - secret=True, - user='joe') + kvp_db = set_datastore_value_for_config_key( + pack_name="dummy_pack_5", + key_name="api_secret", + value="some_api_secret", + secret=True, + user="joe", + ) # This is a secret so a value should be encrypted - self.assertTrue(kvp_db.value != 'some_api_secret') - self.assertTrue(len(kvp_db.value) > len('some_api_secret') * 2) + self.assertTrue(kvp_db.value != "some_api_secret") + self.assertTrue(len(kvp_db.value) > len("some_api_secret") * 2) self.assertTrue(kvp_db.secret) - kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5', - key_name='private_key_path', - value='some_private_key') - self.assertEqual(kvp_db.value, 'some_private_key') + kvp_db = set_datastore_value_for_config_key( + pack_name="dummy_pack_5", + key_name="private_key_path", + value="some_private_key", + ) + self.assertEqual(kvp_db.value, "some_private_key") self.assertFalse(kvp_db.secret) - loader = ContentPackConfigLoader(pack_name='dummy_pack_5', user='joe') + loader = ContentPackConfigLoader(pack_name="dummy_pack_5", user="joe") config = loader.get_config() # regions is provided in the pack global config # api_secret is dynamically loaded from the datastore for a particular user expected_config = { - 'api_key': 'some_api_key', - 'api_secret': 'some_api_secret', - 'regions': ['us-west-1'], - 'region': 'default-region-value', - 'private_key_path': 'some_private_key', - 'non_required_with_default_value': 'config value' + "api_key": "some_api_key", + "api_secret": "some_api_secret", + "regions": ["us-west-1"], + "region": "default-region-value", + "private_key_path": "some_private_key", + "non_required_with_default_value": "config value", } self.assertEqual(config, expected_config) @@ -82,26 +84,26 @@ def test_get_config_some_values_overriden_in_datastore(self): def test_get_config_default_value_from_config_schema_is_used(self): # No value is provided for "region" in the config, default value from config schema # should be used - loader = ContentPackConfigLoader(pack_name='dummy_pack_5') + loader = ContentPackConfigLoader(pack_name="dummy_pack_5") config = loader.get_config() - self.assertEqual(config['region'], 'default-region-value') + self.assertEqual(config["region"], "default-region-value") # Here a default value is specified in schema but an explicit value is provided in the # config - loader = ContentPackConfigLoader(pack_name='dummy_pack_1') + loader = ContentPackConfigLoader(pack_name="dummy_pack_1") config = loader.get_config() - self.assertEqual(config['region'], 'us-west-1') + self.assertEqual(config["region"], "us-west-1") # Config item attribute has required: false # Value is provided in the config - it should be used as provided - pack_name = 'dummy_pack_5' + pack_name = "dummy_pack_5" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['non_required_with_default_value'], 'config value') + self.assertEqual(config["non_required_with_default_value"], "config value") config_db = Config.get_by_pack(pack_name) - del config_db['values']['non_required_with_default_value'] + del config_db["values"]["non_required_with_default_value"] Config.add_or_update(config_db) # No value in the config - default value should be used @@ -111,10 +113,12 @@ def test_get_config_default_value_from_config_schema_is_used(self): # No config exists for that pack - default value should be used loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['non_required_with_default_value'], 'some default value') + self.assertEqual( + config["non_required_with_default_value"], "some default value" + ) def test_default_values_from_schema_are_used_when_no_config_exists(self): - pack_name = 'dummy_pack_5' + pack_name = "dummy_pack_5" config_db = Config.get_by_pack(pack_name) # Delete the existing config loaded in setUp @@ -122,37 +126,37 @@ def test_default_values_from_schema_are_used_when_no_config_exists(self): config_db.delete() # Verify config has been deleted from the database - self.assertRaises(StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name) + self.assertRaises( + StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name + ) loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['region'], 'default-region-value') + self.assertEqual(config["region"], "default-region-value") def test_default_values_are_used_when_default_values_are_falsey(self): - pack_name = 'dummy_pack_17' + pack_name = "dummy_pack_17" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() # 1. Default values are used - self.assertEqual(config['key_with_default_falsy_value_1'], False) - self.assertEqual(config['key_with_default_falsy_value_2'], None) - self.assertEqual(config['key_with_default_falsy_value_3'], {}) - self.assertEqual(config['key_with_default_falsy_value_4'], '') - self.assertEqual(config['key_with_default_falsy_value_5'], 0) - self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False) - self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], 0) + self.assertEqual(config["key_with_default_falsy_value_1"], False) + self.assertEqual(config["key_with_default_falsy_value_2"], None) + self.assertEqual(config["key_with_default_falsy_value_3"], {}) + self.assertEqual(config["key_with_default_falsy_value_4"], "") + self.assertEqual(config["key_with_default_falsy_value_5"], 0) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], 0) # 2. Default values are overwrriten with config values which are also falsey values = { - 'key_with_default_falsy_value_1': 0, - 'key_with_default_falsy_value_2': '', - 'key_with_default_falsy_value_3': False, - 'key_with_default_falsy_value_4': None, - 'key_with_default_falsy_value_5': {}, - 'key_with_default_falsy_value_6': { - 'key_2': False - } + "key_with_default_falsy_value_1": 0, + "key_with_default_falsy_value_2": "", + "key_with_default_falsy_value_3": False, + "key_with_default_falsy_value_4": None, + "key_with_default_falsy_value_5": {}, + "key_with_default_falsy_value_6": {"key_2": False}, } config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) @@ -160,301 +164,296 @@ def test_default_values_are_used_when_default_values_are_falsey(self): loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['key_with_default_falsy_value_1'], 0) - self.assertEqual(config['key_with_default_falsy_value_2'], '') - self.assertEqual(config['key_with_default_falsy_value_3'], False) - self.assertEqual(config['key_with_default_falsy_value_4'], None) - self.assertEqual(config['key_with_default_falsy_value_5'], {}) - self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False) - self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], False) + self.assertEqual(config["key_with_default_falsy_value_1"], 0) + self.assertEqual(config["key_with_default_falsy_value_2"], "") + self.assertEqual(config["key_with_default_falsy_value_3"], False) + self.assertEqual(config["key_with_default_falsy_value_4"], None) + self.assertEqual(config["key_with_default_falsy_value_5"], {}) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], False) def test_get_config_nested_schema_default_values_from_config_schema_are_used(self): # Special case for more complex config schemas with attributes ntesting. # Validate that the default values are also used for one level nested object properties. - pack_name = 'dummy_pack_schema_with_nested_object_1' + pack_name = "dummy_pack_schema_with_nested_object_1" # 1. None of the nested object values are provided loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.3', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'] - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.3", + "port": 8080, + "device_uids": ["a", "b", "c"], + }, } self.assertEqual(config, expected_config) # 2. Some of the nested object values are provided (host, port) - pack_name = 'dummy_pack_schema_with_nested_object_2' + pack_name = "dummy_pack_schema_with_nested_object_2" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.6', - 'port': 9090, - 'device_uids': ['a', 'b', 'c'] - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.6", + "port": 9090, + "device_uids": ["a", "b", "c"], + }, } self.assertEqual(config, expected_config) # 3. Nested attribute (auth_settings.token) references a non-secret datastore value - pack_name = 'dummy_pack_schema_with_nested_object_3' - - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='some_auth_settings_token') - self.assertEqual(kvp_db.value, 'some_auth_settings_token') + pack_name = "dummy_pack_schema_with_nested_object_3" + + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="some_auth_settings_token", + ) + self.assertEqual(kvp_db.value, "some_auth_settings_token") self.assertFalse(kvp_db.secret) loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.10', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'some_auth_settings_token' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.10", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "some_auth_settings_token", + }, } self.assertEqual(config, expected_config) # 4. Nested attribute (auth_settings.token) references a secret datastore value - pack_name = 'dummy_pack_schema_with_nested_object_4' - - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='joe_token_secret', - secret=True, - user='joe') - self.assertTrue(kvp_db.value != 'joe_token_secret') - self.assertTrue(len(kvp_db.value) > len('joe_token_secret') * 2) + pack_name = "dummy_pack_schema_with_nested_object_4" + + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="joe_token_secret", + secret=True, + user="joe", + ) + self.assertTrue(kvp_db.value != "joe_token_secret") + self.assertTrue(len(kvp_db.value) > len("joe_token_secret") * 2) self.assertTrue(kvp_db.secret) - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='alice_token_secret', - secret=True, - user='alice') - self.assertTrue(kvp_db.value != 'alice_token_secret') - self.assertTrue(len(kvp_db.value) > len('alice_token_secret') * 2) + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="alice_token_secret", + secret=True, + user="alice", + ) + self.assertTrue(kvp_db.value != "alice_token_secret") + self.assertTrue(len(kvp_db.value) > len("alice_token_secret") * 2) self.assertTrue(kvp_db.secret) - loader = ContentPackConfigLoader(pack_name=pack_name, user='joe') + loader = ContentPackConfigLoader(pack_name=pack_name, user="joe") config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.11', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'joe_token_secret' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.11", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "joe_token_secret", + }, } self.assertEqual(config, expected_config) - loader = ContentPackConfigLoader(pack_name=pack_name, user='alice') + loader = ContentPackConfigLoader(pack_name=pack_name, user="alice") config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.11', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'alice_token_secret' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.11", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "alice_token_secret", + }, } self.assertEqual(config, expected_config) - def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown(self): - pack_name = 'dummy_pack_schema_with_nested_object_5' + def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown( + self, + ): + pack_name = "dummy_pack_schema_with_nested_object_5" loader = ContentPackConfigLoader(pack_name=pack_name) # Render fails on top-level item - values = { - 'level0_key': '{{st2kvXX.invalid}}' - } + values = {"level0_key": "{{st2kvXX.invalid}}"} config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key "level0_key" with ' - 'value "{{st2kvXX.invalid}}" for pack ".*?" config: ' - ' ' - '\'st2kvXX\' is undefined') + expected_msg = ( + 'Failed to render dynamic configuration value for key "level0_key" with ' + 'value "{{st2kvXX.invalid}}" for pack ".*?" config: ' + " " + "'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on fist level item - values = { - 'level0_object': { - 'level1_key': '{{st2kvXX.invalid}}' - } - } + values = {"level0_object": {"level1_key": "{{st2kvXX.invalid}}"}} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on second level item values = { - 'level0_object': { - 'level1_object': { - 'level2_key': '{{st2kvXX.invalid}}' - } - } + "level0_object": {"level1_object": {"level2_key": "{{st2kvXX.invalid}}"}} } config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on list item - values = { - 'level0_object': [ - 'abc', - '{{st2kvXX.invalid}}' - ] - } + values = {"level0_object": ["abc", "{{st2kvXX.invalid}}"]} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.1" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.1" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on nested object in list item - values = { - 'level0_object': [ - {'level2_key': '{{st2kvXX.invalid}}'} - ] - } + values = {"level0_object": [{"level2_key": "{{st2kvXX.invalid}}"}]} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on invalid syntax - values = { - 'level0_key': '{{ this is some invalid Jinja }}' - } + values = {"level0_key": "{{ this is some invalid Jinja }}"} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_key" with value "{{ this is some invalid Jinja }}"' - ' for pack ".*?" config: ' - ' expected token \'end of print statement\', got \'Jinja\'') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_key" with value "{{ this is some invalid Jinja }}"' + " for pack \".*?\" config: " + " expected token 'end of print statement', got 'Jinja'" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() def test_get_config_dynamic_config_item(self): - pack_name = 'dummy_pack_schema_with_nested_object_6' + pack_name = "dummy_pack_schema_with_nested_object_6" loader = ContentPackConfigLoader(pack_name=pack_name) #################### # value in top level item - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - values = { - 'level0_key': '{{st2kv.system.k1}}' - } + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + values = {"level0_key": "{{st2kv.system.k1}}"} config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) config_rendered = loader.get_config() - self.assertEqual(config_rendered, {'level0_key': 'v1'}) + self.assertEqual(config_rendered, {"level0_key": "v1"}) config_db.delete() def test_get_config_dynamic_config_item_nested_dict(self): - pack_name = 'dummy_pack_schema_with_nested_object_7' + pack_name = "dummy_pack_schema_with_nested_object_7" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) #################### # values nested dictionaries values = { - 'level0_key': '{{st2kv.system.k0}}', - 'level0_object': { - 'level1_key': '{{st2kv.system.k1}}', - 'level1_object': { - 'level2_key': '{{st2kv.system.k2}}' - } - } + "level0_key": "{{st2kv.system.k0}}", + "level0_object": { + "level1_key": "{{st2kv.system.k1}}", + "level1_object": {"level2_key": "{{st2kv.system.k2}}"}, + }, } config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': 'v0', - 'level0_object': { - 'level1_key': 'v1', - 'level1_object': { - 'level2_key': 'v2' - } - } - }) + self.assertEqual( + config_rendered, + { + "level0_key": "v0", + "level0_object": { + "level1_key": "v1", + "level1_object": {"level2_key": "v2"}, + }, + }, + ) config_db.delete() def test_get_config_dynamic_config_item_list(self): - pack_name = 'dummy_pack_schema_with_nested_object_7' + pack_name = "dummy_pack_schema_with_nested_object_7" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) #################### # values in list values = { - 'level0_key': [ - 'a', - '{{st2kv.system.k0}}', - 'b', - '{{st2kv.system.k1}}', + "level0_key": [ + "a", + "{{st2kv.system.k0}}", + "b", + "{{st2kv.system.k1}}", ] } config_db = ConfigDB(pack=pack_name, values=values) @@ -462,44 +461,34 @@ def test_get_config_dynamic_config_item_list(self): config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': [ - 'a', - 'v0', - 'b', - 'v1' - ] - }) + self.assertEqual(config_rendered, {"level0_key": ["a", "v0", "b", "v1"]}) config_db.delete() def test_get_config_dynamic_config_item_nested_list(self): - pack_name = 'dummy_pack_schema_with_nested_object_8' + pack_name = "dummy_pack_schema_with_nested_object_8" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) #################### # values in objects embedded in lists and nested lists values = { - 'level0_key': [ - { - 'level1_key0': '{{st2kv.system.k0}}' - }, - '{{st2kv.system.k1}}', + "level0_key": [ + {"level1_key0": "{{st2kv.system.k0}}"}, + "{{st2kv.system.k1}}", [ - '{{st2kv.system.k0}}', - '{{st2kv.system.k1}}', - '{{st2kv.system.k2}}', + "{{st2kv.system.k0}}", + "{{st2kv.system.k1}}", + "{{st2kv.system.k2}}", ], { - 'level1_key2': [ - '{{st2kv.system.k2}}', + "level1_key2": [ + "{{st2kv.system.k2}}", ] - } + }, ] } config_db = ConfigDB(pack=pack_name, values=values) @@ -507,30 +496,30 @@ def test_get_config_dynamic_config_item_nested_list(self): config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': [ - { - 'level1_key0': 'v0' - }, - 'v1', - [ - 'v0', - 'v1', - 'v2', - ], - { - 'level1_key2': [ - 'v2', - ] - } - ] - }) + self.assertEqual( + config_rendered, + { + "level0_key": [ + {"level1_key0": "v0"}, + "v1", + [ + "v0", + "v1", + "v2", + ], + { + "level1_key2": [ + "v2", + ] + }, + ] + }, + ) config_db.delete() def test_empty_config_object_in_the_database(self): - pack_name = 'dummy_pack_empty_config' + pack_name = "dummy_pack_empty_config" config_db = ConfigDB(pack=pack_name) config_db = Config.add_or_update(config_db) diff --git a/st2common/tests/unit/test_config_parser.py b/st2common/tests/unit/test_config_parser.py index 6dc690b746..fde0385369 100644 --- a/st2common/tests/unit/test_config_parser.py +++ b/st2common/tests/unit/test_config_parser.py @@ -27,27 +27,27 @@ def setUp(self): tests_config.parse_args() def test_get_config_inexistent_pack(self): - parser = ContentPackConfigParser(pack_name='inexistent') + parser = ContentPackConfigParser(pack_name="inexistent") config = parser.get_config() self.assertEqual(config, None) def test_get_config_no_config(self): - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() self.assertEqual(config, None) def test_get_config_existing_config(self): - pack_name = 'dummy_pack_2' + pack_name = "dummy_pack_2" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() - self.assertEqual(config.config['section1']['key1'], 'value1') - self.assertEqual(config.config['section2']['key10'], 'value10') + self.assertEqual(config.config["section1"]["key1"], "value1") + self.assertEqual(config.config["section2"]["key10"], "value10") def test_get_config_for_unicode_char(self): - pack_name = 'dummy_pack_18' + pack_name = "dummy_pack_18" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() - self.assertEqual(config.config['section1']['key1'], u'测试') + self.assertEqual(config.config["section1"]["key1"], "测试") diff --git a/st2common/tests/unit/test_configs_registrar.py b/st2common/tests/unit/test_configs_registrar.py index 09d002eb6a..821cec75fa 100644 --- a/st2common/tests/unit/test_configs_registrar.py +++ b/st2common/tests/unit/test_configs_registrar.py @@ -30,15 +30,23 @@ from st2tests import fixturesloader -__all__ = [ - 'ConfigsRegistrarTestCase' -] - -PACK_1_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1') -PACK_6_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_6') -PACK_19_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_19') -PACK_11_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_11') -PACK_22_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_22') +__all__ = ["ConfigsRegistrarTestCase"] + +PACK_1_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) +PACK_6_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_6" +) +PACK_19_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_19" +) +PACK_11_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_11" +) +PACK_22_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_22" +) class ConfigsRegistrarTestCase(CleanDbTestCase): @@ -52,7 +60,7 @@ def test_register_configs_for_all_packs(self): registrar = ConfigsRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_1_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_1_PATH} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_from_packs(base_dirs=packs_base_paths) @@ -64,9 +72,9 @@ def test_register_configs_for_all_packs(self): self.assertEqual(len(config_dbs), 1) config_db = config_dbs[0] - self.assertEqual(config_db.values['api_key'], '{{st2kv.user.api_key}}') - self.assertEqual(config_db.values['api_secret'], SUPER_SECRET_PARAMETER) - self.assertEqual(config_db.values['region'], 'us-west-1') + self.assertEqual(config_db.values["api_key"], "{{st2kv.user.api_key}}") + self.assertEqual(config_db.values["api_secret"], SUPER_SECRET_PARAMETER) + self.assertEqual(config_db.values["region"], "us-west-1") def test_register_all_configs_invalid_config_no_config_schema(self): # verify_ configs is on, but ConfigSchema for the pack doesn't exist so @@ -81,7 +89,7 @@ def test_register_all_configs_invalid_config_no_config_schema(self): registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_from_packs(base_dirs=packs_base_paths) @@ -92,7 +100,9 @@ def test_register_all_configs_invalid_config_no_config_schema(self): self.assertEqual(len(pack_dbs), 1) self.assertEqual(len(config_dbs), 1) - def test_register_all_configs_with_config_schema_validation_validation_failure_1(self): + def test_register_all_configs_with_config_schema_validation_validation_failure_1( + self, + ): # Verify DB is empty pack_dbs = Pack.get_all() config_dbs = Config.get_all() @@ -100,28 +110,38 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_1 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_5', pack_dir=PACK_6_PATH) + registrar._register_pack(pack_name="dummy_pack_5", pack_dir=PACK_6_PATH) packs_base_paths = content_utils.get_packs_base_paths() if six.PY3: - expected_msg = ('Failed validating attribute "regions" in config for pack ' - '"dummy_pack_6" (.*?): 1000 is not of type \'array\'') + expected_msg = ( + 'Failed validating attribute "regions" in config for pack ' + "\"dummy_pack_6\" (.*?): 1000 is not of type 'array'" + ) else: - expected_msg = ('Failed validating attribute "regions" in config for pack ' - '"dummy_pack_6" (.*?): 1000 is not of type u\'array\'') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_2(self): + expected_msg = ( + 'Failed validating attribute "regions" in config for pack ' + "\"dummy_pack_6\" (.*?): 1000 is not of type u'array'" + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_2( + self, + ): # Verify DB is empty pack_dbs = Pack.get_all() config_dbs = Config.get_all() @@ -129,30 +149,40 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_2 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_19': PACK_19_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_19": PACK_19_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_19', pack_dir=PACK_19_PATH) + registrar._register_pack(pack_name="dummy_pack_19", pack_dir=PACK_19_PATH) packs_base_paths = content_utils.get_packs_base_paths() if six.PY3: - expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack ' - '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type ' - '\'string\'') + expected_msg = ( + 'Failed validating attribute "instances.0.alias" in config for pack ' + "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type " + "'string'" + ) else: - expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack ' - '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type ' - 'u\'string\'') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_3(self): + expected_msg = ( + 'Failed validating attribute "instances.0.alias" in config for pack ' + "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type " + "u'string'" + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_3( + self, + ): # This test checks for values containing "decrypt_kv" jinja filter in the config # object where keys have "secret: True" set in the schema. @@ -163,26 +193,34 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_3 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_11': PACK_11_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_11": PACK_11_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_11', pack_dir=PACK_11_PATH) + registrar._register_pack(pack_name="dummy_pack_11", pack_dir=PACK_11_PATH) packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_4(self): + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_4( + self, + ): # This test checks for default values containing "decrypt_kv" jinja filter for # keys which have "secret: True" set. @@ -193,21 +231,27 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_4 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_22': PACK_22_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_22": PACK_22_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_22', pack_dir=PACK_22_PATH) + registrar._register_pack(pack_name="dummy_pack_22", pack_dir=PACK_22_PATH) packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) diff --git a/st2common/tests/unit/test_connection_retry_wrapper.py b/st2common/tests/unit/test_connection_retry_wrapper.py index 8c75ff4955..831ac8c22e 100644 --- a/st2common/tests/unit/test_connection_retry_wrapper.py +++ b/st2common/tests/unit/test_connection_retry_wrapper.py @@ -21,19 +21,18 @@ class TestClusterRetryContext(unittest.TestCase): - def test_single_node_cluster_retry(self): retry_context = ClusterRetryContext(cluster_size=1) should_stop, wait = retry_context.test_should_stop() - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") self.assertEqual(wait, 10) should_stop, wait = retry_context.test_should_stop() - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") self.assertEqual(wait, 10) should_stop, wait = retry_context.test_should_stop() - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) def test_should_stop_second_channel_open_error_should_be_non_fatal(self): @@ -58,10 +57,10 @@ def test_multiple_node_cluster_retry(self): for i in range(last_index + 1): should_stop, wait = retry_context.test_should_stop() if i == last_index: - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) else: - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") # on cluster boundaries the wait is longer. Short wait when switching # to a different server within a cluster. if (i + 1) % cluster_size == 0: @@ -72,5 +71,5 @@ def test_multiple_node_cluster_retry(self): def test_zero_node_cluster_retry(self): retry_context = ClusterRetryContext(cluster_size=0) should_stop, wait = retry_context.test_should_stop() - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) diff --git a/st2common/tests/unit/test_content_loader.py b/st2common/tests/unit/test_content_loader.py index c20afda87a..8b8e650afb 100644 --- a/st2common/tests/unit/test_content_loader.py +++ b/st2common/tests/unit/test_content_loader.py @@ -23,64 +23,81 @@ from st2common.content.loader import LOG CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) class ContentLoaderTest(unittest2.TestCase): def test_get_sensors(self): - packs_base_path = os.path.join(RESOURCES_DIR, 'packs/') + packs_base_path = os.path.join(RESOURCES_DIR, "packs/") loader = ContentPackLoader() - pack_sensors = loader.get_content(base_dirs=[packs_base_path], content_type='sensors') - self.assertIsNotNone(pack_sensors.get('pack1', None)) + pack_sensors = loader.get_content( + base_dirs=[packs_base_path], content_type="sensors" + ) + self.assertIsNotNone(pack_sensors.get("pack1", None)) def test_get_sensors_pack_missing_sensors(self): loader = ContentPackLoader() - fail_pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2') + fail_pack_path = os.path.join(RESOURCES_DIR, "packs/pack2") self.assertTrue(os.path.exists(fail_pack_path)) self.assertEqual(loader._get_sensors(fail_pack_path), None) def test_invalid_content_type(self): - packs_base_path = os.path.join(RESOURCES_DIR, 'packs/') + packs_base_path = os.path.join(RESOURCES_DIR, "packs/") loader = ContentPackLoader() - self.assertRaises(ValueError, loader.get_content, base_dirs=[packs_base_path], - content_type='stuff') + self.assertRaises( + ValueError, + loader.get_content, + base_dirs=[packs_base_path], + content_type="stuff", + ) def test_get_content_multiple_directories(self): - packs_base_path_1 = os.path.join(RESOURCES_DIR, 'packs/') - packs_base_path_2 = os.path.join(RESOURCES_DIR, 'packs2/') + packs_base_path_1 = os.path.join(RESOURCES_DIR, "packs/") + packs_base_path_2 = os.path.join(RESOURCES_DIR, "packs2/") base_dirs = [packs_base_path_1, packs_base_path_2] LOG.warning = Mock() loader = ContentPackLoader() - sensors = loader.get_content(base_dirs=base_dirs, content_type='sensors') - self.assertIn('pack1', sensors) # from packs/ - self.assertIn('pack3', sensors) # from packs2/ + sensors = loader.get_content(base_dirs=base_dirs, content_type="sensors") + self.assertIn("pack1", sensors) # from packs/ + self.assertIn("pack3", sensors) # from packs2/ # Assert that a warning is emitted when a duplicated pack is found - expected_msg = ('Pack "pack1" already found in ' - '"%s/packs/", ignoring content from ' - '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR)) + expected_msg = ( + 'Pack "pack1" already found in ' + '"%s/packs/", ignoring content from ' + '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR) + ) LOG.warning.assert_called_once_with(expected_msg) def test_get_content_from_pack_success(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack1') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack1") - sensors = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors') - self.assertTrue(sensors.endswith('packs/pack1/sensors')) + sensors = loader.get_content_from_pack( + pack_dir=pack_path, content_type="sensors" + ) + self.assertTrue(sensors.endswith("packs/pack1/sensors")) def test_get_content_from_pack_directory_doesnt_exist(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack100') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack100") - message_regex = 'Directory .*? doesn\'t exist' - self.assertRaisesRegexp(ValueError, message_regex, loader.get_content_from_pack, - pack_dir=pack_path, content_type='sensors') + message_regex = "Directory .*? doesn't exist" + self.assertRaisesRegexp( + ValueError, + message_regex, + loader.get_content_from_pack, + pack_dir=pack_path, + content_type="sensors", + ) def test_get_content_from_pack_no_sensors(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack2") - result = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors') + result = loader.get_content_from_pack( + pack_dir=pack_path, content_type="sensors" + ) self.assertEqual(result, None) diff --git a/st2common/tests/unit/test_content_utils.py b/st2common/tests/unit/test_content_utils.py index 703c75aa70..523114a613 100644 --- a/st2common/tests/unit/test_content_utils.py +++ b/st2common/tests/unit/test_content_utils.py @@ -39,205 +39,260 @@ def setUpClass(cls): tests_config.parse_args() def test_get_pack_base_paths(self): - cfg.CONF.content.system_packs_base_path = '' - cfg.CONF.content.packs_base_paths = '/opt/path1' + cfg.CONF.content.system_packs_base_path = "" + cfg.CONF.content.packs_base_paths = "/opt/path1" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1']) + self.assertEqual(result, ["/opt/path1"]) # Multiple paths, no trailing colon - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2' + cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple paths, trailing colon - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:' + cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2:" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple same paths - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2' + cfg.CONF.content.packs_base_paths = ( + "/opt/path1:/opt/path2:/opt/path1:/opt/path2" + ) result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Assert system path is always first - cfg.CONF.content.system_packs_base_path = '/opt/system' - cfg.CONF.content.packs_base_paths = '/opt/path2:/opt/path1' + cfg.CONF.content.system_packs_base_path = "/opt/system" + cfg.CONF.content.packs_base_paths = "/opt/path2:/opt/path1" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/system', '/opt/path2', '/opt/path1']) + self.assertEqual(result, ["/opt/system", "/opt/path2", "/opt/path1"]) # More scenarios orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" - names = [ - 'test_pack_1', - 'test_pack_2', - 'ma_pack' - ] + names = ["test_pack_1", "test_pack_2", "ma_pack"] for name in names: actual = get_pack_base_path(pack_name=name) - expected = os.path.join(cfg.CONF.content.system_packs_base_path, - name) + expected = os.path.join(cfg.CONF.content.system_packs_base_path, name) self.assertEqual(actual, expected) cfg.CONF.content.system_packs_base_path = orig_path def test_get_aliases_base_paths(self): - cfg.CONF.content.aliases_base_paths = '/opt/path1' + cfg.CONF.content.aliases_base_paths = "/opt/path1" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1']) + self.assertEqual(result, ["/opt/path1"]) # Multiple paths, no trailing colon - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2' + cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple paths, trailing colon - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:' + cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2:" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple same paths - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2' + cfg.CONF.content.aliases_base_paths = ( + "/opt/path1:/opt/path2:/opt/path1:/opt/path2" + ) result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) def test_get_pack_resource_file_abs_path(self): # Mock the packs path to point to the fixtures directory cfg.CONF.content.packs_base_paths = get_fixtures_packs_base_path() # Invalid resource type - expected_msg = 'Invalid resource type: fooo' - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='fooo', - file_path='test.py') + expected_msg = "Invalid resource type: fooo" + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="fooo", + file_path="test.py", + ) # Invalid paths (directory traversal and absolute paths) - file_paths = ['/tmp/foo.py', '../foo.py', '/etc/passwd', '../../foo.py', - '/opt/stackstorm/packs/invalid_pack/actions/my_action.py', - '../../foo.py'] + file_paths = [ + "/tmp/foo.py", + "../foo.py", + "/etc/passwd", + "../../foo.py", + "/opt/stackstorm/packs/invalid_pack/actions/my_action.py", + "../../foo.py", + ] for file_path in file_paths: # action resource_type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack actions directory (.*). For example "my_action.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='action', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack actions directory (.*). For example "my_action.py"\.' + % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="action", + file_path=file_path, + ) # sensor resource_type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack sensors directory (.*). For example "my_sensor.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='sensor', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack sensors directory (.*). For example "my_sensor.py"\.' + % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="sensor", + file_path=file_path, + ) # no resource type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack directory (.*). For example "my_action.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_file_abs_path, - pack_ref='dummy_pack_1', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack directory (.*). For example "my_action.py"\.' % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_file_abs_path, + pack_ref="dummy_pack_1", + file_path=file_path, + ) # Valid paths - file_paths = ['foo.py', 'a/foo.py', 'a/b/foo.py'] + file_paths = ["foo.py", "a/foo.py", "a/b/foo.py"] for file_path in file_paths: - expected = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_1/actions', file_path) - result = get_pack_resource_file_abs_path(pack_ref='dummy_pack_1', - resource_type='action', - file_path=file_path) + expected = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_1/actions", file_path + ) + result = get_pack_resource_file_abs_path( + pack_ref="dummy_pack_1", resource_type="action", file_path=file_path + ) self.assertEqual(result, expected) def test_get_entry_point_absolute_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" acutal_path = get_entry_point_abs_path( - pack='foo', - entry_point='/tests/packs/foo/bar.py') - self.assertEqual(acutal_path, '/tests/packs/foo/bar.py', 'Entry point path doesn\'t match.') + pack="foo", entry_point="/tests/packs/foo/bar.py" + ) + self.assertEqual( + acutal_path, "/tests/packs/foo/bar.py", "Entry point path doesn't match." + ) cfg.CONF.content.system_packs_base_path = orig_path def test_get_entry_point_absolute_path_empty(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' - acutal_path = get_entry_point_abs_path(pack='foo', entry_point=None) - self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.') - acutal_path = get_entry_point_abs_path(pack='foo', entry_point='') - self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.') + cfg.CONF.content.system_packs_base_path = "/tests/packs" + acutal_path = get_entry_point_abs_path(pack="foo", entry_point=None) + self.assertEqual(acutal_path, None, "Entry point path doesn't match.") + acutal_path = get_entry_point_abs_path(pack="foo", entry_point="") + self.assertEqual(acutal_path, None, "Entry point path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_entry_point_relative_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' - acutal_path = get_entry_point_abs_path(pack='foo', entry_point='foo/bar.py') - expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions', - 'foo/bar.py') - self.assertEqual(acutal_path, expected_path, 'Entry point path doesn\'t match.') + cfg.CONF.content.system_packs_base_path = "/tests/packs" + acutal_path = get_entry_point_abs_path(pack="foo", entry_point="foo/bar.py") + expected_path = os.path.join( + cfg.CONF.content.system_packs_base_path, "foo", "actions", "foo/bar.py" + ) + self.assertEqual(acutal_path, expected_path, "Entry point path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_action_libs_abs_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" # entry point relative. - acutal_path = get_action_libs_abs_path(pack='foo', entry_point='foo/bar.py') - expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions', - os.path.join('foo', ACTION_LIBS_DIR)) - self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.') + acutal_path = get_action_libs_abs_path(pack="foo", entry_point="foo/bar.py") + expected_path = os.path.join( + cfg.CONF.content.system_packs_base_path, + "foo", + "actions", + os.path.join("foo", ACTION_LIBS_DIR), + ) + self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.") # entry point absolute. acutal_path = get_action_libs_abs_path( - pack='foo', - entry_point='/tests/packs/foo/tmp/foo.py') - expected_path = os.path.join('/tests/packs/foo/tmp', ACTION_LIBS_DIR) - self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.') + pack="foo", entry_point="/tests/packs/foo/tmp/foo.py" + ) + expected_path = os.path.join("/tests/packs/foo/tmp", ACTION_LIBS_DIR) + self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_relative_path_to_pack_file(self): packs_base_paths = get_fixtures_packs_base_path() - pack_ref = 'dummy_pack_1' + pack_ref = "dummy_pack_1" # 1. Valid paths - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/pack.yaml') + file_path = os.path.join(packs_base_paths, "dummy_pack_1/pack.yaml") result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'pack.yaml') + self.assertEqual(result, "pack.yaml") - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/action.meta.yaml') + file_path = os.path.join( + packs_base_paths, "dummy_pack_1/actions/action.meta.yaml" + ) result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/action.meta.yaml') + self.assertEqual(result, "actions/action.meta.yaml") - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/lib/foo.py') + file_path = os.path.join(packs_base_paths, "dummy_pack_1/actions/lib/foo.py") result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/lib/foo.py') + self.assertEqual(result, "actions/lib/foo.py") # Already relative - file_path = 'actions/lib/foo2.py' + file_path = "actions/lib/foo2.py" result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/lib/foo2.py') + self.assertEqual(result, "actions/lib/foo2.py") # 2. Invalid path - outside pack directory - expected_msg = r'file_path (.*?) is not located inside the pack directory (.*?)' - - file_path = os.path.join(packs_base_paths, 'dummy_pack_2/actions/lib/foo.py') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = '/tmp/foo/bar.py' - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = os.path.join(packs_base_paths, '../dummy_pack_1/pack.yaml') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = os.path.join(packs_base_paths, '../../dummy_pack_1/pack.yaml') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) + expected_msg = r"file_path (.*?) is not located inside the pack directory (.*?)" + + file_path = os.path.join(packs_base_paths, "dummy_pack_2/actions/lib/foo.py") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = "/tmp/foo/bar.py" + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = os.path.join(packs_base_paths, "../dummy_pack_1/pack.yaml") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = os.path.join(packs_base_paths, "../../dummy_pack_1/pack.yaml") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) diff --git a/st2common/tests/unit/test_crypto_utils.py b/st2common/tests/unit/test_crypto_utils.py index 3bd63ecefe..5f8f07fa69 100644 --- a/st2common/tests/unit/test_crypto_utils.py +++ b/st2common/tests/unit/test_crypto_utils.py @@ -40,37 +40,32 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'CryptoUtilsTestCase', - 'CryptoUtilsKeyczarCompatibilityTestCase' -] +__all__ = ["CryptoUtilsTestCase", "CryptoUtilsKeyczarCompatibilityTestCase"] -KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'keyczar_keys/') +KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "keyczar_keys/") class CryptoUtilsTestCase(TestCase): - @classmethod def setUpClass(cls): super(CryptoUtilsTestCase, cls).setUpClass() CryptoUtilsTestCase.test_crypto_key = AESKey.generate() def test_symmetric_encrypt_decrypt_short_string_needs_to_be_padded(self): - original = u'a' + original = "a" crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original) plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto) self.assertEqual(plain, original) def test_symmetric_encrypt_decrypt_utf8_character(self): values = [ - u'£', - u'£££', - u'££££££', - u'č š hello đ č p ž Ž', - u'hello 💩', - u'💩💩💩💩💩' - u'💩💩💩', - u'💩😁' + "£", + "£££", + "££££££", + "č š hello đ č p ž Ž", + "hello 💩", + "💩💩💩💩💩" "💩💩💩", + "💩😁", ] for index, original in enumerate(values): @@ -81,13 +76,13 @@ def test_symmetric_encrypt_decrypt_utf8_character(self): self.assertEqual(index, (len(values) - 1)) def test_symmetric_encrypt_decrypt(self): - original = 'secret' + original = "secret" crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original) plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto) self.assertEqual(plain, original) def test_encrypt_output_is_diff_due_to_diff_IV(self): - original = 'Kami is a little boy.' + original = "Kami is a little boy." cryptos = set() for _ in range(0, 10000): @@ -97,7 +92,7 @@ def test_encrypt_output_is_diff_due_to_diff_IV(self): def test_decrypt_ciphertext_is_too_short(self): aes_key = AESKey.generate() - plaintext = 'hello world ponies 1' + plaintext = "hello world ponies 1" encrypted = cryptography_symmetric_encrypt(aes_key, plaintext) # Verify original non manipulated value can be decrypted @@ -117,13 +112,18 @@ def test_decrypt_ciphertext_is_too_short(self): encrypted_malformed = binascii.hexlify(encrypted_malformed) # Verify corrupted value results in an excpetion - expected_msg = 'Invalid or malformed ciphertext' - self.assertRaisesRegexp(ValueError, expected_msg, cryptography_symmetric_decrypt, - aes_key, encrypted_malformed) + expected_msg = "Invalid or malformed ciphertext" + self.assertRaisesRegexp( + ValueError, + expected_msg, + cryptography_symmetric_decrypt, + aes_key, + encrypted_malformed, + ) def test_exception_is_thrown_on_invalid_hmac_signature(self): aes_key = AESKey.generate() - plaintext = 'hello world ponies 2' + plaintext = "hello world ponies 2" encrypted = cryptography_symmetric_encrypt(aes_key, plaintext) # Verify original non manipulated value can be decrypted @@ -133,13 +133,18 @@ def test_exception_is_thrown_on_invalid_hmac_signature(self): # Corrupt the HMAC signature (last part is the HMAC signature) encrypted_malformed = binascii.unhexlify(encrypted) encrypted_malformed = encrypted_malformed[:-3] - encrypted_malformed += b'abc' + encrypted_malformed += b"abc" encrypted_malformed = binascii.hexlify(encrypted_malformed) # Verify corrupted value results in an excpetion - expected_msg = 'Signature did not match digest' - self.assertRaisesRegexp(InvalidSignature, expected_msg, cryptography_symmetric_decrypt, - aes_key, encrypted_malformed) + expected_msg = "Signature did not match digest" + self.assertRaisesRegexp( + InvalidSignature, + expected_msg, + cryptography_symmetric_decrypt, + aes_key, + encrypted_malformed, + ) class CryptoUtilsKeyczarCompatibilityTestCase(TestCase): @@ -150,44 +155,69 @@ class CryptoUtilsKeyczarCompatibilityTestCase(TestCase): def test_aes_key_class(self): # 1. Unsupported mode - expected_msg = 'Unsupported mode: EBC' - self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a', - hmac_key_string='b', hmac_key_size=128, mode='EBC') + expected_msg = "Unsupported mode: EBC" + self.assertRaisesRegexp( + ValueError, + expected_msg, + AESKey, + aes_key_string="a", + hmac_key_string="b", + hmac_key_size=128, + mode="EBC", + ) # 2. AES key is too small - expected_msg = 'Unsafe key size: 64' - self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a', - hmac_key_string='b', hmac_key_size=128, mode='CBC', size=64) + expected_msg = "Unsafe key size: 64" + self.assertRaisesRegexp( + ValueError, + expected_msg, + AESKey, + aes_key_string="a", + hmac_key_string="b", + hmac_key_size=128, + mode="CBC", + size=64, + ) def test_loading_keys_from_keyczar_formatted_key_files(self): - key_path = os.path.join(KEY_FIXTURES_PATH, 'one.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "one.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, 'lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc') + self.assertEqual( + aes_key.hmac_key_string, "lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual( + aes_key.aes_key_string, "vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58" + ) + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 256) - key_path = os.path.join(KEY_FIXTURES_PATH, 'two.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "two.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, '92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s') + self.assertEqual( + aes_key.hmac_key_string, "92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual( + aes_key.aes_key_string, "fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs" + ) + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 256) - key_path = os.path.join(KEY_FIXTURES_PATH, 'five.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "five.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, 'GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM') + self.assertEqual( + aes_key.hmac_key_string, "GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'EeBcUcbH14tL0w_fF5siEw') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual(aes_key.aes_key_string, "EeBcUcbH14tL0w_fF5siEw") + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 128) def test_key_generation_file_format_is_fully_keyczar_compatible(self): @@ -197,13 +227,13 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self): json_parsed = json.loads(key_json) expected = { - 'hmacKey': { - 'hmacKeyString': aes_key.hmac_key_string, - 'size': aes_key.hmac_key_size + "hmacKey": { + "hmacKeyString": aes_key.hmac_key_string, + "size": aes_key.hmac_key_size, }, - 'aesKeyString': aes_key.aes_key_string, - 'mode': aes_key.mode, - 'size': aes_key.size + "aesKeyString": aes_key.aes_key_string, + "mode": aes_key.mode, + "size": aes_key.size, } self.assertEqual(json_parsed, expected) @@ -211,15 +241,14 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self): def test_symmetric_encrypt_decrypt_cryptography(self): key = AESKey.generate() plaintexts = [ - 'a b c', - 'ab', - 'hello foo', - 'hell', - 'bar5' - 'hello hello bar bar hello', - 'a', - '', - 'c' + "a b c", + "ab", + "hello foo", + "hell", + "bar5" "hello hello bar bar hello", + "a", + "", + "c", ] for plaintext in plaintexts: @@ -228,13 +257,13 @@ def test_symmetric_encrypt_decrypt_cryptography(self): self.assertEqual(decrypted, plaintext) - @unittest2.skipIf(six.PY3, 'keyczar doesn\'t work under Python 3') + @unittest2.skipIf(six.PY3, "keyczar doesn't work under Python 3") def test_symmetric_encrypt_decrypt_roundtrips_1(self): encrypt_keys = [ AESKey.generate(), AESKey.generate(), AESKey.generate(), - AESKey.generate() + AESKey.generate(), ] # Verify all keys are unique @@ -248,7 +277,7 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self): self.assertEqual(len(aes_key_strings), 4) self.assertEqual(len(hmac_key_strings), 4) - plaintext = 'hello world test dummy 8 9 5 1 bar2' + plaintext = "hello world test dummy 8 9 5 1 bar2" # Verify that round trips work and that cryptography based primitives are fully compatible # with keyczar format @@ -261,14 +290,19 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self): self.assertNotEqual(data_enc_keyczar, data_enc_cryptography) data_dec_keyczar_keyczar = keyczar_symmetric_decrypt(key, data_enc_keyczar) - data_dec_keyczar_cryptography = keyczar_symmetric_decrypt(key, data_enc_cryptography) + data_dec_keyczar_cryptography = keyczar_symmetric_decrypt( + key, data_enc_cryptography + ) self.assertEqual(data_dec_keyczar_keyczar, plaintext) self.assertEqual(data_dec_keyczar_cryptography, plaintext) - data_dec_cryptography_cryptography = cryptography_symmetric_decrypt(key, - data_enc_cryptography) - data_dec_cryptography_keyczar = cryptography_symmetric_decrypt(key, data_enc_keyczar) + data_dec_cryptography_cryptography = cryptography_symmetric_decrypt( + key, data_enc_cryptography + ) + data_dec_cryptography_keyczar = cryptography_symmetric_decrypt( + key, data_enc_keyczar + ) self.assertEqual(data_dec_cryptography_cryptography, plaintext) self.assertEqual(data_dec_cryptography_keyczar, plaintext) diff --git a/st2common/tests/unit/test_datastore.py b/st2common/tests/unit/test_datastore.py index 30d3c7dc76..1e3dc86d30 100644 --- a/st2common/tests/unit/test_datastore.py +++ b/st2common/tests/unit/test_datastore.py @@ -28,12 +28,10 @@ from st2tests import DbTestCase from st2tests import config -__all__ = [ - 'DatastoreServiceTestCase' -] +__all__ = ["DatastoreServiceTestCase"] CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) class DatastoreServiceTestCase(DbTestCase): @@ -41,9 +39,9 @@ def setUp(self): super(DatastoreServiceTestCase, self).setUp() config.parse_args() - self._datastore_service = BaseDatastoreService(logger=mock.Mock(), - pack_name='core', - class_name='TestSensor') + self._datastore_service = BaseDatastoreService( + logger=mock.Mock(), pack_name="core", class_name="TestSensor" + ) self._datastore_service.get_api_client = mock.Mock() def test_datastore_operations_list_values(self): @@ -53,14 +51,14 @@ def test_datastore_operations_list_values(self): self._set_mock_api_client(mock_api_client) self._datastore_service.list_values(local=True, prefix=None) - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:') - self._datastore_service.list_values(local=True, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:ponies') + mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:") + self._datastore_service.list_values(local=True, prefix="ponies") + mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:ponies") self._datastore_service.list_values(local=False, prefix=None) mock_api_client.keys.get_all.assert_called_with(prefix=None) - self._datastore_service.list_values(local=False, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='ponies') + self._datastore_service.list_values(local=False, prefix="ponies") + mock_api_client.keys.get_all.assert_called_with(prefix="ponies") # No values in the datastore mock_api_client = mock.Mock() @@ -74,11 +72,11 @@ def test_datastore_operations_list_values(self): # Values in the datastore kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" kvp2 = KeyValuePair() - kvp2.name = 'test2' - kvp2.value = 'bar' + kvp2.name = "test2" + kvp2.value = "bar" mock_return_value = [kvp1, kvp2] mock_api_client.keys.get_all.return_value = mock_return_value self._set_mock_api_client(mock_api_client) @@ -90,12 +88,12 @@ def test_datastore_operations_list_values(self): def test_datastore_operations_get_value(self): mock_api_client = mock.Mock() kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" mock_api_client.keys.get_by_id.return_value = kvp1 self._set_mock_api_client(mock_api_client) - value = self._datastore_service.get_value(name='test1', local=False) + value = self._datastore_service.get_value(name="test1", local=False) self.assertEqual(value, kvp1.value) def test_datastore_operations_set_value(self): @@ -103,10 +101,12 @@ def test_datastore_operations_set_value(self): mock_api_client.keys.update.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.set_value(name='test1', value='foo', local=False) + value = self._datastore_service.set_value( + name="test1", value="foo", local=False + ) self.assertTrue(value) - kvp = mock_api_client.keys.update.call_args[1]['instance'] - self.assertEqual(kvp.value, 'foo') + kvp = mock_api_client.keys.update.call_args[1]["instance"] + self.assertEqual(kvp.value, "foo") self.assertEqual(kvp.scope, SYSTEM_SCOPE) def test_datastore_operations_delete_value(self): @@ -114,53 +114,69 @@ def test_datastore_operations_delete_value(self): mock_api_client.keys.delete.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.delete_value(name='test', local=False) + value = self._datastore_service.delete_value(name="test", local=False) self.assertTrue(value) def test_datastore_operations_set_encrypted_value(self): mock_api_client = mock.Mock() mock_api_client.keys.update.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.set_value(name='test1', value='foo', local=False, - encrypt=True) + value = self._datastore_service.set_value( + name="test1", value="foo", local=False, encrypt=True + ) self.assertTrue(value) - kvp = mock_api_client.keys.update.call_args[1]['instance'] - self.assertEqual(kvp.value, 'foo') + kvp = mock_api_client.keys.update.call_args[1]["instance"] + self.assertEqual(kvp.value, "foo") self.assertTrue(kvp.secret) self.assertEqual(kvp.scope, SYSTEM_SCOPE) def test_datastore_unsupported_scope(self): - self.assertRaises(ValueError, self._datastore_service.get_value, name='test1', - scope='NOT_SYSTEM') - self.assertRaises(ValueError, self._datastore_service.set_value, name='test1', - value='foo', scope='NOT_SYSTEM') - self.assertRaises(ValueError, self._datastore_service.delete_value, name='test1', - scope='NOT_SYSTEM') + self.assertRaises( + ValueError, + self._datastore_service.get_value, + name="test1", + scope="NOT_SYSTEM", + ) + self.assertRaises( + ValueError, + self._datastore_service.set_value, + name="test1", + value="foo", + scope="NOT_SYSTEM", + ) + self.assertRaises( + ValueError, + self._datastore_service.delete_value, + name="test1", + scope="NOT_SYSTEM", + ) def test_datastore_get_exception(self): mock_api_client = mock.Mock() mock_api_client.keys.get_by_id.side_effect = ValueError("Exception test") self._set_mock_api_client(mock_api_client) - value = self._datastore_service.get_value(name='test1') + value = self._datastore_service.get_value(name="test1") self.assertEqual(value, None) def test_datastore_delete_exception(self): mock_api_client = mock.Mock() mock_api_client.keys.delete.side_effect = ValueError("Exception test") self._set_mock_api_client(mock_api_client) - delete_success = self._datastore_service.delete_value(name='test1') + delete_success = self._datastore_service.delete_value(name="test1") self.assertEqual(delete_success, False) def test_datastore_token_timeout(self): - datastore_service = SensorDatastoreService(logger=mock.Mock(), - pack_name='core', - class_name='TestSensor', - api_username='sensor_service') + datastore_service = SensorDatastoreService( + logger=mock.Mock(), + pack_name="core", + class_name="TestSensor", + api_username="sensor_service", + ) mock_api_client = mock.Mock() kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" mock_api_client.keys.get_by_id.return_value = kvp1 token_expire_time = get_datetime_utc_now() - timedelta(seconds=5) @@ -170,10 +186,9 @@ def test_datastore_token_timeout(self): self._set_mock_api_client(mock_api_client) with mock.patch( - 'st2common.services.datastore.Client', - return_value=mock_api_client + "st2common.services.datastore.Client", return_value=mock_api_client ) as datastore_client: - value = datastore_service.get_value(name='test1', local=False) + value = datastore_service.get_value(name="test1", local=False) self.assertTrue(datastore_client.called) self.assertEqual(value, kvp1.value) self.assertGreater(datastore_service._token_expire, token_expire_time) diff --git a/st2common/tests/unit/test_date_utils.py b/st2common/tests/unit/test_date_utils.py index 1b1d3b465c..d453edb8f7 100644 --- a/st2common/tests/unit/test_date_utils.py +++ b/st2common/tests/unit/test_date_utils.py @@ -25,44 +25,44 @@ class DateUtilsTestCase(unittest2.TestCase): def test_get_datetime_utc_now(self): date = date_utils.get_datetime_utc_now() - self.assertEqual(date.tzinfo.tzname(None), 'UTC') + self.assertEqual(date.tzinfo.tzname(None), "UTC") def test_add_utc_tz(self): dt = datetime.datetime.utcnow() self.assertIsNone(dt.tzinfo) dt = date_utils.add_utc_tz(dt) self.assertIsNotNone(dt.tzinfo) - self.assertEqual(dt.tzinfo.tzname(None), 'UTC') + self.assertEqual(dt.tzinfo.tzname(None), "UTC") def test_convert_to_utc(self): date_without_tz = datetime.datetime.utcnow() self.assertEqual(date_without_tz.tzinfo, None) result = date_utils.convert_to_utc(date_without_tz) - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(result.tzinfo.tzname(None), "UTC") date_with_pdt_tz = datetime.datetime(2015, 10, 28, 10, 0, 0, 0) - date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone('US/Pacific')) - self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), 'US/Pacific') + date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone("US/Pacific")) + self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), "US/Pacific") result = date_utils.convert_to_utc(date_with_pdt_tz) - self.assertEqual(str(result), '2015-10-28 17:53:00+00:00') - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(str(result), "2015-10-28 17:53:00+00:00") + self.assertEqual(result.tzinfo.tzname(None), "UTC") def test_parse(self): - date_str_without_tz = 'January 1st, 2014 10:00:00' + date_str_without_tz = "January 1st, 2014 10:00:00" result = date_utils.parse(value=date_str_without_tz) - self.assertEqual(str(result), '2014-01-01 10:00:00+00:00') - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(str(result), "2014-01-01 10:00:00+00:00") + self.assertEqual(result.tzinfo.tzname(None), "UTC") # preserve original tz - date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00' + date_str_with_tz = "January 1st, 2014 10:00:00 +07:00" result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=True) - self.assertEqual(str(result), '2014-01-01 10:00:00+07:00') + self.assertEqual(str(result), "2014-01-01 10:00:00+07:00") self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=7)) # convert to utc - date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00' + date_str_with_tz = "January 1st, 2014 10:00:00 +07:00" result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=False) - self.assertEqual(str(result), '2014-01-01 03:00:00+00:00') + self.assertEqual(str(result), "2014-01-01 03:00:00+00:00") self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=0)) - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(result.tzinfo.tzname(None), "UTC") diff --git a/st2common/tests/unit/test_db.py b/st2common/tests/unit/test_db.py index 756c0a105e..da0157127e 100644 --- a/st2common/tests/unit/test_db.py +++ b/st2common/tests/unit/test_db.py @@ -18,6 +18,7 @@ # NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. # See https://github.com/StackStorm/st2/pull/4834 for details from st2common.util.monkey_patch import monkey_patch + monkey_patch() import ssl @@ -52,47 +53,50 @@ __all__ = [ - 'DbConnectionTestCase', - 'DbConnectionTestCase', - 'ReactorModelTestCase', - 'ActionModelTestCase', - 'KeyValuePairModelTestCase' + "DbConnectionTestCase", + "DbConnectionTestCase", + "ReactorModelTestCase", + "ActionModelTestCase", + "KeyValuePairModelTestCase", ] SKIP_DELETE = False -DUMMY_DESCRIPTION = 'Sample Description.' +DUMMY_DESCRIPTION = "Sample Description." class DbIndexNameTestCase(TestCase): """ Test which verifies that model index name are not longer than the specified limit. """ + LIMIT = 65 def test_index_name_length(self): - db_name = 'st2' + db_name = "st2" for model in ALL_MODELS: collection_name = model._get_collection_name() - model_indexes = model._meta['index_specs'] + model_indexes = model._meta["index_specs"] for index_specs in model_indexes: - index_name = index_specs.get('name', None) + index_name = index_specs.get("name", None) if index_name: # Custom index name defined by the developer index_field_name = index_name else: # No explicit index name specified, one is auto-generated using # .. schema - index_fields = dict(index_specs['fields']).keys() - index_field_name = '.'.join(index_fields) + index_fields = dict(index_specs["fields"]).keys() + index_field_name = ".".join(index_fields) - index_name = '%s.%s.%s' % (db_name, collection_name, index_field_name) + index_name = "%s.%s.%s" % (db_name, collection_name, index_field_name) if len(index_name) > self.LIMIT: - self.fail('Index name "%s" for model "%s" is longer than %s characters. ' - 'Please manually define name for this index so it\'s shorter than ' - 'that' % (index_name, model.__name__, self.LIMIT)) + self.fail( + 'Index name "%s" for model "%s" is longer than %s characters. ' + "Please manually define name for this index so it's shorter than " + "that" % (index_name, model.__name__, self.LIMIT) + ) class DbConnectionTestCase(DbTestCase): @@ -111,210 +115,293 @@ def test_check_connect(self): """ client = mongoengine.connection.get_connection() - expected_str = "host=['%s:%s']" % (cfg.CONF.database.host, cfg.CONF.database.port) - self.assertIn(expected_str, str(client), 'Not connected to desired host.') + expected_str = "host=['%s:%s']" % ( + cfg.CONF.database.host, + cfg.CONF.database.port, + ) + self.assertIn(expected_str, str(client), "Not connected to desired host.") def test_get_ssl_kwargs(self): # 1. No SSL kwargs provided ssl_kwargs = _get_ssl_kwargs() - self.assertEqual(ssl_kwargs, {'ssl': False}) + self.assertEqual(ssl_kwargs, {"ssl": False}) # 2. ssl kwarg provided ssl_kwargs = _get_ssl_kwargs(ssl=True) - self.assertEqual(ssl_kwargs, {'ssl': True, 'ssl_match_hostname': True}) + self.assertEqual(ssl_kwargs, {"ssl": True, "ssl_match_hostname": True}) # 2. authentication_mechanism kwarg provided - ssl_kwargs = _get_ssl_kwargs(authentication_mechanism='MONGODB-X509') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_match_hostname': True, - 'authentication_mechanism': 'MONGODB-X509' - }) + ssl_kwargs = _get_ssl_kwargs(authentication_mechanism="MONGODB-X509") + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_match_hostname": True, + "authentication_mechanism": "MONGODB-X509", + }, + ) # 3. ssl_keyfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_keyfile': '/tmp/keyfile', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_keyfile": "/tmp/keyfile", "ssl_match_hostname": True}, + ) # 4. ssl_certfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_certfile': '/tmp/certfile', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_certfile": "/tmp/certfile", "ssl_match_hostname": True}, + ) # 5. ssl_ca_certs provided - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_ca_certs": "/tmp/ca_certs", "ssl_match_hostname": True}, + ) # 6. ssl_ca_certs and ssl_cert_reqs combinations - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_NONE, - 'ssl_match_hostname': True - }) - - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_OPTIONAL, - 'ssl_match_hostname': True - }) - - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_REQUIRED, - 'ssl_match_hostname': True - }) - - @mock.patch('st2common.models.db.mongoengine') + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none") + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_NONE, + "ssl_match_hostname": True, + }, + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional" + ) + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_OPTIONAL, + "ssl_match_hostname": True, + }, + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required" + ) + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_REQUIRED, + "ssl_match_hostname": True, + }, + ) + + @mock.patch("st2common.models.db.mongoengine") def test_db_setup(self, mock_mongoengine): - db_setup(db_name='name', db_host='host', db_port=12345, username='username', - password='password', authentication_mechanism='MONGODB-X509') + db_setup( + db_name="name", + db_host="host", + db_port=12345, + username="username", + password="password", + authentication_mechanism="MONGODB-X509", + ) call_args = mock_mongoengine.connection.connect.call_args_list[0][0] call_kwargs = mock_mongoengine.connection.connect.call_args_list[0][1] - self.assertEqual(call_args, ('name',)) - self.assertEqual(call_kwargs, { - 'host': 'host', - 'port': 12345, - 'username': 'username', - 'password': 'password', - 'tz_aware': True, - 'authentication_mechanism': 'MONGODB-X509', - 'ssl': True, - 'ssl_match_hostname': True, - 'connectTimeoutMS': 3000, - 'serverSelectionTimeoutMS': 3000 - }) - - @mock.patch('st2common.models.db.mongoengine') - @mock.patch('st2common.models.db.LOG') + self.assertEqual(call_args, ("name",)) + self.assertEqual( + call_kwargs, + { + "host": "host", + "port": 12345, + "username": "username", + "password": "password", + "tz_aware": True, + "authentication_mechanism": "MONGODB-X509", + "ssl": True, + "ssl_match_hostname": True, + "connectTimeoutMS": 3000, + "serverSelectionTimeoutMS": 3000, + }, + ) + + @mock.patch("st2common.models.db.mongoengine") + @mock.patch("st2common.models.db.LOG") def test_db_setup_connecting_info_logging(self, mock_log, mock_mongoengine): # Verify that password is not included in the log message - db_name = 'st2' - db_port = '27017' - username = 'user_st2' - password = 'pass_st2' + db_name = "st2" + db_port = "27017" + username = "user_st2" + password = "pass_st2" # 1. Password provided as separate argument - db_host = 'localhost' - username = 'user_st2' - password = 'pass_st2' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".' + db_host = "localhost" + username = "user_st2" + password = "pass_st2" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".' + ) actual_message = mock_log.info.call_args_list[0][0][0] self.assertEqual(expected_message, actual_message) # Check for helpful error messages if the connection is successful - expected_log_message = ('Successfully connected to database "st2" @ "localhost:27017" as ' - 'user "user_st2".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "localhost:27017" as ' + 'user "user_st2".' + ) actual_log_message = mock_log.info.call_args_list[1][0][0] self.assertEqual(expected_log_message, actual_log_message) # 2. Password provided as part of uri string (single host) - db_host = 'mongodb://user_st22:pass_st22@127.0.0.2:5555' + db_host = "mongodb://user_st22:pass_st22@127.0.0.2:5555" username = None password = None - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".' + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".' + ) actual_message = mock_log.info.call_args_list[2][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st22".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st22".' + ) actual_log_message = mock_log.info.call_args_list[3][0][0] self.assertEqual(expected_log_message, actual_log_message) # 3. Password provided as part of uri string (single host) - username # provided as argument has precedence - db_host = 'mongodb://user_st210:pass_st23@127.0.0.2:5555' - username = 'user_st23' + db_host = "mongodb://user_st210:pass_st23@127.0.0.2:5555" + username = "user_st23" password = None - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".' + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".' + ) actual_message = mock_log.info.call_args_list[4][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st23".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st23".' + ) actual_log_message = mock_log.info.call_args_list[5][0][0] self.assertEqual(expected_log_message, actual_log_message) # 4. Just host provided in the url string - db_host = 'mongodb://127.0.0.2:5555' - username = 'user_st24' - password = 'foobar' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".' + db_host = "mongodb://127.0.0.2:5555" + username = "user_st24" + password = "foobar" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".' + ) actual_message = mock_log.info.call_args_list[6][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st24".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st24".' + ) actual_log_message = mock_log.info.call_args_list[7][0][0] self.assertEqual(expected_log_message, actual_log_message) # 5. Multiple hosts specified as part of connection uri - db_host = 'mongodb://user6:pass6@host1,host2,host3' + db_host = "mongodb://user6:pass6@host1,host2,host3" username = None - password = 'foobar' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = ('Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 ' - '(replica set)" as user "user6".') + password = "foobar" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 ' + '(replica set)" as user "user6".' + ) actual_message = mock_log.info.call_args_list[8][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ ' - '"host1:27017,host2:27017,host3:27017 ' - '(replica set)" as user "user6".') + expected_log_message = ( + 'Successfully connected to database "st2" @ ' + '"host1:27017,host2:27017,host3:27017 ' + '(replica set)" as user "user6".' + ) actual_log_message = mock_log.info.call_args_list[9][0][0] self.assertEqual(expected_log_message, actual_log_message) # 6. Check for error message when failing to establish a connection mock_connect = mock.Mock() - mock_connect.admin.command = mock.Mock(side_effect=ConnectionFailure('Failed to connect')) + mock_connect.admin.command = mock.Mock( + side_effect=ConnectionFailure("Failed to connect") + ) mock_mongoengine.connection.connect.return_value = mock_connect - db_host = 'mongodb://localhost:9797' - username = 'user_st2' - password = 'pass_st2' - - expected_msg = 'Failed to connect' - self.assertRaisesRegexp(ConnectionFailure, expected_msg, db_setup, - db_name=db_name, db_host=db_host, db_port=db_port, - username=username, password=password) - - expected_message = 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".' + db_host = "mongodb://localhost:9797" + username = "user_st2" + password = "pass_st2" + + expected_msg = "Failed to connect" + self.assertRaisesRegexp( + ConnectionFailure, + expected_msg, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".' + ) actual_message = mock_log.info.call_args_list[10][0][0] self.assertEqual(expected_message, actual_message) - expected_message = ('Failed to connect to database "st2" @ "localhost:9797" as user ' - '"user_st2": Failed to connect') + expected_message = ( + 'Failed to connect to database "st2" @ "localhost:9797" as user ' + '"user_st2": Failed to connect' + ) actual_message = mock_log.error.call_args_list[0][0][0] self.assertEqual(expected_message, actual_message) @@ -323,29 +410,43 @@ def test_db_connect_server_selection_timeout_ssl_on_non_ssl_listener(self): # and propagating the error disconnect() - db_name = 'st2' - db_host = 'localhost' + db_name = "st2" + db_host = "localhost" db_port = 27017 - cfg.CONF.set_override(name='connection_timeout', group='database', override=1000) + cfg.CONF.set_override( + name="connection_timeout", group="database", override=1000 + ) start = time.time() - self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host, - db_port=db_port, ssl=True) + self.assertRaises( + ServerSelectionTimeoutError, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + ssl=True, + ) end = time.time() - diff = (end - start) + diff = end - start self.assertTrue(diff >= 1) disconnect() - cfg.CONF.set_override(name='connection_timeout', group='database', override=400) + cfg.CONF.set_override(name="connection_timeout", group="database", override=400) start = time.time() - self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host, - db_port=db_port, ssl=True) + self.assertRaises( + ServerSelectionTimeoutError, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + ssl=True, + ) end = time.time() - diff = (end - start) + diff = end - start self.assertTrue(diff >= 0.4) @@ -364,60 +465,63 @@ def test_cleanup(self): self.assertNotIn(cfg.CONF.database.db_name, connection.database_names()) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ReactorModelTestCase(DbTestCase): - def test_triggertype_crud(self): saved = ReactorModelTestCase._create_save_triggertype() retrieved = TriggerType.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same triggertype was not returned.') + self.assertEqual( + saved.name, retrieved.name, "Same triggertype was not returned." + ) # test update - self.assertEqual(retrieved.description, '') + self.assertEqual(retrieved.description, "") retrieved.description = DUMMY_DESCRIPTION saved = TriggerType.add_or_update(retrieved) retrieved = TriggerType.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed." + ) # cleanup ReactorModelTestCase._delete([retrieved]) try: retrieved = TriggerType.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_trigger_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() saved = ReactorModelTestCase._create_save_trigger(triggertype) retrieved = Trigger.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same trigger was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same trigger was not returned.") # test update - self.assertEqual(retrieved.description, '') + self.assertEqual(retrieved.description, "") retrieved.description = DUMMY_DESCRIPTION saved = Trigger.add_or_update(retrieved) retrieved = Trigger.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed." + ) # cleanup ReactorModelTestCase._delete([retrieved, triggertype]) try: retrieved = Trigger.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_triggerinstance_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() trigger = ReactorModelTestCase._create_save_trigger(triggertype) saved = ReactorModelTestCase._create_save_triggerinstance(trigger) retrieved = TriggerInstance.get_by_id(saved.id) - self.assertIsNotNone(retrieved, 'No triggerinstance created.') + self.assertIsNotNone(retrieved, "No triggerinstance created.") ReactorModelTestCase._delete([retrieved, trigger, triggertype]) try: retrieved = TriggerInstance.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_rule_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() @@ -426,20 +530,22 @@ def test_rule_crud(self): action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) retrieved = Rule.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, 'Same rule was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same rule was not returned.") # test update self.assertEqual(retrieved.enabled, True) retrieved.enabled = False saved = Rule.add_or_update(retrieved) retrieved = Rule.get_by_id(saved.id) - self.assertEqual(retrieved.enabled, False, 'Update to rule failed.') + self.assertEqual(retrieved.enabled, False, "Update to rule failed.") # cleanup - ReactorModelTestCase._delete([retrieved, trigger, action, runnertype, triggertype]) + ReactorModelTestCase._delete( + [retrieved, trigger, action, runnertype, triggertype] + ) try: retrieved = Rule.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_rule_lookup(self): triggertype = ReactorModelTestCase._create_save_triggertype() @@ -447,10 +553,12 @@ def test_rule_lookup(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger)) - self.assertEqual(1, len(retrievedrules), 'No rules found.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger) + ) + self.assertEqual(1, len(retrievedrules), "No rules found.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_rule_lookup_enabled(self): @@ -459,12 +567,12 @@ def test_rule_lookup_enabled(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger), - enabled=True) - self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger), enabled=True + ) + self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, - 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_rule_lookup_disabled(self): @@ -473,49 +581,64 @@ def test_rule_lookup_disabled(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action, False) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger), - enabled=False) - self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger), enabled=False + ) + self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_trigger_lookup(self): triggertype = ReactorModelTestCase._create_save_triggertype() saved = ReactorModelTestCase._create_save_trigger(triggertype) retrievedtriggers = Trigger.query(name=saved.name) - self.assertEqual(1, len(retrievedtriggers), 'No triggers found.') + self.assertEqual(1, len(retrievedtriggers), "No triggers found.") for retrievedtrigger in retrievedtriggers: - self.assertEqual(saved.id, retrievedtrigger.id, - 'Incorrect trigger returned.') + self.assertEqual( + saved.id, retrievedtrigger.id, "Incorrect trigger returned." + ) ReactorModelTestCase._delete([saved, triggertype]) @staticmethod def _create_save_triggertype(): - created = TriggerTypeDB(pack='dummy_pack_1', name='triggertype-1', description='', - payload_schema={}, parameters_schema={}) + created = TriggerTypeDB( + pack="dummy_pack_1", + name="triggertype-1", + description="", + payload_schema={}, + parameters_schema={}, + ) return Trigger.add_or_update(created) @staticmethod def _create_save_trigger(triggertype): - created = TriggerDB(pack='dummy_pack_1', name='trigger-1', description='', - type=triggertype.get_reference().ref, parameters={}) + created = TriggerDB( + pack="dummy_pack_1", + name="trigger-1", + description="", + type=triggertype.get_reference().ref, + parameters={}, + ) return Trigger.add_or_update(created) @staticmethod def _create_save_triggerinstance(trigger): - created = TriggerInstanceDB(trigger=trigger.get_reference().ref, payload={}, - occurrence_time=date_utils.get_datetime_utc_now(), - status=TRIGGER_INSTANCE_PROCESSED) + created = TriggerInstanceDB( + trigger=trigger.get_reference().ref, + payload={}, + occurrence_time=date_utils.get_datetime_utc_now(), + status=TRIGGER_INSTANCE_PROCESSED, + ) return TriggerInstance.add_or_update(created) @staticmethod def _create_save_rule(trigger, action=None, enabled=True): - name = 'rule-1' - pack = 'default' + name = "rule-1" + pack = "default" ref = ResourceReference.to_string_reference(name=name, pack=pack) created = RuleDB(name=name, pack=pack, ref=ref) - created.description = '' + created.description = "" created.enabled = enabled created.trigger = reference.get_str_resource_ref_from_model(trigger) created.criteria = {} @@ -547,44 +670,21 @@ def _delete(model_objects): "description": "awesomeness", "type": "object", "properties": { - "r1": { - "type": "object", - "properties": { - "r1a": { - "type": "string" - } - } - }, - "r2": { - "type": "string", - "required": True - }, - "p1": { - "type": "string", - "required": True - }, - "p2": { - "type": "number", - "default": 2868 - }, - "p3": { - "type": "boolean", - "default": False - }, - "p4": { - "type": "string", - "secret": True - } + "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}}, + "r2": {"type": "string", "required": True}, + "p1": {"type": "string", "required": True}, + "p2": {"type": "number", "default": 2868}, + "p3": {"type": "boolean", "default": False}, + "p4": {"type": "string", "secret": True}, }, - "additionalProperties": False + "additionalProperties": False, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionModelTestCase(DbTestCase): - def tearDown(self): - runnertype = RunnerType.get_by_name('python') + runnertype = RunnerType.get_by_name("python") self._delete([runnertype]) super(ActionModelTestCase, self).tearDown() @@ -592,15 +692,16 @@ def test_action_crud(self): runnertype = self._create_save_runnertype(metadata=False) saved = self._create_save_action(runnertype, metadata=False) retrieved = Action.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same Action was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same Action was not returned.") # test update - self.assertEqual(retrieved.description, 'awesomeness') + self.assertEqual(retrieved.description, "awesomeness") retrieved.description = DUMMY_DESCRIPTION saved = Action.add_or_update(retrieved) retrieved = Action.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to action failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to action failed." + ) # cleanup self._delete([retrieved]) @@ -608,14 +709,14 @@ def test_action_crud(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_action_with_notify_crud(self): runnertype = self._create_save_runnertype(metadata=False) saved = self._create_save_action(runnertype, metadata=False) # Update action with notification settings - on_complete = NotificationSubSchema(message='Action complete.') + on_complete = NotificationSubSchema(message="Action complete.") saved.notify = NotificationSchema(on_complete=on_complete) saved = Action.add_or_update(saved) @@ -635,7 +736,7 @@ def test_action_with_notify_crud(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_parameter_schema(self): runnertype = self._create_save_runnertype(metadata=True) @@ -650,13 +751,30 @@ def test_parameter_schema(self): # use schema to validate parameters jsonschema.validate({"r2": "abc", "p1": "def"}, schema, validator) - jsonschema.validate({"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - '{"r2": "abc", "p1": "def"}', schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - {"r2": "abc"}, schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - {"r2": "abc", "p1": "def", "r1": 123}, schema, validator) + jsonschema.validate( + {"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + '{"r2": "abc", "p1": "def"}', + schema, + validator, + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + {"r2": "abc"}, + schema, + validator, + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + {"r2": "abc", "p1": "def", "r1": 123}, + schema, + validator, + ) # cleanup self._delete([retrieved]) @@ -664,7 +782,7 @@ def test_parameter_schema(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(self): # Test that the runner and action parameters are correctly deep merged when building @@ -673,54 +791,55 @@ def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(sel self._create_save_runnertype(metadata=True) action_db = mock.Mock() - action_db.runner_type = {'name': 'python'} - action_db.parameters = {'r1': {'immutable': True}} + action_db.runner_type = {"name": "python"} + action_db.parameters = {"r1": {"immutable": True}} schema = util_schema.get_schema_for_action_parameters(action_db=action_db) expected = { - u'type': u'object', - u'properties': { - u'r1a': { - u'type': u'string' - } - }, - 'immutable': True + "type": "object", + "properties": {"r1a": {"type": "string"}}, + "immutable": True, } - self.assertEqual(schema['properties']['r1'], expected) + self.assertEqual(schema["properties"]["r1"], expected) @staticmethod def _create_save_runnertype(metadata=False): - created = RunnerTypeDB(name='python') - created.description = '' + created = RunnerTypeDB(name="python") + created.description = "" created.enabled = True if not metadata: - created.runner_parameters = {'r1': None, 'r2': None} + created.runner_parameters = {"r1": None, "r2": None} else: created.runner_parameters = { - 'r1': {'type': 'object', 'properties': {'r1a': {'type': 'string'}}}, - 'r2': {'type': 'string', 'required': True} + "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}}, + "r2": {"type": "string", "required": True}, } - created.runner_module = 'nomodule' + created.runner_module = "nomodule" return RunnerType.add_or_update(created) @staticmethod def _create_save_action(runnertype, metadata=False): - name = 'action-1' - pack = 'wolfpack' + name = "action-1" + pack = "wolfpack" ref = ResourceReference(pack=pack, name=name).ref - created = ActionDB(name=name, description='awesomeness', enabled=True, - entry_point='/tmp/action.py', pack=pack, - ref=ref, - runner_type={'name': runnertype.name}) + created = ActionDB( + name=name, + description="awesomeness", + enabled=True, + entry_point="/tmp/action.py", + pack=pack, + ref=ref, + runner_type={"name": runnertype.name}, + ) if not metadata: - created.parameters = {'p1': None, 'p2': None, 'p3': None, 'p4': None} + created.parameters = {"p1": None, "p2": None, "p3": None, "p4": None} else: created.parameters = { - 'p1': {'type': 'string', 'required': True}, - 'p2': {'type': 'number', 'default': 2868}, - 'p3': {'type': 'boolean', 'default': False}, - 'p4': {'type': 'string', 'secret': True} + "p1": {"type": "string", "required": True}, + "p2": {"type": "number", "default": 2868}, + "p3": {"type": "boolean", "default": False}, + "p4": {"type": "string", "secret": True}, } return Action.add_or_update(created) @@ -738,20 +857,19 @@ def _delete(model_objects): class KeyValuePairModelTestCase(DbTestCase): - def test_kvp_crud(self): saved = KeyValuePairModelTestCase._create_save_kvp() retrieved = KeyValuePair.get_by_name(saved.name) - self.assertEqual(saved.id, retrieved.id, - 'Same KeyValuePair was not returned.') + self.assertEqual(saved.id, retrieved.id, "Same KeyValuePair was not returned.") # test update - self.assertEqual(retrieved.value, '0123456789ABCDEF') - retrieved.value = 'ABCDEF0123456789' + self.assertEqual(retrieved.value, "0123456789ABCDEF") + retrieved.value = "ABCDEF0123456789" saved = KeyValuePair.add_or_update(retrieved) retrieved = KeyValuePair.get_by_name(saved.name) - self.assertEqual(retrieved.value, 'ABCDEF0123456789', - 'Update of key value failed') + self.assertEqual( + retrieved.value, "ABCDEF0123456789", "Update of key value failed" + ) # cleanup KeyValuePairModelTestCase._delete([retrieved]) @@ -759,11 +877,11 @@ def test_kvp_crud(self): retrieved = KeyValuePair.get_by_name(saved.name) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_kvp(): - created = KeyValuePairDB(name='token', value='0123456789ABCDEF') + created = KeyValuePairDB(name="token", value="0123456789ABCDEF") return KeyValuePair.add_or_update(created) @staticmethod diff --git a/st2common/tests/unit/test_db_action_state.py b/st2common/tests/unit/test_db_action_state.py index 3251898e29..47b9d170bd 100644 --- a/st2common/tests/unit/test_db_action_state.py +++ b/st2common/tests/unit/test_db_action_state.py @@ -34,13 +34,13 @@ def test_state_crud(self): retrieved = ActionExecutionState.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_actionstate(): created = ActionExecutionStateDB() - created.query_context = {'id': 'some_external_service_id'} - created.query_module = 'dummy.modules.query1' + created.query_context = {"id": "some_external_service_id"} + created.query_module = "dummy.modules.query1" created.execution_id = bson.ObjectId() return ActionExecutionState.add_or_update(created) diff --git a/st2common/tests/unit/test_db_auth.py b/st2common/tests/unit/test_db_auth.py index 9cf35bc737..b159580505 100644 --- a/st2common/tests/unit/test_db_auth.py +++ b/st2common/tests/unit/test_db_auth.py @@ -26,44 +26,35 @@ from tests.unit.base import BaseDBModelCRUDTestCase -__all__ = [ - 'UserDBModelCRUDTestCase' -] +__all__ = ["UserDBModelCRUDTestCase"] class UserDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = UserDB persistance_class = User model_class_kwargs = { - 'name': 'pony', - 'is_service': False, - 'nicknames': { - 'pony1': 'ponyA' - } + "name": "pony", + "is_service": False, + "nicknames": {"pony1": "ponyA"}, } - update_attribute_name = 'name' + update_attribute_name = "name" class TokenDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = TokenDB persistance_class = Token model_class_kwargs = { - 'user': 'pony', - 'token': 'token-token-token-token', - 'expiry': get_datetime_utc_now(), - 'metadata': { - 'service': 'action-runner' - } + "user": "pony", + "token": "token-token-token-token", + "expiry": get_datetime_utc_now(), + "metadata": {"service": "action-runner"}, } - skip_check_attribute_names = ['expiry'] - update_attribute_name = 'user' + skip_check_attribute_names = ["expiry"] + update_attribute_name = "user" class ApiKeyDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = ApiKeyDB persistance_class = ApiKey - model_class_kwargs = { - 'user': 'pony', - 'key_hash': 'token-token-token-token' - } - update_attribute_name = 'user' + model_class_kwargs = {"user": "pony", "key_hash": "token-token-token-token"} + update_attribute_name = "user" diff --git a/st2common/tests/unit/test_db_base.py b/st2common/tests/unit/test_db_base.py index 0c77c336bf..6849643243 100644 --- a/st2common/tests/unit/test_db_base.py +++ b/st2common/tests/unit/test_db_base.py @@ -27,11 +27,11 @@ class FakeRuleSpecDB(mongoengine.EmbeddedDocument): def __str__(self): result = [] - result.append('ActionExecutionSpecDB@') - result.append('test') + result.append("ActionExecutionSpecDB@") + result.append("test") result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) class FakeModel(stormbase.StormBaseDB): @@ -52,30 +52,43 @@ class FakeRuleModel(stormbase.StormBaseDB): class TestBaseModel(DbTestCase): - def test_print(self): - instance = FakeModel(name='seesaw', boolean_field=True, - datetime_field=date_utils.get_datetime_utc_now(), - description=u'fun!', dict_field={'a': 1}, - integer_field=68, list_field=['abc']) - - expected = ('FakeModel(boolean_field=True, datetime_field="%s", description="fun!", ' - 'dict_field={\'a\': 1}, id=None, integer_field=68, list_field=[\'abc\'], ' - 'name="seesaw")' % str(instance.datetime_field)) + instance = FakeModel( + name="seesaw", + boolean_field=True, + datetime_field=date_utils.get_datetime_utc_now(), + description="fun!", + dict_field={"a": 1}, + integer_field=68, + list_field=["abc"], + ) + + expected = ( + 'FakeModel(boolean_field=True, datetime_field="%s", description="fun!", ' + "dict_field={'a': 1}, id=None, integer_field=68, list_field=['abc'], " + 'name="seesaw")' % str(instance.datetime_field) + ) self.assertEqual(str(instance), expected) def test_rule_print(self): - instance = FakeRuleModel(name='seesaw', boolean_field=True, - datetime_field=date_utils.get_datetime_utc_now(), - description=u'fun!', dict_field={'a': 1}, - integer_field=68, list_field=['abc'], - embedded_doc_field={'ref': '1234', 'parameters': {'b': 2}}) - - expected = ('FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", ' - 'dict_field={\'a\': 1}, embedded_doc_field=ActionExecutionSpecDB@test(' - 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, ' - 'list_field=[\'abc\'], ' - 'name="seesaw")' % str(instance.datetime_field)) + instance = FakeRuleModel( + name="seesaw", + boolean_field=True, + datetime_field=date_utils.get_datetime_utc_now(), + description="fun!", + dict_field={"a": 1}, + integer_field=68, + list_field=["abc"], + embedded_doc_field={"ref": "1234", "parameters": {"b": 2}}, + ) + + expected = ( + 'FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", ' + "dict_field={'a': 1}, embedded_doc_field=ActionExecutionSpecDB@test(" + 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, ' + "list_field=['abc'], " + 'name="seesaw")' % str(instance.datetime_field) + ) self.assertEqual(str(instance), expected) diff --git a/st2common/tests/unit/test_db_execution.py b/st2common/tests/unit/test_db_execution.py index 62478ee13e..e94ccb3d94 100644 --- a/st2common/tests/unit/test_db_execution.py +++ b/st2common/tests/unit/test_db_execution.py @@ -27,79 +27,71 @@ INQUIRY_RESULT = { - 'users': [], - 'roles': [], - 'route': 'developers', - 'ttl': 1440, - 'response': { - 'secondfactor': 'supersecretvalue' - }, - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': 'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "users": [], + "roles": [], + "route": "developers", + "ttl": 1440, + "response": {"secondfactor": "supersecretvalue"}, + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } + }, + }, } INQUIRY_LIVEACTION = { - 'parameters': { - 'route': 'developers', - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': u'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "parameters": { + "route": "developers", + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } + }, + }, }, - 'action': 'core.ask' + "action": "core.ask", } RESPOND_LIVEACTION = { - 'parameters': { - 'response': { - 'secondfactor': 'omgsupersecret', + "parameters": { + "response": { + "secondfactor": "omgsupersecret", } }, - 'action': 'st2.inquiry.respond' + "action": "st2.inquiry.respond", } ACTIONEXECUTIONS = { "execution_1": { - 'action': {'uid': 'action:core:ask'}, - 'status': 'succeeded', - 'runner': {'name': 'inquirer'}, - 'liveaction': INQUIRY_LIVEACTION, - 'result': INQUIRY_RESULT + "action": {"uid": "action:core:ask"}, + "status": "succeeded", + "runner": {"name": "inquirer"}, + "liveaction": INQUIRY_LIVEACTION, + "result": INQUIRY_RESULT, }, "execution_2": { - 'action': {'uid': 'action:st2:inquiry.respond'}, - 'status': 'succeeded', - 'runner': {'name': 'python-script'}, - 'liveaction': RESPOND_LIVEACTION, - 'result': { - 'exit_code': 0, - 'result': None, - 'stderr': '', - 'stdout': '' - } - } + "action": {"uid": "action:st2:inquiry.respond"}, + "status": "succeeded", + "runner": {"name": "python-script"}, + "liveaction": RESPOND_LIVEACTION, + "result": {"exit_code": 0, "result": None, "stderr": "", "stdout": ""}, + }, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionExecutionModelTest(DbTestCase): - def setUp(self): self.executions = {} @@ -107,16 +99,17 @@ def setUp(self): for name, execution in ACTIONEXECUTIONS.items(): created = ActionExecutionDB() - created.action = execution['action'] - created.status = execution['status'] - created.runner = execution['runner'] - created.liveaction = execution['liveaction'] - created.result = execution['result'] + created.action = execution["action"] + created.status = execution["status"] + created.runner = execution["runner"] + created.liveaction = execution["liveaction"] + created.result = execution["result"] saved = ActionExecutionModelTest._save_execution(created) retrieved = ActionExecution.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same action was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same action was not returned." + ) self.executions[name] = retrieved @@ -128,15 +121,16 @@ def tearDown(self): retrieved = ActionExecution.get_by_id(execution.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_update_execution(self): - """Test ActionExecutionDb update - """ - self.assertIsNone(self.executions['execution_1'].end_timestamp) - self.executions['execution_1'].end_timestamp = date_utils.get_datetime_utc_now() - updated = ActionExecution.add_or_update(self.executions['execution_1']) - self.assertTrue(updated.end_timestamp == self.executions['execution_1'].end_timestamp) + """Test ActionExecutionDb update""" + self.assertIsNone(self.executions["execution_1"].end_timestamp) + self.executions["execution_1"].end_timestamp = date_utils.get_datetime_utc_now() + updated = ActionExecution.add_or_update(self.executions["execution_1"]) + self.assertTrue( + updated.end_timestamp == self.executions["execution_1"].end_timestamp + ) def test_execution_inquiry_secrets(self): """Corner case test for Inquiry responses that contain secrets. @@ -148,13 +142,15 @@ def test_execution_inquiry_secrets(self): """ # Test Inquiry response masking is done properly within this model - masked = self.executions['execution_1'].mask_secrets( - self.executions['execution_1'].to_serializable_dict() + masked = self.executions["execution_1"].mask_secrets( + self.executions["execution_1"].to_serializable_dict() + ) + self.assertEqual( + masked["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE ) - self.assertEqual(masked['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE) self.assertEqual( - self.executions['execution_1'].result['response']['secondfactor'], - "supersecretvalue" + self.executions["execution_1"].result["response"]["secondfactor"], + "supersecretvalue", ) def test_execution_inquiry_response_action(self): @@ -164,10 +160,10 @@ def test_execution_inquiry_response_action(self): so we mask all response values. This test ensures this happens. """ - masked = self.executions['execution_2'].mask_secrets( - self.executions['execution_2'].to_serializable_dict() + masked = self.executions["execution_2"].mask_secrets( + self.executions["execution_2"].to_serializable_dict() ) - for value in masked['parameters']['response'].values(): + for value in masked["parameters"]["response"].values(): self.assertEqual(value, MASKED_ATTRIBUTE_VALUE) @staticmethod diff --git a/st2common/tests/unit/test_db_fields.py b/st2common/tests/unit/test_db_fields.py index eceb70d4c0..86fd3bc6fb 100644 --- a/st2common/tests/unit/test_db_fields.py +++ b/st2common/tests/unit/test_db_fields.py @@ -37,12 +37,12 @@ def test_round_trip_conversion(self): datetime_values = [ datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500), datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=0), - datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999) + datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999), ] datetime_values = [ date_utils.add_utc_tz(datetime_values[0]), date_utils.add_utc_tz(datetime_values[1]), - date_utils.add_utc_tz(datetime_values[2]) + date_utils.add_utc_tz(datetime_values[2]), ] microsecond_values = [] @@ -69,7 +69,7 @@ def test_round_trip_conversion(self): expected_value = datetime_values[index] self.assertEqual(actual_value, expected_value) - @mock.patch('st2common.fields.LongField.__get__') + @mock.patch("st2common.fields.LongField.__get__") def test_get_(self, mock_get): field = ComplexDateTimeField() @@ -79,7 +79,9 @@ def test_get_(self, mock_get): # Already a datetime mock_get.return_value = date_utils.get_datetime_utc_now() - self.assertEqual(field.__get__(instance=None, owner=None), mock_get.return_value) + self.assertEqual( + field.__get__(instance=None, owner=None), mock_get.return_value + ) # Microseconds dt = datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500) diff --git a/st2common/tests/unit/test_db_liveaction.py b/st2common/tests/unit/test_db_liveaction.py index 7c8b6aa35f..605aa759f6 100644 --- a/st2common/tests/unit/test_db_liveaction.py +++ b/st2common/tests/unit/test_db_liveaction.py @@ -26,19 +26,19 @@ from st2tests import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class LiveActionModelTest(DbTestCase): - def test_liveaction_crud_no_notify(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) self.assertEqual(retrieved.notify, None) # Test update @@ -52,80 +52,81 @@ def test_liveaction_crud_no_notify(self): retrieved = LiveAction.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_liveaction_create_with_notify_on_complete_only(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} notify_db = NotificationSchema() notify_sub_schema = NotificationSubSchema() - notify_sub_schema.message = 'Action complete.' - notify_sub_schema.data = { - 'foo': 'bar', - 'bar': 1, - 'baz': {'k1': 'v1'} - } + notify_sub_schema.message = "Action complete." + notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}} notify_db.on_complete = notify_sub_schema created.notify = notify_db saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. - self.assertEqual(notify_sub_schema.message, retrieved.notify.on_complete.message) + self.assertEqual( + notify_sub_schema.message, retrieved.notify.on_complete.message + ) self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_complete.data) - self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_complete.routes) + self.assertListEqual( + notify_sub_schema.routes, retrieved.notify.on_complete.routes + ) self.assertEqual(retrieved.notify.on_success, None) self.assertEqual(retrieved.notify.on_failure, None) def test_liveaction_create_with_notify_on_success_only(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} notify_db = NotificationSchema() notify_sub_schema = NotificationSubSchema() - notify_sub_schema.message = 'Action succeeded.' - notify_sub_schema.data = { - 'foo': 'bar', - 'bar': 1, - 'baz': {'k1': 'v1'} - } + notify_sub_schema.message = "Action succeeded." + notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}} notify_db.on_success = notify_sub_schema created.notify = notify_db saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. - self.assertEqual(notify_sub_schema.message, - retrieved.notify.on_success.message) + self.assertEqual(notify_sub_schema.message, retrieved.notify.on_success.message) self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_success.data) - self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_success.routes) + self.assertListEqual( + notify_sub_schema.routes, retrieved.notify.on_success.routes + ) self.assertEqual(retrieved.notify.on_failure, None) self.assertEqual(retrieved.notify.on_complete, None) def test_liveaction_create_with_notify_both_on_success_and_on_error(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - created.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + created.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. self.assertEqual(on_success.message, retrieved.notify.on_success.message) self.assertEqual(on_failure.message, retrieved.notify.on_failure.message) diff --git a/st2common/tests/unit/test_db_marker.py b/st2common/tests/unit/test_db_marker.py index 72dc879697..b9cd879ea3 100644 --- a/st2common/tests/unit/test_db_marker.py +++ b/st2common/tests/unit/test_db_marker.py @@ -26,26 +26,27 @@ class DumperMarkerModelTest(DbTestCase): def test_dumper_marker_crud(self): saved = DumperMarkerModelTest._create_save_dumper_marker() retrieved = DumperMarker.get_by_id(saved.id) - self.assertEqual(saved.marker, retrieved.marker, - 'Same marker was not returned.') + self.assertEqual( + saved.marker, retrieved.marker, "Same marker was not returned." + ) # test update time_now = date_utils.get_datetime_utc_now() retrieved.updated_at = time_now saved = DumperMarker.add_or_update(retrieved) retrieved = DumperMarker.get_by_id(saved.id) - self.assertEqual(retrieved.updated_at, time_now, 'Update to marker failed.') + self.assertEqual(retrieved.updated_at, time_now, "Update to marker failed.") # cleanup DumperMarkerModelTest._delete([retrieved]) try: retrieved = DumperMarker.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_dumper_marker(): created = DumperMarkerDB() - created.marker = '2015-06-11T00:35:15.260439Z' + created.marker = "2015-06-11T00:35:15.260439Z" created.updated_at = date_utils.get_datetime_utc_now() return DumperMarker.add_or_update(created) diff --git a/st2common/tests/unit/test_db_model_uids.py b/st2common/tests/unit/test_db_model_uids.py index 3f5ec1ca6c..2dd3bfb87d 100644 --- a/st2common/tests/unit/test_db_model_uids.py +++ b/st2common/tests/unit/test_db_model_uids.py @@ -30,72 +30,80 @@ from st2common.models.db.policy import PolicyDB from st2common.models.db.auth import ApiKeyDB -__all__ = [ - 'DBModelUIDFieldTestCase' -] +__all__ = ["DBModelUIDFieldTestCase"] class DBModelUIDFieldTestCase(unittest2.TestCase): def test_get_uid(self): - pack_db = PackDB(ref='ma_pack') - self.assertEqual(pack_db.get_uid(), 'pack:ma_pack') + pack_db = PackDB(ref="ma_pack") + self.assertEqual(pack_db.get_uid(), "pack:ma_pack") self.assertTrue(pack_db.has_valid_uid()) - sensor_type_db = SensorTypeDB(name='sname', pack='spack') - self.assertEqual(sensor_type_db.get_uid(), 'sensor_type:spack:sname') + sensor_type_db = SensorTypeDB(name="sname", pack="spack") + self.assertEqual(sensor_type_db.get_uid(), "sensor_type:spack:sname") self.assertTrue(sensor_type_db.has_valid_uid()) - action_db = ActionDB(name='aname', pack='apack', runner_type={}) - self.assertEqual(action_db.get_uid(), 'action:apack:aname') + action_db = ActionDB(name="aname", pack="apack", runner_type={}) + self.assertEqual(action_db.get_uid(), "action:apack:aname") self.assertTrue(action_db.has_valid_uid()) - rule_db = RuleDB(name='rname', pack='rpack') - self.assertEqual(rule_db.get_uid(), 'rule:rpack:rname') + rule_db = RuleDB(name="rname", pack="rpack") + self.assertEqual(rule_db.get_uid(), "rule:rpack:rname") self.assertTrue(rule_db.has_valid_uid()) - trigger_type_db = TriggerTypeDB(name='ttname', pack='ttpack') - self.assertEqual(trigger_type_db.get_uid(), 'trigger_type:ttpack:ttname') + trigger_type_db = TriggerTypeDB(name="ttname", pack="ttpack") + self.assertEqual(trigger_type_db.get_uid(), "trigger_type:ttpack:ttname") self.assertTrue(trigger_type_db.has_valid_uid()) - trigger_db = TriggerDB(name='tname', pack='tpack') - self.assertTrue(trigger_db.get_uid().startswith('trigger:tpack:tname:')) + trigger_db = TriggerDB(name="tname", pack="tpack") + self.assertTrue(trigger_db.get_uid().startswith("trigger:tpack:tname:")) # Verify that same set of parameters always results in the same hash - parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}} + parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}} paramers_hash = json.dumps(parameters, sort_keys=True) paramers_hash = hashlib.md5(paramers_hash.encode()).hexdigest() - parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = {'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = {'b': u'unicode', 'c': [1, 2, 3], 'd': {'h': 2, 'g': 1}, 'a': 1} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"b": "unicode", "c": [1, 2, 3], "d": {"h": 2, "g": 1}, "a": 1} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = OrderedDict({'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1}) - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = OrderedDict( + {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1} + ) + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - policy_type_db = PolicyTypeDB(resource_type='action', name='concurrency') - self.assertEqual(policy_type_db.get_uid(), 'policy_type:action:concurrency') + policy_type_db = PolicyTypeDB(resource_type="action", name="concurrency") + self.assertEqual(policy_type_db.get_uid(), "policy_type:action:concurrency") self.assertTrue(policy_type_db.has_valid_uid()) - policy_db = PolicyDB(pack='dummy', name='policy1') - self.assertEqual(policy_db.get_uid(), 'policy:dummy:policy1') + policy_db = PolicyDB(pack="dummy", name="policy1") + self.assertEqual(policy_db.get_uid(), "policy:dummy:policy1") - api_key_db = ApiKeyDB(key_hash='valid') - self.assertEqual(api_key_db.get_uid(), 'api_key:valid') + api_key_db = ApiKeyDB(key_hash="valid") + self.assertEqual(api_key_db.get_uid(), "api_key:valid") self.assertTrue(api_key_db.has_valid_uid()) api_key_db = ApiKeyDB() - self.assertEqual(api_key_db.get_uid(), 'api_key:') + self.assertEqual(api_key_db.get_uid(), "api_key:") self.assertFalse(api_key_db.has_valid_uid()) diff --git a/st2common/tests/unit/test_db_pack.py b/st2common/tests/unit/test_db_pack.py index c8df8b5a28..d5b5af00f4 100644 --- a/st2common/tests/unit/test_db_pack.py +++ b/st2common/tests/unit/test_db_pack.py @@ -26,21 +26,21 @@ class PackDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = PackDB persistance_class = Pack model_class_kwargs = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen', - 'path': '/opt/stackstorm/packs/yolo_ci/' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", + "path": "/opt/stackstorm/packs/yolo_ci/", } - update_attribute_name = 'author' + update_attribute_name = "author" def test_path_none(self): PackDBModelCRUDTestCase.model_class_kwargs = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", } super(PackDBModelCRUDTestCase, self).test_crud_operations() diff --git a/st2common/tests/unit/test_db_policy.py b/st2common/tests/unit/test_db_policy.py index 9364c61074..95b682e4a4 100644 --- a/st2common/tests/unit/test_db_policy.py +++ b/st2common/tests/unit/test_db_policy.py @@ -24,64 +24,113 @@ class PolicyTypeReferenceTest(unittest2.TestCase): - def test_is_reference(self): - self.assertTrue(PolicyTypeReference.is_reference('action.concurrency')) - self.assertFalse(PolicyTypeReference.is_reference('concurrency')) - self.assertFalse(PolicyTypeReference.is_reference('')) + self.assertTrue(PolicyTypeReference.is_reference("action.concurrency")) + self.assertFalse(PolicyTypeReference.is_reference("concurrency")) + self.assertFalse(PolicyTypeReference.is_reference("")) self.assertFalse(PolicyTypeReference.is_reference(None)) def test_validate_resource_type(self): - self.assertEqual(PolicyTypeReference.validate_resource_type('action'), 'action') - self.assertRaises(ValueError, PolicyTypeReference.validate_resource_type, 'action.test') + self.assertEqual(PolicyTypeReference.validate_resource_type("action"), "action") + self.assertRaises( + ValueError, PolicyTypeReference.validate_resource_type, "action.test" + ) def test_get_resource_type(self): - self.assertEqual(PolicyTypeReference.get_resource_type('action.concurrency'), 'action') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '.abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, 'abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, None) + self.assertEqual( + PolicyTypeReference.get_resource_type("action.concurrency"), "action" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, ".abc" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, "abc" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, "" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, None + ) def test_get_name(self): - self.assertEqual(PolicyTypeReference.get_name('action.concurrency'), 'concurrency') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '.abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, 'abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '') + self.assertEqual( + PolicyTypeReference.get_name("action.concurrency"), "concurrency" + ) + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, ".abc") + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "abc") + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "") self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, None) def test_to_string_reference(self): - ref = PolicyTypeReference.to_string_reference(resource_type='action', name='concurrency') - self.assertEqual(ref, 'action.concurrency') - - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action.test', name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type=None, name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='', name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action', name=None) - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action', name='') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type=None, name=None) - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='', name='') + ref = PolicyTypeReference.to_string_reference( + resource_type="action", name="concurrency" + ) + self.assertEqual(ref, "action.concurrency") + + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action.test", + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type=None, + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="", + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action", + name=None, + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action", + name="", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type=None, + name=None, + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="", + name="", + ) def test_from_string_reference(self): - ref = PolicyTypeReference.from_string_reference('action.concurrency') - self.assertEqual(ref.resource_type, 'action') - self.assertEqual(ref.name, 'concurrency') - self.assertEqual(ref.ref, 'action.concurrency') - - ref = PolicyTypeReference.from_string_reference('action.concurrency.targeted') - self.assertEqual(ref.resource_type, 'action') - self.assertEqual(ref.name, 'concurrency.targeted') - self.assertEqual(ref.ref, 'action.concurrency.targeted') - - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '.test') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, None) + ref = PolicyTypeReference.from_string_reference("action.concurrency") + self.assertEqual(ref.resource_type, "action") + self.assertEqual(ref.name, "concurrency") + self.assertEqual(ref.ref, "action.concurrency") + + ref = PolicyTypeReference.from_string_reference("action.concurrency.targeted") + self.assertEqual(ref.resource_type, "action") + self.assertEqual(ref.name, "concurrency.targeted") + self.assertEqual(ref.ref, "action.concurrency.targeted") + + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, ".test" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, "" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, None + ) class PolicyTypeTest(DbModelTestCase): @@ -89,34 +138,26 @@ class PolicyTypeTest(DbModelTestCase): @staticmethod def _create_instance(): - parameters = { - 'threshold': { - 'type': 'integer', - 'required': True - } - } - - instance = PolicyTypeDB(name='concurrency', - description='TBD', - enabled=None, - ref=None, - resource_type='action', - module='st2action.policies.concurrency', - parameters=parameters) + parameters = {"threshold": {"type": "integer", "required": True}} + + instance = PolicyTypeDB( + name="concurrency", + description="TBD", + enabled=None, + ref=None, + resource_type="action", + module="st2action.policies.concurrency", + parameters=parameters, + ) return instance def test_crud(self): instance = self._create_instance() - defaults = { - 'ref': 'action.concurrency', - 'enabled': True - } + defaults = {"ref": "action.concurrency", "enabled": True} - updates = { - 'description': 'Limits the concurrent executions for the action.' - } + updates = {"description": "Limits the concurrent executions for the action."} self._assert_crud(instance, defaults=defaults, updates=updates) @@ -130,16 +171,16 @@ class PolicyTest(DbModelTestCase): @staticmethod def _create_instance(): - instance = PolicyDB(pack=None, - name='local.concurrency', - description='TBD', - enabled=None, - ref=None, - resource_ref='core.local', - policy_type='action.concurrency', - parameters={ - 'threshold': 25 - }) + instance = PolicyDB( + pack=None, + name="local.concurrency", + description="TBD", + enabled=None, + ref=None, + resource_ref="core.local", + policy_type="action.concurrency", + parameters={"threshold": 25}, + ) return instance @@ -147,13 +188,13 @@ def test_crud(self): instance = self._create_instance() defaults = { - 'pack': pack_constants.DEFAULT_PACK_NAME, - 'ref': '%s.local.concurrency' % pack_constants.DEFAULT_PACK_NAME, - 'enabled': True + "pack": pack_constants.DEFAULT_PACK_NAME, + "ref": "%s.local.concurrency" % pack_constants.DEFAULT_PACK_NAME, + "enabled": True, } updates = { - 'description': 'Limits the concurrent executions for the action "core.local".' + "description": 'Limits the concurrent executions for the action "core.local".' } self._assert_crud(instance, defaults=defaults, updates=updates) @@ -164,7 +205,7 @@ def test_ref(self): self.assertIsNotNone(ref) self.assertEqual(ref.pack, instance.pack) self.assertEqual(ref.name, instance.name) - self.assertEqual(ref.ref, instance.pack + '.' + instance.name) + self.assertEqual(ref.ref, instance.pack + "." + instance.name) self.assertEqual(ref.ref, instance.ref) def test_unique_key(self): diff --git a/st2common/tests/unit/test_db_rbac.py b/st2common/tests/unit/test_db_rbac.py index 62b9763272..d9c3fcc958 100644 --- a/st2common/tests/unit/test_db_rbac.py +++ b/st2common/tests/unit/test_db_rbac.py @@ -28,10 +28,10 @@ __all__ = [ - 'RoleDBModelCRUDTestCase', - 'UserRoleAssignmentDBModelCRUDTestCase', - 'PermissionGrantDBModelCRUDTestCase', - 'GroupToRoleMappingDBModelCRUDTestCase' + "RoleDBModelCRUDTestCase", + "UserRoleAssignmentDBModelCRUDTestCase", + "PermissionGrantDBModelCRUDTestCase", + "GroupToRoleMappingDBModelCRUDTestCase", ] @@ -39,44 +39,44 @@ class RoleDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = RoleDB persistance_class = Role model_class_kwargs = { - 'name': 'role_one', - 'description': None, - 'system': False, - 'permission_grants': [] + "name": "role_one", + "description": None, + "system": False, + "permission_grants": [], } - update_attribute_name = 'name' + update_attribute_name = "name" class UserRoleAssignmentDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = UserRoleAssignmentDB persistance_class = UserRoleAssignment model_class_kwargs = { - 'user': 'user_one', - 'role': 'role_one', - 'source': 'source_one', - 'is_remote': True + "user": "user_one", + "role": "role_one", + "source": "source_one", + "is_remote": True, } - update_attribute_name = 'role' + update_attribute_name = "role" class PermissionGrantDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = PermissionGrantDB persistance_class = PermissionGrant model_class_kwargs = { - 'resource_uid': 'pack:core', - 'resource_type': 'pack', - 'permission_types': [] + "resource_uid": "pack:core", + "resource_type": "pack", + "permission_types": [], } - update_attribute_name = 'resource_uid' + update_attribute_name = "resource_uid" class GroupToRoleMappingDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = GroupToRoleMappingDB persistance_class = GroupToRoleMapping model_class_kwargs = { - 'group': 'some group', - 'roles': ['role_one', 'role_two'], - 'description': 'desc', - 'enabled': True + "group": "some group", + "roles": ["role_one", "role_two"], + "description": "desc", + "enabled": True, } - update_attribute_name = 'group' + update_attribute_name = "group" diff --git a/st2common/tests/unit/test_db_rule_enforcement.py b/st2common/tests/unit/test_db_rule_enforcement.py index 734a34ffc3..5cececffa0 100644 --- a/st2common/tests/unit/test_db_rule_enforcement.py +++ b/st2common/tests/unit/test_db_rule_enforcement.py @@ -28,19 +28,19 @@ SKIP_DELETE = False -__all__ = [ - 'RuleEnforcementModelTest' -] +__all__ = ["RuleEnforcementModelTest"] -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RuleEnforcementModelTest(DbTestCase): - def test_ruleenforcment_crud(self): saved = RuleEnforcementModelTest._create_save_rule_enforcement() retrieved = RuleEnforcement.get_by_id(saved.id) - self.assertEqual(saved.rule.ref, retrieved.rule.ref, - 'Same rule enforcement was not returned.') + self.assertEqual( + saved.rule.ref, + retrieved.rule.ref, + "Same rule enforcement was not returned.", + ) self.assertIsNotNone(retrieved.enforced_at) # test update RULE_ID = str(bson.ObjectId()) @@ -48,73 +48,82 @@ def test_ruleenforcment_crud(self): retrieved.rule.id = RULE_ID saved = RuleEnforcement.add_or_update(retrieved) retrieved = RuleEnforcement.get_by_id(saved.id) - self.assertEqual(retrieved.rule.id, RULE_ID, - 'Update to rule enforcement failed.') + self.assertEqual( + retrieved.rule.id, RULE_ID, "Update to rule enforcement failed." + ) # cleanup RuleEnforcementModelTest._delete([retrieved]) try: retrieved = RuleEnforcement.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after delete.') + self.assertIsNone(retrieved, "managed to retrieve after delete.") def test_status_set_to_failed_for_objects_which_predate_status_field(self): - rule = { - 'ref': 'foo_pack.foo_rule', - 'uid': 'rule:foo_pack:foo_rule' - } + rule = {"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"} # 1. No status field explicitly set and no failure reason - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId())) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED) # 2. No status field, with failure reason, status should be set to failed - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - failure_reason='so much fail') + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + failure_reason="so much fail", + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) # 3. Explcit status field - succeeded + failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, - failure_reason='so much fail') + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, + failure_reason="so much fail", + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) # 4. Explcit status field - succeeded + no failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_SUCCEEDED) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED) # 5. Explcit status field - failed + no failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_FAILED) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_FAILED, + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) @staticmethod def _create_save_rule_enforcement(): - created = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule={'ref': 'foo_pack.foo_rule', - 'uid': 'rule:foo_pack:foo_rule'}, - execution_id=str(bson.ObjectId())) + created = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule={"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"}, + execution_id=str(bson.ObjectId()), + ) return RuleEnforcement.add_or_update(created) @staticmethod diff --git a/st2common/tests/unit/test_db_task.py b/st2common/tests/unit/test_db_task.py index 60285f1366..bc0d3e2382 100644 --- a/st2common/tests/unit/test_db_task.py +++ b/st2common/tests/unit/test_db_task.py @@ -27,19 +27,18 @@ from st2common.util import date as date_utils -@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) class TaskExecutionModelTest(st2tests.DbTestCase): - def test_task_execution_crud(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Test create created = wf_db_access.TaskExecution.add_or_update(initial) @@ -61,7 +60,7 @@ def test_task_execution_crud(self): self.assertDictEqual(created.context, retrieved.context) # Test update - status = 'running' + status = "running" retrieved = wf_db_access.TaskExecution.update(retrieved, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -79,8 +78,8 @@ def test_task_execution_crud(self): self.assertDictEqual(updated.context, retrieved.context) # Test add or update - retrieved.result = {'output': 'fubar'} - retrieved.status = 'succeeded' + retrieved.result = {"output": "fubar"} + retrieved.status = "succeeded" retrieved.end_timestamp = date_utils.get_datetime_utc_now() retrieved = wf_db_access.TaskExecution.add_or_update(retrieved) updated = wf_db_access.TaskExecution.get_by_id(doc_id) @@ -105,20 +104,20 @@ def test_task_execution_crud(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) def test_task_execution_crud_set_itemized_true(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 initial.itemized = True - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Test create created = wf_db_access.TaskExecution.add_or_update(initial) @@ -140,7 +139,7 @@ def test_task_execution_crud_set_itemized_true(self): self.assertDictEqual(created.context, retrieved.context) # Test update - status = 'running' + status = "running" retrieved = wf_db_access.TaskExecution.update(retrieved, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -158,8 +157,8 @@ def test_task_execution_crud_set_itemized_true(self): self.assertDictEqual(updated.context, retrieved.context) # Test add or update - retrieved.result = {'output': 'fubar'} - retrieved.status = 'succeeded' + retrieved.result = {"output": "fubar"} + retrieved.status = "succeeded" retrieved.end_timestamp = date_utils.get_datetime_utc_now() retrieved = wf_db_access.TaskExecution.add_or_update(retrieved) updated = wf_db_access.TaskExecution.get_by_id(doc_id) @@ -184,19 +183,19 @@ def test_task_execution_crud_set_itemized_true(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) def test_task_execution_write_conflict(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Prep record created = wf_db_access.TaskExecution.add_or_update(initial) @@ -208,7 +207,7 @@ def test_task_execution_write_conflict(self): retrieved2 = wf_db_access.TaskExecution.get_by_id(doc_id) # Test update on instance 1, expect success - status = 'running' + status = "running" retrieved1 = wf_db_access.TaskExecution.update(retrieved1, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -230,7 +229,7 @@ def test_task_execution_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, wf_db_access.TaskExecution.update, retrieved2, - status='pausing' + status="pausing", ) # Test delete @@ -239,5 +238,5 @@ def test_task_execution_write_conflict(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) diff --git a/st2common/tests/unit/test_db_trace.py b/st2common/tests/unit/test_db_trace.py index b9e2ec9c8a..1e0f884472 100644 --- a/st2common/tests/unit/test_db_trace.py +++ b/st2common/tests/unit/test_db_trace.py @@ -24,85 +24,103 @@ class TraceDBTest(CleanDbTestCase): - def test_get(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(4)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(5)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(5)], + ) retrieved = Trace.get(id=saved.id) - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") def test_query(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(4)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(5)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(5)], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") # Add another trace with same trace_tag and confirm that we support. # This is most likley an anti-pattern for the trace_tag but it is an unknown. saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(2)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(3)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(3)], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 2, 'Should have 2 traces.') + self.assertEqual(len(retrieved), 2, "Should have 2 traces.") def test_update(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[], - rules=[], - trigger_instances=[]) + trace_tag="test_trace", action_executions=[], rules=[], trigger_instances=[] + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") no_action_executions = 4 no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", id_=retrieved[0].id, - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") # validate update - self.assertEqual(len(retrieved[0].action_executions), no_action_executions, - 'Failed to update action_executions.') - self.assertEqual(len(retrieved[0].rules), no_rules, 'Failed to update rules.') - self.assertEqual(len(retrieved[0].trigger_instances), no_trigger_instances, - 'Failed to update trigger_instances.') + self.assertEqual( + len(retrieved[0].action_executions), + no_action_executions, + "Failed to update action_executions.", + ) + self.assertEqual(len(retrieved[0].rules), no_rules, "Failed to update rules.") + self.assertEqual( + len(retrieved[0].trigger_instances), + no_trigger_instances, + "Failed to update trigger_instances.", + ) def test_update_via_list_push(self): no_action_executions = 4 no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + trace_tag="test_trace", + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) # push updates Trace.push_action_execution( - saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId()))) + saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId())) + ) Trace.push_rule(saved, rule=TraceComponentDB(object_id=str(bson.ObjectId()))) Trace.push_trigger_instance( - saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId()))) + saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId())) + ) retrieved = Trace.get(id=saved.id) - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") self.assertEqual(len(retrieved.action_executions), no_action_executions + 1) self.assertEqual(len(retrieved.rules), no_rules + 1) self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances + 1) @@ -112,33 +130,48 @@ def test_update_via_list_push_components(self): no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + trace_tag="test_trace", + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) retrieved = Trace.push_components( saved, - action_executions=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_action_executions)], - rules=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_rules)], - trigger_instances=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_trigger_instances)]) - - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + action_executions=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_action_executions) + ], + rules=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_rules) + ], + trigger_instances=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_trigger_instances) + ], + ) + + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") self.assertEqual(len(retrieved.action_executions), no_action_executions * 2) self.assertEqual(len(retrieved.rules), no_rules * 2) self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances * 2) @staticmethod - def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None, - trigger_instances=None): + def _create_save_trace( + trace_tag, id_=None, action_executions=None, rules=None, trigger_instances=None + ): if action_executions is None: action_executions = [] - action_executions = [TraceComponentDB(object_id=action_execution) - for action_execution in action_executions] + action_executions = [ + TraceComponentDB(object_id=action_execution) + for action_execution in action_executions + ] if rules is None: rules = [] @@ -146,12 +179,16 @@ def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None, if trigger_instances is None: trigger_instances = [] - trigger_instances = [TraceComponentDB(object_id=trigger_instance) - for trigger_instance in trigger_instances] - - created = TraceDB(id=id_, - trace_tag=trace_tag, - trigger_instances=trigger_instances, - rules=rules, - action_executions=action_executions) + trigger_instances = [ + TraceComponentDB(object_id=trigger_instance) + for trigger_instance in trigger_instances + ] + + created = TraceDB( + id=id_, + trace_tag=trace_tag, + trigger_instances=trigger_instances, + rules=rules, + action_executions=action_executions, + ) return Trace.add_or_update(created) diff --git a/st2common/tests/unit/test_db_uid_mixin.py b/st2common/tests/unit/test_db_uid_mixin.py index e3283e6f91..b7a6a25108 100644 --- a/st2common/tests/unit/test_db_uid_mixin.py +++ b/st2common/tests/unit/test_db_uid_mixin.py @@ -23,28 +23,41 @@ class UIDMixinTestCase(CleanDbTestCase): def test_get_uid(self): - pack_1_db = PackDB(ref='test_pack') - pack_2_db = PackDB(ref='examples') + pack_1_db = PackDB(ref="test_pack") + pack_2_db = PackDB(ref="examples") - self.assertEqual(pack_1_db.get_uid(), 'pack:test_pack') - self.assertEqual(pack_2_db.get_uid(), 'pack:examples') + self.assertEqual(pack_1_db.get_uid(), "pack:test_pack") + self.assertEqual(pack_2_db.get_uid(), "pack:examples") - action_1_db = ActionDB(pack='examples', name='my_action', ref='examples.my_action') - action_2_db = ActionDB(pack='core', name='local', ref='core.local') - self.assertEqual(action_1_db.get_uid(), 'action:examples:my_action') - self.assertEqual(action_2_db.get_uid(), 'action:core:local') + action_1_db = ActionDB( + pack="examples", name="my_action", ref="examples.my_action" + ) + action_2_db = ActionDB(pack="core", name="local", ref="core.local") + self.assertEqual(action_1_db.get_uid(), "action:examples:my_action") + self.assertEqual(action_2_db.get_uid(), "action:core:local") def test_uid_is_populated_on_save(self): - pack_1_db = PackDB(ref='test_pack', name='test', description='foo', version='1.0.0', - author='dev', email='test@example.com') + pack_1_db = PackDB( + ref="test_pack", + name="test", + description="foo", + version="1.0.0", + author="dev", + email="test@example.com", + ) pack_1_db = Pack.add_or_update(pack_1_db) pack_1_db.reload() - self.assertEqual(pack_1_db.uid, 'pack:test_pack') + self.assertEqual(pack_1_db.uid, "pack:test_pack") - action_1_db = ActionDB(name='local', pack='core', ref='core.local', entry_point='', - runner_type={'name': 'local-shell-cmd'}) + action_1_db = ActionDB( + name="local", + pack="core", + ref="core.local", + entry_point="", + runner_type={"name": "local-shell-cmd"}, + ) action_1_db = Action.add_or_update(action_1_db) action_1_db.reload() - self.assertEqual(action_1_db.uid, 'action:core:local') + self.assertEqual(action_1_db.uid, "action:core:local") diff --git a/st2common/tests/unit/test_db_workflow.py b/st2common/tests/unit/test_db_workflow.py index 1f7ce38a4a..e434d0f9d6 100644 --- a/st2common/tests/unit/test_db_workflow.py +++ b/st2common/tests/unit/test_db_workflow.py @@ -26,14 +26,13 @@ from st2common.exceptions import db as db_exc -@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) class WorkflowExecutionModelTest(st2tests.DbTestCase): - def test_workflow_execution_crud(self): initial = wf_db_models.WorkflowExecutionDB() initial.action_execution = uuid.uuid4().hex - initial.graph = {'var1': 'foobar'} - initial.status = 'requested' + initial.graph = {"var1": "foobar"} + initial.status = "requested" # Test create created = wf_db_access.WorkflowExecution.add_or_update(initial) @@ -47,9 +46,11 @@ def test_workflow_execution_crud(self): self.assertEqual(created.status, retrieved.status) # Test update - graph = {'var1': 'fubar'} - status = 'running' - retrieved = wf_db_access.WorkflowExecution.update(retrieved, graph=graph, status=status) + graph = {"var1": "fubar"} + status = "running" + retrieved = wf_db_access.WorkflowExecution.update( + retrieved, graph=graph, status=status + ) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved.rev, updated.rev) @@ -58,7 +59,7 @@ def test_workflow_execution_crud(self): self.assertEqual(retrieved.status, updated.status) # Test add or update - retrieved.graph = {'var2': 'fubar'} + retrieved.graph = {"var2": "fubar"} retrieved = wf_db_access.WorkflowExecution.add_or_update(retrieved) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -73,14 +74,14 @@ def test_workflow_execution_crud(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.WorkflowExecution.get_by_id, - doc_id + doc_id, ) def test_workflow_execution_write_conflict(self): initial = wf_db_models.WorkflowExecutionDB() initial.action_execution = uuid.uuid4().hex - initial.graph = {'var1': 'foobar'} - initial.status = 'requested' + initial.graph = {"var1": "foobar"} + initial.status = "requested" # Prep record created = wf_db_access.WorkflowExecution.add_or_update(initial) @@ -92,9 +93,11 @@ def test_workflow_execution_write_conflict(self): retrieved2 = wf_db_access.WorkflowExecution.get_by_id(doc_id) # Test update on instance 1, expect success - graph = {'var1': 'fubar'} - status = 'running' - retrieved1 = wf_db_access.WorkflowExecution.update(retrieved1, graph=graph, status=status) + graph = {"var1": "fubar"} + status = "running" + retrieved1 = wf_db_access.WorkflowExecution.update( + retrieved1, graph=graph, status=status + ) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved1.rev, updated.rev) @@ -107,7 +110,7 @@ def test_workflow_execution_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, wf_db_access.WorkflowExecution.update, retrieved2, - graph={'var2': 'fubar'} + graph={"var2": "fubar"}, ) # Test delete @@ -116,5 +119,5 @@ def test_workflow_execution_write_conflict(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.WorkflowExecution.get_by_id, - doc_id + doc_id, ) diff --git a/st2common/tests/unit/test_dist_utils.py b/st2common/tests/unit/test_dist_utils.py index 901f8abd44..1b01d4ff48 100644 --- a/st2common/tests/unit/test_dist_utils.py +++ b/st2common/tests/unit/test_dist_utils.py @@ -21,7 +21,7 @@ import unittest2 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -SCRIPTS_PATH = os.path.join(BASE_DIR, '../../../scripts/') +SCRIPTS_PATH = os.path.join(BASE_DIR, "../../../scripts/") # Add scripts/ which contain main dist_utils.py to PYTHONPATH sys.path.insert(0, SCRIPTS_PATH) @@ -32,21 +32,21 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -__all__ = [ - 'DistUtilsTestCase' -] +__all__ = ["DistUtilsTestCase"] -REQUIREMENTS_PATH_1 = os.path.join(BASE_DIR, '../fixtures/requirements-used-for-tests.txt') -REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, '../../../requirements.txt') -VERSION_FILE_PATH = os.path.join(BASE_DIR, '../fixtures/version_file.py') +REQUIREMENTS_PATH_1 = os.path.join( + BASE_DIR, "../fixtures/requirements-used-for-tests.txt" +) +REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, "../../../requirements.txt") +VERSION_FILE_PATH = os.path.join(BASE_DIR, "../fixtures/version_file.py") class DistUtilsTestCase(unittest2.TestCase): def setUp(self): super(DistUtilsTestCase, self).setUp() - if 'pip' in sys.modules: - del sys.modules['pip'] + if "pip" in sys.modules: + del sys.modules["pip"] def tearDown(self): super(DistUtilsTestCase, self).tearDown() @@ -54,15 +54,15 @@ def tearDown(self): def test_check_pip_is_installed_success(self): self.assertTrue(check_pip_is_installed()) - @mock.patch('sys.exit') + @mock.patch("sys.exit") def test_check_pip_is_installed_failure(self, mock_sys_exit): if six.PY3: - module_name = 'builtins.__import__' + module_name = "builtins.__import__" else: - module_name = '__builtin__.__import__' + module_name = "__builtin__.__import__" with mock.patch(module_name) as mock_import: - mock_import.side_effect = ImportError('not found') + mock_import.side_effect = ImportError("not found") self.assertEqual(mock_sys_exit.call_count, 0) check_pip_is_installed() @@ -72,12 +72,12 @@ def test_check_pip_is_installed_failure(self, mock_sys_exit): def test_check_pip_version_success(self): self.assertTrue(check_pip_version()) - @mock.patch('sys.exit') + @mock.patch("sys.exit") def test_check_pip_version_failure(self, mock_sys_exit): mock_pip = mock.Mock() - mock_pip.__version__ = '0.0.0' - sys.modules['pip'] = mock_pip + mock_pip.__version__ = "0.0.0" + sys.modules["pip"] = mock_pip self.assertEqual(mock_sys_exit.call_count, 0) check_pip_version() @@ -86,50 +86,50 @@ def test_check_pip_version_failure(self, mock_sys_exit): def test_get_version_string(self): version = get_version_string(VERSION_FILE_PATH) - self.assertEqual(version, '1.2.3') + self.assertEqual(version, "1.2.3") def test_apply_vagrant_workaround(self): - with mock.patch('os.link') as _: - os.environ['USER'] = 'stanley' + with mock.patch("os.link") as _: + os.environ["USER"] = "stanley" apply_vagrant_workaround() self.assertTrue(os.link) - with mock.patch('os.link') as _: - os.environ['USER'] = 'vagrant' + with mock.patch("os.link") as _: + os.environ["USER"] = "vagrant" apply_vagrant_workaround() - self.assertFalse(getattr(os, 'link', None)) + self.assertFalse(getattr(os, "link", None)) def test_fetch_requirements(self): expected_reqs = [ - 'RandomWords', - 'amqp==2.5.1', - 'argcomplete', - 'bcrypt==3.1.6', - 'flex==6.14.0', - 'logshipper', - 'orquesta', - 'st2-auth-backend-flat-file', - 'logshipper-editable', - 'python_runner', - 'SomePackageHq', - 'SomePackageSvn', - 'gitpython==2.1.11', - 'ose-timer==0.7.5', - 'oslo.config<1.13,>=1.12.1', - 'requests[security]<2.22.0,>=2.21.0', - 'retrying==1.3.3', - 'zake==0.2.2' + "RandomWords", + "amqp==2.5.1", + "argcomplete", + "bcrypt==3.1.6", + "flex==6.14.0", + "logshipper", + "orquesta", + "st2-auth-backend-flat-file", + "logshipper-editable", + "python_runner", + "SomePackageHq", + "SomePackageSvn", + "gitpython==2.1.11", + "ose-timer==0.7.5", + "oslo.config<1.13,>=1.12.1", + "requests[security]<2.22.0,>=2.21.0", + "retrying==1.3.3", + "zake==0.2.2", ] expected_links = [ - 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper', - 'git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta', # NOQA - 'git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file', # NOQA - 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable', - 'git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner', # NOQA - 'hg+https://hg.repo/some_pkg.git#egg=SomePackageHq', - 'svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn' + "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper", + "git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta", # NOQA + "git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file", # NOQA + "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable", + "git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner", # NOQA + "hg+https://hg.repo/some_pkg.git#egg=SomePackageHq", + "svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn", ] reqs, links = fetch_requirements(REQUIREMENTS_PATH_1) diff --git a/st2common/tests/unit/test_exceptions_workflow.py b/st2common/tests/unit/test_exceptions_workflow.py index 9e37f6c5d9..a9fbcc549f 100644 --- a/st2common/tests/unit/test_exceptions_workflow.py +++ b/st2common/tests/unit/test_exceptions_workflow.py @@ -26,7 +26,6 @@ class WorkflowExceptionTest(unittest2.TestCase): - def test_retry_on_transient_db_errors(self): instance = wf_db_models.WorkflowExecutionDB() exc = db_exc.StackStormDBObjectWriteConflictError(instance) @@ -34,13 +33,13 @@ def test_retry_on_transient_db_errors(self): def test_do_not_retry_on_transient_db_errors(self): instance = wf_db_models.WorkflowExecutionDB() - exc = db_exc.StackStormDBObjectConflictError('foobar', '1234', instance) + exc = db_exc.StackStormDBObjectConflictError("foobar", "1234", instance) self.assertFalse(wf_exc.retry_on_transient_db_errors(exc)) self.assertFalse(wf_exc.retry_on_transient_db_errors(NotImplementedError())) self.assertFalse(wf_exc.retry_on_transient_db_errors(Exception())) def test_retry_on_connection_errors(self): - exc = coordination.ToozConnectionError('foobar') + exc = coordination.ToozConnectionError("foobar") self.assertTrue(wf_exc.retry_on_connection_errors(exc)) exc = mongoengine.connection.MongoEngineConnectionError() diff --git a/st2common/tests/unit/test_executions.py b/st2common/tests/unit/test_executions.py index 59353379ac..0be1ca7c9d 100644 --- a/st2common/tests/unit/test_executions.py +++ b/st2common/tests/unit/test_executions.py @@ -29,94 +29,117 @@ class TestActionExecutionHistoryModel(DbTestCase): - def setUp(self): super(TestActionExecutionHistoryModel, self).setUp() # Fake execution record for action liveactions triggered by workflow runner. self.fake_history_subtasks = [ { - 'id': str(bson.ObjectId()), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1']), - 'status': fixture.ARTIFACTS['liveactions']['task1']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['task1']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['task1']['end_timestamp'] + "id": str(bson.ObjectId()), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]), + "status": fixture.ARTIFACTS["liveactions"]["task1"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][ + "end_timestamp" + ], }, { - 'id': str(bson.ObjectId()), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task2']), - 'status': fixture.ARTIFACTS['liveactions']['task2']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['task2']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['task2']['end_timestamp'] - } + "id": str(bson.ObjectId()), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task2"]), + "status": fixture.ARTIFACTS["liveactions"]["task2"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][ + "end_timestamp" + ], + }, ] # Fake execution record for a workflow action execution triggered by rule. self.fake_history_workflow = { - 'id': str(bson.ObjectId()), - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']), - 'children': [task['id'] for task in self.fake_history_subtasks], - 'status': fixture.ARTIFACTS['liveactions']['workflow']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['end_timestamp'] + "id": str(bson.ObjectId()), + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["workflow"]), + "children": [task["id"] for task in self.fake_history_subtasks], + "status": fixture.ARTIFACTS["liveactions"]["workflow"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][ + "end_timestamp" + ], } # Assign parent to the execution records for the subtasks. for task in self.fake_history_subtasks: - task['parent'] = self.fake_history_workflow['id'] + task["parent"] = self.fake_history_workflow["id"] def test_model_complete(self): # Create API object. obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_workflow)) - self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(obj.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(obj.action, self.fake_history_workflow['action']) - self.assertDictEqual(obj.runner, self.fake_history_workflow['runner']) - self.assertEqual(obj.liveaction, self.fake_history_workflow['liveaction']) - self.assertIsNone(getattr(obj, 'parent', None)) - self.assertListEqual(obj.children, self.fake_history_workflow['children']) + self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + obj.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + obj.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(obj.action, self.fake_history_workflow["action"]) + self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"]) + self.assertEqual(obj.liveaction, self.fake_history_workflow["liveaction"]) + self.assertIsNone(getattr(obj, "parent", None)) + self.assertListEqual(obj.children, self.fake_history_workflow["children"]) # Convert API object to DB model. model = ActionExecutionAPI.to_model(obj) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(model.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(model.action, self.fake_history_workflow['action']) - self.assertDictEqual(model.runner, self.fake_history_workflow['runner']) - doc = copy.deepcopy(self.fake_history_workflow['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + model.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + model.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(model.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(model.action, self.fake_history_workflow["action"]) + self.assertDictEqual(model.runner, self.fake_history_workflow["runner"]) + doc = copy.deepcopy(self.fake_history_workflow["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertIsNone(getattr(model, 'parent', None)) - self.assertListEqual(model.children, self.fake_history_workflow['children']) + self.assertIsNone(getattr(model, "parent", None)) + self.assertListEqual(model.children, self.fake_history_workflow["children"]) # Convert DB model to API object. obj = ActionExecutionAPI.from_model(model) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(obj.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(obj.action, self.fake_history_workflow['action']) - self.assertDictEqual(obj.runner, self.fake_history_workflow['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_workflow['liveaction']) - self.assertIsNone(getattr(obj, 'parent', None)) - self.assertListEqual(obj.children, self.fake_history_workflow['children']) + self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + obj.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + obj.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(obj.action, self.fake_history_workflow["action"]) + self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"]) + self.assertDictEqual(obj.liveaction, self.fake_history_workflow["liveaction"]) + self.assertIsNone(getattr(obj, "parent", None)) + self.assertListEqual(obj.children, self.fake_history_workflow["children"]) def test_crud_complete(self): # Create the DB record. @@ -124,18 +147,22 @@ def test_crud_complete(self): ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) model = ActionExecution.get_by_id(obj.id) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(model.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(model.action, self.fake_history_workflow['action']) - self.assertDictEqual(model.runner, self.fake_history_workflow['runner']) - doc = copy.deepcopy(self.fake_history_workflow['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + model.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + model.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(model.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(model.action, self.fake_history_workflow["action"]) + self.assertDictEqual(model.runner, self.fake_history_workflow["runner"]) + doc = copy.deepcopy(self.fake_history_workflow["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertIsNone(getattr(model, 'parent', None)) - self.assertListEqual(model.children, self.fake_history_workflow['children']) + self.assertIsNone(getattr(model, "parent", None)) + self.assertListEqual(model.children, self.fake_history_workflow["children"]) # Update the DB record. children = [str(bson.ObjectId()), str(bson.ObjectId())] @@ -146,20 +173,24 @@ def test_crud_complete(self): # Delete the DB record. ActionExecution.delete(model) - self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id) + self.assertRaises( + StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id + ) def test_model_partial(self): # Create API object. obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_subtasks[0])) - self.assertIsNone(getattr(obj, 'trigger', None)) - self.assertIsNone(getattr(obj, 'trigger_type', None)) - self.assertIsNone(getattr(obj, 'trigger_instance', None)) - self.assertIsNone(getattr(obj, 'rule', None)) - self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction']) - self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent']) - self.assertIsNone(getattr(obj, 'children', None)) + self.assertIsNone(getattr(obj, "trigger", None)) + self.assertIsNone(getattr(obj, "trigger_type", None)) + self.assertIsNone(getattr(obj, "trigger_instance", None)) + self.assertIsNone(getattr(obj, "rule", None)) + self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"]) + self.assertDictEqual( + obj.liveaction, self.fake_history_subtasks[0]["liveaction"] + ) + self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"]) + self.assertIsNone(getattr(obj, "children", None)) # Convert API object to DB model. model = ActionExecutionAPI.to_model(obj) @@ -168,28 +199,30 @@ def test_model_partial(self): self.assertDictEqual(model.trigger_type, {}) self.assertDictEqual(model.trigger_instance, {}) self.assertDictEqual(model.rule, {}) - self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner']) - doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"]) + doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent']) + self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"]) self.assertListEqual(model.children, []) # Convert DB model to API object. obj = ActionExecutionAPI.from_model(model) self.assertEqual(str(model.id), obj.id) - self.assertIsNone(getattr(obj, 'trigger', None)) - self.assertIsNone(getattr(obj, 'trigger_type', None)) - self.assertIsNone(getattr(obj, 'trigger_instance', None)) - self.assertIsNone(getattr(obj, 'rule', None)) - self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction']) - self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent']) - self.assertIsNone(getattr(obj, 'children', None)) + self.assertIsNone(getattr(obj, "trigger", None)) + self.assertIsNone(getattr(obj, "trigger_type", None)) + self.assertIsNone(getattr(obj, "trigger_instance", None)) + self.assertIsNone(getattr(obj, "rule", None)) + self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"]) + self.assertDictEqual( + obj.liveaction, self.fake_history_subtasks[0]["liveaction"] + ) + self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"]) + self.assertIsNone(getattr(obj, "children", None)) def test_crud_partial(self): # Create the DB record. @@ -201,13 +234,13 @@ def test_crud_partial(self): self.assertDictEqual(model.trigger_type, {}) self.assertDictEqual(model.trigger_instance, {}) self.assertDictEqual(model.rule, {}) - self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner']) - doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"]) + doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent']) + self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"]) self.assertListEqual(model.children, []) # Update the DB record. @@ -219,23 +252,25 @@ def test_crud_partial(self): # Delete the DB record. ActionExecution.delete(model) - self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id) + self.assertRaises( + StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id + ) def test_datetime_range(self): base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(60): timestamp = base + datetime.timedelta(seconds=i) doc = copy.deepcopy(self.fake_history_subtasks[0]) - doc['id'] = str(bson.ObjectId()) - doc['start_timestamp'] = isotime.format(timestamp) + doc["id"] = str(bson.ObjectId()) + doc["start_timestamp"] = isotime.format(timestamp) obj = ActionExecutionAPI(**doc) ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" objs = ActionExecution.query(start_timestamp=dt_range) self.assertEqual(len(objs), 10) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" objs = ActionExecution.query(start_timestamp=dt_range) self.assertEqual(len(objs), 10) @@ -244,19 +279,19 @@ def test_sort_by_start_timestamp(self): for i in range(60): timestamp = base + datetime.timedelta(seconds=i) doc = copy.deepcopy(self.fake_history_subtasks[0]) - doc['id'] = str(bson.ObjectId()) - doc['start_timestamp'] = isotime.format(timestamp) + doc["id"] = str(bson.ObjectId()) + doc["start_timestamp"] = isotime.format(timestamp) obj = ActionExecutionAPI(**doc) ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' - objs = ActionExecution.query(start_timestamp=dt_range, - order_by=['start_timestamp']) - self.assertLess(objs[0]['start_timestamp'], - objs[9]['start_timestamp']) + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" + objs = ActionExecution.query( + start_timestamp=dt_range, order_by=["start_timestamp"] + ) + self.assertLess(objs[0]["start_timestamp"], objs[9]["start_timestamp"]) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' - objs = ActionExecution.query(start_timestamp=dt_range, - order_by=['-start_timestamp']) - self.assertLess(objs[9]['start_timestamp'], - objs[0]['start_timestamp']) + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" + objs = ActionExecution.query( + start_timestamp=dt_range, order_by=["-start_timestamp"] + ) + self.assertLess(objs[9]["start_timestamp"], objs[0]["start_timestamp"]) diff --git a/st2common/tests/unit/test_executions_util.py b/st2common/tests/unit/test_executions_util.py index 9177493188..f7702614d3 100644 --- a/st2common/tests/unit/test_executions_util.py +++ b/st2common/tests/unit/test_executions_util.py @@ -35,25 +35,28 @@ import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'liveactions': ['liveaction1.yaml', 'parentliveaction.yaml', 'childliveaction.yaml', - 'successful_liveaction.yaml'], - 'actions': ['local.yaml'], - 'executions': ['execution1.yaml'], - 'runners': ['run-local.yaml'], - 'triggertypes': ['triggertype2.yaml'], - 'rules': ['rule3.yaml'], - 'triggers': ['trigger2.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml'] + "liveactions": [ + "liveaction1.yaml", + "parentliveaction.yaml", + "childliveaction.yaml", + "successful_liveaction.yaml", + ], + "actions": ["local.yaml"], + "executions": ["execution1.yaml"], + "runners": ["run-local.yaml"], + "triggertypes": ["triggertype2.yaml"], + "rules": ["rule3.yaml"], + "triggers": ["trigger2.yaml"], + "triggerinstances": ["trigger_instance_1.yaml"], } -DYNAMIC_FIXTURES = { - 'liveactions': ['liveaction3.yaml'] -} +DYNAMIC_FIXTURES = {"liveactions": ["liveaction3.yaml"]} class ExecutionsUtilTestCase(CleanDbTestCase): @@ -63,118 +66,144 @@ def __init__(self, *args, **kwargs): def setUp(self): super(ExecutionsUtilTestCase, self).setUp() - self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES) - self.FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK, - fixtures_dict=DYNAMIC_FIXTURES) + self.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + self.FIXTURES = FixturesLoader().load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict=DYNAMIC_FIXTURES + ) def test_execution_creation_manual_action_run(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] pre_creation_timestamp = date_utils.get_datetime_utc_now() executions_util.create_execution_object(liveaction) post_creation_timestamp = date_utils.get_datetime_utc_now() - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertDictEqual(execution.trigger, {}) self.assertDictEqual(execution.trigger_type, {}) self.assertDictEqual(execution.trigger_instance, {}) self.assertDictEqual(execution.rule, {}) - action = action_utils.get_action_by_ref('core.local') + action = action_utils.get_action_by_ref("core.local") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(execution.liveaction['id'], str(liveaction.id)) + self.assertEqual(execution.liveaction["id"], str(liveaction.id)) self.assertEqual(len(execution.log), 1) - self.assertEqual(execution.log[0]['status'], liveaction.status) - self.assertGreater(execution.log[0]['timestamp'], pre_creation_timestamp) - self.assertLess(execution.log[0]['timestamp'], post_creation_timestamp) + self.assertEqual(execution.log[0]["status"], liveaction.status) + self.assertGreater(execution.log[0]["timestamp"], pre_creation_timestamp) + self.assertLess(execution.log[0]["timestamp"], post_creation_timestamp) def test_execution_creation_action_triggered_by_rule(self): # Wait for the action execution to complete and then confirm outcome. - trigger_type = self.MODELS['triggertypes']['triggertype2.yaml'] - trigger = self.MODELS['triggers']['trigger2.yaml'] - trigger_instance = self.MODELS['triggerinstances']['trigger_instance_1.yaml'] - test_liveaction = self.FIXTURES['liveactions']['liveaction3.yaml'] - rule = self.MODELS['rules']['rule3.yaml'] + trigger_type = self.MODELS["triggertypes"]["triggertype2.yaml"] + trigger = self.MODELS["triggers"]["trigger2.yaml"] + trigger_instance = self.MODELS["triggerinstances"]["trigger_instance_1.yaml"] + test_liveaction = self.FIXTURES["liveactions"]["liveaction3.yaml"] + rule = self.MODELS["rules"]["rule3.yaml"] # Setup LiveAction to point to right rule and trigger_instance. # XXX: We need support for dynamic fixtures. - test_liveaction['context']['rule']['id'] = str(rule.id) - test_liveaction['context']['trigger_instance']['id'] = str(trigger_instance.id) + test_liveaction["context"]["rule"]["id"] = str(rule.id) + test_liveaction["context"]["trigger_instance"]["id"] = str(trigger_instance.id) test_liveaction_api = LiveActionAPI(**test_liveaction) - test_liveaction = LiveAction.add_or_update(LiveActionAPI.to_model(test_liveaction_api)) - liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id)) + test_liveaction = LiveAction.add_or_update( + LiveActionAPI.to_model(test_liveaction_api) + ) + liveaction = LiveAction.get( + context__trigger_instance__id=str(trigger_instance.id) + ) self.assertIsNotNone(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) executions_util.create_execution_object(liveaction) - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger))) - self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))) - self.assertDictEqual(execution.trigger_instance, - vars(TriggerInstanceAPI.from_model(trigger_instance))) + self.assertDictEqual( + execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)) + ) + self.assertDictEqual( + execution.trigger_instance, + vars(TriggerInstanceAPI.from_model(trigger_instance)), + ) self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule))) action = action_utils.get_action_by_ref(liveaction.action) self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(execution.liveaction['id'], str(liveaction.id)) + self.assertEqual(execution.liveaction["id"], str(liveaction.id)) def test_execution_creation_with_web_url(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction) - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertIsNotNone(execution.web_url) execution_id = str(execution.id) - self.assertIn(('history/%s/general' % execution_id), execution.web_url) + self.assertIn(("history/%s/general" % execution_id), execution.web_url) def test_execution_creation_chains(self): - childliveaction = self.MODELS['liveactions']['childliveaction.yaml'] + childliveaction = self.MODELS["liveactions"]["childliveaction.yaml"] child_exec = executions_util.create_execution_object(childliveaction) - parent_execution_id = childliveaction.context['parent']['execution_id'] + parent_execution_id = childliveaction.context["parent"]["execution_id"] parent_execution = ActionExecution.get_by_id(parent_execution_id) child_execs = parent_execution.children self.assertIn(str(child_exec.id), child_execs) def test_execution_update(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction) - liveaction.status = 'running' + liveaction.status = "running" pre_update_timestamp = date_utils.get_datetime_utc_now() executions_util.update_execution(liveaction) post_update_timestamp = date_utils.get_datetime_utc_now() - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertEqual(len(execution.log), 2) - self.assertEqual(execution.log[1]['status'], liveaction.status) - self.assertGreater(execution.log[1]['timestamp'], pre_update_timestamp) - self.assertLess(execution.log[1]['timestamp'], post_update_timestamp) - - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) - @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) + self.assertEqual(execution.log[1]["status"], liveaction.status) + self.assertGreater(execution.log[1]["timestamp"], pre_update_timestamp) + self.assertLess(execution.log[1]["timestamp"], post_update_timestamp) + + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) + @mock.patch.object( + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_abandon_executions(self): - liveaction_db = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction_db = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction_db) execution_db = executions_util.abandon_execution_if_incomplete( - liveaction_id=str(liveaction_db.id)) + liveaction_id=str(liveaction_db.id) + ) - self.assertEqual(execution_db.status, 'abandoned') + self.assertEqual(execution_db.status, "abandoned") runners_utils.invoke_post_run.assert_called_once() - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) - @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) + @mock.patch.object( + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_abandon_executions_on_complete(self): - liveaction_db = self.MODELS['liveactions']['successful_liveaction.yaml'] + liveaction_db = self.MODELS["liveactions"]["successful_liveaction.yaml"] executions_util.create_execution_object(liveaction_db) - expected_msg = r'LiveAction %s already in a completed state %s\.' % \ - (str(liveaction_db.id), liveaction_db.status) - - self.assertRaisesRegexp(ValueError, expected_msg, - executions_util.abandon_execution_if_incomplete, - liveaction_id=str(liveaction_db.id)) + expected_msg = r"LiveAction %s already in a completed state %s\." % ( + str(liveaction_db.id), + liveaction_db.status, + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + executions_util.abandon_execution_if_incomplete, + liveaction_id=str(liveaction_db.id), + ) runners_utils.invoke_post_run.assert_not_called() @@ -184,12 +213,20 @@ def _get_action_execution(self, **kwargs): # descendants test section -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } @@ -200,75 +237,90 @@ def __init__(self, *args, **kwargs): def setUp(self): super(ExecutionsUtilDescendantsTestCase, self).setUp() - self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + self.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_get_all_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) # verify sort order for idx in range(len(all_descendants) - 1): - self.assertLess(all_descendants[idx].start_timestamp, - all_descendants[idx + 1].start_timestamp) + self.assertLess( + all_descendants[idx].start_timestamp, + all_descendants[idx + 1].start_timestamp, + ) def test_get_all_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] + root_execution = self.MODELS["executions"]["root_execution.yaml"] all_descendants = executions_util.get_descendants(str(root_execution.id)) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_1_level_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - descendant_depth=1, - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), descendant_depth=1, result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # All children of root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.parent == str(root_execution.id)] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.parent == str(root_execution.id) + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) # verify sort order for idx in range(len(all_descendants) - 1): - self.assertLess(all_descendants[idx].start_timestamp, - all_descendants[idx + 1].start_timestamp) + self.assertLess( + all_descendants[idx].start_timestamp, + all_descendants[idx + 1].start_timestamp, + ) def test_get_2_level_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - descendant_depth=2, - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), descendant_depth=2, result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # All children of root_execution - root_execution = self.MODELS['executions']['root_execution.yaml'] + root_execution = self.MODELS["executions"]["root_execution.yaml"] expected_ids = [] traverse = [(child_id, 1) for child_id in root_execution.children] while traverse: @@ -282,7 +334,7 @@ def test_get_2_level_descendants_sorted(self): self.assertListEqual(all_descendants_ids, expected_ids) def _get_action_execution(self, ae_id): - for _, execution in six.iteritems(self.MODELS['executions']): + for _, execution in six.iteritems(self.MODELS["executions"]): if str(execution.id) == ae_id: return execution return None diff --git a/st2common/tests/unit/test_greenpooldispatch.py b/st2common/tests/unit/test_greenpooldispatch.py index 84c411d140..45cc568759 100644 --- a/st2common/tests/unit/test_greenpooldispatch.py +++ b/st2common/tests/unit/test_greenpooldispatch.py @@ -23,7 +23,6 @@ class TestGreenPoolDispatch(TestCase): - def test_dispatch_simple(self): dispatcher = BufferedDispatcher(dispatch_pool_size=10) mock_handler = mock.MagicMock() @@ -34,13 +33,17 @@ def test_dispatch_simple(self): while mock_handler.call_count < 10: eventlet.sleep(0.01) dispatcher.shutdown() - call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list] + call_args_list = [ + (args[0][0], args[0][1]) for args in mock_handler.call_args_list + ] self.assertItemsEqual(expected, call_args_list) def test_dispatch_starved(self): - dispatcher = BufferedDispatcher(dispatch_pool_size=2, - monitor_thread_empty_q_sleep_time=0.01, - monitor_thread_no_workers_sleep_time=0.01) + dispatcher = BufferedDispatcher( + dispatch_pool_size=2, + monitor_thread_empty_q_sleep_time=0.01, + monitor_thread_no_workers_sleep_time=0.01, + ) mock_handler = mock.MagicMock() expected = [] for i in range(10): @@ -49,5 +52,7 @@ def test_dispatch_starved(self): while mock_handler.call_count < 10: eventlet.sleep(0.01) dispatcher.shutdown() - call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list] + call_args_list = [ + (args[0][0], args[0][1]) for args in mock_handler.call_args_list + ] self.assertItemsEqual(expected, call_args_list) diff --git a/st2common/tests/unit/test_hash.py b/st2common/tests/unit/test_hash.py index 7211879ff6..234d4969da 100644 --- a/st2common/tests/unit/test_hash.py +++ b/st2common/tests/unit/test_hash.py @@ -22,15 +22,14 @@ class TestHashWithApiKeys(unittest2.TestCase): - def test_hash_repeatability(self): api_key = auth_utils.generate_api_key() hash1 = hash_utils.hash(api_key) hash2 = hash_utils.hash(api_key) - self.assertEqual(hash1, hash2, 'Expected a repeated hash.') + self.assertEqual(hash1, hash2, "Expected a repeated hash.") def test_hash_uniqueness(self): count = 10000 api_keys = [auth_utils.generate_api_key() for _ in range(count)] hashes = set([hash_utils.hash(api_key) for api_key in api_keys]) - self.assertEqual(len(hashes), count, 'Expected all unique hashes.') + self.assertEqual(len(hashes), count, "Expected all unique hashes.") diff --git a/st2common/tests/unit/test_ip_utils.py b/st2common/tests/unit/test_ip_utils.py index a33c220d71..cd1339be73 100644 --- a/st2common/tests/unit/test_ip_utils.py +++ b/st2common/tests/unit/test_ip_utils.py @@ -20,73 +20,72 @@ class IPUtilsTests(unittest2.TestCase): - def test_host_port_split(self): # Simple IPv4 - host_str = '1.2.3.4' + host_str = "1.2.3.4" host, port = split_host_port(host_str) self.assertEqual(host, host_str) self.assertEqual(port, None) # Simple IPv4 with port - host_str = '1.2.3.4:55' + host_str = "1.2.3.4:55" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 55) # Simple IPv6 - host_str = 'fec2::10' + host_str = "fec2::10" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, None) # IPv6 with square brackets no port - host_str = '[fec2::10]' + host_str = "[fec2::10]" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, None) # IPv6 with square brackets with port - host_str = '[fec2::10]:55' + host_str = "[fec2::10]:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, 55) # IPv4 inside bracket - host_str = '[1.2.3.4]' + host_str = "[1.2.3.4]" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, None) # IPv4 inside bracket and port - host_str = '[1.2.3.4]:55' + host_str = "[1.2.3.4]:55" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 55) # Hostname inside bracket - host_str = '[st2build001]:55' + host_str = "[st2build001]:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, 55) # Simple hostname - host_str = 'st2build001' + host_str = "st2build001" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, None) # Simple hostname with port - host_str = 'st2build001:55' + host_str = "st2build001:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, 55) # No-bracket invalid port - host_str = 'st2build001:abc' + host_str = "st2build001:abc" self.assertRaises(Exception, split_host_port, host_str) # Bracket invalid port - host_str = '[fec2::10]:abc' + host_str = "[fec2::10]:abc" self.assertRaises(Exception, split_host_port, host_str) diff --git a/st2common/tests/unit/test_isotime_utils.py b/st2common/tests/unit/test_isotime_utils.py index 5ec5495ca9..34d785031b 100644 --- a/st2common/tests/unit/test_isotime_utils.py +++ b/st2common/tests/unit/test_isotime_utils.py @@ -24,50 +24,54 @@ class IsoTimeUtilsTestCase(unittest.TestCase): def test_validate(self): - self.assertTrue(isotime.validate('2000-01-01 12:00:00Z')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+0000')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+00:00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000Z')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+0000')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00:00')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00Z')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000Z')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00+00:00')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000+00:00')) - self.assertTrue(isotime.validate('2015-02-10T21:21:53.399Z')) - self.assertFalse(isotime.validate('2000-01-01', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00+00:00Z', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00.000000', raise_exception=False)) - self.assertFalse(isotime.validate('Epic!', raise_exception=False)) + self.assertTrue(isotime.validate("2000-01-01 12:00:00Z")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+0000")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+00:00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000Z")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+0000")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00:00")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00Z")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000Z")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00+00:00")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000+00:00")) + self.assertTrue(isotime.validate("2015-02-10T21:21:53.399Z")) + self.assertFalse(isotime.validate("2000-01-01", raise_exception=False)) + self.assertFalse(isotime.validate("2000-01-01T12:00:00", raise_exception=False)) + self.assertFalse( + isotime.validate("2000-01-01T12:00:00+00:00Z", raise_exception=False) + ) + self.assertFalse( + isotime.validate("2000-01-01T12:00:00.000000", raise_exception=False) + ) + self.assertFalse(isotime.validate("Epic!", raise_exception=False)) self.assertFalse(isotime.validate(object(), raise_exception=False)) - self.assertRaises(ValueError, isotime.validate, 'Epic!', True) + self.assertRaises(ValueError, isotime.validate, "Epic!", True) def test_parse(self): dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12)) - self.assertEqual(isotime.parse('2000-01-01 12:00:00Z'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+0000'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000Z'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+0000'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00Z'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000Z'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000Z'), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00Z"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+0000"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000Z"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+0000"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00Z"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000Z"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000Z"), dt) def test_format(self): dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12)) - dt_str_usec_offset = '2000-01-01T12:00:00.000000+00:00' - dt_str_usec = '2000-01-01T12:00:00.000000Z' - dt_str_offset = '2000-01-01T12:00:00+00:00' - dt_str = '2000-01-01T12:00:00Z' - dt_unicode = u'2000-01-01T12:00:00Z' + dt_str_usec_offset = "2000-01-01T12:00:00.000000+00:00" + dt_str_usec = "2000-01-01T12:00:00.000000Z" + dt_str_offset = "2000-01-01T12:00:00+00:00" + dt_str = "2000-01-01T12:00:00Z" + dt_unicode = "2000-01-01T12:00:00Z" # datetime object self.assertEqual(isotime.format(dt, usec=True, offset=True), dt_str_usec_offset) @@ -75,16 +79,22 @@ def test_format(self): self.assertEqual(isotime.format(dt, usec=False, offset=True), dt_str_offset) self.assertEqual(isotime.format(dt, usec=False, offset=False), dt_str) self.assertEqual(isotime.format(dt_str, usec=False, offset=False), dt_str) - self.assertEqual(isotime.format(dt_unicode, usec=False, offset=False), dt_unicode) + self.assertEqual( + isotime.format(dt_unicode, usec=False, offset=False), dt_unicode + ) # unix timestamp (epoch) dt = 1557390483 - self.assertEqual(isotime.format(dt, usec=True, offset=True), - '2019-05-09T08:28:03.000000+00:00') - self.assertEqual(isotime.format(dt, usec=False, offset=False), - '2019-05-09T08:28:03Z') - self.assertEqual(isotime.format(dt, usec=False, offset=True), - '2019-05-09T08:28:03+00:00') + self.assertEqual( + isotime.format(dt, usec=True, offset=True), + "2019-05-09T08:28:03.000000+00:00", + ) + self.assertEqual( + isotime.format(dt, usec=False, offset=False), "2019-05-09T08:28:03Z" + ) + self.assertEqual( + isotime.format(dt, usec=False, offset=True), "2019-05-09T08:28:03+00:00" + ) def test_format_tz_naive(self): dt1 = datetime.datetime.utcnow() @@ -99,6 +109,8 @@ def test_format_tz_aware(self): def test_format_sec_truncated(self): dt1 = date.add_utc_tz(datetime.datetime.utcnow()) dt2 = isotime.parse(isotime.format(dt1, usec=False)) - dt3 = datetime.datetime(dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second) + dt3 = datetime.datetime( + dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second + ) self.assertLess(dt2, dt1) self.assertEqual(dt2, date.add_utc_tz(dt3)) diff --git a/st2common/tests/unit/test_jinja_render_crypto_filters.py b/st2common/tests/unit/test_jinja_render_crypto_filters.py index 1a026e83ed..f58edb1309 100644 --- a/st2common/tests/unit/test_jinja_render_crypto_filters.py +++ b/st2common/tests/unit/test_jinja_render_crypto_filters.py @@ -38,72 +38,101 @@ def setUp(self): crypto_key_path = cfg.CONF.keyvalue.encryption_key_path crypto_key = read_crypto_key(key_path=crypto_key_path) - self.secret = 'Build a wall' - self.secret_value = symmetric_encrypt(encrypt_key=crypto_key, plaintext=self.secret) + self.secret = "Build a wall" + self.secret_value = symmetric_encrypt( + encrypt_key=crypto_key, plaintext=self.secret + ) self.env = jinja_utils.get_jinja_environment() def test_filter_decrypt_kv(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value=self.secret_value, - scope=FULL_SYSTEM_SCOPE, - secret=True)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="k8", value=self.secret_value, scope=FULL_SYSTEM_SCOPE, secret=True + ) + ) context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) - template = '{{st2kv.system.k8 | decrypt_kv}}' + template = "{{st2kv.system.k8 | decrypt_kv}}" actual = self.env.from_string(template).render(context) self.assertEqual(actual, self.secret) def test_filter_decrypt_kv_datastore_value_doesnt_exist(self): context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) - template = '{{st2kv.system.doesnt_exist | decrypt_kv}}' + template = "{{st2kv.system.doesnt_exist | decrypt_kv}}" - expected_msg = ('Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or ' - 'it contains an empty string') - self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render, - context) + expected_msg = ( + 'Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or ' + "it contains an empty string" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, self.env.from_string(template).render, context + ) def test_filter_decrypt_kv_with_user_scope_value(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k8', value=self.secret_value, - scope=FULL_USER_SCOPE, - secret=True)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="stanley:k8", + value=self.secret_value, + scope=FULL_USER_SCOPE, + secret=True, + ) + ) context = {} - context.update({USER_SCOPE: UserKeyValueLookup(user='stanley', scope=USER_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) + context.update( + {USER_SCOPE: UserKeyValueLookup(user="stanley", scope=USER_SCOPE)} + ) + context.update( + { + DATASTORE_PARENT_SCOPE: { + USER_SCOPE: UserKeyValueLookup( + user="stanley", scope=FULL_USER_SCOPE + ) + } } - }) + ) - template = '{{st2kv.user.k8 | decrypt_kv}}' + template = "{{st2kv.user.k8 | decrypt_kv}}" actual = self.env.from_string(template).render(context) self.assertEqual(actual, self.secret) def test_filter_decrypt_kv_with_user_scope_value_datastore_value_doesnt_exist(self): context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + USER_SCOPE: UserKeyValueLookup( + user="stanley", scope=FULL_USER_SCOPE + ) + } } - }) + ) - template = '{{st2kv.user.doesnt_exist | decrypt_kv}}' + template = "{{st2kv.user.doesnt_exist | decrypt_kv}}" - expected_msg = ('Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or ' - 'it contains an empty string') - self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render, - context) + expected_msg = ( + 'Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or ' + "it contains an empty string" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, self.env.from_string(template).render, context + ) diff --git a/st2common/tests/unit/test_jinja_render_data_filters.py b/st2common/tests/unit/test_jinja_render_data_filters.py index fd923e870f..44d2f296f9 100644 --- a/st2common/tests/unit/test_jinja_render_data_filters.py +++ b/st2common/tests/unit/test_jinja_render_data_filters.py @@ -24,77 +24,68 @@ class JinjaUtilsDataFilterTestCase(unittest2.TestCase): - def test_filter_from_json_string(self): env = jinja_utils.get_jinja_environment() - expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} obj_json_str = '{"a": "b", "c": {"d": "e", "f": 1, "g": true}}' - template = '{{k1 | from_json_string}}' + template = "{{k1 | from_json_string}}" - obj_str = env.from_string(template).render({'k1': obj_json_str}) + obj_str = env.from_string(template).render({"k1": obj_json_str}) obj = eval(obj_str) self.assertDictEqual(obj, expected_obj) # With KeyValueLookup object env = jinja_utils.get_jinja_environment() obj_json_str = '["a", "b", "c"]' - expected_obj = ['a', 'b', 'c'] + expected_obj = ["a", "b", "c"] - template = '{{ k1 | from_json_string}}' + template = "{{ k1 | from_json_string}}" - lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='a') - lookup._value_cache['a'] = obj_json_str - obj_str = env.from_string(template).render({'k1': lookup}) + lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="a") + lookup._value_cache["a"] = obj_json_str + obj_str = env.from_string(template).render({"k1": lookup}) obj = eval(obj_str) self.assertEqual(obj, expected_obj) def test_filter_from_yaml_string(self): env = jinja_utils.get_jinja_environment() - expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} - obj_yaml_str = ("---\n" - "a: b\n" - "c:\n" - " d: e\n" - " f: 1\n" - " g: true\n") - - template = '{{k1 | from_yaml_string}}' - obj_str = env.from_string(template).render({'k1': obj_yaml_str}) + expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} + obj_yaml_str = "---\n" "a: b\n" "c:\n" " d: e\n" " f: 1\n" " g: true\n" + + template = "{{k1 | from_yaml_string}}" + obj_str = env.from_string(template).render({"k1": obj_yaml_str}) obj = eval(obj_str) self.assertDictEqual(obj, expected_obj) # With KeyValueLookup object env = jinja_utils.get_jinja_environment() - obj_yaml_str = ("---\n" - "- a\n" - "- b\n" - "- c\n") - expected_obj = ['a', 'b', 'c'] + obj_yaml_str = "---\n" "- a\n" "- b\n" "- c\n" + expected_obj = ["a", "b", "c"] - template = '{{ k1 | from_yaml_string }}' + template = "{{ k1 | from_yaml_string }}" - lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='b') - lookup._value_cache['b'] = obj_yaml_str - obj_str = env.from_string(template).render({'k1': lookup}) + lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="b") + lookup._value_cache["b"] = obj_yaml_str + obj_str = env.from_string(template).render({"k1": lookup}) obj = eval(obj_str) self.assertEqual(obj, expected_obj) def test_filter_to_json_string(self): env = jinja_utils.get_jinja_environment() - obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} - template = '{{k1 | to_json_string}}' + template = "{{k1 | to_json_string}}" - obj_json_str = env.from_string(template).render({'k1': obj}) + obj_json_str = env.from_string(template).render({"k1": obj}) actual_obj = json.loads(obj_json_str) self.assertDictEqual(obj, actual_obj) def test_filter_to_yaml_string(self): env = jinja_utils.get_jinja_environment() - obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} - template = '{{k1 | to_yaml_string}}' - obj_yaml_str = env.from_string(template).render({'k1': obj}) + template = "{{k1 | to_yaml_string}}" + obj_yaml_str = env.from_string(template).render({"k1": obj}) actual_obj = yaml.safe_load(obj_yaml_str) self.assertDictEqual(obj, actual_obj) diff --git a/st2common/tests/unit/test_jinja_render_json_escape_filters.py b/st2common/tests/unit/test_jinja_render_json_escape_filters.py index 82534100c5..48fef776c1 100644 --- a/st2common/tests/unit/test_jinja_render_json_escape_filters.py +++ b/st2common/tests/unit/test_jinja_render_json_escape_filters.py @@ -21,52 +21,51 @@ class JinjaUtilsJsonEscapeTestCase(unittest2.TestCase): - def test_doublequotes(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo """ bar'}) + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": 'foo """ bar'}) expected = 'foo \\"\\"\\" bar' self.assertEqual(actual, expected) def test_backslashes(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': r'foo \ bar'}) - expected = 'foo \\\\ bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": r"foo \ bar"}) + expected = "foo \\\\ bar" self.assertEqual(actual, expected) def test_backspace(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \b bar'}) - expected = 'foo \\b bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \b bar"}) + expected = "foo \\b bar" self.assertEqual(actual, expected) def test_formfeed(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \f bar'}) - expected = 'foo \\f bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \f bar"}) + expected = "foo \\f bar" self.assertEqual(actual, expected) def test_newline(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \n bar'}) - expected = 'foo \\n bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \n bar"}) + expected = "foo \\n bar" self.assertEqual(actual, expected) def test_carriagereturn(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \r bar'}) - expected = 'foo \\r bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \r bar"}) + expected = "foo \\r bar" self.assertEqual(actual, expected) def test_tab(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \t bar'}) - expected = 'foo \\t bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \t bar"}) + expected = "foo \\t bar" self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py index fd199ebf64..934aa04de8 100644 --- a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py +++ b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py @@ -21,49 +21,58 @@ class JinjaUtilsJsonpathQueryTestCase(unittest2.TestCase): - def test_jsonpath_query_static(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } template = '{{ obj | jsonpath_query("people[*].first") }}' - actual_str = env.from_string(template).render({'obj': obj}) + actual_str = env.from_string(template).render({"obj": obj}) actual = eval(actual_str) - expected = ['James', 'Jacob', 'Jayden'] + expected = ["James", "Jacob", "Jayden"] self.assertEqual(actual, expected) def test_jsonpath_query_dynamic(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } query = "people[*].last" - template = '{{ obj | jsonpath_query(query) }}' - actual_str = env.from_string(template).render({'obj': obj, - 'query': query}) + template = "{{ obj | jsonpath_query(query) }}" + actual_str = env.from_string(template).render({"obj": obj, "query": query}) actual = eval(actual_str) - expected = ['d', 'e', 'f'] + expected = ["d", "e", "f"] self.assertEqual(actual, expected) def test_jsonpath_query_no_results(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } query = "query_returns_no_results" - template = '{{ obj | jsonpath_query(query) }}' - actual_str = env.from_string(template).render({'obj': obj, - 'query': query}) + template = "{{ obj | jsonpath_query(query) }}" + actual_str = env.from_string(template).render({"obj": obj, "query": query}) actual = eval(actual_str) expected = None self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_jinja_render_path_filters.py b/st2common/tests/unit/test_jinja_render_path_filters.py index 504b6454bb..23507bbbc1 100644 --- a/st2common/tests/unit/test_jinja_render_path_filters.py +++ b/st2common/tests/unit/test_jinja_render_path_filters.py @@ -21,29 +21,28 @@ class JinjaUtilsPathFilterTestCase(unittest2.TestCase): - def test_basename(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | basename}}' - actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'}) - self.assertEqual(actual, 'file.txt') + template = "{{k1 | basename}}" + actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"}) + self.assertEqual(actual, "file.txt") - actual = env.from_string(template).render({'k1': '/some/path/to/dir'}) - self.assertEqual(actual, 'dir') + actual = env.from_string(template).render({"k1": "/some/path/to/dir"}) + self.assertEqual(actual, "dir") - actual = env.from_string(template).render({'k1': '/some/path/to/dir/'}) - self.assertEqual(actual, '') + actual = env.from_string(template).render({"k1": "/some/path/to/dir/"}) + self.assertEqual(actual, "") def test_dirname(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | dirname}}' - actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'}) - self.assertEqual(actual, '/some/path/to') + template = "{{k1 | dirname}}" + actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"}) + self.assertEqual(actual, "/some/path/to") - actual = env.from_string(template).render({'k1': '/some/path/to/dir'}) - self.assertEqual(actual, '/some/path/to') + actual = env.from_string(template).render({"k1": "/some/path/to/dir"}) + self.assertEqual(actual, "/some/path/to") - actual = env.from_string(template).render({'k1': '/some/path/to/dir/'}) - self.assertEqual(actual, '/some/path/to/dir') + actual = env.from_string(template).render({"k1": "/some/path/to/dir/"}) + self.assertEqual(actual, "/some/path/to/dir") diff --git a/st2common/tests/unit/test_jinja_render_regex_filters.py b/st2common/tests/unit/test_jinja_render_regex_filters.py index 081d068682..df2e347779 100644 --- a/st2common/tests/unit/test_jinja_render_regex_filters.py +++ b/st2common/tests/unit/test_jinja_render_regex_filters.py @@ -20,54 +20,53 @@ class JinjaUtilsRegexFilterTestCase(unittest2.TestCase): - def test_filters_regex_match(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_match("x")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_match("y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'False' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "False" self.assertEqual(actual, expected) template = '{{k1 | regex_match("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}' - actual = env.from_string(template).render({'k1': 'v0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "v0.10.1"}) + expected = "True" self.assertEqual(actual, expected) def test_filters_regex_replace(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_replace("x", "y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'yyz' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "yyz" self.assertEqual(actual, expected) template = '{{k1 | regex_replace("(blue|white|red)", "color")}}' - actual = env.from_string(template).render({'k1': 'blue socks and red shoes'}) - expected = 'color socks and color shoes' + actual = env.from_string(template).render({"k1": "blue socks and red shoes"}) + expected = "color socks and color shoes" self.assertEqual(actual, expected) def test_filters_regex_search(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_search("x")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_search("y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_search("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}' - actual = env.from_string(template).render({'k1': 'v0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "v0.10.1"}) + expected = "True" self.assertEqual(actual, expected) def test_filters_regex_substring(self): @@ -76,29 +75,31 @@ def test_filters_regex_substring(self): # Normal (match) template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))")}}' actual = env.from_string(template).render( - {'input_str': 'My address is 123 Somewhere Ave. See you soon!'} + {"input_str": "My address is 123 Somewhere Ave. See you soon!"} ) - expected = '123 Somewhere Ave' + expected = "123 Somewhere Ave" self.assertEqual(actual, expected) # Selecting second match explicitly template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}' actual = env.from_string(template).render( - {'input_str': 'Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave.'} + { + "input_str": "Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave." + } ) - expected = '123 Somewhere Ave' + expected = "123 Somewhere Ave" self.assertEqual(actual, expected) # Selecting second match explicitly, but doesn't exist template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}' with self.assertRaises(IndexError): actual = env.from_string(template).render( - {'input_str': 'Your address is 567 Elsewhere Dr.'} + {"input_str": "Your address is 567 Elsewhere Dr."} ) # No match template = r'{{input_str | regex_substring("([0-3]{3} \w+ (?:Ave|St|Dr))")}}' with self.assertRaises(IndexError): actual = env.from_string(template).render( - {'input_str': 'My address is 986 Somewhere Ave. See you soon!'} + {"input_str": "My address is 986 Somewhere Ave. See you soon!"} ) diff --git a/st2common/tests/unit/test_jinja_render_time_filters.py b/st2common/tests/unit/test_jinja_render_time_filters.py index 5151cec695..2cf002a0e3 100644 --- a/st2common/tests/unit/test_jinja_render_time_filters.py +++ b/st2common/tests/unit/test_jinja_render_time_filters.py @@ -20,16 +20,16 @@ class JinjaUtilsTimeFilterTestCase(unittest2.TestCase): - def test_to_human_time_filter(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | to_human_time_from_seconds}}' - actual = env.from_string(template).render({'k1': 12345}) - self.assertEqual(actual, '3h25m45s') + template = "{{k1 | to_human_time_from_seconds}}" + actual = env.from_string(template).render({"k1": 12345}) + self.assertEqual(actual, "3h25m45s") - actual = env.from_string(template).render({'k1': 0}) - self.assertEqual(actual, '0s') + actual = env.from_string(template).render({"k1": 0}) + self.assertEqual(actual, "0s") - self.assertRaises(AssertionError, env.from_string(template).render, - {'k1': 'stuff'}) + self.assertRaises( + AssertionError, env.from_string(template).render, {"k1": "stuff"} + ) diff --git a/st2common/tests/unit/test_jinja_render_version_filters.py b/st2common/tests/unit/test_jinja_render_version_filters.py index 9cbacd7dcb..41b2b23670 100644 --- a/st2common/tests/unit/test_jinja_render_version_filters.py +++ b/st2common/tests/unit/test_jinja_render_version_filters.py @@ -21,134 +21,133 @@ class JinjaUtilsVersionsFilterTestCase(unittest2.TestCase): - def test_version_compare(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = '-1' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "-1" self.assertEqual(actual, expected) template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '1' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "1" self.assertEqual(actual, expected) template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = '0' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "0" self.assertEqual(actual, expected) def test_version_more_than(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "True" self.assertEqual(actual, expected) template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "False" self.assertEqual(actual, expected) def test_version_less_than(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "True" self.assertEqual(actual, expected) template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "False" self.assertEqual(actual, expected) def test_version_equal(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "True" self.assertEqual(actual, expected) def test_version_match(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_match(">0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '0.1.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.1.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_match("<0.10.0")}}' - actual = env.from_string(template).render({'version': '0.1.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.1.0"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '1.1.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "1.1.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_match("==0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) def test_version_bump_major(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_major}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '1.0.0' + template = "{{version | version_bump_major}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "1.0.0" self.assertEqual(actual, expected) def test_version_bump_minor(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_minor}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.11.0' + template = "{{version | version_bump_minor}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.11.0" self.assertEqual(actual, expected) def test_version_bump_patch(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_patch}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.10.2' + template = "{{version | version_bump_patch}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.10.2" self.assertEqual(actual, expected) def test_version_strip_patch(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_strip_patch}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.10' + template = "{{version | version_strip_patch}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.10" self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_json_schema.py b/st2common/tests/unit/test_json_schema.py index 42b94efb70..892e2604ed 100644 --- a/st2common/tests/unit/test_json_schema.py +++ b/st2common/tests/unit/test_json_schema.py @@ -20,158 +20,127 @@ from st2common.util import schema as util_schema TEST_SCHEMA_1 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", }, - 'arg_optional_no_type': { - 'description': 'Bar' + "arg_optional_no_type": {"description": "Bar"}, + "arg_optional_multi_type": { + "description": "Mirror mirror", + "type": ["string", "boolean", "number"], }, - 'arg_optional_multi_type': { - 'description': 'Mirror mirror', - 'type': ['string', 'boolean', 'number'] + "arg_optional_multi_type_none": { + "description": "Mirror mirror on the wall", + "type": ["string", "boolean", "number", "null"], }, - 'arg_optional_multi_type_none': { - 'description': 'Mirror mirror on the wall', - 'type': ['string', 'boolean', 'number', 'null'] + "arg_optional_type_array": { + "description": "Who" "s the fairest?", + "type": "array", }, - 'arg_optional_type_array': { - 'description': 'Who''s the fairest?', - 'type': 'array' + "arg_optional_type_object": { + "description": "Who" "s the fairest of them?", + "type": "object", }, - 'arg_optional_type_object': { - 'description': 'Who''s the fairest of them?', - 'type': 'object' + "arg_optional_multi_collection_type": { + "description": "Who" "s the fairest of them all?", + "type": ["array", "object"], }, - 'arg_optional_multi_collection_type': { - 'description': 'Who''s the fairest of them all?', - 'type': ['array', 'object'] - } - } + }, } TEST_SCHEMA_2 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_required_default': { - 'default': 'date', - 'description': 'Foo', - 'required': True, - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_required_default": { + "default": "date", + "description": "Foo", + "required": True, + "type": "string", } - } + }, } TEST_SCHEMA_3 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "type": "string", }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'type': 'string' + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "type": "string", }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'type': 'string' - } - } + "arg_optional_no_default": {"description": "Foo", "type": "string"}, + }, } TEST_SCHEMA_4 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_no_default": { + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default_anyof_none': { - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'}, - {'type': 'null'} - ] - } - } + "arg_optional_no_default_anyof_none": { + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}], + }, + }, } TEST_SCHEMA_5 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_no_default": { + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default_oneof_none': { - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'}, - {'type': 'null'} - ] - } - } + "arg_optional_no_default_oneof_none": { + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}], + }, + }, } @@ -181,192 +150,265 @@ def test_use_default_value(self): instance = {} validator = util_schema.get_validator() - expected_msg = '\'arg_required_no_default\' is a required property' - self.assertRaisesRegexp(ValidationError, expected_msg, util_schema.validate, - instance=instance, schema=TEST_SCHEMA_1, cls=validator, - use_default=True) + expected_msg = "'arg_required_no_default' is a required property" + self.assertRaisesRegexp( + ValidationError, + expected_msg, + util_schema.validate, + instance=instance, + schema=TEST_SCHEMA_1, + cls=validator, + use_default=True, + ) # No default, value provided - instance = {'arg_required_no_default': 'foo'} - util_schema.validate(instance=instance, schema=TEST_SCHEMA_1, cls=validator, - use_default=True) + instance = {"arg_required_no_default": "foo"} + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_1, cls=validator, use_default=True + ) # default value provided, no value, should pass instance = {} validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator, - use_default=True) + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True + ) # default value provided, value provided, should pass - instance = {'arg_required_default': 'foo'} + instance = {"arg_required_default": "foo"} validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator, - use_default=True) + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True + ) def test_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_3, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_3, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_3, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_3, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_anyof_type_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_4, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_4, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_anyof_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None, - 'arg_optional_no_default_anyof_none': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, + "arg_optional_no_default_anyof_none": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_4, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_4, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_oneof_type_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_5, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_5, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_oneof_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None, - 'arg_optional_no_default_oneof_none': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, + "arg_optional_no_default_oneof_none": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_5, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_5, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_is_property_type_single(self): - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertTrue(util_schema.is_property_type_single(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertTrue(util_schema.is_property_type_single(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_single(multi_typed_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_single(anyof_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_single(oneof_property)) def test_is_property_type_anyof(self): - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertTrue(util_schema.is_property_type_anyof(anyof_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_anyof(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_anyof(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_anyof(multi_typed_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_anyof(oneof_property)) def test_is_property_type_oneof(self): - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertTrue(util_schema.is_property_type_oneof(oneof_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_oneof(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_oneof(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_oneof(multi_typed_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_oneof(anyof_property)) def test_is_property_type_list(self): - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertTrue(util_schema.is_property_type_list(multi_typed_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_list(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_list(untyped_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_list(anyof_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_list(oneof_property)) def test_is_property_nullable(self): - multi_typed_prop_nullable = TEST_SCHEMA_1['properties']['arg_optional_multi_type_none'] - self.assertTrue(util_schema.is_property_nullable(multi_typed_prop_nullable.get('type'))) - - anyof_property_nullable = TEST_SCHEMA_4['properties']['arg_optional_no_default_anyof_none'] - self.assertTrue(util_schema.is_property_nullable(anyof_property_nullable.get('anyOf'))) - - oneof_property_nullable = TEST_SCHEMA_5['properties']['arg_optional_no_default_oneof_none'] - self.assertTrue(util_schema.is_property_nullable(oneof_property_nullable.get('oneOf'))) - - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + multi_typed_prop_nullable = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_type_none" + ] + self.assertTrue( + util_schema.is_property_nullable(multi_typed_prop_nullable.get("type")) + ) + + anyof_property_nullable = TEST_SCHEMA_4["properties"][ + "arg_optional_no_default_anyof_none" + ] + self.assertTrue( + util_schema.is_property_nullable(anyof_property_nullable.get("anyOf")) + ) + + oneof_property_nullable = TEST_SCHEMA_5["properties"][ + "arg_optional_no_default_oneof_none" + ] + self.assertTrue( + util_schema.is_property_nullable(oneof_property_nullable.get("oneOf")) + ) + + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_nullable(typed_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_property_nullable(multi_typed_property.get('type'))) + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_property_nullable(multi_typed_property.get("type")) + ) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_no_default'] - self.assertFalse(util_schema.is_property_nullable(anyof_property.get('anyOf'))) + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_no_default"] + self.assertFalse(util_schema.is_property_nullable(anyof_property.get("anyOf"))) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_no_default'] - self.assertFalse(util_schema.is_property_nullable(oneof_property.get('oneOf'))) + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_no_default"] + self.assertFalse(util_schema.is_property_nullable(oneof_property.get("oneOf"))) def test_is_attribute_type_array(self): - multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type'] - self.assertTrue(util_schema.is_attribute_type_array(multi_coll_typed_prop.get('type'))) - - array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array'] - self.assertTrue(util_schema.is_attribute_type_array(array_type_property.get('type'))) - - multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_attribute_type_array(multi_non_coll_prop.get('type'))) - - object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object'] - self.assertFalse(util_schema.is_attribute_type_array(object_type_property.get('type'))) + multi_coll_typed_prop = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_collection_type" + ] + self.assertTrue( + util_schema.is_attribute_type_array(multi_coll_typed_prop.get("type")) + ) + + array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"] + self.assertTrue( + util_schema.is_attribute_type_array(array_type_property.get("type")) + ) + + multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_attribute_type_array(multi_non_coll_prop.get("type")) + ) + + object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"] + self.assertFalse( + util_schema.is_attribute_type_array(object_type_property.get("type")) + ) def test_is_attribute_type_object(self): - multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type'] - self.assertTrue(util_schema.is_attribute_type_object(multi_coll_typed_prop.get('type'))) - - object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object'] - self.assertTrue(util_schema.is_attribute_type_object(object_type_property.get('type'))) - - multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_attribute_type_object(multi_non_coll_prop.get('type'))) - - array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array'] - self.assertFalse(util_schema.is_attribute_type_object(array_type_property.get('type'))) + multi_coll_typed_prop = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_collection_type" + ] + self.assertTrue( + util_schema.is_attribute_type_object(multi_coll_typed_prop.get("type")) + ) + + object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"] + self.assertTrue( + util_schema.is_attribute_type_object(object_type_property.get("type")) + ) + + multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_attribute_type_object(multi_non_coll_prop.get("type")) + ) + + array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"] + self.assertFalse( + util_schema.is_attribute_type_object(array_type_property.get("type")) + ) diff --git a/st2common/tests/unit/test_jsonify.py b/st2common/tests/unit/test_jsonify.py index 801d912333..1feaac96b0 100644 --- a/st2common/tests/unit/test_jsonify.py +++ b/st2common/tests/unit/test_jsonify.py @@ -20,33 +20,32 @@ class JsonifyTests(unittest2.TestCase): - def test_none_object(self): obj = None self.assertIsNone(jsonify.json_loads(obj)) def test_no_keys(self): - obj = {'foo': '{"bar": "baz"}'} + obj = {"foo": '{"bar": "baz"}'} transformed_obj = jsonify.json_loads(obj) - self.assertTrue(transformed_obj['foo']['bar'] == 'baz') + self.assertTrue(transformed_obj["foo"]["bar"] == "baz") def test_no_json_value(self): - obj = {'foo': 'bar'} + obj = {"foo": "bar"} transformed_obj = jsonify.json_loads(obj) - self.assertTrue(transformed_obj['foo'] == 'bar') + self.assertTrue(transformed_obj["foo"] == "bar") def test_happy_case(self): - obj = {'foo': '{"bar": "baz"}', 'yo': 'bibimbao'} - transformed_obj = jsonify.json_loads(obj, ['yo']) - self.assertTrue(transformed_obj['yo'] == 'bibimbao') + obj = {"foo": '{"bar": "baz"}', "yo": "bibimbao"} + transformed_obj = jsonify.json_loads(obj, ["yo"]) + self.assertTrue(transformed_obj["yo"] == "bibimbao") def test_try_loads(self): # The function json.loads will fail and the function should return the original value. - values = ['abc', 123, True, object()] + values = ["abc", 123, True, object()] for value in values: self.assertEqual(jsonify.try_loads(value), value) # The function json.loads succeed. d = '{"a": 1, "b": true}' - expected = {'a': 1, 'b': True} + expected = {"a": 1, "b": True} self.assertDictEqual(jsonify.try_loads(d), expected) diff --git a/st2common/tests/unit/test_keyvalue_lookup.py b/st2common/tests/unit/test_keyvalue_lookup.py index f37cc04dc9..afcd76901a 100644 --- a/st2common/tests/unit/test_keyvalue_lookup.py +++ b/st2common/tests/unit/test_keyvalue_lookup.py @@ -24,23 +24,29 @@ class TestKeyValueLookup(CleanDbTestCase): def test_lookup_with_key_prefix(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='some:prefix:stanley:k5', value='v5', - scope=FULL_USER_SCOPE)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="some:prefix:stanley:k5", value="v5", scope=FULL_USER_SCOPE + ) + ) # No prefix provided, should return None - lookup = UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) - self.assertEqual(str(lookup.k5), '') + lookup = UserKeyValueLookup(user="stanley", scope=FULL_USER_SCOPE) + self.assertEqual(str(lookup.k5), "") # Prefix provided - lookup = UserKeyValueLookup(prefix='some:prefix', user='stanley', scope=FULL_USER_SCOPE) - self.assertEqual(str(lookup.k5), 'v5') + lookup = UserKeyValueLookup( + prefix="some:prefix", user="stanley", scope=FULL_USER_SCOPE + ) + self.assertEqual(str(lookup.k5), "v5") def test_non_hierarchical_lookup(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='k3', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k4', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="k3", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:k4", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() self.assertEqual(str(lookup.k1), k1.value) @@ -49,108 +55,119 @@ def test_non_hierarchical_lookup(self): # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.k4), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + self.assertEqual(str(lookup.k4), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.k4), k4.value) def test_hierarchical_lookup_dotted(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() self.assertEqual(str(lookup.a.b), k1.value) self.assertEqual(str(lookup.a.b.c), k2.value) self.assertEqual(str(lookup.b.c), k3.value) - self.assertEqual(str(lookup.a), '') + self.assertEqual(str(lookup.a), "") # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.r.i.p), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + self.assertEqual(str(lookup.r.i.p), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.r.i.p), k4.value) def test_hierarchical_lookup_dict(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() - self.assertEqual(str(lookup['a']['b']), k1.value) - self.assertEqual(str(lookup['a']['b']['c']), k2.value) - self.assertEqual(str(lookup['b']['c']), k3.value) - self.assertEqual(str(lookup['a']), '') + self.assertEqual(str(lookup["a"]["b"]), k1.value) + self.assertEqual(str(lookup["a"]["b"]["c"]), k2.value) + self.assertEqual(str(lookup["b"]["c"]), k3.value) + self.assertEqual(str(lookup["a"]), "") # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup['r']['i']['p']), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup['r']['i']['p']), k4.value) + self.assertEqual(str(lookup["r"]["i"]["p"]), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup["r"]["i"]["p"]), k4.value) def test_lookups_older_scope_names_backward_compatibility(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1', - scope=FULL_SYSTEM_SCOPE)) + k1 = KeyValuePair.add_or_update( + KeyValuePairDB(name="a.b", value="v1", scope=FULL_SYSTEM_SCOPE) + ) lookup = KeyValueLookup(scope=SYSTEM_SCOPE) - self.assertEqual(str(lookup['a']['b']), k1.value) + self.assertEqual(str(lookup["a"]["b"]), k1.value) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) - user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup['r']['i']['p']), k2.value) + k2 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) + user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup["r"]["i"]["p"]), k2.value) def test_user_scope_lookups_dot_in_user(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='first.last:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) - lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='first.last') - self.assertEqual(str(lookup.r.i.p), 'v4') - self.assertEqual(str(lookup['r']['i']['p']), 'v4') + KeyValuePair.add_or_update( + KeyValuePairDB(name="first.last:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) + lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="first.last") + self.assertEqual(str(lookup.r.i.p), "v4") + self.assertEqual(str(lookup["r"]["i"]["p"]), "v4") def test_user_scope_lookups_user_sep_in_name(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r:i:p', value='v4', - scope=FULL_USER_SCOPE)) - lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r:i:p", value="v4", scope=FULL_USER_SCOPE) + ) + lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") # This is the only way to lookup because USER_SEPARATOR (':') cannot be a part of # variable name in Python. - self.assertEqual(str(lookup['r:i:p']), 'v4') + self.assertEqual(str(lookup["r:i:p"]), "v4") def test_missing_key_lookup(self): lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.missing_key), '') - self.assertTrue(lookup.missing_key, 'Should be not none.') + self.assertEqual(str(lookup.missing_key), "") + self.assertTrue(lookup.missing_key, "Should be not none.") - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup.missing_key), '') - self.assertTrue(user_lookup.missing_key, 'Should be not none.') + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup.missing_key), "") + self.assertTrue(user_lookup.missing_key, "Should be not none.") def test_secret_lookup(self): - secret_value = '0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3' + \ - '64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638' - k1 = KeyValuePair.add_or_update(KeyValuePairDB( - name='k1', value=secret_value, - secret=True) + secret_value = ( + "0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3" + + "64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638" + ) + k1 = KeyValuePair.add_or_update( + KeyValuePairDB(name="k1", value=secret_value, secret=True) ) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB( - name='stanley:k3', value=secret_value, scope=FULL_USER_SCOPE, - secret=True) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) + k3 = KeyValuePair.add_or_update( + KeyValuePairDB( + name="stanley:k3", + value=secret_value, + scope=FULL_USER_SCOPE, + secret=True, + ) ) lookup = KeyValueLookup() self.assertEqual(str(lookup.k1), k1.value) self.assertEqual(str(lookup.k2), k2.value) - self.assertEqual(str(lookup.k3), '') + self.assertEqual(str(lookup.k3), "") - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.k3), k3.value) def test_lookup_cast(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='count', value='5.5')) + KeyValuePair.add_or_update(KeyValuePairDB(name="count", value="5.5")) lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.count), '5.5') + self.assertEqual(str(lookup.count), "5.5") self.assertEqual(float(lookup.count), 5.5) self.assertEqual(int(lookup.count), 5) diff --git a/st2common/tests/unit/test_keyvalue_system_model.py b/st2common/tests/unit/test_keyvalue_system_model.py index ff834f2d6d..a8ea10822b 100644 --- a/st2common/tests/unit/test_keyvalue_system_model.py +++ b/st2common/tests/unit/test_keyvalue_system_model.py @@ -21,15 +21,19 @@ class UserKeyReferenceSystemModelTest(unittest2.TestCase): - def test_to_string_reference(self): - key_ref = UserKeyReference.to_string_reference(user='stanley', name='foo') - self.assertEqual(key_ref, 'stanley:foo') - self.assertRaises(ValueError, UserKeyReference.to_string_reference, user=None, name='foo') + key_ref = UserKeyReference.to_string_reference(user="stanley", name="foo") + self.assertEqual(key_ref, "stanley:foo") + self.assertRaises( + ValueError, UserKeyReference.to_string_reference, user=None, name="foo" + ) def test_from_string_reference(self): - user, name = UserKeyReference.from_string_reference('stanley:foo') - self.assertEqual(user, 'stanley') - self.assertEqual(name, 'foo') - self.assertRaises(InvalidUserKeyReferenceError, UserKeyReference.from_string_reference, - 'this_key_has_no_sep') + user, name = UserKeyReference.from_string_reference("stanley:foo") + self.assertEqual(user, "stanley") + self.assertEqual(name, "foo") + self.assertRaises( + InvalidUserKeyReferenceError, + UserKeyReference.from_string_reference, + "this_key_has_no_sep", + ) diff --git a/st2common/tests/unit/test_logger.py b/st2common/tests/unit/test_logger.py index 30d18e9f89..79158b8b7f 100644 --- a/st2common/tests/unit/test_logger.py +++ b/st2common/tests/unit/test_logger.py @@ -36,13 +36,13 @@ import st2tests.config as tests_config CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) -CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, 'logging.conf') +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) +CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, "logging.conf") MOCK_MASKED_ATTRIBUTES_BLACKLIST = [ - 'blacklisted_1', - 'blacklisted_2', - 'blacklisted_3', + "blacklisted_1", + "blacklisted_2", + "blacklisted_3", ] @@ -69,9 +69,8 @@ def setUp(self): self.cfg_fd, self.cfg_path = tempfile.mkstemp() self.info_log_fd, self.info_log_path = tempfile.mkstemp() self.audit_log_fd, self.audit_log_path = tempfile.mkstemp() - with open(self.cfg_path, 'a') as f: - f.write(self.config_text.format(self.info_log_path, - self.audit_log_path)) + with open(self.cfg_path, "a") as f: + f.write(self.config_text.format(self.info_log_path, self.audit_log_path)) def tearDown(self): self._remove_tempfile(self.cfg_fd, self.cfg_path) @@ -84,7 +83,7 @@ def _remove_tempfile(self, fd, path): os.unlink(path) def test_logger_setup_failure(self): - config_file = '/tmp/abc123' + config_file = "/tmp/abc123" self.assertFalse(os.path.exists(config_file)) self.assertRaises(Exception, logging.setup, config_file) @@ -146,7 +145,7 @@ def test_format(self): formatter = ConsoleLogFormatter() # No extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message @@ -155,94 +154,109 @@ def test_format(self): self.assertEqual(message, mock_message) # Some extra attributes - mock_message = 'test message 2' + mock_message = "test message 2" record = MockRecord() record.msg = mock_message # Add "extra" attributes record._user_id = 1 - record._value = 'bar' - record.ignored = 'foo' # this one is ignored since it doesnt have a prefix + record._value = "bar" + record.ignored = "foo" # this one is ignored since it doesnt have a prefix message = formatter.format(record=record) - expected = 'test message 2 (value=\'bar\',user_id=1)' + expected = "test message 2 (value='bar',user_id=1)" self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_blacklisted_attributes_are_masked(self): formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._foo1 = "bar" message = formatter.format(record=record) - expected = ("test message 1 (blacklisted_1='********',blacklisted_2='********'," - "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," - "foo1='bar')") + expected = ( + "test message 1 (blacklisted_1='********',blacklisted_2='********'," + "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," + "foo1='bar')" + ) self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_custom_blacklist_attributes_are_masked(self): - cfg.CONF.set_override(group='log', name='mask_secrets_blacklist', - override=['blacklisted_4', 'blacklisted_5']) + cfg.CONF.set_override( + group="log", + name="mask_secrets_blacklist", + override=["blacklisted_4", "blacklisted_5"], + ) formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._blacklisted_4 = 'fowa' - record._blacklisted_5 = 'fiva' - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._blacklisted_4 = "fowa" + record._blacklisted_5 = "fiva" + record._foo1 = "bar" message = formatter.format(record=record) - expected = ("test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********'," - "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," - "blacklisted_4='********',blacklisted_5='********')") + expected = ( + "test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********'," + "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," + "blacklisted_4='********',blacklisted_5='********')" + ) self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_secret_action_parameters_are_masked(self): formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" parameters = { - 'parameter1': { - 'type': 'string', - 'required': False - }, - 'parameter2': { - 'type': 'string', - 'required': False, - 'secret': True - } + "parameter1": {"type": "string", "required": False}, + "parameter2": {"type": "string", "required": False, "secret": True}, } - mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters) + mock_action_db = ActionDB( + pack="testpack", name="test.action", parameters=parameters + ) action = mock_action_db.to_serializable_dict() - parameters = { - 'parameter1': 'value1', - 'parameter2': 'value2' - } - mock_action_execution_db = ActionExecutionDB(action=action, parameters=parameters) + parameters = {"parameter1": "value1", "parameter2": "value2"} + mock_action_execution_db = ActionExecutionDB( + action=action, parameters=parameters + ) record = MockRecord() record.msg = mock_message @@ -250,97 +264,94 @@ def test_format_secret_action_parameters_are_masked(self): # Add "extra" attributes record._action_execution_db = mock_action_execution_db - expected_msg_part = (r"'parameters': {u?'parameter1': u?'value1', " - r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}") + expected_msg_part = ( + r"'parameters': {u?'parameter1': u?'value1', " + r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}" + ) message = formatter.format(record=record) - self.assertIn('test message 1', message) + self.assertIn("test message 1", message) self.assertRegexpMatches(message, expected_msg_part) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_rule(self): expected_result = { - 'description': 'Test description', - 'tags': [], - 'type': { - 'ref': 'standard', - 'parameters': {}}, - 'enabled': True, - 'trigger': 'test tigger', - 'metadata_file': None, - 'context': {}, - 'criteria': {}, - 'action': { - 'ref': '1234', - 'parameters': {'b': 2}}, - 'uid': 'rule:testpack:test.action', - 'pack': 'testpack', - 'ref': 'testpack.test.action', - 'id': None, - 'name': 'test.action' + "description": "Test description", + "tags": [], + "type": {"ref": "standard", "parameters": {}}, + "enabled": True, + "trigger": "test tigger", + "metadata_file": None, + "context": {}, + "criteria": {}, + "action": {"ref": "1234", "parameters": {"b": 2}}, + "uid": "rule:testpack:test.action", + "pack": "testpack", + "ref": "testpack.test.action", + "id": None, + "name": "test.action", } - mock_rule_db = RuleDB(pack='testpack', - name='test.action', - description='Test description', - trigger='test tigger', - action={'ref': '1234', 'parameters': {'b': 2}}) + mock_rule_db = RuleDB( + pack="testpack", + name="test.action", + description="Test description", + trigger="test tigger", + action={"ref": "1234", "parameters": {"b": 2}}, + ) result = mock_rule_db.to_serializable_dict() self.assertEqual(expected_result, result) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) - @mock.patch('st2common.models.db.rule.RuleDB._get_referenced_action_model') - def test_format_secret_rule_parameters_are_masked(self, mock__get_referenced_action_model): + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) + @mock.patch("st2common.models.db.rule.RuleDB._get_referenced_action_model") + def test_format_secret_rule_parameters_are_masked( + self, mock__get_referenced_action_model + ): expected_result = { - 'description': 'Test description', - 'tags': [], - 'type': { - 'ref': 'standard', - 'parameters': {}}, - 'enabled': True, - 'trigger': 'test tigger', - 'metadata_file': None, - 'context': {}, - 'criteria': {}, - 'action': { - 'ref': '1234', - 'parameters': { - 'parameter1': 'value1', - 'parameter2': '********' - }}, - 'uid': 'rule:testpack:test.action', - 'pack': 'testpack', - 'ref': 'testpack.test.action', - 'id': None, - 'name': 'test.action' + "description": "Test description", + "tags": [], + "type": {"ref": "standard", "parameters": {}}, + "enabled": True, + "trigger": "test tigger", + "metadata_file": None, + "context": {}, + "criteria": {}, + "action": { + "ref": "1234", + "parameters": {"parameter1": "value1", "parameter2": "********"}, + }, + "uid": "rule:testpack:test.action", + "pack": "testpack", + "ref": "testpack.test.action", + "id": None, + "name": "test.action", } parameters = { - 'parameter1': { - 'type': 'string', - 'required': False - }, - 'parameter2': { - 'type': 'string', - 'required': False, - 'secret': True - } + "parameter1": {"type": "string", "required": False}, + "parameter2": {"type": "string", "required": False, "secret": True}, } - mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters) + mock_action_db = ActionDB( + pack="testpack", name="test.action", parameters=parameters + ) mock__get_referenced_action_model.return_value = mock_action_db - cfg.CONF.set_override(group='log', name='mask_secrets', - override=True) - mock_rule_db = RuleDB(pack='testpack', - name='test.action', - description='Test description', - trigger='test tigger', - action={'ref': '1234', - 'parameters': { - 'parameter1': 'value1', - 'parameter2': 'value2' - }}) + cfg.CONF.set_override(group="log", name="mask_secrets", override=True) + mock_rule_db = RuleDB( + pack="testpack", + name="test.action", + description="Test description", + trigger="test tigger", + action={ + "ref": "1234", + "parameters": {"parameter1": "value1", "parameter2": "value2"}, + }, + ) result = mock_rule_db.to_serializable_dict(True) @@ -355,11 +366,18 @@ def setUpClass(cls): def test_format(self): formatter = GelfLogFormatter() - expected_keys = ['version', 'host', 'short_message', 'full_message', - 'timestamp', 'timestamp_f', 'level'] + expected_keys = [ + "version", + "host", + "short_message", + "full_message", + "timestamp", + "timestamp_f", + "level", + ] # No extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message @@ -370,19 +388,19 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertEqual(parsed['full_message'], mock_message) + self.assertEqual(parsed["short_message"], mock_message) + self.assertEqual(parsed["full_message"], mock_message) # Some extra attributes - mock_message = 'test message 2' + mock_message = "test message 2" record = MockRecord() record.msg = mock_message # Add "extra" attributes record._user_id = 1 - record._value = 'bar' - record.ignored = 'foo' # this one is ignored since it doesnt have a prefix + record._value = "bar" + record.ignored = "foo" # this one is ignored since it doesnt have a prefix record.created = 1234.5678 message = formatter.format(record=record) @@ -391,16 +409,16 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertEqual(parsed['full_message'], mock_message) - self.assertEqual(parsed['_user_id'], 1) - self.assertEqual(parsed['_value'], 'bar') - self.assertEqual(parsed['timestamp'], 1234) - self.assertEqual(parsed['timestamp_f'], 1234.5678) - self.assertNotIn('ignored', parsed) + self.assertEqual(parsed["short_message"], mock_message) + self.assertEqual(parsed["full_message"], mock_message) + self.assertEqual(parsed["_user_id"], 1) + self.assertEqual(parsed["_value"], "bar") + self.assertEqual(parsed["timestamp"], 1234) + self.assertEqual(parsed["timestamp_f"], 1234.5678) + self.assertNotIn("ignored", parsed) # Record with an exception - mock_exception = Exception('mock exception bar') + mock_exception = Exception("mock exception bar") try: raise mock_exception @@ -408,7 +426,7 @@ def test_format(self): mock_exc_info = sys.exc_info() # Some extra attributes - mock_message = 'test message 3' + mock_message = "test message 3" record = MockRecord() record.msg = mock_message @@ -420,69 +438,77 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertIn(mock_message, parsed['full_message']) - self.assertIn('Traceback', parsed['full_message']) - self.assertIn('_exception', parsed) - self.assertIn('_traceback', parsed) + self.assertEqual(parsed["short_message"], mock_message) + self.assertIn(mock_message, parsed["full_message"]) + self.assertIn("Traceback", parsed["full_message"]) + self.assertIn("_exception", parsed) + self.assertIn("_traceback", parsed) def test_extra_object_serialization(self): class MyClass1(object): def __repr__(self): - return 'repr' + return "repr" class MyClass2(object): def to_dict(self): - return 'to_dict' + return "to_dict" class MyClass3(object): def to_serializable_dict(self, mask_secrets=False): - return 'to_serializable_dict' + return "to_serializable_dict" formatter = GelfLogFormatter() record = MockRecord() - record.msg = 'message' + record.msg = "message" record._obj1 = MyClass1() record._obj2 = MyClass2() record._obj3 = MyClass3() message = formatter.format(record=record) parsed = json.loads(message) - self.assertEqual(parsed['_obj1'], 'repr') - self.assertEqual(parsed['_obj2'], 'to_dict') - self.assertEqual(parsed['_obj3'], 'to_serializable_dict') - - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + self.assertEqual(parsed["_obj1"], "repr") + self.assertEqual(parsed["_obj2"], "to_dict") + self.assertEqual(parsed["_obj3"], "to_serializable_dict") + + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_blacklisted_attributes_are_masked(self): formatter = GelfLogFormatter() # Some extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._foo1 = "bar" message = formatter.format(record=record) parsed = json.loads(message) - self.assertEqual(parsed['_blacklisted_1'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_2'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_3']['key1'], 'val1') - self.assertEqual(parsed['_blacklisted_3']['blacklisted_1'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_3']['key3'], 'val3') - self.assertEqual(parsed['_foo1'], 'bar') + self.assertEqual(parsed["_blacklisted_1"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(parsed["_blacklisted_2"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(parsed["_blacklisted_3"]["key1"], "val1") + self.assertEqual( + parsed["_blacklisted_3"]["blacklisted_1"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(parsed["_blacklisted_3"]["key3"], "val3") + self.assertEqual(parsed["_foo1"], "bar") # Assert that the original dict is left unmodified - self.assertEqual(record._blacklisted_1, 'test value 1') - self.assertEqual(record._blacklisted_2, 'test value 2') - self.assertEqual(record._blacklisted_3['key1'], 'val1') - self.assertEqual(record._blacklisted_3['blacklisted_1'], 'val2') - self.assertEqual(record._blacklisted_3['key3'], 'val3') + self.assertEqual(record._blacklisted_1, "test value 1") + self.assertEqual(record._blacklisted_2, "test value 2") + self.assertEqual(record._blacklisted_3["key1"], "val1") + self.assertEqual(record._blacklisted_3["blacklisted_1"], "val2") + self.assertEqual(record._blacklisted_3["key3"], "val3") diff --git a/st2common/tests/unit/test_logging.py b/st2common/tests/unit/test_logging.py index 7dc4fc1b6d..ebb75b6f1d 100644 --- a/st2common/tests/unit/test_logging.py +++ b/st2common/tests/unit/test_logging.py @@ -21,25 +21,29 @@ from python_runner import python_runner from st2common import runners -__all__ = [ - 'LoggingMiscUtilsTestCase' -] +__all__ = ["LoggingMiscUtilsTestCase"] class LoggingMiscUtilsTestCase(unittest2.TestCase): def test_get_logger_name_for_module(self): logger_name = get_logger_name_for_module(sensormanager) - self.assertEqual(logger_name, 'st2reactor.cmd.sensormanager') + self.assertEqual(logger_name, "st2reactor.cmd.sensormanager") logger_name = get_logger_name_for_module(python_runner) - result = logger_name.endswith('contrib.runners.python_runner.python_runner.python_runner') + result = logger_name.endswith( + "contrib.runners.python_runner.python_runner.python_runner" + ) self.assertTrue(result) - logger_name = get_logger_name_for_module(python_runner, exclude_module_name=True) - self.assertTrue(logger_name.endswith('contrib.runners.python_runner.python_runner')) + logger_name = get_logger_name_for_module( + python_runner, exclude_module_name=True + ) + self.assertTrue( + logger_name.endswith("contrib.runners.python_runner.python_runner") + ) logger_name = get_logger_name_for_module(runners) - self.assertEqual(logger_name, 'st2common.runners.__init__') + self.assertEqual(logger_name, "st2common.runners.__init__") logger_name = get_logger_name_for_module(runners, exclude_module_name=True) - self.assertEqual(logger_name, 'st2common.runners') + self.assertEqual(logger_name, "st2common.runners") diff --git a/st2common/tests/unit/test_logging_middleware.py b/st2common/tests/unit/test_logging_middleware.py index 8a59177beb..b7d34de0bc 100644 --- a/st2common/tests/unit/test_logging_middleware.py +++ b/st2common/tests/unit/test_logging_middleware.py @@ -21,18 +21,15 @@ from st2common.middleware.logging import LoggingMiddleware from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE -__all__ = [ - 'LoggingMiddlewareTestCase' -] +__all__ = ["LoggingMiddlewareTestCase"] class LoggingMiddlewareTestCase(unittest2.TestCase): - @mock.patch('st2common.middleware.logging.LOG') - @mock.patch('st2common.middleware.logging.Request') + @mock.patch("st2common.middleware.logging.LOG") + @mock.patch("st2common.middleware.logging.Request") def test_secret_parameters_are_masked_in_log_message(self, mock_request, mock_log): - def app(environ, custom_start_response): - custom_start_response(status='200 OK', headers=[('Content-Length', 100)]) + custom_start_response(status="200 OK", headers=[("Content-Length", 100)]) return [None] router = mock.Mock() @@ -40,35 +37,38 @@ def app(environ, custom_start_response): router.match.return_value = (endpoint, None) middleware = LoggingMiddleware(app=app, router=router) - cfg.CONF.set_override(group='log', name='mask_secrets_blacklist', - override=['blacklisted_4', 'blacklisted_5']) + cfg.CONF.set_override( + group="log", + name="mask_secrets_blacklist", + override=["blacklisted_4", "blacklisted_5"], + ) environ = {} mock_request.return_value.GET.dict_of_lists.return_value = { - 'foo': 'bar', - 'bar': 'baz', - 'x-auth-token': 'secret', - 'st2-api-key': 'secret', - 'password': 'secret', - 'st2_auth_token': 'secret', - 'token': 'secret', - 'blacklisted_4': 'super secret', - 'blacklisted_5': 'super secret', + "foo": "bar", + "bar": "baz", + "x-auth-token": "secret", + "st2-api-key": "secret", + "password": "secret", + "st2_auth_token": "secret", + "token": "secret", + "blacklisted_4": "super secret", + "blacklisted_5": "super secret", } middleware(environ=environ, start_response=mock.Mock()) expected_query = { - 'foo': 'bar', - 'bar': 'baz', - 'x-auth-token': MASKED_ATTRIBUTE_VALUE, - 'st2-api-key': MASKED_ATTRIBUTE_VALUE, - 'password': MASKED_ATTRIBUTE_VALUE, - 'token': MASKED_ATTRIBUTE_VALUE, - 'st2_auth_token': MASKED_ATTRIBUTE_VALUE, - 'blacklisted_4': MASKED_ATTRIBUTE_VALUE, - 'blacklisted_5': MASKED_ATTRIBUTE_VALUE, + "foo": "bar", + "bar": "baz", + "x-auth-token": MASKED_ATTRIBUTE_VALUE, + "st2-api-key": MASKED_ATTRIBUTE_VALUE, + "password": MASKED_ATTRIBUTE_VALUE, + "token": MASKED_ATTRIBUTE_VALUE, + "st2_auth_token": MASKED_ATTRIBUTE_VALUE, + "blacklisted_4": MASKED_ATTRIBUTE_VALUE, + "blacklisted_5": MASKED_ATTRIBUTE_VALUE, } call_kwargs = mock_log.info.call_args_list[0][1] - query = call_kwargs['extra']['query'] + query = call_kwargs["extra"]["query"] self.assertEqual(query, expected_query) diff --git a/st2common/tests/unit/test_metrics.py b/st2common/tests/unit/test_metrics.py index 084db97c62..4b0df66aa1 100644 --- a/st2common/tests/unit/test_metrics.py +++ b/st2common/tests/unit/test_metrics.py @@ -29,16 +29,16 @@ from st2common.util.date import get_datetime_utc_now __all__ = [ - 'TestBaseMetricsDriver', - 'TestStatsDMetricsDriver', - 'TestCounterContextManager', - 'TestTimerContextManager', - 'TestCounterWithTimerContextManager' + "TestBaseMetricsDriver", + "TestStatsDMetricsDriver", + "TestCounterContextManager", + "TestTimerContextManager", + "TestCounterWithTimerContextManager", ] -cfg.CONF.set_override('driver', 'noop', group='metrics') -cfg.CONF.set_override('host', '127.0.0.1', group='metrics') -cfg.CONF.set_override('port', 8080, group='metrics') +cfg.CONF.set_override("driver", "noop", group="metrics") +cfg.CONF.set_override("host", "127.0.0.1", group="metrics") +cfg.CONF.set_override("port", 8080, group="metrics") class TestBaseMetricsDriver(unittest2.TestCase): @@ -48,45 +48,43 @@ def setUp(self): self._driver = base.BaseMetricsDriver() def test_time(self): - self._driver.time('test', 10) + self._driver.time("test", 10) def test_inc_counter(self): - self._driver.inc_counter('test') + self._driver.inc_counter("test") def test_dec_timer(self): - self._driver.dec_counter('test') + self._driver.dec_counter("test") class TestStatsDMetricsDriver(unittest2.TestCase): _driver = None - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def setUp(self, statsd): - cfg.CONF.set_override(name='prefix', group='metrics', override=None) + cfg.CONF.set_override(name="prefix", group="metrics", override=None) self._driver = StatsdDriver() statsd.Connection.set_defaults.assert_called_once_with( - host=cfg.CONF.metrics.host, - port=cfg.CONF.metrics.port, - sample_rate=1.0 + host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port, sample_rate=1.0 ) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_time(self, statsd): mock_timer = MagicMock() - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10) self._driver.time(*params) - statsd.Timer('').send.assert_called_with('st2.test', 10) + statsd.Timer("").send.assert_called_with("st2.test", 10) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_time_with_float(self, statsd): mock_timer = MagicMock() - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10.5) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10.5) self._driver.time(*params) - statsd.Timer().send.assert_called_with('st2.test', 10.5) + statsd.Timer().send.assert_called_with("st2.test", 10.5) def test_time_with_invalid_key(self): params = (2, 2) @@ -94,21 +92,21 @@ def test_time_with_invalid_key(self): self._driver.time(*params) def test_time_with_invalid_time(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.time(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_counter_with_default_amount(self, statsd): - key = 'test' + key = "test" mock_counter = MagicMock() statsd.Counter(key).increment.side_effect = mock_counter self._driver.inc_counter(key) mock_counter.assert_called_once_with(delta=1) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_counter_with_amount(self, statsd): - params = ('test', 2) + params = ("test", 2) mock_counter = MagicMock() statsd.Counter(params[0]).increment.side_effect = mock_counter self._driver.inc_counter(*params) @@ -120,21 +118,21 @@ def test_inc_timer_with_invalid_key(self): self._driver.inc_counter(*params) def test_inc_timer_with_invalid_amount(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.inc_counter(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_timer_with_default_amount(self, statsd): - key = 'test' + key = "test" mock_counter = MagicMock() statsd.Counter().decrement.side_effect = mock_counter self._driver.dec_counter(key) mock_counter.assert_called_once_with(delta=1) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_timer_with_amount(self, statsd): - params = ('test', 2) + params = ("test", 2) mock_counter = MagicMock() statsd.Counter().decrement.side_effect = mock_counter self._driver.dec_counter(*params) @@ -146,41 +144,41 @@ def test_dec_timer_with_invalid_key(self): self._driver.dec_counter(*params) def test_dec_timer_with_invalid_amount(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.dec_counter(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_set_gauge_success(self, statsd): - params = ('key', 100) + params = ("key", 100) mock_gauge = MagicMock() statsd.Gauge().send.side_effect = mock_gauge self._driver.set_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_gauge_success(self, statsd): - params = ('key1',) + params = ("key1",) mock_gauge = MagicMock() statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key2', 100) + params = ("key2", 100) mock_gauge = MagicMock() statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_gauge_success(self, statsd): - params = ('key1',) + params = ("key1",) mock_gauge = MagicMock() statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key2', 100) + params = ("key2", 100) mock_gauge = MagicMock() statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) @@ -188,71 +186,71 @@ def test_dec_gauge_success(self, statsd): def test_get_full_key_name(self): # No prefix specified in the config - cfg.CONF.set_override(name='prefix', group='metrics', override=None) + cfg.CONF.set_override(name="prefix", group="metrics", override=None) - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.api.requests") # Prefix is defined in the config - cfg.CONF.set_override(name='prefix', group='metrics', override='staging') + cfg.CONF.set_override(name="prefix", group="metrics", override="staging") - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.staging.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.staging.api.requests") - cfg.CONF.set_override(name='prefix', group='metrics', override='prod') + cfg.CONF.set_override(name="prefix", group="metrics", override="prod") - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.prod.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.prod.api.requests") - @patch('st2common.metrics.drivers.statsd_driver.LOG') - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.LOG") + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_driver_socket_exceptions_are_not_fatal(self, statsd, mock_log): # Socket errors such as DNS resolution errors should be considered non fatal and ignored mock_logger = mock.Mock() StatsdDriver.logger = mock_logger # 1. timer - mock_timer = MagicMock(side_effect=socket.error('error 1')) - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10) + mock_timer = MagicMock(side_effect=socket.error("error 1")) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10) self._driver.time(*params) - statsd.Timer('').send.assert_called_with('st2.test', 10) + statsd.Timer("").send.assert_called_with("st2.test", 10) # 2. counter - key = 'test' - mock_counter = MagicMock(side_effect=socket.error('error 2')) + key = "test" + mock_counter = MagicMock(side_effect=socket.error("error 2")) statsd.Counter(key).increment.side_effect = mock_counter self._driver.inc_counter(key) mock_counter.assert_called_once_with(delta=1) - key = 'test' - mock_counter = MagicMock(side_effect=socket.error('error 3')) + key = "test" + mock_counter = MagicMock(side_effect=socket.error("error 3")) statsd.Counter(key).decrement.side_effect = mock_counter self._driver.dec_counter(key) mock_counter.assert_called_once_with(delta=1) # 3. gauge - params = ('key', 100) - mock_gauge = MagicMock(side_effect=socket.error('error 4')) + params = ("key", 100) + mock_gauge = MagicMock(side_effect=socket.error("error 4")) statsd.Gauge().send.side_effect = mock_gauge self._driver.set_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - params = ('key1',) - mock_gauge = MagicMock(side_effect=socket.error('error 5')) + params = ("key1",) + mock_gauge = MagicMock(side_effect=socket.error("error 5")) statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key1',) - mock_gauge = MagicMock(side_effect=socket.error('error 6')) + params = ("key1",) + mock_gauge = MagicMock(side_effect=socket.error("error 6")) statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) mock_gauge.assert_called_once_with(None, 1) class TestCounterContextManager(unittest2.TestCase): - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.METRICS") def test_counter(self, metrics_patch): test_key = "test_key" with base.Counter(test_key): @@ -261,8 +259,8 @@ def test_counter(self, metrics_patch): class TestTimerContextManager(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -272,7 +270,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" with base.Timer(test_key) as timer: @@ -280,23 +278,19 @@ def test_time(self, metrics_patch, datetime_patch): metrics_patch.time.assert_not_called() timer.send_time() metrics_patch.time.assert_called_with( - test_key, - (end_time - middle_time).total_seconds() + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_toes" timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) metrics_patch.time.assert_called_with( - test_key, - (end_time - start_time).total_seconds() + test_key, (end_time - start_time).total_seconds() ) @@ -306,46 +300,44 @@ def setUp(self): self.middle_time = self.start_time + timedelta(seconds=1) self.end_time = self.middle_time + timedelta(seconds=1) - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): datetime_patch.side_effect = [ self.start_time, self.middle_time, self.middle_time, self.middle_time, - self.end_time + self.end_time, ] test_key = "test_key" with base.CounterWithTimer(test_key) as timer: self.assertIsInstance(timer._start_time, datetime) metrics_patch.time.assert_not_called() timer.send_time() - metrics_patch.time.assert_called_with(test_key, - (self.end_time - self.middle_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (self.end_time - self.middle_time).total_seconds() ) second_test_key = "lakshmi_has_a_nose" timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (self.end_time - self.middle_time).total_seconds() + second_test_key, (self.end_time - self.middle_time).total_seconds() ) time_delta = timer.get_time_delta() self.assertEqual( time_delta.total_seconds(), - (self.end_time - self.middle_time).total_seconds() + (self.end_time - self.middle_time).total_seconds(), ) metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() metrics_patch.time.assert_called_with( - test_key, - (self.end_time - self.start_time).total_seconds() + test_key, (self.end_time - self.start_time).total_seconds() ) class TestCounterWithTimerDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -355,7 +347,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" @@ -364,32 +356,30 @@ def _get_tested(metrics_counter_with_timer=None): self.assertIsInstance(metrics_counter_with_timer._start_time, datetime) metrics_patch.time.assert_not_called() metrics_counter_with_timer.send_time() - metrics_patch.time.assert_called_with(test_key, - (end_time - middle_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_a_nose" metrics_counter_with_timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = metrics_counter_with_timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() _get_tested() - metrics_patch.time.assert_called_with(test_key, - (end_time - start_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (end_time - start_time).total_seconds() ) class TestCounterDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.METRICS") def test_counter(self, metrics_patch): test_key = "test_key" @@ -397,12 +387,13 @@ def test_counter(self, metrics_patch): def _get_tested(): metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() + _get_tested() class TestTimerDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -412,7 +403,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" @@ -422,22 +413,19 @@ def _get_tested(metrics_timer=None): metrics_patch.time.assert_not_called() metrics_timer.send_time() metrics_patch.time.assert_called_with( - test_key, - (end_time - middle_time).total_seconds() + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_toes" metrics_timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = metrics_timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) + _get_tested() metrics_patch.time.assert_called_with( - test_key, - (end_time - start_time).total_seconds() + test_key, (end_time - start_time).total_seconds() ) diff --git a/st2common/tests/unit/test_misc_utils.py b/st2common/tests/unit/test_misc_utils.py index d7008e921c..f05615573b 100644 --- a/st2common/tests/unit/test_misc_utils.py +++ b/st2common/tests/unit/test_misc_utils.py @@ -24,71 +24,61 @@ from st2common.util.misc import sanitize_output from st2common.util.ujson import fast_deepcopy -__all__ = [ - 'MiscUtilTestCase' -] +__all__ = ["MiscUtilTestCase"] class MiscUtilTestCase(unittest2.TestCase): def test_rstrip_last_char(self): - self.assertEqual(rstrip_last_char(None, '\n'), None) - self.assertEqual(rstrip_last_char('stuff', None), 'stuff') - self.assertEqual(rstrip_last_char('', '\n'), '') - self.assertEqual(rstrip_last_char('foo', '\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\n', '\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\n\n', '\n'), 'foo\n') - self.assertEqual(rstrip_last_char('foo\r', '\r'), 'foo') - self.assertEqual(rstrip_last_char('foo\r\r', '\r'), 'foo\r') - self.assertEqual(rstrip_last_char('foo\r\n', '\r\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\r\r\n', '\r\n'), 'foo\r') - self.assertEqual(rstrip_last_char('foo\n\r', '\r\n'), 'foo\n\r') + self.assertEqual(rstrip_last_char(None, "\n"), None) + self.assertEqual(rstrip_last_char("stuff", None), "stuff") + self.assertEqual(rstrip_last_char("", "\n"), "") + self.assertEqual(rstrip_last_char("foo", "\n"), "foo") + self.assertEqual(rstrip_last_char("foo\n", "\n"), "foo") + self.assertEqual(rstrip_last_char("foo\n\n", "\n"), "foo\n") + self.assertEqual(rstrip_last_char("foo\r", "\r"), "foo") + self.assertEqual(rstrip_last_char("foo\r\r", "\r"), "foo\r") + self.assertEqual(rstrip_last_char("foo\r\n", "\r\n"), "foo") + self.assertEqual(rstrip_last_char("foo\r\r\n", "\r\n"), "foo\r") + self.assertEqual(rstrip_last_char("foo\n\r", "\r\n"), "foo\n\r") def test_strip_shell_chars(self): self.assertEqual(strip_shell_chars(None), None) - self.assertEqual(strip_shell_chars('foo'), 'foo') - self.assertEqual(strip_shell_chars('foo\r'), 'foo') - self.assertEqual(strip_shell_chars('fo\ro\r'), 'fo\ro') - self.assertEqual(strip_shell_chars('foo\n'), 'foo') - self.assertEqual(strip_shell_chars('fo\no\n'), 'fo\no') - self.assertEqual(strip_shell_chars('foo\r\n'), 'foo') - self.assertEqual(strip_shell_chars('fo\no\r\n'), 'fo\no') - self.assertEqual(strip_shell_chars('foo\r\n\r\n'), 'foo\r\n') + self.assertEqual(strip_shell_chars("foo"), "foo") + self.assertEqual(strip_shell_chars("foo\r"), "foo") + self.assertEqual(strip_shell_chars("fo\ro\r"), "fo\ro") + self.assertEqual(strip_shell_chars("foo\n"), "foo") + self.assertEqual(strip_shell_chars("fo\no\n"), "fo\no") + self.assertEqual(strip_shell_chars("foo\r\n"), "foo") + self.assertEqual(strip_shell_chars("fo\no\r\n"), "fo\no") + self.assertEqual(strip_shell_chars("foo\r\n\r\n"), "foo\r\n") def test_lowercase_value(self): - value = 'TEST' - expected_value = 'test' + value = "TEST" + expected_value = "test" self.assertEqual(expected_value, lowercase_value(value=value)) - value = ['testA', 'TESTb', 'TESTC'] - expected_value = ['testa', 'testb', 'testc'] + value = ["testA", "TESTb", "TESTC"] + expected_value = ["testa", "testb", "testc"] self.assertEqual(expected_value, lowercase_value(value=value)) - value = { - 'testA': 'testB', - 'testC': 'TESTD', - 'TESTE': 'TESTE' - } - expected_value = { - 'testa': 'testb', - 'testc': 'testd', - 'teste': 'teste' - } + value = {"testA": "testB", "testC": "TESTD", "TESTE": "TESTE"} + expected_value = {"testa": "testb", "testc": "testd", "teste": "teste"} self.assertEqual(expected_value, lowercase_value(value=value)) def test_fast_deepcopy_success(self): values = [ - 'a', - u'٩(̾●̮̮̃̾•̃̾)۶', + "a", + "٩(̾●̮̮̃̾•̃̾)۶", 1, - [1, 2, '3', 'b'], - {'a': 1, 'b': '3333', 'c': 'd'}, + [1, 2, "3", "b"], + {"a": 1, "b": "3333", "c": "d"}, ] expected_values = [ - 'a', - u'٩(̾●̮̮̃̾•̃̾)۶', + "a", + "٩(̾●̮̮̃̾•̃̾)۶", 1, - [1, 2, '3', 'b'], - {'a': 1, 'b': '3333', 'c': 'd'}, + [1, 2, "3", "b"], + {"a": 1, "b": "3333", "c": "d"}, ] for value, expected_value in zip(values, expected_values): @@ -99,18 +89,18 @@ def test_fast_deepcopy_success(self): def test_sanitize_output_use_pyt_false(self): # pty is not used, \r\n shouldn't be replaced with \n input_strs = [ - 'foo', - 'foo\n', - 'foo\r\n', - 'foo\nbar\nbaz\n', - 'foo\r\nbar\r\nbaz\r\n', + "foo", + "foo\n", + "foo\r\n", + "foo\nbar\nbaz\n", + "foo\r\nbar\r\nbaz\r\n", ] expected = [ - 'foo', - 'foo', - 'foo', - 'foo\nbar\nbaz', - 'foo\r\nbar\r\nbaz', + "foo", + "foo", + "foo", + "foo\nbar\nbaz", + "foo\r\nbar\r\nbaz", ] for input_str, expected_output in zip(input_strs, expected): @@ -120,18 +110,18 @@ def test_sanitize_output_use_pyt_false(self): def test_sanitize_output_use_pyt_true(self): # pty is used, \r\n should be replaced with \n input_strs = [ - 'foo', - 'foo\n', - 'foo\r\n', - 'foo\nbar\nbaz\n', - 'foo\r\nbar\r\nbaz\r\n', + "foo", + "foo\n", + "foo\r\n", + "foo\nbar\nbaz\n", + "foo\r\nbar\r\nbaz\r\n", ] expected = [ - 'foo', - 'foo', - 'foo', - 'foo\nbar\nbaz', - 'foo\nbar\nbaz', + "foo", + "foo", + "foo", + "foo\nbar\nbaz", + "foo\nbar\nbaz", ] for input_str, expected_output in zip(input_strs, expected): diff --git a/st2common/tests/unit/test_model_utils_profiling.py b/st2common/tests/unit/test_model_utils_profiling.py index 2225e39a7e..db37039c80 100644 --- a/st2common/tests/unit/test_model_utils_profiling.py +++ b/st2common/tests/unit/test_model_utils_profiling.py @@ -28,31 +28,37 @@ def setUp(self): super(MongoDBProfilingTestCase, self).setUp() disable_profiling() - @mock.patch('st2common.models.utils.profiling.LOG') + @mock.patch("st2common.models.utils.profiling.LOG") def test_logging_profiling_is_disabled(self, mock_log): disable_profiling() - queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1) + queryset = User.query( + name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1 + ) result = log_query_and_profile_data_for_queryset(queryset=queryset) self.assertEqual(queryset, result) call_args_list = mock_log.debug.call_args_list self.assertItemsEqual(call_args_list, []) - @mock.patch('st2common.models.utils.profiling.LOG') + @mock.patch("st2common.models.utils.profiling.LOG") def test_logging_profiling_is_enabled(self, mock_log): enable_profiling() - queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1) + queryset = User.query( + name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1 + ) result = log_query_and_profile_data_for_queryset(queryset=queryset) call_args_list = mock_log.debug.call_args_list call_args = call_args_list[0][0] call_kwargs = call_args_list[0][1] - expected_result = ("db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})" - ".sort({aa: 1, bb: -1}).limit(1);") + expected_result = ( + "db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})" + ".sort({aa: 1, bb: -1}).limit(1);" + ) self.assertEqual(queryset, result) self.assertIn(expected_result, call_args[0]) - self.assertIn('mongo_query', call_kwargs['extra']) - self.assertIn('mongo_shell_query', call_kwargs['extra']) + self.assertIn("mongo_query", call_kwargs["extra"]) + self.assertIn("mongo_shell_query", call_kwargs["extra"]) def test_logging_profiling_is_enabled_non_queryset_object(self): enable_profiling() diff --git a/st2common/tests/unit/test_mongoescape.py b/st2common/tests/unit/test_mongoescape.py index 05e3b7962f..0ad12e2823 100644 --- a/st2common/tests/unit/test_mongoescape.py +++ b/st2common/tests/unit/test_mongoescape.py @@ -21,68 +21,70 @@ class TestMongoEscape(unittest.TestCase): def test_unnested(self): - field = {'k1.k1.k1': 'v1', 'k2$': 'v2', '$k3.': 'v3'} + field = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"} escaped = mongoescape.escape_chars(field) - self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': 'v1', - u'k2\uff04': 'v2', - u'\uff04k3\uff0e': 'v3'}, 'Escaping failed.') + self.assertEqual( + escaped, + {"k1\uff0ek1\uff0ek1": "v1", "k2\uff04": "v2", "\uff04k3\uff0e": "v3"}, + "Escaping failed.", + ) unescaped = mongoescape.unescape_chars(escaped) - self.assertEqual(unescaped, field, 'Unescaping failed.') + self.assertEqual(unescaped, field, "Unescaping failed.") def test_nested(self): - nested_field = {'nk1.nk1.nk1': 'v1', 'nk2$': 'v2', '$nk3.': 'v3'} - field = {'k1.k1.k1': nested_field, 'k2$': 'v2', '$k3.': 'v3'} + nested_field = {"nk1.nk1.nk1": "v1", "nk2$": "v2", "$nk3.": "v3"} + field = {"k1.k1.k1": nested_field, "k2$": "v2", "$k3.": "v3"} escaped = mongoescape.escape_chars(field) - self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': {u'\uff04nk3\uff0e': 'v3', - u'nk1\uff0enk1\uff0enk1': 'v1', - u'nk2\uff04': 'v2'}, - u'k2\uff04': 'v2', - u'\uff04k3\uff0e': 'v3'}, 'un-escaping failed.') + self.assertEqual( + escaped, + { + "k1\uff0ek1\uff0ek1": { + "\uff04nk3\uff0e": "v3", + "nk1\uff0enk1\uff0enk1": "v1", + "nk2\uff04": "v2", + }, + "k2\uff04": "v2", + "\uff04k3\uff0e": "v3", + }, + "un-escaping failed.", + ) unescaped = mongoescape.unescape_chars(escaped) - self.assertEqual(unescaped, field, 'Unescaping failed.') + self.assertEqual(unescaped, field, "Unescaping failed.") def test_unescaping_of_rule_criteria(self): # Verify that dot escaped in rule criteria is correctly escaped. # Note: In the past we used different character to escape dot in the # rule criteria. - escaped = { - u'k1\u2024k1\u2024k1': 'v1', - u'k2$': 'v2', - u'$k3\u2024': 'v3' - } - unescaped = { - 'k1.k1.k1': 'v1', - 'k2$': 'v2', - '$k3.': 'v3' - } + escaped = {"k1\u2024k1\u2024k1": "v1", "k2$": "v2", "$k3\u2024": "v3"} + unescaped = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"} result = mongoescape.unescape_chars(escaped) self.assertEqual(result, unescaped) def test_original_value(self): - field = {'k1.k2.k3': 'v1'} + field = {"k1.k2.k3": "v1"} escaped = mongoescape.escape_chars(field) - self.assertIn('k1.k2.k3', list(field.keys())) - self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys())) + self.assertIn("k1.k2.k3", list(field.keys())) + self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys())) unescaped = mongoescape.unescape_chars(escaped) - self.assertIn('k1.k2.k3', list(unescaped.keys())) - self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys())) + self.assertIn("k1.k2.k3", list(unescaped.keys())) + self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys())) def test_complex(self): field = { - 'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}], - 'k3': [{'l5.l6': '789'}], - 'k4.k5': [1, 2, 3], - 'k6': ['a', 'b'] + "k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}], + "k3": [{"l5.l6": "789"}], + "k4.k5": [1, 2, 3], + "k6": ["a", "b"], } expected = { - u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}], - 'k3': [{u'l5\uff0el6': '789'}], - u'k4\uff0ek5': [1, 2, 3], - 'k6': ['a', 'b'] + "k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}], + "k3": [{"l5\uff0el6": "789"}], + "k4\uff0ek5": [1, 2, 3], + "k6": ["a", "b"], } escaped = mongoescape.escape_chars(field) @@ -93,17 +95,17 @@ def test_complex(self): def test_complex_list(self): field = [ - {'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}]}, - {'k3': [{'l5.l6': '789'}]}, - {'k4.k5': [1, 2, 3]}, - {'k6': ['a', 'b']} + {"k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}]}, + {"k3": [{"l5.l6": "789"}]}, + {"k4.k5": [1, 2, 3]}, + {"k6": ["a", "b"]}, ] expected = [ - {u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}]}, - {'k3': [{u'l5\uff0el6': '789'}]}, - {u'k4\uff0ek5': [1, 2, 3]}, - {'k6': ['a', 'b']} + {"k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}]}, + {"k3": [{"l5\uff0el6": "789"}]}, + {"k4\uff0ek5": [1, 2, 3]}, + {"k6": ["a", "b"]}, ] escaped = mongoescape.escape_chars(field) diff --git a/st2common/tests/unit/test_notification_helper.py b/st2common/tests/unit/test_notification_helper.py index d169dd5a5f..9c00ea4771 100644 --- a/st2common/tests/unit/test_notification_helper.py +++ b/st2common/tests/unit/test_notification_helper.py @@ -20,7 +20,6 @@ class NotificationsHelperTestCase(unittest2.TestCase): - def test_model_transformations(self): notify = {} @@ -31,42 +30,56 @@ def test_model_transformations(self): notify_api = NotificationsHelper.from_model(notify_model) self.assertEqual(notify_api, {}) - notify['on-complete'] = { - 'message': 'Action completed.', - 'routes': [ - '66' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - 'baz': [1, 2, 3] - } + notify["on-complete"] = { + "message": "Action completed.", + "routes": ["66"], + "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]}, } - notify['on-success'] = { - 'message': 'Action succeeded.', - 'routes': [ - '100' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - } + notify["on-success"] = { + "message": "Action succeeded.", + "routes": ["100"], + "data": { + "foo": "{{foo}}", + "bar": 1, + }, } notify_model = NotificationsHelper.to_model(notify) - self.assertEqual(notify['on-complete']['message'], notify_model.on_complete.message) - self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data) - self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes) - self.assertEqual(notify['on-success']['message'], notify_model.on_success.message) - self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data) - self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes) + self.assertEqual( + notify["on-complete"]["message"], notify_model.on_complete.message + ) + self.assertDictEqual( + notify["on-complete"]["data"], notify_model.on_complete.data + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_model.on_complete.routes + ) + self.assertEqual( + notify["on-success"]["message"], notify_model.on_success.message + ) + self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data) + self.assertListEqual( + notify["on-success"]["routes"], notify_model.on_success.routes + ) notify_api = NotificationsHelper.from_model(notify_model) - self.assertEqual(notify['on-complete']['message'], notify_api['on-complete']['message']) - self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data']) - self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes']) - self.assertEqual(notify['on-success']['message'], notify_api['on-success']['message']) - self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data']) - self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes']) + self.assertEqual( + notify["on-complete"]["message"], notify_api["on-complete"]["message"] + ) + self.assertDictEqual( + notify["on-complete"]["data"], notify_api["on-complete"]["data"] + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_api["on-complete"]["routes"] + ) + self.assertEqual( + notify["on-success"]["message"], notify_api["on-success"]["message"] + ) + self.assertDictEqual( + notify["on-success"]["data"], notify_api["on-success"]["data"] + ) + self.assertListEqual( + notify["on-success"]["routes"], notify_api["on-success"]["routes"] + ) def test_model_transformations_missing_fields(self): notify = {} @@ -78,33 +91,39 @@ def test_model_transformations_missing_fields(self): notify_api = NotificationsHelper.from_model(notify_model) self.assertEqual(notify_api, {}) - notify['on-complete'] = { - 'routes': [ - '66' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - 'baz': [1, 2, 3] - } + notify["on-complete"] = { + "routes": ["66"], + "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]}, } - notify['on-success'] = { - 'routes': [ - '100' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - } + notify["on-success"] = { + "routes": ["100"], + "data": { + "foo": "{{foo}}", + "bar": 1, + }, } notify_model = NotificationsHelper.to_model(notify) - self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data) - self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes) - self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data) - self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes) + self.assertDictEqual( + notify["on-complete"]["data"], notify_model.on_complete.data + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_model.on_complete.routes + ) + self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data) + self.assertListEqual( + notify["on-success"]["routes"], notify_model.on_success.routes + ) notify_api = NotificationsHelper.from_model(notify_model) - self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data']) - self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes']) - self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data']) - self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes']) + self.assertDictEqual( + notify["on-complete"]["data"], notify_api["on-complete"]["data"] + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_api["on-complete"]["routes"] + ) + self.assertDictEqual( + notify["on-success"]["data"], notify_api["on-success"]["data"] + ) + self.assertListEqual( + notify["on-success"]["routes"], notify_api["on-success"]["routes"] + ) diff --git a/st2common/tests/unit/test_operators.py b/st2common/tests/unit/test_operators.py index 48f693af30..5917e4277c 100644 --- a/st2common/tests/unit/test_operators.py +++ b/st2common/tests/unit/test_operators.py @@ -44,6 +44,7 @@ class ListOfDictsStrictEqualTest(unittest2.TestCase): We should test our comparison functions, even if they're only used in our other tests. """ + def test_empty_lists(self): self.assertTrue(list_of_dicts_strict_equal([], [])) @@ -54,65 +55,105 @@ def test_multiple_empty_dicts(self): self.assertTrue(list_of_dicts_strict_equal([{}, {}], [{}, {}])) def test_simple_dicts(self): - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 1}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 2}, - ])) + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 1}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 2}, + ], + ) + ) def test_lists_of_different_lengths(self): - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - ])) + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + ], + ) + ) def test_less_simple_dicts(self): - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - {'a': 1}, - ], [ - {'a': 1}, - {'a': 1}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'a': 1}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - {'a': 1}, - ])) + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"a": 1}, + ], + [ + {"a": 1}, + {"a": 1}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"a": 1}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + {"a": 1}, + ], + ) + ) class SearchOperatorTest(unittest2.TestCase): @@ -120,774 +161,850 @@ class SearchOperatorTest(unittest2.TestCase): # parser. As such, its tests are much more complex than other commands, so we # pull its tests out into their own test case. def test_search_with_weird_condition(self): - op = operators.get_operator('search') + op = operators.get_operator("search") with self.assertRaises(operators.UnrecognizedConditionError): - op([], [], 'weird', None) + op([], [], "weird", None) def test_search_any_true(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) - return (len(called_function_args) < 3) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) + return len(called_function_args) < 3 payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'any', record_function_args) + result = op(payload, criteria_pattern, "any", record_function_args) self.assertTrue(result) - self.assertTrue(list_of_dicts_strict_equal(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ])) + self.assertTrue( + list_of_dicts_strict_equal( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) + ) def test_search_any_false(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return (len(called_function_args) % 2) == 0 payload = [ { - 'field_name': "Status", - 'to_value': "Denied", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Denied", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'any', record_function_args) + result = op(payload, criteria_pattern, "any", record_function_args) self.assertFalse(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Denied", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Denied", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Denied", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Denied", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) def test_search_all_false(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return (len(called_function_args) % 2) == 0 payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'all', record_function_args) + result = op(payload, criteria_pattern, "all", record_function_args) self.assertFalse(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) def test_search_all_true(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return True payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Signed off by", - 'to_value': "Approved", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Signed off by", + "to_value": "Approved", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "startswith", - 'pattern': "S", + "item.field_name": { + "type": "startswith", + "pattern": "S", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'all', record_function_args) + result = op(payload, criteria_pattern, "all", record_function_args) self.assertTrue(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "startswith", - 'pattern': "S", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "startswith", - 'pattern': "S", - }, - 'payload_lookup': { - 'field_name': "Signed off by", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Signed off by", - 'to_value': "Approved", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "startswith", + "pattern": "S", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "startswith", + "pattern": "S", + }, + "payload_lookup": { + "field_name": "Signed off by", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Signed off by", + "to_value": "Approved", + }, + }, + ], + ) class OperatorTest(unittest2.TestCase): def test_matchwildcard(self): - op = operators.get_operator('matchwildcard') - self.assertTrue(op('v1', 'v1'), 'Failed matchwildcard.') + op = operators.get_operator("matchwildcard") + self.assertTrue(op("v1", "v1"), "Failed matchwildcard.") - self.assertFalse(op('test foo test', 'foo'), 'Passed matchwildcard.') - self.assertTrue(op('test foo test', '*foo*'), 'Failed matchwildcard.') - self.assertTrue(op('bar', 'b*r'), 'Failed matchwildcard.') - self.assertTrue(op('bar', 'b?r'), 'Failed matchwildcard.') + self.assertFalse(op("test foo test", "foo"), "Passed matchwildcard.") + self.assertTrue(op("test foo test", "*foo*"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b*r"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'bar', 'b?r'), 'Failed matchwildcard.') - self.assertTrue(op('bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(b'bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(u'bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(u'bar', u'b?r'), 'Failed matchwildcard.') + self.assertTrue(op(b"bar", "b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op(b"bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.") - self.assertFalse(op('1', None), 'Passed matchwildcard with None as criteria_pattern.') + self.assertFalse( + op("1", None), "Passed matchwildcard with None as criteria_pattern." + ) def test_matchregex(self): - op = operators.get_operator('matchregex') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') + op = operators.get_operator("matchregex") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") # Multi line string, make sure re.DOTALL is used - string = '''ponies + string = """ponies moar foo bar yeah! - ''' - self.assertTrue(op(string, '.*bar.*'), 'Failed matchregex.') + """ + self.assertTrue(op(string, ".*bar.*"), "Failed matchregex.") - string = 'foo\r\nponies\nbar\nfooooo' - self.assertTrue(op(string, '.*ponies.*'), 'Failed matchregex.') + string = "foo\r\nponies\nbar\nfooooo" + self.assertTrue(op(string, ".*ponies.*"), "Failed matchregex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'foo ponies bar', '.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op('foo ponies bar', b'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(b'foo ponies bar', b'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(b'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(u'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.') + self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.") + self.assertTrue(op("foo ponies bar", b".*ponies.*"), "Failed matchregex.") + self.assertTrue(op(b"foo ponies bar", b".*ponies.*"), "Failed matchregex.") + self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.") + self.assertTrue(op("foo ponies bar", ".*ponies.*"), "Failed matchregex.") def test_iregex(self): - op = operators.get_operator('iregex') - self.assertTrue(op('V1', 'v1$'), 'Failed iregex.') + op = operators.get_operator("iregex") + self.assertTrue(op("V1", "v1$"), "Failed iregex.") - string = 'fooPONIESbarfooooo' - self.assertTrue(op(string, 'ponies'), 'Failed iregex.') + string = "fooPONIESbarfooooo" + self.assertTrue(op(string, "ponies"), "Failed iregex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'fooPONIESbarfooooo', 'ponies'), 'Failed iregex.') - self.assertTrue(op('fooPONIESbarfooooo', b'ponies'), 'Failed iregex.') - self.assertTrue(op(b'fooPONIESbarfooooo', b'ponies'), 'Failed iregex.') - self.assertTrue(op(b'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.') - self.assertTrue(op(u'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.') + self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.") + self.assertTrue(op("fooPONIESbarfooooo", b"ponies"), "Failed iregex.") + self.assertTrue(op(b"fooPONIESbarfooooo", b"ponies"), "Failed iregex.") + self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.") + self.assertTrue(op("fooPONIESbarfooooo", "ponies"), "Failed iregex.") def test_iregex_fail(self): - op = operators.get_operator('iregex') - self.assertFalse(op('V1_foo', 'v1$'), 'Passed iregex.') - self.assertFalse(op('1', None), 'Passed iregex with None as criteria_pattern.') + op = operators.get_operator("iregex") + self.assertFalse(op("V1_foo", "v1$"), "Passed iregex.") + self.assertFalse(op("1", None), "Passed iregex with None as criteria_pattern.") def test_regex(self): - op = operators.get_operator('regex') - self.assertTrue(op('v1', 'v1$'), 'Failed regex.') + op = operators.get_operator("regex") + self.assertTrue(op("v1", "v1$"), "Failed regex.") - string = 'fooponiesbarfooooo' - self.assertTrue(op(string, 'ponies'), 'Failed regex.') + string = "fooponiesbarfooooo" + self.assertTrue(op(string, "ponies"), "Failed regex.") # Example with | modifier - string = 'apple ponies oranges' - self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.') + string = "apple ponies oranges" + self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.") - string = 'apple unicorns oranges' - self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.') + string = "apple unicorns oranges" + self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'apples unicorns oranges', '(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op('apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(b'apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(b'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(u'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.') - - string = 'apple unicorns oranges' - self.assertFalse(op(string, '(pikachu|snorlax|charmander)'), 'Passed regex.') + self.assertTrue( + op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op("apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op(b"apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op("apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + + string = "apple unicorns oranges" + self.assertFalse(op(string, "(pikachu|snorlax|charmander)"), "Passed regex.") def test_regex_fail(self): - op = operators.get_operator('regex') - self.assertFalse(op('v1_foo', 'v1$'), 'Passed regex.') + op = operators.get_operator("regex") + self.assertFalse(op("v1_foo", "v1$"), "Passed regex.") - string = 'fooPONIESbarfooooo' - self.assertFalse(op(string, 'ponies'), 'Passed regex.') + string = "fooPONIESbarfooooo" + self.assertFalse(op(string, "ponies"), "Passed regex.") - self.assertFalse(op('1', None), 'Passed regex with None as criteria_pattern.') + self.assertFalse(op("1", None), "Passed regex with None as criteria_pattern.") def test_matchregex_case_variants(self): - op = operators.get_operator('MATCHREGEX') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') - op = operators.get_operator('MATCHregex') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') + op = operators.get_operator("MATCHREGEX") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") + op = operators.get_operator("MATCHregex") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") def test_matchregex_fail(self): - op = operators.get_operator('matchregex') - self.assertFalse(op('v1_foo', 'v1$'), 'Passed matchregex.') - self.assertFalse(op('1', None), 'Passed matchregex with None as criteria_pattern.') + op = operators.get_operator("matchregex") + self.assertFalse(op("v1_foo", "v1$"), "Passed matchregex.") + self.assertFalse( + op("1", None), "Passed matchregex with None as criteria_pattern." + ) def test_equals_numeric(self): - op = operators.get_operator('equals') - self.assertTrue(op(1, 1), 'Failed equals.') + op = operators.get_operator("equals") + self.assertTrue(op(1, 1), "Failed equals.") def test_equals_string(self): - op = operators.get_operator('equals') - self.assertTrue(op('1', '1'), 'Failed equals.') - self.assertTrue(op('', ''), 'Failed equals.') + op = operators.get_operator("equals") + self.assertTrue(op("1", "1"), "Failed equals.") + self.assertTrue(op("", ""), "Failed equals.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'1', '1'), 'Failed equals.') - self.assertTrue(op('1', b'1'), 'Failed equals.') - self.assertTrue(op(b'1', b'1'), 'Failed equals.') - self.assertTrue(op(b'1', u'1'), 'Failed equals.') - self.assertTrue(op(u'1', u'1'), 'Failed equals.') + self.assertTrue(op(b"1", "1"), "Failed equals.") + self.assertTrue(op("1", b"1"), "Failed equals.") + self.assertTrue(op(b"1", b"1"), "Failed equals.") + self.assertTrue(op(b"1", "1"), "Failed equals.") + self.assertTrue(op("1", "1"), "Failed equals.") def test_equals_fail(self): - op = operators.get_operator('equals') - self.assertFalse(op('1', '2'), 'Passed equals.') - self.assertFalse(op('1', None), 'Passed equals with None as criteria_pattern.') + op = operators.get_operator("equals") + self.assertFalse(op("1", "2"), "Passed equals.") + self.assertFalse(op("1", None), "Passed equals with None as criteria_pattern.") def test_nequals(self): - op = operators.get_operator('nequals') - self.assertTrue(op('foo', 'bar')) - self.assertTrue(op('foo', 'foo1')) - self.assertTrue(op('foo', 'FOO')) - self.assertTrue(op('True', True)) - self.assertTrue(op('None', None)) - - self.assertFalse(op('True', 'True')) + op = operators.get_operator("nequals") + self.assertTrue(op("foo", "bar")) + self.assertTrue(op("foo", "foo1")) + self.assertTrue(op("foo", "FOO")) + self.assertTrue(op("True", True)) + self.assertTrue(op("None", None)) + + self.assertFalse(op("True", "True")) self.assertFalse(op(None, None)) def test_iequals(self): - op = operators.get_operator('iequals') - self.assertTrue(op('ABC', 'ABC'), 'Failed iequals.') - self.assertTrue(op('ABC', 'abc'), 'Failed iequals.') - self.assertTrue(op('AbC', 'aBc'), 'Failed iequals.') + op = operators.get_operator("iequals") + self.assertTrue(op("ABC", "ABC"), "Failed iequals.") + self.assertTrue(op("ABC", "abc"), "Failed iequals.") + self.assertTrue(op("AbC", "aBc"), "Failed iequals.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'AbC', 'aBc'), 'Failed iequals.') - self.assertTrue(op('AbC', b'aBc'), 'Failed iequals.') - self.assertTrue(op(b'AbC', b'aBc'), 'Failed iequals.') - self.assertTrue(op(b'AbC', u'aBc'), 'Failed iequals.') - self.assertTrue(op(u'AbC', u'aBc'), 'Failed iequals.') + self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.") + self.assertTrue(op("AbC", b"aBc"), "Failed iequals.") + self.assertTrue(op(b"AbC", b"aBc"), "Failed iequals.") + self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.") + self.assertTrue(op("AbC", "aBc"), "Failed iequals.") def test_iequals_fail(self): - op = operators.get_operator('iequals') - self.assertFalse(op('ABC', 'BCA'), 'Passed iequals.') - self.assertFalse(op('1', None), 'Passed iequals with None as criteria_pattern.') + op = operators.get_operator("iequals") + self.assertFalse(op("ABC", "BCA"), "Passed iequals.") + self.assertFalse(op("1", None), "Passed iequals with None as criteria_pattern.") def test_contains(self): - op = operators.get_operator('contains') - self.assertTrue(op('hasystack needle haystack', 'needle')) - self.assertTrue(op('needle', 'needle')) - self.assertTrue(op('needlehaystack', 'needle')) - self.assertTrue(op('needle haystack', 'needle')) - self.assertTrue(op('haystackneedle', 'needle')) - self.assertTrue(op('haystack needle', 'needle')) + op = operators.get_operator("contains") + self.assertTrue(op("hasystack needle haystack", "needle")) + self.assertTrue(op("needle", "needle")) + self.assertTrue(op("needlehaystack", "needle")) + self.assertTrue(op("needle haystack", "needle")) + self.assertTrue(op("haystackneedle", "needle")) + self.assertTrue(op("haystack needle", "needle")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'needle')) - self.assertTrue(op('haystack needle', b'needle')) - self.assertTrue(op(b'haystack needle', b'needle')) - self.assertTrue(op(b'haystack needle', u'needle')) - self.assertTrue(op(u'haystack needle', b'needle')) + self.assertTrue(op(b"haystack needle", "needle")) + self.assertTrue(op("haystack needle", b"needle")) + self.assertTrue(op(b"haystack needle", b"needle")) + self.assertTrue(op(b"haystack needle", "needle")) + self.assertTrue(op("haystack needle", b"needle")) def test_contains_fail(self): - op = operators.get_operator('contains') - self.assertFalse(op('hasystack needl haystack', 'needle')) - self.assertFalse(op('needla', 'needle')) - self.assertFalse(op('1', None), 'Passed contains with None as criteria_pattern.') + op = operators.get_operator("contains") + self.assertFalse(op("hasystack needl haystack", "needle")) + self.assertFalse(op("needla", "needle")) + self.assertFalse( + op("1", None), "Passed contains with None as criteria_pattern." + ) def test_icontains(self): - op = operators.get_operator('icontains') - self.assertTrue(op('hasystack nEEdle haystack', 'needle')) - self.assertTrue(op('neeDle', 'NeedlE')) - self.assertTrue(op('needlehaystack', 'needle')) - self.assertTrue(op('NEEDLE haystack', 'NEEDLE')) - self.assertTrue(op('haystackNEEDLE', 'needle')) - self.assertTrue(op('haystack needle', 'NEEDLE')) + op = operators.get_operator("icontains") + self.assertTrue(op("hasystack nEEdle haystack", "needle")) + self.assertTrue(op("neeDle", "NeedlE")) + self.assertTrue(op("needlehaystack", "needle")) + self.assertTrue(op("NEEDLE haystack", "NEEDLE")) + self.assertTrue(op("haystackNEEDLE", "needle")) + self.assertTrue(op("haystack needle", "NEEDLE")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'NEEDLE')) - self.assertTrue(op('haystack needle', b'NEEDLE')) - self.assertTrue(op(b'haystack needle', b'NEEDLE')) - self.assertTrue(op(b'haystack needle', u'NEEDLE')) - self.assertTrue(op(u'haystack needle', b'NEEDLE')) + self.assertTrue(op(b"haystack needle", "NEEDLE")) + self.assertTrue(op("haystack needle", b"NEEDLE")) + self.assertTrue(op(b"haystack needle", b"NEEDLE")) + self.assertTrue(op(b"haystack needle", "NEEDLE")) + self.assertTrue(op("haystack needle", b"NEEDLE")) def test_icontains_fail(self): - op = operators.get_operator('icontains') - self.assertFalse(op('hasystack needl haystack', 'needle')) - self.assertFalse(op('needla', 'needle')) - self.assertFalse(op('1', None), 'Passed icontains with None as criteria_pattern.') + op = operators.get_operator("icontains") + self.assertFalse(op("hasystack needl haystack", "needle")) + self.assertFalse(op("needla", "needle")) + self.assertFalse( + op("1", None), "Passed icontains with None as criteria_pattern." + ) def test_ncontains(self): - op = operators.get_operator('ncontains') - self.assertTrue(op('hasystack needle haystack', 'foo')) - self.assertTrue(op('needle', 'foo')) - self.assertTrue(op('needlehaystack', 'needlex')) - self.assertTrue(op('needle haystack', 'needlex')) - self.assertTrue(op('haystackneedle', 'needlex')) - self.assertTrue(op('haystack needle', 'needlex')) + op = operators.get_operator("ncontains") + self.assertTrue(op("hasystack needle haystack", "foo")) + self.assertTrue(op("needle", "foo")) + self.assertTrue(op("needlehaystack", "needlex")) + self.assertTrue(op("needle haystack", "needlex")) + self.assertTrue(op("haystackneedle", "needlex")) + self.assertTrue(op("haystack needle", "needlex")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'needlex')) - self.assertTrue(op('haystack needle', b'needlex')) - self.assertTrue(op(b'haystack needle', b'needlex')) - self.assertTrue(op(b'haystack needle', u'needlex')) - self.assertTrue(op(u'haystack needle', b'needlex')) + self.assertTrue(op(b"haystack needle", "needlex")) + self.assertTrue(op("haystack needle", b"needlex")) + self.assertTrue(op(b"haystack needle", b"needlex")) + self.assertTrue(op(b"haystack needle", "needlex")) + self.assertTrue(op("haystack needle", b"needlex")) def test_ncontains_fail(self): - op = operators.get_operator('ncontains') - self.assertFalse(op('hasystack needle haystack', 'needle')) - self.assertFalse(op('needla', 'needla')) - self.assertFalse(op('1', None), 'Passed ncontains with None as criteria_pattern.') + op = operators.get_operator("ncontains") + self.assertFalse(op("hasystack needle haystack", "needle")) + self.assertFalse(op("needla", "needla")) + self.assertFalse( + op("1", None), "Passed ncontains with None as criteria_pattern." + ) def test_incontains(self): - op = operators.get_operator('incontains') - self.assertTrue(op('hasystack needle haystack', 'FOO')) - self.assertTrue(op('needle', 'FOO')) - self.assertTrue(op('needlehaystack', 'needlex')) - self.assertTrue(op('needle haystack', 'needlex')) - self.assertTrue(op('haystackneedle', 'needlex')) - self.assertTrue(op('haystack needle', 'needlex')) + op = operators.get_operator("incontains") + self.assertTrue(op("hasystack needle haystack", "FOO")) + self.assertTrue(op("needle", "FOO")) + self.assertTrue(op("needlehaystack", "needlex")) + self.assertTrue(op("needle haystack", "needlex")) + self.assertTrue(op("haystackneedle", "needlex")) + self.assertTrue(op("haystack needle", "needlex")) def test_incontains_fail(self): - op = operators.get_operator('incontains') - self.assertFalse(op('hasystack needle haystack', 'nEeDle')) - self.assertFalse(op('needlA', 'needlA')) - self.assertFalse(op('1', None), 'Passed incontains with None as criteria_pattern.') + op = operators.get_operator("incontains") + self.assertFalse(op("hasystack needle haystack", "nEeDle")) + self.assertFalse(op("needlA", "needlA")) + self.assertFalse( + op("1", None), "Passed incontains with None as criteria_pattern." + ) def test_startswith(self): - op = operators.get_operator('startswith') - self.assertTrue(op('hasystack needle haystack', 'hasystack')) - self.assertTrue(op('a hasystack needle haystack', 'a ')) + op = operators.get_operator("startswith") + self.assertTrue(op("hasystack needle haystack", "hasystack")) + self.assertTrue(op("a hasystack needle haystack", "a ")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'haystack')) - self.assertTrue(op('haystack needle', b'haystack')) - self.assertTrue(op(b'haystack needle', b'haystack')) - self.assertTrue(op(b'haystack needle', u'haystack')) - self.assertTrue(op(u'haystack needle', b'haystack')) + self.assertTrue(op(b"haystack needle", "haystack")) + self.assertTrue(op("haystack needle", b"haystack")) + self.assertTrue(op(b"haystack needle", b"haystack")) + self.assertTrue(op(b"haystack needle", "haystack")) + self.assertTrue(op("haystack needle", b"haystack")) def test_startswith_fail(self): - op = operators.get_operator('startswith') - self.assertFalse(op('hasystack needle haystack', 'needle')) - self.assertFalse(op('a hasystack needle haystack', 'haystack')) - self.assertFalse(op('1', None), 'Passed startswith with None as criteria_pattern.') + op = operators.get_operator("startswith") + self.assertFalse(op("hasystack needle haystack", "needle")) + self.assertFalse(op("a hasystack needle haystack", "haystack")) + self.assertFalse( + op("1", None), "Passed startswith with None as criteria_pattern." + ) def test_istartswith(self): - op = operators.get_operator('istartswith') - self.assertTrue(op('haystack needle haystack', 'HAYstack')) - self.assertTrue(op('HAYSTACK needle haystack', 'haystack')) + op = operators.get_operator("istartswith") + self.assertTrue(op("haystack needle haystack", "HAYstack")) + self.assertTrue(op("HAYSTACK needle haystack", "haystack")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'HAYSTACK needle haystack', 'haystack')) - self.assertTrue(op('HAYSTACK needle haystack', b'haystack')) - self.assertTrue(op(b'HAYSTACK needle haystack', b'haystack')) - self.assertTrue(op(b'HAYSTACK needle haystack', u'haystack')) - self.assertTrue(op(u'HAYSTACK needle haystack', b'haystack')) + self.assertTrue(op(b"HAYSTACK needle haystack", "haystack")) + self.assertTrue(op("HAYSTACK needle haystack", b"haystack")) + self.assertTrue(op(b"HAYSTACK needle haystack", b"haystack")) + self.assertTrue(op(b"HAYSTACK needle haystack", "haystack")) + self.assertTrue(op("HAYSTACK needle haystack", b"haystack")) def test_istartswith_fail(self): - op = operators.get_operator('istartswith') - self.assertFalse(op('hasystack needle haystack', 'NEEDLE')) - self.assertFalse(op('a hasystack needle haystack', 'haystack')) - self.assertFalse(op('1', None), 'Passed istartswith with None as criteria_pattern.') + op = operators.get_operator("istartswith") + self.assertFalse(op("hasystack needle haystack", "NEEDLE")) + self.assertFalse(op("a hasystack needle haystack", "haystack")) + self.assertFalse( + op("1", None), "Passed istartswith with None as criteria_pattern." + ) def test_endswith(self): - op = operators.get_operator('endswith') - self.assertTrue(op('hasystack needle haystackend', 'haystackend')) - self.assertTrue(op('a hasystack needle haystack b', 'b')) + op = operators.get_operator("endswith") + self.assertTrue(op("hasystack needle haystackend", "haystackend")) + self.assertTrue(op("a hasystack needle haystack b", "b")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'a hasystack needle haystack b', 'b')) - self.assertTrue(op('a hasystack needle haystack b', b'b')) - self.assertTrue(op(b'a hasystack needle haystack b', b'b')) - self.assertTrue(op(b'a hasystack needle haystack b', u'b')) - self.assertTrue(op(u'a hasystack needle haystack b', b'b')) + self.assertTrue(op(b"a hasystack needle haystack b", "b")) + self.assertTrue(op("a hasystack needle haystack b", b"b")) + self.assertTrue(op(b"a hasystack needle haystack b", b"b")) + self.assertTrue(op(b"a hasystack needle haystack b", "b")) + self.assertTrue(op("a hasystack needle haystack b", b"b")) def test_endswith_fail(self): - op = operators.get_operator('endswith') - self.assertFalse(op('hasystack needle haystackend', 'haystack')) - self.assertFalse(op('a hasystack needle haystack', 'a')) - self.assertFalse(op('1', None), 'Passed endswith with None as criteria_pattern.') + op = operators.get_operator("endswith") + self.assertFalse(op("hasystack needle haystackend", "haystack")) + self.assertFalse(op("a hasystack needle haystack", "a")) + self.assertFalse( + op("1", None), "Passed endswith with None as criteria_pattern." + ) def test_iendswith(self): - op = operators.get_operator('iendswith') - self.assertTrue(op('haystack needle haystackEND', 'HAYstackend')) - self.assertTrue(op('HAYSTACK needle haystackend', 'haystackEND')) + op = operators.get_operator("iendswith") + self.assertTrue(op("haystack needle haystackEND", "HAYstackend")) + self.assertTrue(op("HAYSTACK needle haystackend", "haystackEND")) def test_iendswith_fail(self): - op = operators.get_operator('iendswith') - self.assertFalse(op('hasystack needle haystack', 'NEEDLE')) - self.assertFalse(op('a hasystack needle haystack', 'a ')) - self.assertFalse(op('1', None), 'Passed iendswith with None as criteria_pattern.') + op = operators.get_operator("iendswith") + self.assertFalse(op("hasystack needle haystack", "NEEDLE")) + self.assertFalse(op("a hasystack needle haystack", "a ")) + self.assertFalse( + op("1", None), "Passed iendswith with None as criteria_pattern." + ) def test_lt(self): - op = operators.get_operator('lessthan') - self.assertTrue(op(1, 2), 'Failed lessthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op(1, 2), "Failed lessthan.") def test_lt_char(self): - op = operators.get_operator('lessthan') - self.assertTrue(op('a', 'b'), 'Failed lessthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op("a", "b"), "Failed lessthan.") def test_lt_fail(self): - op = operators.get_operator('lessthan') - self.assertFalse(op(1, 1), 'Passed lessthan.') - self.assertFalse(op('1', None), 'Passed lessthan with None as criteria_pattern.') + op = operators.get_operator("lessthan") + self.assertFalse(op(1, 1), "Passed lessthan.") + self.assertFalse( + op("1", None), "Passed lessthan with None as criteria_pattern." + ) def test_gt(self): - op = operators.get_operator('greaterthan') - self.assertTrue(op(2, 1), 'Failed greaterthan.') + op = operators.get_operator("greaterthan") + self.assertTrue(op(2, 1), "Failed greaterthan.") def test_gt_str(self): - op = operators.get_operator('lessthan') - self.assertTrue(op('aba', 'bcb'), 'Failed greaterthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op("aba", "bcb"), "Failed greaterthan.") def test_gt_fail(self): - op = operators.get_operator('greaterthan') - self.assertFalse(op(2, 3), 'Passed greaterthan.') - self.assertFalse(op('1', None), 'Passed greaterthan with None as criteria_pattern.') + op = operators.get_operator("greaterthan") + self.assertFalse(op(2, 3), "Passed greaterthan.") + self.assertFalse( + op("1", None), "Passed greaterthan with None as criteria_pattern." + ) def test_timediff_lt(self): - op = operators.get_operator('timediff_lt') - self.assertTrue(op(date_utils.get_datetime_utc_now().isoformat(), 10), - 'Failed test_timediff_lt.') + op = operators.get_operator("timediff_lt") + self.assertTrue( + op(date_utils.get_datetime_utc_now().isoformat(), 10), + "Failed test_timediff_lt.", + ) def test_timediff_lt_fail(self): - op = operators.get_operator('timediff_lt') - self.assertFalse(op('2014-07-01T00:01:01.000000', 10), - 'Passed test_timediff_lt.') - self.assertFalse(op('2014-07-01T00:01:01.000000', None), - 'Passed test_timediff_lt with None as criteria_pattern.') + op = operators.get_operator("timediff_lt") + self.assertFalse( + op("2014-07-01T00:01:01.000000", 10), "Passed test_timediff_lt." + ) + self.assertFalse( + op("2014-07-01T00:01:01.000000", None), + "Passed test_timediff_lt with None as criteria_pattern.", + ) def test_timediff_gt(self): - op = operators.get_operator('timediff_gt') - self.assertTrue(op('2014-07-01T00:01:01.000000', 1), - 'Failed test_timediff_gt.') + op = operators.get_operator("timediff_gt") + self.assertTrue(op("2014-07-01T00:01:01.000000", 1), "Failed test_timediff_gt.") def test_timediff_gt_fail(self): - op = operators.get_operator('timediff_gt') - self.assertFalse(op(date_utils.get_datetime_utc_now().isoformat(), 10), - 'Passed test_timediff_gt.') - self.assertFalse(op('2014-07-01T00:01:01.000000', None), - 'Passed test_timediff_gt with None as criteria_pattern.') + op = operators.get_operator("timediff_gt") + self.assertFalse( + op(date_utils.get_datetime_utc_now().isoformat(), 10), + "Passed test_timediff_gt.", + ) + self.assertFalse( + op("2014-07-01T00:01:01.000000", None), + "Passed test_timediff_gt with None as criteria_pattern.", + ) def test_exists(self): - op = operators.get_operator('exists') - self.assertTrue(op(False, None), 'Should return True') - self.assertTrue(op(1, None), 'Should return True') - self.assertTrue(op('foo', None), 'Should return True') - self.assertFalse(op(None, None), 'Should return False') + op = operators.get_operator("exists") + self.assertTrue(op(False, None), "Should return True") + self.assertTrue(op(1, None), "Should return True") + self.assertTrue(op("foo", None), "Should return True") + self.assertFalse(op(None, None), "Should return False") def test_nexists(self): - op = operators.get_operator('nexists') - self.assertFalse(op(False, None), 'Should return False') - self.assertFalse(op(1, None), 'Should return False') - self.assertFalse(op('foo', None), 'Should return False') - self.assertTrue(op(None, None), 'Should return True') + op = operators.get_operator("nexists") + self.assertFalse(op(False, None), "Should return False") + self.assertFalse(op(1, None), "Should return False") + self.assertFalse(op("foo", None), "Should return False") + self.assertTrue(op(None, None), "Should return True") def test_inside(self): - op = operators.get_operator('inside') - self.assertFalse(op('a', None), 'Should return False') - self.assertFalse(op('a', 'bcd'), 'Should return False') - self.assertTrue(op('a', 'abc'), 'Should return True') + op = operators.get_operator("inside") + self.assertFalse(op("a", None), "Should return False") + self.assertFalse(op("a", "bcd"), "Should return False") + self.assertTrue(op("a", "abc"), "Should return True") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'a', 'abc'), 'Should return True') - self.assertTrue(op('a', b'abc'), 'Should return True') - self.assertTrue(op(b'a', b'abc'), 'Should return True') + self.assertTrue(op(b"a", "abc"), "Should return True") + self.assertTrue(op("a", b"abc"), "Should return True") + self.assertTrue(op(b"a", b"abc"), "Should return True") def test_ninside(self): - op = operators.get_operator('ninside') - self.assertFalse(op('a', None), 'Should return False') - self.assertFalse(op('a', 'abc'), 'Should return False') - self.assertTrue(op('a', 'bcd'), 'Should return True') + op = operators.get_operator("ninside") + self.assertFalse(op("a", None), "Should return False") + self.assertFalse(op("a", "abc"), "Should return False") + self.assertTrue(op("a", "bcd"), "Should return True") class GetOperatorsTest(unittest2.TestCase): def test_get_operator(self): - self.assertTrue(operators.get_operator('equals')) - self.assertTrue(operators.get_operator('EQUALS')) + self.assertTrue(operators.get_operator("equals")) + self.assertTrue(operators.get_operator("EQUALS")) def test_get_operator_returns_same_operator_with_different_cases(self): - equals = operators.get_operator('equals') - EQUALS = operators.get_operator('EQUALS') - Equals = operators.get_operator('Equals') + equals = operators.get_operator("equals") + EQUALS = operators.get_operator("EQUALS") + Equals = operators.get_operator("Equals") self.assertEqual(equals, EQUALS) self.assertEqual(equals, Equals) def test_get_operator_with_nonexistent_operator(self): with self.assertRaises(Exception): - operators.get_operator('weird') + operators.get_operator("weird") def test_get_allowed_operators(self): # This test will need to change as operators are deprecated diff --git a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py index 355e680c85..630b14ed37 100644 --- a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py +++ b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py @@ -23,111 +23,117 @@ from st2common.exceptions.content import ParseException from st2common.models.db.actionalias import ActionAliasDB -__all__ = [ - 'PackActionAliasUnitTestUtils' -] +__all__ = ["PackActionAliasUnitTestUtils"] -PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/pack_dir_name_doesnt_match_ref') +PACK_PATH_1 = os.path.join( + get_fixtures_base_path(), "packs/pack_dir_name_doesnt_match_ref" +) class PackActionAliasUnitTestUtils(BaseActionAliasTestCase): - action_alias_name = 'mock' + action_alias_name = "mock" mock_get_action_alias_db_by_name = True def test_assertExtractedParametersMatch_success(self): format_string = self.action_alias_db.formats[0] - command = 'show last 3 metrics for my.host' - expected_parameters = { - 'count': '3', - 'server': 'my.host' - } - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + command = "show last 3 metrics for my.host" + expected_parameters = {"count": "3", "server": "my.host"} + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) format_string = self.action_alias_db.formats[0] - command = 'show last 10 metrics for my.host.example' - expected_parameters = { - 'count': '10', - 'server': 'my.host.example' - } - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + command = "show last 10 metrics for my.host.example" + expected_parameters = {"count": "10", "server": "my.host.example"} + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) def test_assertExtractedParametersMatch_command_doesnt_match_format_string(self): format_string = self.action_alias_db.formats[0] - command = 'show last foo' + command = "show last foo" expected_parameters = {} - expected_msg = ('Command "show last foo" doesn\'t match format string ' - '"show last {{count}} metrics for {{server}}"') - - self.assertRaisesRegexp(ParseException, expected_msg, - self.assertExtractedParametersMatch, - format_string=format_string, - command=command, - parameters=expected_parameters) + expected_msg = ( + 'Command "show last foo" doesn\'t match format string ' + '"show last {{count}} metrics for {{server}}"' + ) + + self.assertRaisesRegexp( + ParseException, + expected_msg, + self.assertExtractedParametersMatch, + format_string=format_string, + command=command, + parameters=expected_parameters, + ) def test_assertCommandMatchesExactlyOneFormatString(self): # Matches single format string - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}} baz' - ] - command = 'foo bar a test=1' - self.assertCommandMatchesExactlyOneFormatString(format_strings=format_strings, - command=command) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}} baz"] + command = "foo bar a test=1" + self.assertCommandMatchesExactlyOneFormatString( + format_strings=format_strings, command=command + ) # Matches multiple format strings - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}}' - ] - command = 'foo bar a test=1' - - expected_msg = ('Command "foo bar a test=1" matched multiple format ' - 'strings: foo bar {{bar}}, foo bar {{baz}}') - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertCommandMatchesExactlyOneFormatString, - format_strings=format_strings, - command=command) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"] + command = "foo bar a test=1" + + expected_msg = ( + 'Command "foo bar a test=1" matched multiple format ' + "strings: foo bar {{bar}}, foo bar {{baz}}" + ) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertCommandMatchesExactlyOneFormatString, + format_strings=format_strings, + command=command, + ) # Doesn't matches any format strings - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}}' - ] - command = 'does not match foo' - - expected_msg = ('Command "does not match foo" didn\'t match any of the provided format ' - 'strings') - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertCommandMatchesExactlyOneFormatString, - format_strings=format_strings, - command=command) - - @mock.patch.object(BaseActionAliasTestCase, '_get_base_pack_path', - mock.Mock(return_value=PACK_PATH_1)) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"] + command = "does not match foo" + + expected_msg = ( + 'Command "does not match foo" didn\'t match any of the provided format ' + "strings" + ) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertCommandMatchesExactlyOneFormatString, + format_strings=format_strings, + command=command, + ) + + @mock.patch.object( + BaseActionAliasTestCase, + "_get_base_pack_path", + mock.Mock(return_value=PACK_PATH_1), + ) def test_base_class_works_when_pack_directory_name_doesnt_match_pack_name(self): # Verify that the alias can still be succesfuly loaded from disk even if the pack directory # name doesn't match "pack" resource attribute (aka pack ref) self.mock_get_action_alias_db_by_name = False - action_alias_db = self._get_action_alias_db_by_name(name='alias1') - self.assertEqual(action_alias_db.name, 'alias1') - self.assertEqual(action_alias_db.pack, 'pack_name_not_the_same_as_dir_name') + action_alias_db = self._get_action_alias_db_by_name(name="alias1") + self.assertEqual(action_alias_db.name, "alias1") + self.assertEqual(action_alias_db.pack, "pack_name_not_the_same_as_dir_name") # Note: We mock the original method to make testing of all the edge cases easier def _get_action_alias_db_by_name(self, name): if not self.mock_get_action_alias_db_by_name: - return super(PackActionAliasUnitTestUtils, self)._get_action_alias_db_by_name(name) + return super( + PackActionAliasUnitTestUtils, self + )._get_action_alias_db_by_name(name) values = { - 'name': self.action_alias_name, - 'pack': 'mock', - 'formats': [ - 'show last {{count}} metrics for {{server}}', - ] + "name": self.action_alias_name, + "pack": "mock", + "formats": [ + "show last {{count}} metrics for {{server}}", + ], } action_alias_db = ActionAliasDB(**values) return action_alias_db diff --git a/st2common/tests/unit/test_pack_management.py b/st2common/tests/unit/test_pack_management.py index abc0498489..b350c7d98f 100644 --- a/st2common/tests/unit/test_pack_management.py +++ b/st2common/tests/unit/test_pack_management.py @@ -21,37 +21,35 @@ import unittest2 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -PACK_ACTIONS_DIR = os.path.join(BASE_DIR, '../../../contrib/packs/actions') +PACK_ACTIONS_DIR = os.path.join(BASE_DIR, "../../../contrib/packs/actions") PACK_ACTIONS_DIR = os.path.abspath(PACK_ACTIONS_DIR) sys.path.insert(0, PACK_ACTIONS_DIR) from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from st2common.util.pack_management import eval_repo_url -__all__ = [ - 'InstallPackTestCase' -] +__all__ = ["InstallPackTestCase"] class InstallPackTestCase(unittest2.TestCase): - def test_eval_repo(self): - result = eval_repo_url('stackstorm/st2contrib') - self.assertEqual(result, 'https://github.com/stackstorm/st2contrib') + result = eval_repo_url("stackstorm/st2contrib") + self.assertEqual(result, "https://github.com/stackstorm/st2contrib") - result = eval_repo_url('git@github.com:StackStorm/st2contrib.git') - self.assertEqual(result, 'git@github.com:StackStorm/st2contrib.git') + result = eval_repo_url("git@github.com:StackStorm/st2contrib.git") + self.assertEqual(result, "git@github.com:StackStorm/st2contrib.git") - result = eval_repo_url('gitlab@gitlab.com:StackStorm/st2contrib.git') - self.assertEqual(result, 'gitlab@gitlab.com:StackStorm/st2contrib.git') + result = eval_repo_url("gitlab@gitlab.com:StackStorm/st2contrib.git") + self.assertEqual(result, "gitlab@gitlab.com:StackStorm/st2contrib.git") - repo_url = 'https://github.com/StackStorm/st2contrib.git' + repo_url = "https://github.com/StackStorm/st2contrib.git" result = eval_repo_url(repo_url) self.assertEqual(result, repo_url) - repo_url = 'https://git-wip-us.apache.org/repos/asf/libcloud.git' + repo_url = "https://git-wip-us.apache.org/repos/asf/libcloud.git" result = eval_repo_url(repo_url) self.assertEqual(result, repo_url) diff --git a/st2common/tests/unit/test_param_utils.py b/st2common/tests/unit/test_param_utils.py index 695d17f448..c2e5810815 100644 --- a/st2common/tests/unit/test_param_utils.py +++ b/st2common/tests/unit/test_param_utils.py @@ -36,30 +36,31 @@ from st2tests.fixturesloader import FixturesLoader -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'actions': ['action_4_action_context_param.yaml', 'action_system_default.yaml'], - 'runners': ['testrunner1.yaml'] + "actions": ["action_4_action_context_param.yaml", "action_system_default.yaml"], + "runners": ["testrunner1.yaml"], } -FIXTURES = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) +FIXTURES = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ParamsUtilsTest(DbTestCase): - action_db = FIXTURES['actions']['action_4_action_context_param.yaml'] - action_system_default_db = FIXTURES['actions']['action_system_default.yaml'] - runnertype_db = FIXTURES['runners']['testrunner1.yaml'] + action_db = FIXTURES["actions"]["action_4_action_context_param.yaml"] + action_system_default_db = FIXTURES["actions"]["action_system_default.yaml"] + runnertype_db = FIXTURES["runners"]["testrunner1.yaml"] def test_get_finalized_params(self): params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555, - 'runnerimmutable': 'failed_override', - 'actionimmutable': 'failed_override' + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, + "runnerimmutable": "failed_override", + "actionimmutable": "failed_override", } liveaction_db = self._get_liveaction_model(params) @@ -67,289 +68,320 @@ def test_get_finalized_params(self): ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - liveaction_db.context) + liveaction_db.context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that a runner param can be overridden by action param default. - self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy') + self.assertEqual(runner_params.get("runnerdummy"), "actiondummy") # Assert that a runner param default can be overridden by 'falsey' action param default, # (timeout: 0 case). - self.assertEqual(runner_params.get('runnerdefaultint'), 0) + self.assertEqual(runner_params.get("runnerdefaultint"), 0) # Assert that an immutable param cannot be overridden by action param or execution param. - self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable') + self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') + self.assertEqual(action_params.get("actionstr"), "foo") # Assert that a param that is provided in action exec that isn't in action or runner params # isn't in resolved params. - self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None) + self.assertEqual( + action_params.get("some_key_that_aint_exist_in_action_or_runner"), None + ) # Assert that an immutable param cannot be overridden by execution param. - self.assertEqual(action_params.get('actionimmutable'), 'actionimmutable') + self.assertEqual(action_params.get("actionimmutable"), "actionimmutable") # Assert that an action context param is set correctly. - self.assertEqual(action_params.get('action_api_user'), 'noob') + self.assertEqual(action_params.get("action_api_user"), "noob") # Assert that none of runner params are present in action_params. for k in action_params: - self.assertNotIn(k, runner_params, 'Param ' + k + ' is a runner param.') + self.assertNotIn(k, runner_params, "Param " + k + " is a runner param.") def test_get_finalized_params_system_values(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='actionstr', value='foo')) - KeyValuePair.add_or_update(KeyValuePairDB(name='actionnumber', value='1.0')) - params = { - 'runnerint': 555 - } + KeyValuePair.add_or_update(KeyValuePairDB(name="actionstr", value="foo")) + KeyValuePair.add_or_update(KeyValuePairDB(name="actionnumber", value="1.0")) + params = {"runnerint": 555} liveaction_db = self._get_liveaction_model(params) runner_params, action_params = param_utils.get_finalized_params( ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_system_default_db.parameters, liveaction_db.parameters, - liveaction_db.context) + liveaction_db.context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that an immutable param cannot be overridden by action param or execution param. - self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable') + self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') - self.assertEqual(action_params.get('actionnumber'), 1.0) + self.assertEqual(action_params.get("actionstr"), "foo") + self.assertEqual(action_params.get("actionnumber"), 1.0) def test_get_finalized_params_action_immutable(self): params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555, - 'actionimmutable': 'failed_override' + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, + "actionimmutable": "failed_override", } liveaction_db = self._get_liveaction_model(params) - action_context = {'api_user': None} + action_context = {"api_user": None} runner_params, action_params = param_utils.get_finalized_params( ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - action_context) + action_context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that a runner param can be overridden by action param default. - self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy') + self.assertEqual(runner_params.get("runnerdummy"), "actiondummy") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') + self.assertEqual(action_params.get("actionstr"), "foo") # Assert that a param that is provided in action exec that isn't in action or runner params # isn't in resolved params. - self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None) + self.assertEqual( + action_params.get("some_key_that_aint_exist_in_action_or_runner"), None + ) def test_get_finalized_params_empty(self): params = {} runner_param_info = {} action_param_info = {} - action_context = {'user': None} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) self.assertEqual(r_runner_params, params) self.assertEqual(r_action_params, params) def test_get_finalized_params_none(self): - params = { - 'r1': None, - 'a1': None - } - runner_param_info = {'r1': {}} - action_param_info = {'a1': {}} - action_context = {'api_user': None} + params = {"r1": None, "a1": None} + runner_param_info = {"r1": {}} + action_param_info = {"a1": {}} + action_context = {"api_user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None}) - self.assertEqual(r_action_params, {'a1': None}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None}) + self.assertEqual(r_action_params, {"a1": None}) def test_get_finalized_params_no_cast(self): params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': True, - 'a2': '{{r1}} {{a1}}', - 'a3': '{{action_context.api_user}}' - } - runner_param_info = {'r1': {}, 'r2': {}} - action_param_info = {'a1': {}, 'a2': {}, 'a3': {}} - action_context = {'api_user': 'noob'} + "r1": "{{r2}}", + "r2": 1, + "a1": True, + "a2": "{{r1}} {{a1}}", + "a3": "{{action_context.api_user}}", + } + runner_param_info = {"r1": {}, "r2": {}} + action_param_info = {"a1": {}, "a2": {}, "a3": {}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'1', 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': u'1 True', 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "1", "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": "1 True", "a3": "noob"}) def test_get_finalized_params_with_cast(self): # Note : In this test runner_params.r1 has a string value. However per runner_param_info the # type is an integer. The expected type is considered and cast is performed accordingly. params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': True, - 'a2': '{{a1}}', - 'a3': '{{action_context.api_user}}' + "r1": "{{r2}}", + "r2": 1, + "a1": True, + "a2": "{{a1}}", + "a3": "{{action_context.api_user}}", } - runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'boolean'}, 'a3': {}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}} + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "boolean"}, + "a3": {}, + } + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': True, 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": 1, "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": True, "a3": "noob"}) def test_get_finalized_params_with_cast_overriden(self): params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': '{{r1}}', - 'a2': '{{r1}}', - 'a3': '{{r1}}' + "r1": "{{r2}}", + "r2": 1, + "a1": "{{r1}}", + "a2": "{{r1}}", + "a3": "{{r1}}", } - runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'string'}, - 'a3': {'type': 'integer'}, 'r1': {'type': 'string'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}} + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "string"}, + "a3": {"type": "integer"}, + "r1": {"type": "string"}, + } + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1}) - self.assertEqual(r_action_params, {'a1': 1, 'a2': u'1', 'a3': 1}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": 1, "r2": 1}) + self.assertEqual(r_action_params, {"a1": 1, "a2": "1", "a3": 1}) def test_get_finalized_params_cross_talk_no_cast(self): params = { - 'r1': '{{a1}}', - 'r2': 1, - 'a1': True, - 'a2': '{{r1}} {{a1}}', - 'a3': '{{action_context.api_user}}' - } - runner_param_info = {'r1': {}, 'r2': {}} - action_param_info = {'a1': {}, 'a2': {}, 'a3': {}} - action_context = {'api_user': 'noob'} + "r1": "{{a1}}", + "r2": 1, + "a1": True, + "a2": "{{r1}} {{a1}}", + "a3": "{{action_context.api_user}}", + } + runner_param_info = {"r1": {}, "r2": {}} + action_param_info = {"a1": {}, "a2": {}, "a3": {}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'True', 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': u'True True', 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "True", "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": "True True", "a3": "noob"}) def test_get_finalized_params_cross_talk_with_cast(self): params = { - 'r1': '{{a1}}', - 'r2': 1, - 'r3': 1, - 'a1': True, - 'a2': '{{r1}},{{a1}},{{a3}},{{r3}}', - 'a3': '{{a1}}' + "r1": "{{a1}}", + "r2": 1, + "r3": 1, + "a1": True, + "a2": "{{r1}},{{a1}},{{a3}},{{r3}}", + "a3": "{{a1}}", } - runner_param_info = {'r1': {'type': 'boolean'}, 'r2': {'type': 'integer'}, 'r3': {}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'array'}, 'a3': {}} - action_context = {'user': None} + runner_param_info = { + "r1": {"type": "boolean"}, + "r2": {"type": "integer"}, + "r3": {}, + } + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "array"}, + "a3": {}, + } + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': True, 'r2': 1, 'r3': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': (True, True, True, 1), 'a3': u'True'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": True, "r2": 1, "r3": 1}) + self.assertEqual( + r_action_params, {"a1": True, "a2": (True, True, True, 1), "a3": "True"} + ) def test_get_finalized_params_order(self): - params = { - 'r1': 'p1', - 'r2': 'p2', - 'r3': 'p3', - 'a1': 'p4', - 'a2': 'p5' - } - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}} - action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + params = {"r1": "p1", "r2": "p2", "r3": "p3", "a1": "p4", "a2": "p5"} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}} + action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'p1', 'r2': u'p2', 'r3': u'p3'}) - self.assertEqual(r_action_params, {'a1': u'p4', 'a2': u'p5'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "p1", "r2": "p2", "r3": "p3"}) + self.assertEqual(r_action_params, {"a1": "p4", "a2": "p5"}) params = {} - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}} - action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}} + action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'}) - self.assertEqual(r_action_params, {'a1': None, 'a2': u'a2'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"}) + self.assertEqual(r_action_params, {"a1": None, "a2": "a2"}) params = {} - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {}} - action_param_info = {'r1': {}, 'r2': {}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {}} + action_param_info = {"r1": {}, "r2": {}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"}) def test_get_finalized_params_non_existent_template_key_in_action_context(self): params = { - 'r1': 'foo', - 'r2': 2, - 'a1': 'i love tests', - 'a2': '{{action_context.lorem_ipsum}}' - } - runner_param_info = {'r1': {'type': 'string'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'string'}, 'a2': {'type': 'string'}} - action_context = {'api_user': 'noob', 'source_channel': 'reddit'} + "r1": "foo", + "r2": 2, + "a1": "i love tests", + "a2": "{{action_context.lorem_ipsum}}", + } + runner_param_info = {"r1": {"type": "string"}, "r2": {"type": "integer"}} + action_param_info = {"a1": {"type": "string"}, "a2": {"type": "string"}} + action_context = {"api_user": "noob", "source_channel": "reddit"} try: r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.fail('This should have thrown because we are trying to deref a key in ' + - 'action context that ain\'t exist.') + runner_param_info, action_param_info, params, action_context + ) + self.fail( + "This should have thrown because we are trying to deref a key in " + + "action context that ain't exist." + ) except ParamException as e: - error_msg = 'Failed to render parameter "a2": \'dict object\' ' + \ - 'has no attribute \'lorem_ipsum\'' + error_msg = ( + "Failed to render parameter \"a2\": 'dict object' " + + "has no attribute 'lorem_ipsum'" + ) self.assertIn(error_msg, six.text_type(e)) pass def test_unicode_value_casting(self): - rendered = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'} - parameter_schemas = {'a1': {'type': 'string'}} + rendered = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"} + parameter_schemas = {"a1": {"type": "string"}} - result = param_utils._cast_params(rendered=rendered, - parameter_schemas=parameter_schemas) + result = param_utils._cast_params( + rendered=rendered, parameter_schemas=parameter_schemas + ) if six.PY3: - expected = { - 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2') - } + expected = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")} else: expected = { - 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc' - u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2') + "a1": ( + "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc" + "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2" + ) } self.assertEqual(result, expected) def test_get_finalized_params_with_casting_unicode_values(self): - params = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'} + params = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"} runner_param_info = {} - action_param_info = {'a1': {'type': 'string'}} + action_param_info = {"a1": {"type": "string"}} - action_context = {'user': None} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) if six.PY3: - expected_action_params = { - 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2') - } + expected_action_params = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")} else: expected_action_params = { - 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc' - u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2') + "a1": ( + "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc" + "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2" + ) } self.assertEqual(r_runner_params, {}) @@ -359,59 +391,53 @@ def test_get_finalized_params_with_dict(self): # Note : In this test runner_params.r1 has a string value. However per runner_param_info the # type is an integer. The expected type is considered and cast is performed accordingly. params = { - 'r1': '{{r2}}', - 'r2': {'r2.1': 1}, - 'a1': True, - 'a2': '{{a1}}', - 'a3': { - 'test': '{{a1}}', - 'test1': '{{a4}}', - 'test2': '{{a5}}', + "r1": "{{r2}}", + "r2": {"r2.1": 1}, + "a1": True, + "a2": "{{a1}}", + "a3": { + "test": "{{a1}}", + "test1": "{{a4}}", + "test2": "{{a5}}", }, - 'a4': 3, - 'a5': ['1', '{{a1}}'] + "a4": 3, + "a5": ["1", "{{a1}}"], } - runner_param_info = {'r1': {'type': 'object'}, 'r2': {'type': 'object'}} + runner_param_info = {"r1": {"type": "object"}, "r2": {"type": "object"}} action_param_info = { - 'a1': { - 'type': 'boolean', + "a1": { + "type": "boolean", }, - 'a2': { - 'type': 'boolean', + "a2": { + "type": "boolean", }, - 'a3': { - 'type': 'object', + "a3": { + "type": "object", }, - 'a4': { - 'type': 'integer', + "a4": { + "type": "integer", }, - 'a5': { - 'type': 'array', + "a5": { + "type": "array", }, } r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, {'user': None}) - self.assertEqual( - r_runner_params, {'r1': {'r2.1': 1}, 'r2': {'r2.1': 1}}) + runner_param_info, action_param_info, params, {"user": None} + ) + self.assertEqual(r_runner_params, {"r1": {"r2.1": 1}, "r2": {"r2.1": 1}}) self.assertEqual( r_action_params, { - 'a1': True, - 'a2': True, - 'a3': { - 'test': True, - 'test1': 3, - 'test2': [ - '1', - True - ], + "a1": True, + "a2": True, + "a3": { + "test": True, + "test1": 3, + "test2": ["1", True], }, - 'a4': 3, - 'a5': [ - '1', - True - ], - } + "a4": 3, + "a5": ["1", True], + }, ) def test_get_finalized_params_with_list(self): @@ -419,183 +445,177 @@ def test_get_finalized_params_with_list(self): # type is an integer. The expected type is considered and cast is performed accordingly. self.maxDiff = None params = { - 'r1': '{{r2}}', - 'r2': ['1', '2'], - 'a1': True, - 'a2': 'Test', - 'a3': 'Test2', - 'a4': '{{a1}}', - 'a5': ['{{a2}}', '{{a3}}'], - 'a6': [ - ['{{r2}}', '{{a2}}'], - ['{{a3}}', '{{a1}}'], + "r1": "{{r2}}", + "r2": ["1", "2"], + "a1": True, + "a2": "Test", + "a3": "Test2", + "a4": "{{a1}}", + "a5": ["{{a2}}", "{{a3}}"], + "a6": [ + ["{{r2}}", "{{a2}}"], + ["{{a3}}", "{{a1}}"], [ - '{{a7}}', - 'This should be rendered as a string {{a1}}', - '{{a1}} This, too, should be rendered as a string {{a1}}', - ] + "{{a7}}", + "This should be rendered as a string {{a1}}", + "{{a1}} This, too, should be rendered as a string {{a1}}", + ], ], - 'a7': 5, + "a7": 5, } - runner_param_info = {'r1': {'type': 'array'}, 'r2': {'type': 'array'}} + runner_param_info = {"r1": {"type": "array"}, "r2": {"type": "array"}} action_param_info = { - 'a1': {'type': 'boolean'}, - 'a2': {'type': 'string'}, - 'a3': {'type': 'string'}, - 'a4': {'type': 'boolean'}, - 'a5': {'type': 'array'}, - 'a6': {'type': 'array'}, - 'a7': {'type': 'integer'}, + "a1": {"type": "boolean"}, + "a2": {"type": "string"}, + "a3": {"type": "string"}, + "a4": {"type": "boolean"}, + "a5": {"type": "array"}, + "a6": {"type": "array"}, + "a7": {"type": "integer"}, } r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, {'user': None}) - self.assertEqual(r_runner_params, {'r1': ['1', '2'], 'r2': ['1', '2']}) + runner_param_info, action_param_info, params, {"user": None} + ) + self.assertEqual(r_runner_params, {"r1": ["1", "2"], "r2": ["1", "2"]}) self.assertEqual( r_action_params, { - 'a1': True, - 'a2': 'Test', - 'a3': 'Test2', - 'a4': True, - 'a5': ['Test', 'Test2'], - 'a6': [ - [['1', '2'], 'Test'], - ['Test2', True], + "a1": True, + "a2": "Test", + "a3": "Test2", + "a4": True, + "a5": ["Test", "Test2"], + "a6": [ + [["1", "2"], "Test"], + ["Test2", True], [ 5, - u'This should be rendered as a string True', - u'True This, too, should be rendered as a string True' - ] + "This should be rendered as a string True", + "True This, too, should be rendered as a string True", + ], ], - 'a7': 5, - } + "a7": 5, + }, ) def test_get_finalized_params_with_cyclic_dependency(self): - params = {'r1': '{{r2}}', 'r2': '{{r1}}'} - runner_param_info = {'r1': {}, 'r2': {}} + params = {"r1": "{{r2}}", "r2": "{{r1}}"} + runner_param_info = {"r1": {}, "r2": {}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Cyclic') == 0 + test_pass = six.text_type(e).find("Cyclic") == 0 self.assertTrue(test_pass) def test_get_finalized_params_with_missing_dependency(self): - params = {'r1': '{{r3}}', 'r2': '{{r3}}'} - runner_param_info = {'r1': {}, 'r2': {}} + params = {"r1": "{{r3}}", "r2": "{{r3}}"} + runner_param_info = {"r1": {}, "r2": {}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Dependency') == 0 + test_pass = six.text_type(e).find("Dependency") == 0 self.assertTrue(test_pass) params = {} - runner_param_info = {'r1': {'default': '{{r3}}'}, 'r2': {'default': '{{r3}}'}} + runner_param_info = {"r1": {"default": "{{r3}}"}, "r2": {"default": "{{r3}}"}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Dependency') == 0 + test_pass = six.text_type(e).find("Dependency") == 0 self.assertTrue(test_pass) def test_get_finalized_params_no_double_rendering(self): - params = { - 'r1': '{{ action_context.h1 }}{{ action_context.h2 }}' - } - runner_param_info = {'r1': {}} + params = {"r1": "{{ action_context.h1 }}{{ action_context.h2 }}"} + runner_param_info = {"r1": {}} action_param_info = {} - action_context = { - 'h1': '{', - 'h2': '{ missing }}', - 'user': None - } + action_context = {"h1": "{", "h2": "{ missing }}", "user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': '{{ missing }}'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "{{ missing }}"}) self.assertEqual(r_action_params, {}) def test_get_finalized_params_jinja_filters(self): - params = {'cmd': 'echo {{"1.6.0" | version_bump_minor}}'} - runner_param_info = {'r1': {}} - action_param_info = {'cmd': {}} - action_context = {'user': None} + params = {"cmd": 'echo {{"1.6.0" | version_bump_minor}}'} + runner_param_info = {"r1": {}} + action_param_info = {"cmd": {}} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) - self.assertEqual(r_action_params['cmd'], "echo 1.7.0") + self.assertEqual(r_action_params["cmd"], "echo 1.7.0") def test_get_finalized_params_param_rendering_failure(self): - params = {'cmd': '{{a2.foo}}', 'a2': 'test'} - action_param_info = {'cmd': {}, 'a2': {}} + params = {"cmd": "{{a2.foo}}", "a2": "test"} + action_param_info = {"cmd": {}, "a2": {}} expected_msg = 'Failed to render parameter "cmd": .*' - self.assertRaisesRegexp(ParamException, - expected_msg, - param_utils.get_finalized_params, - runnertype_parameter_info={}, - action_parameter_info=action_param_info, - liveaction_parameters=params, - action_context={'user': None}) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.get_finalized_params, + runnertype_parameter_info={}, + action_parameter_info=action_param_info, + liveaction_parameters=params, + action_context={"user": None}, + ) def test_get_finalized_param_object_contains_template_notation_in_the_value(self): - runner_param_info = {'r1': {}} + runner_param_info = {"r1": {}} action_param_info = { - 'params': { - 'type': 'object', - 'default': { - 'host': '{{host}}', - 'port': '{{port}}', - 'path': '/bar'} + "params": { + "type": "object", + "default": {"host": "{{host}}", "port": "{{port}}", "path": "/bar"}, } } - params = { - 'host': 'lolcathost', - 'port': 5555 - } - action_context = {'user': None} + params = {"host": "lolcathost", "port": 5555} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) - expected_params = { - 'host': 'lolcathost', - 'port': 5555, - 'path': '/bar' - } - self.assertEqual(r_action_params['params'], expected_params) + expected_params = {"host": "lolcathost", "port": 5555, "path": "/bar"} + self.assertEqual(r_action_params["params"], expected_params) def test_cast_param_referenced_action_doesnt_exist(self): # Make sure the function throws if the action doesnt exist expected_msg = 'Action with ref "foo.doesntexist" doesn\'t exist' - self.assertRaisesRegexp(ValueError, expected_msg, action_param_utils.cast_params, - action_ref='foo.doesntexist', params={}) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action_param_utils.cast_params, + action_ref="foo.doesntexist", + params={}, + ) def test_get_finalized_params_with_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: config_loader().get_config.return_value = { - 'generic_config_param': 'So generic' + "generic_config_param": "So generic" } params = { - 'config_param': '{{config_context.generic_config_param}}', + "config_param": "{{config_context.generic_config_param}}", } liveaction_db = self._get_liveaction_model(params, True) @@ -603,369 +623,327 @@ def test_get_finalized_params_with_config(self): ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - liveaction_db.context) - self.assertEqual( - action_params.get('config_param'), - 'So generic' + liveaction_db.context, ) + self.assertEqual(action_params.get("config_param"), "So generic") def test_get_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: - mock_config_return = { - 'generic_config_param': 'So generic' - } + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: + mock_config_return = {"generic_config_param": "So generic"} config_loader().get_config.return_value = mock_config_return self.assertEqual(get_config(None, None), {}) - self.assertEqual(get_config('pack', None), {}) - self.assertEqual(get_config(None, 'user'), {}) - self.assertEqual( - get_config('pack', 'user'), mock_config_return - ) + self.assertEqual(get_config("pack", None), {}) + self.assertEqual(get_config(None, "user"), {}) + self.assertEqual(get_config("pack", "user"), mock_config_return) - config_loader.assert_called_with(pack_name='pack', user='user') + config_loader.assert_called_with(pack_name="pack", user="user") config_loader().get_config.assert_called_once() def _get_liveaction_model(self, params, with_config_context=False): - status = 'initializing' + status = "initializing" start_timestamp = date_utils.get_datetime_utc_now() - action_ref = ResourceReference(name=ParamsUtilsTest.action_db.name, - pack=ParamsUtilsTest.action_db.pack).ref - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=params) + action_ref = ResourceReference( + name=ParamsUtilsTest.action_db.name, pack=ParamsUtilsTest.action_db.pack + ).ref + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=params, + ) liveaction_db.context = { - 'api_user': 'noob', - 'source_channel': 'reddit', + "api_user": "noob", + "source_channel": "reddit", } if with_config_context: - liveaction_db.context.update( - { - 'pack': 'generic', - 'user': 'st2admin' - } - ) + liveaction_db.context.update({"pack": "generic", "user": "st2admin"}) return liveaction_db def test_get_value_from_datastore_through_render_live_params(self): # Register datastore value to be refered by this test-case register_kwargs = [ - {'name': 'test_key', 'value': 'foo'}, - {'name': 'user1:test_key', 'value': 'bar', 'scope': FULL_USER_SCOPE}, - {'name': '%s:test_key' % cfg.CONF.system_user.user, 'value': 'baz', - 'scope': FULL_USER_SCOPE}, + {"name": "test_key", "value": "foo"}, + {"name": "user1:test_key", "value": "bar", "scope": FULL_USER_SCOPE}, + { + "name": "%s:test_key" % cfg.CONF.system_user.user, + "value": "baz", + "scope": FULL_USER_SCOPE, + }, ] for kwargs in register_kwargs: KeyValuePair.add_or_update(KeyValuePairDB(**kwargs)) # Assert that datastore value can be got via the Jinja expression from individual scopes. - context = {'user': 'user1'} + context = {"user": "user1"} param = { - 'system_value': {'default': '{{ st2kv.system.test_key }}'}, - 'user_value': {'default': '{{ st2kv.user.test_key }}'}, + "system_value": {"default": "{{ st2kv.system.test_key }}"}, + "user_value": {"default": "{{ st2kv.user.test_key }}"}, } - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['system_value'], 'foo') - self.assertEqual(live_params['user_value'], 'bar') + self.assertEqual(live_params["system_value"], "foo") + self.assertEqual(live_params["user_value"], "bar") # Assert that datastore value in the user-scope that is registered by user1 # cannot be got by the operation of user2. - context = {'user': 'user2'} - param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}} - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + context = {"user": "user2"} + param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}} + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['user_value'], '') + self.assertEqual(live_params["user_value"], "") # Assert that system-user's scope is selected when user and api_user parameter specified context = {} - param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}} - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}} + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['user_value'], 'baz') + self.assertEqual(live_params["user_value"], "baz") def test_get_live_params_with_additional_context(self): - runner_param_info = { - 'r1': { - 'default': 'some' - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' - } - } - params = { - 'r3': 'lolcathost', - 'r1': '{{ additional.stuff }}' - } - action_context = {'user': None} - additional_contexts = { - 'additional': { - 'stuff': 'generic' - } - } + runner_param_info = {"r1": {"default": "some"}} + action_param_info = {"r2": {"default": "{{ r1 }}"}} + params = {"r3": "lolcathost", "r1": "{{ additional.stuff }}"} + action_context = {"user": None} + additional_contexts = {"additional": {"stuff": "generic"}} live_params = param_utils.render_live_params( - runner_param_info, action_param_info, params, action_context, additional_contexts) + runner_param_info, + action_param_info, + params, + action_context, + additional_contexts, + ) - expected_params = { - 'r1': 'generic', - 'r2': 'generic', - 'r3': 'lolcathost' - } + expected_params = {"r1": "generic", "r2": "generic", "r3": "lolcathost"} self.assertEqual(live_params, expected_params) def test_cyclic_dependency_friendly_error_message(self): runner_param_info = { - 'r1': { - 'default': 'some', - 'cyclic': 'cyclic value', - 'morecyclic': 'cyclic value' - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' + "r1": { + "default": "some", + "cyclic": "cyclic value", + "morecyclic": "cyclic value", } } + action_param_info = {"r2": {"default": "{{ r1 }}"}} params = { - 'r3': 'lolcathost', - 'cyclic': '{{ cyclic }}', - 'morecyclic': '{{ morecyclic }}' + "r3": "lolcathost", + "cyclic": "{{ cyclic }}", + "morecyclic": "{{ morecyclic }}", } - action_context = {'user': None} + action_context = {"user": None} - expected_msg = 'Cyclic dependency found in the following variables: cyclic, morecyclic' - self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params, - runner_param_info, action_param_info, params, action_context) + expected_msg = ( + "Cyclic dependency found in the following variables: cyclic, morecyclic" + ) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.render_live_params, + runner_param_info, + action_param_info, + params, + action_context, + ) def test_unsatisfied_dependency_friendly_error_message(self): runner_param_info = { - 'r1': { - 'default': 'some', - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' + "r1": { + "default": "some", } } + action_param_info = {"r2": {"default": "{{ r1 }}"}} params = { - 'r3': 'lolcathost', - 'r4': '{{ variable_not_defined }}', + "r3": "lolcathost", + "r4": "{{ variable_not_defined }}", } - action_context = {'user': None} + action_context = {"user": None} expected_msg = 'Dependency unsatisfied in variable "variable_not_defined"' - self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params, - runner_param_info, action_param_info, params, action_context) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.render_live_params, + runner_param_info, + action_param_info, + params, + action_context, + ) def test_add_default_templates_to_live_params(self): - """Test addition of template values in defaults to live params - """ + """Test addition of template values in defaults to live params""" # Ensure parameter is skipped if the parameter has immutable set to true in schema schemas = [ { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer', - 'immutable': True + "templateparam": { + "default": "{{ 3 | int }}", + "type": "integer", + "immutable": True, } } ] - context = { - 'templateparam': '3' - } + context = {"templateparam": "3"} result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Test with no live params, and two parameters - one should make it through because # it was a template, and the other shouldn't because its default wasn't a template - schemas = [ - { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer' - } - } - ] - context = { - 'templateparam': '3' - } + schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}] + context = {"templateparam": "3"} result = param_utils._cast_params_from({}, context, schemas) - self.assertEqual(result, {'templateparam': 3}) + self.assertEqual(result, {"templateparam": 3}) # Ensure parameter is skipped if the value in context is identical to default - schemas = [ - { - 'nottemplateparam': { - 'default': '4', - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"default": "4", "type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Ensure parameter is skipped if the parameter doesn't have a default - schemas = [ - { - 'nottemplateparam': { - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Skip if the default value isn't a Jinja expression - schemas = [ - { - 'nottemplateparam': { - 'default': '5', - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"default": "5", "type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Ensure parameter is skipped if the parameter is being overridden - schemas = [ - { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer' - } - } - ] + schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}] context = { - 'templateparam': '4', + "templateparam": "4", } - result = param_utils._cast_params_from({'templateparam': '4'}, context, schemas) - self.assertEqual(result, {'templateparam': 4}) + result = param_utils._cast_params_from({"templateparam": "4"}, context, schemas) + self.assertEqual(result, {"templateparam": 4}) def test_render_final_params_and_shell_script_action_command_strings(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False - }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) diff --git a/st2common/tests/unit/test_paramiko_command_action_model.py b/st2common/tests/unit/test_paramiko_command_action_model.py index 2ce7bbfed3..0d023d4f8a 100644 --- a/st2common/tests/unit/test_paramiko_command_action_model.py +++ b/st2common/tests/unit/test_paramiko_command_action_model.py @@ -18,76 +18,84 @@ from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction -__all__ = [ - 'ParamikoRemoteCommandActionTestCase' -] +__all__ = ["ParamikoRemoteCommandActionTestCase"] class ParamikoRemoteCommandActionTestCase(unittest2.TestCase): - def test_get_command_string_no_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') - ex = 'cd /tmp && echo boo bah baz' + "echo boo bah baz" + ) + ex = "cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # With sudo cmd_action.sudo = True - ex = 'sudo -E -- bash -c \'cd /tmp && echo boo bah baz\'' + ex = "sudo -E -- bash -c 'cd /tmp && echo boo bah baz'" self.assertEqual(cmd_action.get_full_command_string(), ex) # Executing a path command requires user to provide an escaped input. # E.g. st2 run core.remote hosts=localhost cmd='"/tmp/space stuff.sh"' cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - '"/t/space stuff.sh"') + '"/t/space stuff.sh"' + ) ex = 'cd /tmp && "/t/space stuff.sh"' self.assertEqual(cmd_action.get_full_command_string(), ex) # sudo_password provided cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.sudo = True - cmd_action.sudo_password = 'sudo pass' + cmd_action.sudo_password = "sudo pass" - ex = ('set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- ' - 'bash -c \'cd /tmp && echo boo bah baz\'') + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- " + "bash -c 'cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) def test_get_command_string_with_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') - cmd_action.env_vars = {'FOO': 'BAR', 'BAR': 'BEET CAFE'} - ex = 'export BAR=\'BEET CAFE\' ' + \ - 'FOO=BAR' + \ - ' && cd /tmp && echo boo bah baz' + "echo boo bah baz" + ) + cmd_action.env_vars = {"FOO": "BAR", "BAR": "BEET CAFE"} + ex = "export BAR='BEET CAFE' " + "FOO=BAR" + " && cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # With sudo cmd_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=BAR ' + \ - 'BAR=\'"\'"\'BEET CAFE\'"\'"\'' + \ - ' && cd /tmp && echo boo bah baz\'' - ex = 'sudo -E -- bash -c ' + \ - '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \ - 'FOO=BAR' + \ - ' && cd /tmp && echo boo bah baz\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO=BAR " + + "BAR='\"'\"'BEET CAFE'\"'\"'" + + " && cd /tmp && echo boo bah baz'" + ) + ex = ( + "sudo -E -- bash -c " + + "'export BAR='\"'\"'BEET CAFE'\"'\"' " + + "FOO=BAR" + + " && cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) # with sudo_password cmd_action.sudo = True - cmd_action.sudo_password = 'sudo pass' - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \ - 'FOO=BAR HISTFILE=/dev/null HISTSIZE=0' + \ - ' && cd /tmp && echo boo bah baz\'' + cmd_action.sudo_password = "sudo pass" + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export BAR='\"'\"'BEET CAFE'\"'\"' " + + "FOO=BAR HISTFILE=/dev/null HISTSIZE=0" + + " && cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) def test_get_command_string_no_user(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.user = None - ex = 'cd /tmp && echo boo bah baz' + ex = "cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # Executing a path command requires user to provide an escaped input. @@ -99,25 +107,28 @@ def test_get_command_string_no_user(self): def test_get_command_string_no_user_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.user = None - cmd_action.env_vars = {'FOO': 'BAR'} - ex = 'export FOO=BAR && cd /tmp && echo boo bah baz' + cmd_action.env_vars = {"FOO": "BAR"} + ex = "export FOO=BAR && cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) @staticmethod def _get_test_command_action(command): - cmd_action = ParamikoRemoteCommandAction('fixtures.remote_command', - '55ce39d532ed3543aecbe71d', - command=command, - env_vars={}, - on_behalf_user='svetlana', - user='estee', - password=None, - private_key='---PRIVATE-KEY---', - hosts='127.0.0.1', - parallel=True, - sudo=False, - timeout=None, - cwd='/tmp') + cmd_action = ParamikoRemoteCommandAction( + "fixtures.remote_command", + "55ce39d532ed3543aecbe71d", + command=command, + env_vars={}, + on_behalf_user="svetlana", + user="estee", + password=None, + private_key="---PRIVATE-KEY---", + hosts="127.0.0.1", + parallel=True, + sudo=False, + timeout=None, + cwd="/tmp", + ) return cmd_action diff --git a/st2common/tests/unit/test_paramiko_script_action_model.py b/st2common/tests/unit/test_paramiko_script_action_model.py index e05350e46d..3efae1053f 100644 --- a/st2common/tests/unit/test_paramiko_script_action_model.py +++ b/st2common/tests/unit/test_paramiko_script_action_model.py @@ -18,75 +18,81 @@ from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction -__all__ = [ - 'ParamikoRemoteScriptActionTestCase' -] +__all__ = ["ParamikoRemoteScriptActionTestCase"] class ParamikoRemoteScriptActionTestCase(unittest2.TestCase): - def test_get_command_string_no_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - ex = 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\'' + ex = "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # with sudo password script_action.sudo = True - script_action.sudo_password = 'sudo pass' - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + script_action.sudo_password = "sudo pass" + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_with_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() script_action.env_vars = { - 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d', - 'FOO': 'BAR BAZ BOOZ' + "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d", + "FOO": "BAR BAZ BOOZ", } - ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\'' + ex = ( + "export FOO='BAR BAZ BOOZ' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'" + ) self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # with sudo password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' HISTFILE=/dev/null HISTSIZE=0 ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' HISTFILE=/dev/null HISTSIZE=0 " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_no_script_args_no_env_args(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() script_action.named_args = {} script_action.positional_args = [] - ex = 'cd /tmp && /tmp/remote_script.sh' + ex = "cd /tmp && /tmp/remote_script.sh" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && /tmp/remote_script.sh\'' + ex = "sudo -E -- bash -c " + "'cd /tmp && /tmp/remote_script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_no_script_args_with_env_args(self): @@ -94,88 +100,100 @@ def test_get_command_string_no_script_args_with_env_args(self): script_action.named_args = {} script_action.positional_args = [] script_action.env_vars = { - 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d', - 'FOO': 'BAR BAZ BOOZ' + "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d", + "FOO": "BAR BAZ BOOZ", } - ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && /tmp/remote_script.sh' + ex = ( + "export FOO='BAR BAZ BOOZ' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && /tmp/remote_script.sh" + ) self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh'" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_script_path_shell_injection_safe(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - test_path = '/tmp/remote script.sh' + test_path = "/tmp/remote script.sh" script_action.remote_script = test_path script_action.named_args = {} script_action.positional_args = [] - ex = 'cd /tmp && \'/tmp/remote script.sh\'' + ex = "cd /tmp && '/tmp/remote script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = "sudo -E -- bash -c " + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" self.assertEqual(script_action.get_full_command_string(), ex) # With sudo_password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_script_path_shell_injection_safe_with_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - test_path = '/tmp/remote script.sh' + test_path = "/tmp/remote script.sh" script_action.remote_script = test_path script_action.named_args = {} script_action.positional_args = [] - script_action.env_vars = {'FOO': 'BAR'} - ex = 'export FOO=BAR && cd /tmp && \'/tmp/remote script.sh\'' + script_action.env_vars = {"FOO": "BAR"} + ex = "export FOO=BAR && cd /tmp && '/tmp/remote script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=BAR && ' + \ - 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO=BAR && " + + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # With sudo_password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && ' + \ - 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && " + + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) @staticmethod def _get_test_script_action(): - local_script_path = '/opt/stackstorm/packs/fixtures/actions/remote_script.sh' - script_action = ParamikoRemoteScriptAction('fixtures.remote_script', - '55ce39d532ed3543aecbe71d', - local_script_path, - '/opt/stackstorm/packs/fixtures/actions/lib/', - named_args={'song': 'b s'}, - positional_args=['taylor swift'], - env_vars={}, - on_behalf_user='stanley', - user='vagrant', - private_key='/home/vagrant/.ssh/stanley_rsa', - remote_dir='/tmp', - hosts=['127.0.0.1'], - parallel=True, - sudo=False, - timeout=60, cwd='/tmp') + local_script_path = "/opt/stackstorm/packs/fixtures/actions/remote_script.sh" + script_action = ParamikoRemoteScriptAction( + "fixtures.remote_script", + "55ce39d532ed3543aecbe71d", + local_script_path, + "/opt/stackstorm/packs/fixtures/actions/lib/", + named_args={"song": "b s"}, + positional_args=["taylor swift"], + env_vars={}, + on_behalf_user="stanley", + user="vagrant", + private_key="/home/vagrant/.ssh/stanley_rsa", + remote_dir="/tmp", + hosts=["127.0.0.1"], + parallel=True, + sudo=False, + timeout=60, + cwd="/tmp", + ) return script_action diff --git a/st2common/tests/unit/test_persistence.py b/st2common/tests/unit/test_persistence.py index 14f25731ff..6fce36c18d 100644 --- a/st2common/tests/unit/test_persistence.py +++ b/st2common/tests/unit/test_persistence.py @@ -27,7 +27,6 @@ class TestPersistence(DbTestCase): - @classmethod def setUpClass(cls): super(TestPersistence, cls).setUpClass() @@ -38,7 +37,7 @@ def tearDown(self): super(TestPersistence, self).tearDown() def test_crud(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get(name=obj1.name) self.assertIsNotNone(obj2) @@ -59,16 +58,16 @@ def test_crud(self): self.assertIsNone(obj2) def test_count(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) self.assertEqual(self.access.count(), 2) def test_get_all(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) objs = self.access.get_all() self.assertIsNotNone(objs) @@ -76,33 +75,35 @@ def test_get_all(self): self.assertListEqual(list(objs), [obj1, obj2]) def test_query_by_id(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get_by_id(str(obj1.id)) self.assertIsNotNone(obj2) self.assertEqual(obj1.id, obj2.id) self.assertEqual(obj1.name, obj2.name) self.assertDictEqual(obj1.context, obj2.context) - self.assertRaises(StackStormDBObjectNotFoundError, - self.access.get_by_id, str(bson.ObjectId())) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access.get_by_id, str(bson.ObjectId()) + ) def test_query_by_name(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get_by_name(obj1.name) self.assertIsNotNone(obj2) self.assertEqual(obj1.id, obj2.id) self.assertEqual(obj1.name, obj2.name) self.assertDictEqual(obj1.context, obj2.context) - self.assertRaises(StackStormDBObjectNotFoundError, self.access.get_by_name, - uuid.uuid4().hex) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access.get_by_name, uuid.uuid4().hex + ) def test_query_filter(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) - objs = self.access.query(context__user='system') + objs = self.access.query(context__user="system") self.assertIsNotNone(objs) self.assertGreater(len(objs), 0) self.assertEqual(obj1.id, objs[0].id) @@ -113,17 +114,17 @@ def test_null_filter(self): obj1 = FakeModelDB(name=uuid.uuid4().hex) obj1 = self.access.add_or_update(obj1) - objs = self.access.query(index='null') + objs = self.access.query(index="null") self.assertEqual(len(objs), 1) self.assertEqual(obj1.id, objs[0].id) self.assertEqual(obj1.name, objs[0].name) - self.assertIsNone(getattr(obj1, 'index', None)) + self.assertIsNone(getattr(obj1, "index", None)) objs = self.access.query(index=None) self.assertEqual(len(objs), 1) self.assertEqual(obj1.id, objs[0].id) self.assertEqual(obj1.name, objs[0].name) - self.assertIsNone(getattr(obj1, 'index', None)) + self.assertIsNone(getattr(obj1, "index", None)) def test_datetime_range(self): base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) @@ -132,12 +133,12 @@ def test_datetime_range(self): obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp) self.access.add_or_update(obj) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" objs = self.access.query(timestamp=dt_range) self.assertEqual(len(objs), 10) self.assertLess(objs[0].timestamp, objs[9].timestamp) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" objs = self.access.query(timestamp=dt_range) self.assertEqual(len(objs), 10) self.assertLess(objs[9].timestamp, objs[0].timestamp) @@ -146,52 +147,61 @@ def test_pagination(self): count = 100 page_size = 25 pages = int(count / page_size) - users = ['Peter', 'Susan', 'Edmund', 'Lucy'] + users = ["Peter", "Susan", "Edmund", "Lucy"] for user in users: - context = {'user': user} + context = {"user": user} for i in range(count): - self.access.add_or_update(FakeModelDB(name=uuid.uuid4().hex, - context=context, index=i)) + self.access.add_or_update( + FakeModelDB(name=uuid.uuid4().hex, context=context, index=i) + ) self.assertEqual(self.access.count(), len(users) * count) for user in users: for i in range(pages): offset = i * page_size - objs = self.access.query(context__user=user, order_by=['index'], - offset=offset, limit=page_size) + objs = self.access.query( + context__user=user, + order_by=["index"], + offset=offset, + limit=page_size, + ) self.assertEqual(len(objs), page_size) for j in range(page_size): - self.assertEqual(objs[j].context['user'], user) + self.assertEqual(objs[j].context["user"], user) self.assertEqual(objs[j].index, (i * page_size) + j) def test_sort_multiple(self): count = 60 base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' if i % 2 else 'type2' + category = "type1" if i % 2 else "type2" timestamp = base + datetime.timedelta(seconds=i) - obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp, category=category) + obj = FakeModelDB( + name=uuid.uuid4().hex, timestamp=timestamp, category=category + ) self.access.add_or_update(obj) - objs = self.access.query(order_by=['category', 'timestamp']) + objs = self.access.query(order_by=["category", "timestamp"]) self.assertEqual(len(objs), count) for i in range(count): - category = 'type1' if i < count / 2 else 'type2' + category = "type1" if i < count / 2 else "type2" self.assertEqual(objs[i].category, category) self.assertLess(objs[0].timestamp, objs[(int(count / 2)) - 1].timestamp) - self.assertLess(objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp) + self.assertLess( + objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp + ) self.assertLess(objs[int(count / 2)].timestamp, objs[count - 1].timestamp) def test_escaped_field(self): - context = {'a.b.c': 'abc'} + context = {"a.b.c": "abc"} obj1 = FakeModelDB(name=uuid.uuid4().hex, context=context) obj2 = self.access.add_or_update(obj1) # Check that the original dict has not been altered. - self.assertIn('a.b.c', list(context.keys())) - self.assertNotIn('a\uff0eb\uff0ec', list(context.keys())) + self.assertIn("a.b.c", list(context.keys())) + self.assertNotIn("a\uff0eb\uff0ec", list(context.keys())) # Check to_python has run and context is not left escaped. self.assertDictEqual(obj2.context, context) @@ -206,26 +216,26 @@ def test_query_only_fields(self): count = 5 ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' - obj = FakeModelDB(name='test-%s' % (i), timestamp=ts, category=category) + category = "type1" + obj = FakeModelDB(name="test-%s" % (i), timestamp=ts, category=category) self.access.add_or_update(obj) model_dbs = FakeModel.query() - self.assertEqual(model_dbs[0].name, 'test-0') + self.assertEqual(model_dbs[0].name, "test-0") self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") # only id - model_dbs = FakeModel.query(only_fields=['id']) + model_dbs = FakeModel.query(only_fields=["id"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) self.assertEqual(model_dbs[0].category, None) # only name - note: id is always included - model_dbs = FakeModel.query(only_fields=['name']) + model_dbs = FakeModel.query(only_fields=["name"]) self.assertTrue(model_dbs[0].id) - self.assertEqual(model_dbs[0].name, 'test-0') + self.assertEqual(model_dbs[0].name, "test-0") self.assertEqual(model_dbs[0].timestamp, None) self.assertEqual(model_dbs[0].category, None) @@ -233,28 +243,28 @@ def test_query_exclude_fields(self): count = 5 ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' - obj = FakeModelDB(name='test-2-%s' % (i), timestamp=ts, category=category) + category = "type1" + obj = FakeModelDB(name="test-2-%s" % (i), timestamp=ts, category=category) self.access.add_or_update(obj) model_dbs = FakeModel.query() - self.assertEqual(model_dbs[0].name, 'test-2-0') + self.assertEqual(model_dbs[0].name, "test-2-0") self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name']) + model_dbs = FakeModel.query(exclude_fields=["name"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp']) + model_dbs = FakeModel.query(exclude_fields=["name", "timestamp"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp', 'category']) + model_dbs = FakeModel.query(exclude_fields=["name", "timestamp", "category"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) diff --git a/st2common/tests/unit/test_persistence_change_revision.py b/st2common/tests/unit/test_persistence_change_revision.py index f9e31e1c73..c268fa86b5 100644 --- a/st2common/tests/unit/test_persistence_change_revision.py +++ b/st2common/tests/unit/test_persistence_change_revision.py @@ -24,7 +24,6 @@ class TestChangeRevision(DbTestCase): - @classmethod def setUpClass(cls): super(TestChangeRevision, cls).setUpClass() @@ -35,7 +34,7 @@ def tearDown(self): super(TestChangeRevision, self).tearDown() def test_crud(self): - initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) # Test create created = self.access.add_or_update(initial) @@ -47,14 +46,14 @@ def test_crud(self): self.assertDictEqual(created.context, retrieved.context) # Test update - retrieved = self.access.update(retrieved, context={'a': 2}) + retrieved = self.access.update(retrieved, context={"a": 2}) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved.rev, updated.rev) self.assertDictEqual(retrieved.context, updated.context) # Test add or update - retrieved.context = {'a': 1, 'b': 2} + retrieved.context = {"a": 1, "b": 2} retrieved = self.access.add_or_update(retrieved) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -65,13 +64,11 @@ def test_crud(self): created.delete() self.assertRaises( - db_exc.StackStormDBObjectNotFoundError, - self.access.get_by_id, - doc_id + db_exc.StackStormDBObjectNotFoundError, self.access.get_by_id, doc_id ) def test_write_conflict(self): - initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) # Prep record created = self.access.add_or_update(initial) @@ -83,7 +80,7 @@ def test_write_conflict(self): retrieved2 = self.access.get_by_id(doc_id) # Test update on instance 1, expect success - retrieved1 = self.access.update(retrieved1, context={'a': 2}) + retrieved1 = self.access.update(retrieved1, context={"a": 2}) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved1.rev, updated.rev) @@ -94,5 +91,5 @@ def test_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, self.access.update, retrieved2, - context={'a': 1, 'b': 2} + context={"a": 1, "b": 2}, ) diff --git a/st2common/tests/unit/test_plugin_loader.py b/st2common/tests/unit/test_plugin_loader.py index 4b78b6b4cc..4641b66e9c 100644 --- a/st2common/tests/unit/test_plugin_loader.py +++ b/st2common/tests/unit/test_plugin_loader.py @@ -24,8 +24,8 @@ import st2common.util.loader as plugin_loader -PLUGIN_FOLDER = 'loadableplugin' -SRC_RELATIVE = os.path.join('../resources', PLUGIN_FOLDER) +PLUGIN_FOLDER = "loadableplugin" +SRC_RELATIVE = os.path.join("../resources", PLUGIN_FOLDER) SRC_ROOT = os.path.join(os.path.abspath(os.path.dirname(__file__)), SRC_RELATIVE) @@ -51,64 +51,71 @@ def tearDown(self): sys.path = LoaderTest.sys_path def test_module_load_from_file(self): - plugin_path = os.path.join(SRC_ROOT, 'plugin/standaloneplugin.py') + plugin_path = os.path.join(SRC_ROOT, "plugin/standaloneplugin.py") plugin_classes = plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_path) + LoaderTest.DummyPlugin, plugin_path + ) # Even though there are two classes in that file, only one # matches the specs of DummyPlugin class. self.assertEqual(1, len(plugin_classes)) # Validate sys.path now contains the plugin directory. - self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, 'plugin')), sys.path) + self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, "plugin")), sys.path) # Validate the individual plugins for plugin_class in plugin_classes: try: plugin_instance = plugin_class() ret_val = plugin_instance.do_work() - self.assertIsNotNone(ret_val, 'Should be non-null.') + self.assertIsNotNone(ret_val, "Should be non-null.") except: pass def test_module_load_from_file_fail(self): try: - plugin_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py') + plugin_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py") plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_path) - self.assertTrue(False, 'Import error is expected.') + self.assertTrue(False, "Import error is expected.") except ImportError: self.assertTrue(True) def test_syspath_unchanged_load_multiple_plugins(self): - plugin_1_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py') + plugin_1_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py") try: - plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_1_path) + plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_1_path) except ImportError: pass old_sys_path = copy.copy(sys.path) - plugin_2_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin2.py') + plugin_2_path = os.path.join(SRC_ROOT, "plugin/sampleplugin2.py") try: - plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_2_path) + plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_2_path) except ImportError: pass - self.assertEqual(old_sys_path, sys.path, 'Should be equal.') + self.assertEqual(old_sys_path, sys.path, "Should be equal.") def test_register_plugin_class_class_doesnt_exist(self): - file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py') + file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py") expected_msg = 'doesn\'t expose class named "SamplePluginNotExists"' - self.assertRaisesRegexp(Exception, expected_msg, - plugin_loader.register_plugin_class, - base_class=LoaderTest.DummyPlugin, - file_path=file_path, - class_name='SamplePluginNotExists') + self.assertRaisesRegexp( + Exception, + expected_msg, + plugin_loader.register_plugin_class, + base_class=LoaderTest.DummyPlugin, + file_path=file_path, + class_name="SamplePluginNotExists", + ) def test_register_plugin_class_abstract_method_not_implemented(self): - file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py') - - expected_msg = 'doesn\'t implement required "do_work" method from the base class' - self.assertRaisesRegexp(plugin_loader.IncompatiblePluginException, expected_msg, - plugin_loader.register_plugin_class, - base_class=LoaderTest.DummyPlugin, - file_path=file_path, - class_name='SamplePlugin') + file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py") + + expected_msg = ( + 'doesn\'t implement required "do_work" method from the base class' + ) + self.assertRaisesRegexp( + plugin_loader.IncompatiblePluginException, + expected_msg, + plugin_loader.register_plugin_class, + base_class=LoaderTest.DummyPlugin, + file_path=file_path, + class_name="SamplePlugin", + ) diff --git a/st2common/tests/unit/test_policies.py b/st2common/tests/unit/test_policies.py index 5491e7482b..f6dd9a47de 100644 --- a/st2common/tests/unit/test_policies.py +++ b/st2common/tests/unit/test_policies.py @@ -19,55 +19,43 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'PolicyTestCase' -] +__all__ = ["PolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'runners': [ - 'testrunner1.yaml' - ], - 'actions': [ - 'action1.yaml' - ], - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } class PolicyTestCase(DbTestCase): - @classmethod def setUpClass(cls): super(PolicyTestCase, cls).setUpClass() loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def test_get_by_ref(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") self.assertIsNotNone(policy_db) - self.assertEqual(policy_db.pack, 'wolfpack') - self.assertEqual(policy_db.name, 'action-1.concurrency') + self.assertEqual(policy_db.pack, "wolfpack") + self.assertEqual(policy_db.name, "action-1.concurrency") policy_type_db = PolicyType.get_by_ref(policy_db.policy_type) self.assertIsNotNone(policy_type_db) - self.assertEqual(policy_type_db.resource_type, 'action') - self.assertEqual(policy_type_db.name, 'concurrency') + self.assertEqual(policy_type_db.resource_type, "action") + self.assertEqual(policy_type_db.name, "concurrency") def test_get_driver(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - policy = get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + policy = get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) self.assertIsInstance(policy, ResourcePolicyApplicator) self.assertEqual(policy._policy_ref, policy_db.ref) self.assertEqual(policy._policy_type, policy_db.policy_type) - self.assertTrue(hasattr(policy, 'threshold')) + self.assertTrue(hasattr(policy, "threshold")) self.assertEqual(policy.threshold, 3) diff --git a/st2common/tests/unit/test_policies_registrar.py b/st2common/tests/unit/test_policies_registrar.py index b46515e08a..85c1d34490 100644 --- a/st2common/tests/unit/test_policies_registrar.py +++ b/st2common/tests/unit/test_policies_registrar.py @@ -29,9 +29,7 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PoliciesRegistrarTestCase' -] +__all__ = ["PoliciesRegistrarTestCase"] class PoliciesRegistrarTestCase(CleanDbTestCase): @@ -44,13 +42,13 @@ def setUp(self): def test_register_policy_types(self): self.assertEqual(register_policy_types(st2tests), 2) - type1 = PolicyType.get_by_ref('action.concurrency') - self.assertEqual(type1.name, 'concurrency') - self.assertEqual(type1.resource_type, 'action') + type1 = PolicyType.get_by_ref("action.concurrency") + self.assertEqual(type1.name, "concurrency") + self.assertEqual(type1.resource_type, "action") - type2 = PolicyType.get_by_ref('action.mock_policy_error') - self.assertEqual(type2.name, 'mock_policy_error') - self.assertEqual(type2.resource_type, 'action') + type2 = PolicyType.get_by_ref("action.mock_policy_error") + self.assertEqual(type2.name, "mock_policy_error") + self.assertEqual(type2.resource_type, "action") def test_register_all_policies(self): policies_dbs = Policy.get_all() @@ -64,38 +62,29 @@ def test_register_all_policies(self): policies = { policies_db.name: { - 'pack': policies_db.pack, - 'type': policies_db.policy_type, - 'parameters': policies_db.parameters + "pack": policies_db.pack, + "type": policies_db.policy_type, + "parameters": policies_db.parameters, } for policies_db in policies_dbs } expected_policies = { - 'test_policy_1': { - 'pack': 'dummy_pack_1', - 'type': 'action.concurrency', - 'parameters': { - 'action': 'delay', - 'threshold': 3 - } + "test_policy_1": { + "pack": "dummy_pack_1", + "type": "action.concurrency", + "parameters": {"action": "delay", "threshold": 3}, }, - 'test_policy_3': { - 'pack': 'dummy_pack_1', - 'type': 'action.retry', - 'parameters': { - 'retry_on': 'timeout', - 'max_retry_count': 5 - } + "test_policy_3": { + "pack": "dummy_pack_1", + "type": "action.retry", + "parameters": {"retry_on": "timeout", "max_retry_count": 5}, + }, + "sequential.retry_on_failure": { + "pack": "orquesta_tests", + "type": "action.retry", + "parameters": {"retry_on": "failure", "max_retry_count": 1}, }, - 'sequential.retry_on_failure': { - 'pack': 'orquesta_tests', - 'type': 'action.retry', - 'parameters': { - 'retry_on': 'failure', - 'max_retry_count': 1 - } - } } self.assertEqual(len(expected_policies), count) @@ -103,39 +92,49 @@ def test_register_all_policies(self): self.assertDictEqual(expected_policies, policies) def test_register_policies_from_pack(self): - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") self.assertEqual(register_policies(pack_dir=pack_dir), 2) - p1 = Policy.get_by_ref('dummy_pack_1.test_policy_1') - self.assertEqual(p1.name, 'test_policy_1') - self.assertEqual(p1.pack, 'dummy_pack_1') - self.assertEqual(p1.resource_ref, 'dummy_pack_1.local') - self.assertEqual(p1.policy_type, 'action.concurrency') + p1 = Policy.get_by_ref("dummy_pack_1.test_policy_1") + self.assertEqual(p1.name, "test_policy_1") + self.assertEqual(p1.pack, "dummy_pack_1") + self.assertEqual(p1.resource_ref, "dummy_pack_1.local") + self.assertEqual(p1.policy_type, "action.concurrency") # Verify that a default value for parameter "action" which isn't provided in the file is set - self.assertEqual(p1.parameters['action'], 'delay') - self.assertEqual(p1.metadata_file, 'policies/policy_1.yaml') + self.assertEqual(p1.parameters["action"], "delay") + self.assertEqual(p1.metadata_file, "policies/policy_1.yaml") - p2 = Policy.get_by_ref('dummy_pack_1.test_policy_2') + p2 = Policy.get_by_ref("dummy_pack_1.test_policy_2") self.assertEqual(p2, None) def test_register_policy_invalid_policy_type_references(self): # Policy references an invalid (inexistent) policy type registrar = PolicyRegistrar() - policy_path = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_1/policies/policy_2.yaml') + policy_path = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_1/policies/policy_2.yaml" + ) expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist' - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_policy, - pack='dummy_pack_1', policy=policy_path) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_policy, + pack="dummy_pack_1", + policy=policy_path, + ) def test_make_sure_policy_parameters_are_validated_during_register(self): # Policy where specified parameters fail schema validation registrar = PolicyRegistrar() - policy_path = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_2/policies/policy_3.yaml') - - expected_msg = '100 is greater than the maximum of 5' - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_policy, - pack='dummy_pack_2', - policy=policy_path) + policy_path = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_2/policies/policy_3.yaml" + ) + + expected_msg = "100 is greater than the maximum of 5" + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_policy, + pack="dummy_pack_2", + policy=policy_path, + ) diff --git a/st2common/tests/unit/test_purge_executions.py b/st2common/tests/unit/test_purge_executions.py index 5362cc753e..64ee4cfa67 100644 --- a/st2common/tests/unit/test_purge_executions.py +++ b/st2common/tests/unit/test_purge_executions.py @@ -34,18 +34,10 @@ LOG = logging.getLogger(__name__) -TEST_FIXTURES = { - 'executions': [ - 'execution1.yaml' - ], - 'liveactions': [ - 'liveaction4.yaml' - ] -} +TEST_FIXTURES = {"executions": ["execution1.yaml"], "liveactions": ["liveaction4.yaml"]} class TestPurgeExecutions(CleanDbTestCase): - @classmethod def setUpClass(cls): CleanDbTestCase.setUpClass() @@ -54,114 +46,128 @@ def setUpClass(cls): def setUp(self): super(TestPurgeExecutions, self).setUp() fixtures_loader = FixturesLoader() - self.models = fixtures_loader.load_models(fixtures_pack='generic', - fixtures_dict=TEST_FIXTURES) + self.models = fixtures_loader.load_models( + fixtures_pack="generic", fixtures_dict=TEST_FIXTURES + ) def test_no_timestamp_doesnt_delete_things(self): now = date_utils.get_datetime_utc_now() - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) - expected_msg = 'Specify a valid timestamp' - self.assertRaisesRegexp(ValueError, expected_msg, purge_executions, - logger=LOG, timestamp=None) + expected_msg = "Specify a valid timestamp" + self.assertRaisesRegexp( + ValueError, expected_msg, purge_executions, logger=LOG, timestamp=None + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) def test_purge_executions_with_action_ref(self): now = date_utils.get_datetime_utc_now() - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) # Invalid action reference, nothing should be deleted - purge_executions(logger=LOG, action_ref='core.localzzz', timestamp=now - timedelta(days=10)) + purge_executions( + logger=LOG, action_ref="core.localzzz", timestamp=now - timedelta(days=10) + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) - purge_executions(logger=LOG, action_ref='core.local', timestamp=now - timedelta(days=10)) + purge_executions( + logger=LOG, action_ref="core.local", timestamp=now - timedelta(days=10) + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 0) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 0) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 0) def test_purge_executions_with_timestamp(self): now = date_utils.get_datetime_utc_now() # Write one execution after cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) # Write one execution before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=22) - exec_model['end_timestamp'] = now - timedelta(days=21) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=22) + exec_model["end_timestamp"] = now - timedelta(days=21) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 2) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 6) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 6) purge_executions(logger=LOG, timestamp=now - timedelta(days=20)) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) def test_liveaction_gets_deleted(self): @@ -169,19 +175,19 @@ def test_liveaction_gets_deleted(self): start_ts = now - timedelta(days=15) end_ts = now - timedelta(days=14) - liveaction_model = copy.deepcopy(self.models['liveactions']['liveaction4.yaml']) - liveaction_model['start_timestamp'] = start_ts - liveaction_model['end_timestamp'] = end_ts - liveaction_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED + liveaction_model = copy.deepcopy(self.models["liveactions"]["liveaction4.yaml"]) + liveaction_model["start_timestamp"] = start_ts + liveaction_model["end_timestamp"] = end_ts + liveaction_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED liveaction = LiveAction.add_or_update(liveaction_model) # Write one execution before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['end_timestamp'] = end_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() - exec_model['liveaction']['id'] = str(liveaction.id) + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["end_timestamp"] = end_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() + exec_model["liveaction"]["id"] = str(liveaction.id) ActionExecution.add_or_update(exec_model) liveactions = LiveAction.get_all() @@ -201,110 +207,143 @@ def test_purge_incomplete(self): start_ts = now - timedelta(days=15) # Write executions before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_RUNNING - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_RUNNING + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_DELAYED - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_DELAYED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_CANCELING - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_CANCELING + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_REQUESTED - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_REQUESTED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) self.assertEqual(len(ActionExecution.get_all()), 5) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 5) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 5) # Incompleted executions shouldnt be purged - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False + ) self.assertEqual(len(ActionExecution.get_all()), 5) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 5) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 5) - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True + ) self.assertEqual(len(ActionExecution.get_all()), 0) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 0) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 0) - @mock.patch('st2common.garbage_collection.executions.LiveAction') - @mock.patch('st2common.garbage_collection.executions.ActionExecution') - def test_purge_executions_whole_model_is_not_loaded_in_memory(self, mock_ActionExecution, - mock_LiveAction): + @mock.patch("st2common.garbage_collection.executions.LiveAction") + @mock.patch("st2common.garbage_collection.executions.ActionExecution") + def test_purge_executions_whole_model_is_not_loaded_in_memory( + self, mock_ActionExecution, mock_LiveAction + ): # Verify that whole execution objects are not loaded in memory and we just retrieve the # id field self.assertEqual(mock_ActionExecution.query.call_count, 0) self.assertEqual(mock_LiveAction.query.call_count, 0) now = date_utils.get_datetime_utc_now() - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True + ) self.assertEqual(mock_ActionExecution.query.call_count, 2) self.assertEqual(mock_LiveAction.query.call_count, 1) - self.assertEqual(mock_ActionExecution.query.call_args_list[0][1]['only_fields'], ['id']) - self.assertTrue(mock_ActionExecution.query.call_args_list[0][1]['no_dereference']) - self.assertEqual(mock_ActionExecution.query.call_args_list[1][1]['only_fields'], ['id']) - self.assertTrue(mock_ActionExecution.query.call_args_list[1][1]['no_dereference']) - self.assertEqual(mock_LiveAction.query.call_args_list[0][1]['only_fields'], ['id']) - self.assertTrue(mock_LiveAction.query.call_args_list[0][1]['no_dereference']) - - def _insert_mock_stdout_and_stderr_objects_for_execution(self, execution_id, count=5): + self.assertEqual( + mock_ActionExecution.query.call_args_list[0][1]["only_fields"], ["id"] + ) + self.assertTrue( + mock_ActionExecution.query.call_args_list[0][1]["no_dereference"] + ) + self.assertEqual( + mock_ActionExecution.query.call_args_list[1][1]["only_fields"], ["id"] + ) + self.assertTrue( + mock_ActionExecution.query.call_args_list[1][1]["no_dereference"] + ) + self.assertEqual( + mock_LiveAction.query.call_args_list[0][1]["only_fields"], ["id"] + ) + self.assertTrue(mock_LiveAction.query.call_args_list[0][1]["no_dereference"]) + + def _insert_mock_stdout_and_stderr_objects_for_execution( + self, execution_id, count=5 + ): execution_id = str(execution_id) stdout_dbs, stderr_dbs = [], [] for i in range(0, count): - stdout_db = ActionExecutionOutputDB(execution_id=execution_id, - action_ref='dummy.pack', - runner_ref='dummy', - output_type='stdout', - data='stdout %s' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=execution_id, + action_ref="dummy.pack", + runner_ref="dummy", + output_type="stdout", + data="stdout %s" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=execution_id, - action_ref='dummy.pack', - runner_ref='dummy', - output_type='stderr', - data='stderr%s' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=execution_id, + action_ref="dummy.pack", + runner_ref="dummy", + output_type="stderr", + data="stderr%s" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) return stdout_dbs, stderr_dbs diff --git a/st2common/tests/unit/test_purge_trigger_instances.py b/st2common/tests/unit/test_purge_trigger_instances.py index 2cc9f6ffed..515c4040c3 100644 --- a/st2common/tests/unit/test_purge_trigger_instances.py +++ b/st2common/tests/unit/test_purge_trigger_instances.py @@ -28,7 +28,6 @@ class TestPurgeTriggerInstances(CleanDbTestCase): - @classmethod def setUpClass(cls): CleanDbTestCase.setUpClass() @@ -40,32 +39,42 @@ def setUp(self): def test_no_timestamp_doesnt_delete(self): now = date_utils.get_datetime_utc_now() - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=20), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=20), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) self.assertEqual(len(TriggerInstance.get_all()), 1) - expected_msg = 'Specify a valid timestamp' - self.assertRaisesRegexp(ValueError, expected_msg, - purge_trigger_instances, - logger=LOG, timestamp=None) + expected_msg = "Specify a valid timestamp" + self.assertRaisesRegexp( + ValueError, + expected_msg, + purge_trigger_instances, + logger=LOG, + timestamp=None, + ) self.assertEqual(len(TriggerInstance.get_all()), 1) def test_purge(self): now = date_utils.get_datetime_utc_now() - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=20), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=20), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=5), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=5), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) self.assertEqual(len(TriggerInstance.get_all()), 2) diff --git a/st2common/tests/unit/test_queue_consumer.py b/st2common/tests/unit/test_queue_consumer.py index 4f54c325aa..463eb0def4 100644 --- a/st2common/tests/unit/test_queue_consumer.py +++ b/st2common/tests/unit/test_queue_consumer.py @@ -23,8 +23,8 @@ from tests.unit.base import FakeModelDB -FAKE_XCHG = Exchange('st2.tests', type='topic') -FAKE_WORK_Q = Queue('st2.tests.unit', FAKE_XCHG) +FAKE_XCHG = Exchange("st2.tests", type="topic") +FAKE_WORK_Q = Queue("st2.tests.unit", FAKE_XCHG) class FakeMessageHandler(consumers.MessageHandler): @@ -39,15 +39,14 @@ def get_handler(): class QueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock()) def test_process_message(self): payload = FakeModelDB() handler = get_handler() handler._queue_consumer._process_message(payload) FakeMessageHandler.process.assert_called_once_with(payload) - @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock()) def test_process_message_wrong_payload_type(self): payload = 100 handler = get_handler() @@ -72,8 +71,7 @@ def get_staged_handler(): class StagedQueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeStagedMessageHandler, 'pre_ack_process', mock.MagicMock()) + @mock.patch.object(FakeStagedMessageHandler, "pre_ack_process", mock.MagicMock()) def test_process_message_pre_ack(self): payload = FakeModelDB() handler = get_staged_handler() @@ -82,15 +80,16 @@ def test_process_message_pre_ack(self): FakeStagedMessageHandler.pre_ack_process.assert_called_once_with(payload) self.assertTrue(mock_message.ack.called) - @mock.patch.object(BufferedDispatcher, 'dispatch', mock.MagicMock()) - @mock.patch.object(FakeStagedMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(BufferedDispatcher, "dispatch", mock.MagicMock()) + @mock.patch.object(FakeStagedMessageHandler, "process", mock.MagicMock()) def test_process_message(self): payload = FakeModelDB() handler = get_staged_handler() mock_message = mock.MagicMock() handler._queue_consumer.process(payload, mock_message) BufferedDispatcher.dispatch.assert_called_once_with( - handler._queue_consumer._process_message, payload) + handler._queue_consumer._process_message, payload + ) handler._queue_consumer._process_message(payload) FakeStagedMessageHandler.process.assert_called_once_with(payload) self.assertTrue(mock_message.ack.called) @@ -104,13 +103,10 @@ def test_process_message_wrong_payload_type(self): class FakeVariableMessageHandler(consumers.VariableMessageHandler): - def __init__(self, connection, queues): super(FakeVariableMessageHandler, self).__init__(connection, queues) - self.message_types = { - FakeModelDB: self.handle_fake_model - } + self.message_types = {FakeModelDB: self.handle_fake_model} def process(self, message): handler_function = self.message_types.get(type(message)) @@ -125,15 +121,16 @@ def get_variable_messages_handler(): class VariableMessageQueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeVariableMessageHandler, 'handle_fake_model', mock.MagicMock()) + @mock.patch.object( + FakeVariableMessageHandler, "handle_fake_model", mock.MagicMock() + ) def test_process_message(self): payload = FakeModelDB() handler = get_variable_messages_handler() handler._queue_consumer._process_message(payload) FakeVariableMessageHandler.handle_fake_model.assert_called_once_with(payload) - @mock.patch.object(FakeVariableMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeVariableMessageHandler, "process", mock.MagicMock()) def test_process_message_wrong_payload_type(self): payload = 100 handler = get_variable_messages_handler() diff --git a/st2common/tests/unit/test_queue_utils.py b/st2common/tests/unit/test_queue_utils.py index 52ad7a60dc..db77fc01c2 100644 --- a/st2common/tests/unit/test_queue_utils.py +++ b/st2common/tests/unit/test_queue_utils.py @@ -22,31 +22,42 @@ class TestQueueUtils(TestCase): - def test_get_queue_name(self): - self.assertRaises(ValueError, - queue_utils.get_queue_name, - queue_name_base=None, queue_name_suffix=None) - self.assertRaises(ValueError, - queue_utils.get_queue_name, - queue_name_base='', queue_name_suffix=None) - self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch', - queue_name_suffix=None), - 'st2.test.watch') - self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch', - queue_name_suffix=''), - 'st2.test.watch') + self.assertRaises( + ValueError, + queue_utils.get_queue_name, + queue_name_base=None, + queue_name_suffix=None, + ) + self.assertRaises( + ValueError, + queue_utils.get_queue_name, + queue_name_base="", + queue_name_suffix=None, + ) + self.assertEqual( + queue_utils.get_queue_name( + queue_name_base="st2.test.watch", queue_name_suffix=None + ), + "st2.test.watch", + ) + self.assertEqual( + queue_utils.get_queue_name( + queue_name_base="st2.test.watch", queue_name_suffix="" + ), + "st2.test.watch", + ) queue_name = queue_utils.get_queue_name( - queue_name_base='st2.test.watch', - queue_name_suffix='foo', - add_random_uuid_to_suffix=True + queue_name_base="st2.test.watch", + queue_name_suffix="foo", + add_random_uuid_to_suffix=True, ) - pattern = re.compile(r'st2.test.watch.foo-\w') + pattern = re.compile(r"st2.test.watch.foo-\w") self.assertTrue(re.match(pattern, queue_name)) queue_name = queue_utils.get_queue_name( - queue_name_base='st2.test.watch', - queue_name_suffix='foo', - add_random_uuid_to_suffix=False + queue_name_base="st2.test.watch", + queue_name_suffix="foo", + add_random_uuid_to_suffix=False, ) - self.assertEqual(queue_name, 'st2.test.watch.foo') + self.assertEqual(queue_name, "st2.test.watch.foo") diff --git a/st2common/tests/unit/test_rbac_types.py b/st2common/tests/unit/test_rbac_types.py index d9d0a1dae8..03b5350cc9 100644 --- a/st2common/tests/unit/test_rbac_types.py +++ b/st2common/tests/unit/test_rbac_types.py @@ -22,158 +22,274 @@ class RBACPermissionTypeTestCase(TestCase): - def test_get_valid_permission_for_resource_type(self): - valid_action_permissions = PermissionType.get_valid_permissions_for_resource_type( - resource_type=ResourceType.ACTION) + valid_action_permissions = ( + PermissionType.get_valid_permissions_for_resource_type( + resource_type=ResourceType.ACTION + ) + ) for name in valid_action_permissions: - self.assertTrue(name.startswith(ResourceType.ACTION + '_')) + self.assertTrue(name.startswith(ResourceType.ACTION + "_")) valid_rule_permissions = PermissionType.get_valid_permissions_for_resource_type( - resource_type=ResourceType.RULE) + resource_type=ResourceType.RULE + ) for name in valid_rule_permissions: - self.assertTrue(name.startswith(ResourceType.RULE + '_')) + self.assertTrue(name.startswith(ResourceType.RULE + "_")) def test_get_resource_type(self): - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_LIST), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_VIEW), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_CREATE), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_MODIFY), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_DELETE), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_ALL), - SystemType.PACK) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_LIST), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_VIEW), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_ALL), - SystemType.SENSOR_TYPE) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_LIST), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_VIEW), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_CREATE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_MODIFY), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_DELETE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_ALL), - SystemType.ACTION) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_LIST), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_STOP), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_ALL), - SystemType.EXECUTION) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_LIST), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_VIEW), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_CREATE), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_MODIFY), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_DELETE), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ALL), - SystemType.RULE) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST), - SystemType.RULE_ENFORCEMENT) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW), - SystemType.RULE_ENFORCEMENT) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW), - SystemType.KEY_VALUE_PAIR) - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET), - SystemType.KEY_VALUE_PAIR) - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE), - SystemType.KEY_VALUE_PAIR) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL), - SystemType.WEBHOOK) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_LIST), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_VIEW), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_CREATE), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_DELETE), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_ALL), - SystemType.API_KEY) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_LIST), SystemType.PACK + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_VIEW), SystemType.PACK + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_CREATE), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_MODIFY), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_DELETE), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_ALL), SystemType.PACK + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_LIST), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_VIEW), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_ALL), + SystemType.SENSOR_TYPE, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_LIST), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_VIEW), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_CREATE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_MODIFY), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_DELETE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_ALL), + SystemType.ACTION, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_LIST), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_STOP), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_ALL), + SystemType.EXECUTION, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_LIST), SystemType.RULE + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_VIEW), SystemType.RULE + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_CREATE), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_MODIFY), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_DELETE), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ALL), SystemType.RULE + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST), + SystemType.RULE_ENFORCEMENT, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW), + SystemType.RULE_ENFORCEMENT, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW), + SystemType.KEY_VALUE_PAIR, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET), + SystemType.KEY_VALUE_PAIR, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE), + SystemType.KEY_VALUE_PAIR, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL), + SystemType.WEBHOOK, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_LIST), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_VIEW), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_CREATE), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_DELETE), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_ALL), + SystemType.API_KEY, + ) def test_get_permission_type(self): - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='view'), - PermissionType.ACTION_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='all'), - PermissionType.ACTION_ALL) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='execute'), - PermissionType.ACTION_EXECUTE) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE, - permission_name='view'), - PermissionType.RULE_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE, - permission_name='delete'), - PermissionType.RULE_DELETE) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='view'), - PermissionType.SENSOR_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='all'), - PermissionType.SENSOR_ALL) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='modify'), - PermissionType.SENSOR_MODIFY) - self.assertEqual( - PermissionType.get_permission_type(resource_type=ResourceType.RULE_ENFORCEMENT, - permission_name='view'), - PermissionType.RULE_ENFORCEMENT_VIEW) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="view" + ), + PermissionType.ACTION_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="all" + ), + PermissionType.ACTION_ALL, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="execute" + ), + PermissionType.ACTION_EXECUTE, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE, permission_name="view" + ), + PermissionType.RULE_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE, permission_name="delete" + ), + PermissionType.RULE_DELETE, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="view" + ), + PermissionType.SENSOR_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="all" + ), + PermissionType.SENSOR_ALL, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="modify" + ), + PermissionType.SENSOR_MODIFY, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE_ENFORCEMENT, permission_name="view" + ), + PermissionType.RULE_ENFORCEMENT_VIEW, + ) def test_get_permission_name(self): - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_LIST), - 'list') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_CREATE), - 'create') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_DELETE), - 'delete') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_ALL), - 'all') - self.assertEqual(PermissionType.get_permission_name(PermissionType.PACK_ALL), - 'all') - self.assertEqual(PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY), - 'modify') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE), - 'execute') - self.assertEqual(PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST), - 'list') + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_LIST), "list" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_CREATE), "create" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_DELETE), "delete" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_ALL), "all" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.PACK_ALL), "all" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY), "modify" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE), "execute" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST), + "list", + ) diff --git a/st2common/tests/unit/test_reference.py b/st2common/tests/unit/test_reference.py index f39800c2dd..ced486a867 100644 --- a/st2common/tests/unit/test_reference.py +++ b/st2common/tests/unit/test_reference.py @@ -26,35 +26,34 @@ from st2tests import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ReferenceTest(DbTestCase): __model = None __ref = None @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def setUpClass(cls): super(ReferenceTest, cls).setUpClass() - trigger = TriggerDB(pack='dummy_pack_1', name='trigger-1') + trigger = TriggerDB(pack="dummy_pack_1", name="trigger-1") cls.__model = Trigger.add_or_update(trigger) - cls.__ref = {'id': str(cls.__model.id), - 'name': cls.__model.name} + cls.__ref = {"id": str(cls.__model.id), "name": cls.__model.name} @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def tearDownClass(cls): Trigger.delete(cls.__model) super(ReferenceTest, cls).tearDownClass() def test_to_reference(self): ref = reference.get_ref_from_model(self.__model) - self.assertEqual(ref, self.__ref, 'Failed to generated equivalent ref.') + self.assertEqual(ref, self.__ref, "Failed to generated equivalent ref.") def test_to_reference_no_model(self): try: reference.get_ref_from_model(None) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except ValueError: self.assertTrue(True) @@ -63,37 +62,37 @@ def test_to_reference_no_model_id(self): model = copy.copy(self.__model) model.id = None reference.get_ref_from_model(model) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectMalformedError: self.assertTrue(True) def test_to_model_with_id(self): model = reference.get_model_from_ref(Trigger, self.__ref) - self.assertEqual(model, self.__model, 'Failed to return correct model.') + self.assertEqual(model, self.__model, "Failed to return correct model.") def test_to_model_with_name(self): ref = copy.copy(self.__ref) - ref['id'] = None + ref["id"] = None model = reference.get_model_from_ref(Trigger, ref) - self.assertEqual(model, self.__model, 'Failed to return correct model.') + self.assertEqual(model, self.__model, "Failed to return correct model.") def test_to_model_no_name_no_id(self): try: reference.get_model_from_ref(Trigger, {}) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectNotFoundError: self.assertTrue(True) def test_to_model_unknown_id(self): try: - reference.get_model_from_ref(Trigger, {'id': '1'}) - self.assertTrue(False, 'Exception expected.') + reference.get_model_from_ref(Trigger, {"id": "1"}) + self.assertTrue(False, "Exception expected.") except mongoengine.ValidationError: self.assertTrue(True) def test_to_model_unknown_name(self): try: - reference.get_model_from_ref(Trigger, {'name': 'unknown'}) - self.assertTrue(False, 'Exception expected.') + reference.get_model_from_ref(Trigger, {"name": "unknown"}) + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectNotFoundError: self.assertTrue(True) diff --git a/st2common/tests/unit/test_register_internal_trigger.py b/st2common/tests/unit/test_register_internal_trigger.py index dd4959611f..3d33e32548 100644 --- a/st2common/tests/unit/test_register_internal_trigger.py +++ b/st2common/tests/unit/test_register_internal_trigger.py @@ -20,7 +20,6 @@ class TestRegisterInternalTriggers(DbTestCase): - def test_register_internal_trigger_types(self): registered_trigger_types_db = register_internal_trigger_types() for trigger_type_db in registered_trigger_types_db: @@ -31,4 +30,6 @@ def _validate_shadow_trigger(self, trigger_type_db): return trigger_type_ref = trigger_type_db.get_reference().ref triggers = Trigger.query(type=trigger_type_ref) - self.assertTrue(len(triggers) > 0, 'Shadow trigger not created for %s.' % trigger_type_ref) + self.assertTrue( + len(triggers) > 0, "Shadow trigger not created for %s." % trigger_type_ref + ) diff --git a/st2common/tests/unit/test_resource_reference.py b/st2common/tests/unit/test_resource_reference.py index 04dfdc9357..95533022ed 100644 --- a/st2common/tests/unit/test_resource_reference.py +++ b/st2common/tests/unit/test_resource_reference.py @@ -22,45 +22,64 @@ class ResourceReferenceTestCase(unittest2.TestCase): def test_resource_reference_success(self): - value = 'pack1.name1' + value = "pack1.name1" ref = ResourceReference.from_string_reference(ref=value) - self.assertEqual(ref.pack, 'pack1') - self.assertEqual(ref.name, 'name1') + self.assertEqual(ref.pack, "pack1") + self.assertEqual(ref.name, "name1") self.assertEqual(ref.ref, value) - ref = ResourceReference(pack='pack1', name='name1') - self.assertEqual(ref.ref, 'pack1.name1') + ref = ResourceReference(pack="pack1", name="name1") + self.assertEqual(ref.ref, "pack1.name1") - ref = ResourceReference(pack='pack1', name='name1.name2') - self.assertEqual(ref.ref, 'pack1.name1.name2') + ref = ResourceReference(pack="pack1", name="name1.name2") + self.assertEqual(ref.ref, "pack1.name1.name2") def test_resource_reference_failure(self): - self.assertRaises(InvalidResourceReferenceError, - ResourceReference.from_string_reference, - ref='blah') + self.assertRaises( + InvalidResourceReferenceError, + ResourceReference.from_string_reference, + ref="blah", + ) - self.assertRaises(InvalidResourceReferenceError, - ResourceReference.from_string_reference, - ref=None) + self.assertRaises( + InvalidResourceReferenceError, + ResourceReference.from_string_reference, + ref=None, + ) def test_to_string_reference(self): - ref = ResourceReference.to_string_reference(pack='mapack', name='moname') - self.assertEqual(ref, 'mapack.moname') + ref = ResourceReference.to_string_reference(pack="mapack", name="moname") + self.assertEqual(ref, "mapack.moname") expected_msg = r'Pack name should not contain "\."' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack='pack.invalid', name='bar') + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack="pack.invalid", + name="bar", + ) - expected_msg = 'Both pack and name needed for building' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack='pack', name=None) + expected_msg = "Both pack and name needed for building" + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack="pack", + name=None, + ) - expected_msg = 'Both pack and name needed for building' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack=None, name='name') + expected_msg = "Both pack and name needed for building" + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack=None, + name="name", + ) def test_is_resource_reference(self): - self.assertTrue(ResourceReference.is_resource_reference('foo.bar')) - self.assertTrue(ResourceReference.is_resource_reference('foo.bar.ponies')) - self.assertFalse(ResourceReference.is_resource_reference('foo')) + self.assertTrue(ResourceReference.is_resource_reference("foo.bar")) + self.assertTrue(ResourceReference.is_resource_reference("foo.bar.ponies")) + self.assertFalse(ResourceReference.is_resource_reference("foo")) diff --git a/st2common/tests/unit/test_resource_registrar.py b/st2common/tests/unit/test_resource_registrar.py index 9850785f21..2a1c61ad6a 100644 --- a/st2common/tests/unit/test_resource_registrar.py +++ b/st2common/tests/unit/test_resource_registrar.py @@ -30,23 +30,21 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'ResourceRegistrarTestCase' -] - -PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_1') -PACK_PATH_6 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_6') -PACK_PATH_7 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_7') -PACK_PATH_8 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_8') -PACK_PATH_9 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_9') -PACK_PATH_10 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_10') -PACK_PATH_12 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_12') -PACK_PATH_13 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_13') -PACK_PATH_14 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_14') -PACK_PATH_17 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_17') -PACK_PATH_18 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_18') -PACK_PATH_20 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_20') -PACK_PATH_21 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_21') +__all__ = ["ResourceRegistrarTestCase"] + +PACK_PATH_1 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_1") +PACK_PATH_6 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_6") +PACK_PATH_7 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_7") +PACK_PATH_8 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_8") +PACK_PATH_9 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_9") +PACK_PATH_10 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_10") +PACK_PATH_12 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_12") +PACK_PATH_13 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_13") +PACK_PATH_14 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_14") +PACK_PATH_17 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_17") +PACK_PATH_18 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_18") +PACK_PATH_20 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_20") +PACK_PATH_21 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_21") class ResourceRegistrarTestCase(CleanDbTestCase): @@ -60,7 +58,7 @@ def test_register_packs(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_PATH_1} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_PATH_1} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) @@ -74,20 +72,20 @@ def test_register_packs(self): pack_db = pack_dbs[0] config_schema_db = config_schema_dbs[0] - self.assertEqual(pack_db.name, 'dummy_pack_1') + self.assertEqual(pack_db.name, "dummy_pack_1") self.assertEqual(len(pack_db.contributors), 2) - self.assertEqual(pack_db.contributors[0], 'John Doe1 ') - self.assertEqual(pack_db.contributors[1], 'John Doe2 ') - self.assertIn('api_key', config_schema_db.attributes) - self.assertIn('api_secret', config_schema_db.attributes) + self.assertEqual(pack_db.contributors[0], "John Doe1 ") + self.assertEqual(pack_db.contributors[1], "John Doe2 ") + self.assertIn("api_key", config_schema_db.attributes) + self.assertIn("api_secret", config_schema_db.attributes) # Verify pack_db.files is correct and doesn't contain excluded files (*.pyc, .git/*, etc.) # Note: We can't test that .git/* files are excluded since git doesn't allow you to add # .git directory to existing repo index :/ excluded_files = [ - '__init__.pyc', - 'actions/dummy1.pyc', - 'actions/dummy2.pyc', + "__init__.pyc", + "actions/dummy1.pyc", + "actions/dummy2.pyc", ] for excluded_file in excluded_files: @@ -100,14 +98,14 @@ def test_register_pack_arbitrary_properties_are_allowed(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() registrar._pack_loader.get_packs.return_value = { - 'dummy_pack_20': PACK_PATH_20, + "dummy_pack_20": PACK_PATH_20, } packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Ref is provided - pack_db = Pack.get_by_name('dummy_pack_20') - self.assertEqual(pack_db.ref, 'dummy_pack_20_ref') + pack_db = Pack.get_by_name("dummy_pack_20") + self.assertEqual(pack_db.ref, "dummy_pack_20_ref") self.assertEqual(len(pack_db.contributors), 0) def test_register_pack_pack_ref(self): @@ -119,53 +117,74 @@ def test_register_pack_pack_ref(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() registrar._pack_loader.get_packs.return_value = { - 'dummy_pack_1': PACK_PATH_1, - 'dummy_pack_6': PACK_PATH_6 + "dummy_pack_1": PACK_PATH_1, + "dummy_pack_6": PACK_PATH_6, } packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Ref is provided - pack_db = Pack.get_by_name('dummy_pack_6') - self.assertEqual(pack_db.ref, 'dummy_pack_6_ref') + pack_db = Pack.get_by_name("dummy_pack_6") + self.assertEqual(pack_db.ref, "dummy_pack_6_ref") self.assertEqual(len(pack_db.contributors), 0) # Ref is not provided, directory name should be used - pack_db = Pack.get_by_name('dummy_pack_1') - self.assertEqual(pack_db.ref, 'dummy_pack_1') + pack_db = Pack.get_by_name("dummy_pack_1") + self.assertEqual(pack_db.ref, "dummy_pack_1") # "ref" is not provided, but "name" is registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_7) - pack_db = Pack.get_by_name('dummy_pack_7_name') - self.assertEqual(pack_db.ref, 'dummy_pack_7_name') + pack_db = Pack.get_by_name("dummy_pack_7_name") + self.assertEqual(pack_db.ref, "dummy_pack_7_name") # "ref" is not provided and "name" contains invalid characters - expected_msg = 'contains invalid characters' - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_8) + expected_msg = "contains invalid characters" + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_8, + ) def test_register_pack_invalid_ref_name_friendly_error_message(self): registrar = ResourceRegistrar(use_pack_cache=False) # Invalid ref - expected_msg = (r'Pack ref / name can only contain valid word characters .*?,' - ' dashes are not allowed.') - self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_13) + expected_msg = ( + r"Pack ref / name can only contain valid word characters .*?," + " dashes are not allowed." + ) + self.assertRaisesRegexp( + ValidationError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_13, + ) try: registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_13) except ValidationError as e: - self.assertIn("'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e)) + self.assertIn( + "'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e) + ) else: - self.fail('Exception not thrown') + self.fail("Exception not thrown") # Pack ref not provided and name doesn't contain valid characters - expected_msg = (r'Pack name "dummy pack 14" contains invalid characters and "ref" ' - 'attribute is not available. You either need to add') - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_14) + expected_msg = ( + r'Pack name "dummy pack 14" contains invalid characters and "ref" ' + "attribute is not available. You either need to add" + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_14, + ) def test_register_pack_pack_stackstorm_version_and_future_parameters(self): # Verify DB is empty @@ -174,53 +193,74 @@ def test_register_pack_pack_stackstorm_version_and_future_parameters(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_9': PACK_PATH_9} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_9": PACK_PATH_9} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Dependencies, stackstorm_version and future values - pack_db = Pack.get_by_name('dummy_pack_9_deps') - self.assertEqual(pack_db.dependencies, ['core=0.2.0']) - self.assertEqual(pack_db.stackstorm_version, '>=1.6.0, <2.2.0') - self.assertEqual(pack_db.system, {'centos': {'foo': '>= 1.0'}}) - self.assertEqual(pack_db.python_versions, ['2', '3']) + pack_db = Pack.get_by_name("dummy_pack_9_deps") + self.assertEqual(pack_db.dependencies, ["core=0.2.0"]) + self.assertEqual(pack_db.stackstorm_version, ">=1.6.0, <2.2.0") + self.assertEqual(pack_db.system, {"centos": {"foo": ">= 1.0"}}) + self.assertEqual(pack_db.python_versions, ["2", "3"]) # Note: We only store parameters which are defined in the schema, all other custom user # defined attributes are ignored - self.assertTrue(not hasattr(pack_db, 'future')) - self.assertTrue(not hasattr(pack_db, 'this')) + self.assertTrue(not hasattr(pack_db, "future")) + self.assertTrue(not hasattr(pack_db, "this")) # Wrong characters in the required st2 version expected_msg = "'wrongstackstormversion' does not match" - self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_10) + self.assertRaisesRegexp( + ValidationError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_10, + ) def test_register_pack_empty_and_invalid_config_schema(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_17': PACK_PATH_17} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_17": PACK_PATH_17} packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.' - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + expected_msg = ( + 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) def test_register_pack_invalid_config_schema_invalid_attribute(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_18': PACK_PATH_18} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_18": PACK_PATH_18} packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = r'Additional properties are not allowed \(\'invalid\' was unexpected\)' - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + expected_msg = ( + r"Additional properties are not allowed \(\'invalid\' was unexpected\)" + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) def test_register_pack_invalid_python_versions_attribute(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_21': PACK_PATH_21} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_21": PACK_PATH_21} packs_base_paths = content_utils.get_packs_base_paths() expected_msg = r"'4' is not one of \['2', '3'\]" - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) diff --git a/st2common/tests/unit/test_runners_base.py b/st2common/tests/unit/test_runners_base.py index 34ede41adf..7490b40cd6 100644 --- a/st2common/tests/unit/test_runners_base.py +++ b/st2common/tests/unit/test_runners_base.py @@ -23,11 +23,12 @@ class RunnersLoaderUtilsTestCase(DbTestCase): def test_get_runner_success(self): - runner = get_runner('local-shell-cmd') + runner = get_runner("local-shell-cmd") self.assertTrue(runner) - self.assertEqual(runner.__class__.__name__, 'LocalShellCommandRunner') + self.assertEqual(runner.__class__.__name__, "LocalShellCommandRunner") def test_get_runner_failure_not_found(self): - expected_msg = 'Failed to find runner invalid-name-not-found.*' - self.assertRaisesRegexp(ActionRunnerCreateError, expected_msg, - get_runner, 'invalid-name-not-found') + expected_msg = "Failed to find runner invalid-name-not-found.*" + self.assertRaisesRegexp( + ActionRunnerCreateError, expected_msg, get_runner, "invalid-name-not-found" + ) diff --git a/st2common/tests/unit/test_runners_utils.py b/st2common/tests/unit/test_runners_utils.py index dc98848223..bc6acfcf7e 100644 --- a/st2common/tests/unit/test_runners_utils.py +++ b/st2common/tests/unit/test_runners_utils.py @@ -24,16 +24,17 @@ from st2tests import config as tests_config + tests_config.parse_args() -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'liveactions': ['liveaction1.yaml'], - 'actions': ['local.yaml'], - 'executions': ['execution1.yaml'], - 'runners': ['run-local.yaml'] + "liveactions": ["liveaction1.yaml"], + "actions": ["local.yaml"], + "executions": ["execution1.yaml"], + "runners": ["run-local.yaml"], } @@ -48,15 +49,16 @@ def setUp(self): loader = fixturesloader.FixturesLoader() self.models = loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES ) - self.liveaction_db = self.models['liveactions']['liveaction1.yaml'] + self.liveaction_db = self.models["liveactions"]["liveaction1.yaml"] exe_svc.create_execution_object(self.liveaction_db) self.action_db = action_db_utils.get_action_by_ref(self.liveaction_db.action) - @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None)) + @mock.patch.object( + action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None) + ) def test_invoke_post_run_action_provided(self): utils.invoke_post_run(self.liveaction_db, action_db=self.action_db) action_db_utils.get_action_by_ref.assert_not_called() @@ -64,8 +66,12 @@ def test_invoke_post_run_action_provided(self): def test_invoke_post_run_action_exists(self): utils.invoke_post_run(self.liveaction_db) - @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None)) - @mock.patch.object(action_db_utils, 'get_runnertype_by_name', mock.MagicMock(return_value=None)) + @mock.patch.object( + action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None) + ) + @mock.patch.object( + action_db_utils, "get_runnertype_by_name", mock.MagicMock(return_value=None) + ) def test_invoke_post_run_action_does_not_exist(self): utils.invoke_post_run(self.liveaction_db) action_db_utils.get_action_by_ref.assert_called_once() diff --git a/st2common/tests/unit/test_sensor_type_utils.py b/st2common/tests/unit/test_sensor_type_utils.py index 08269ebcf2..657054c453 100644 --- a/st2common/tests/unit/test_sensor_type_utils.py +++ b/st2common/tests/unit/test_sensor_type_utils.py @@ -22,59 +22,67 @@ class SensorTypeUtilsTestCase(unittest2.TestCase): - def test_to_sensor_db_model_no_trigger_types(self): sensor_meta = { - 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py', - 'class_name': 'JIRASensor', - 'pack': 'jira' + "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py", + "class_name": "JIRASensor", + "pack": "jira", } sensor_api = SensorTypeAPI(**sensor_meta) sensor_model = SensorTypeAPI.to_model(sensor_api) - self.assertEqual(sensor_model.name, sensor_meta['class_name']) - self.assertEqual(sensor_model.pack, sensor_meta['pack']) - self.assertEqual(sensor_model.artifact_uri, sensor_meta['artifact_uri']) + self.assertEqual(sensor_model.name, sensor_meta["class_name"]) + self.assertEqual(sensor_model.pack, sensor_meta["pack"]) + self.assertEqual(sensor_model.artifact_uri, sensor_meta["artifact_uri"]) self.assertListEqual(sensor_model.trigger_types, []) - @mock.patch.object(sensor_type_utils, 'create_trigger_types', mock.MagicMock( - return_value=['mock.trigger_ref'])) + @mock.patch.object( + sensor_type_utils, + "create_trigger_types", + mock.MagicMock(return_value=["mock.trigger_ref"]), + ) def test_to_sensor_db_model_with_trigger_types(self): sensor_meta = { - 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py', - 'class_name': 'JIRASensor', - 'pack': 'jira', - 'trigger_types': [{'pack': 'jira', 'name': 'issue_created', 'parameters': {}}] + "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py", + "class_name": "JIRASensor", + "pack": "jira", + "trigger_types": [ + {"pack": "jira", "name": "issue_created", "parameters": {}} + ], } sensor_api = SensorTypeAPI(**sensor_meta) sensor_model = SensorTypeAPI.to_model(sensor_api) - self.assertListEqual(sensor_model.trigger_types, ['mock.trigger_ref']) + self.assertListEqual(sensor_model.trigger_types, ["mock.trigger_ref"]) def test_get_sensor_entry_point(self): # System packs - file_path = 'file:///data/st/st2reactor/st2reactor/' + \ - 'contrib/sensors/st2_generic_webhook_sensor.py' - class_name = 'St2GenericWebhooksSensor' + file_path = ( + "file:///data/st/st2reactor/st2reactor/" + + "contrib/sensors/st2_generic_webhook_sensor.py" + ) + class_name = "St2GenericWebhooksSensor" - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'core'} + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "core"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) self.assertEqual(entry_point, class_name) # Non system packs - file_path = 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py' - class_name = 'JIRASensor' - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'jira'} + file_path = "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py" + class_name = "JIRASensor" + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "jira"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) - self.assertEqual(entry_point, 'sensors.jira_sensor.JIRASensor') + self.assertEqual(entry_point, "sensors.jira_sensor.JIRASensor") - file_path = 'file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py' - class_name = 'DockerSensor' - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'docker'} + file_path = ( + "file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py" + ) + class_name = "DockerSensor" + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "docker"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) - self.assertEqual(entry_point, 'sensors.docker_container_sensor.DockerSensor') + self.assertEqual(entry_point, "sensors.docker_container_sensor.DockerSensor") diff --git a/st2common/tests/unit/test_sensor_watcher.py b/st2common/tests/unit/test_sensor_watcher.py index 65f61965df..2379f81562 100644 --- a/st2common/tests/unit/test_sensor_watcher.py +++ b/st2common/tests/unit/test_sensor_watcher.py @@ -22,39 +22,44 @@ from st2common.models.db.sensor import SensorTypeDB from st2common.transport.publishers import PoolPublisher -MOCK_SENSOR_DB = SensorTypeDB(name='foo', pack='test') +MOCK_SENSOR_DB = SensorTypeDB(name="foo", pack="test") class SensorWatcherTests(unittest2.TestCase): - - @mock.patch.object(Message, 'ack', mock.MagicMock()) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(Message, "ack", mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def test_assert_handlers_called(self): handler_vars = { - 'create_handler_called': False, - 'update_handler_called': False, - 'delete_handler_called': False + "create_handler_called": False, + "update_handler_called": False, + "delete_handler_called": False, } def create_handler(sensor_db): - handler_vars['create_handler_called'] = True + handler_vars["create_handler_called"] = True def update_handler(sensor_db): - handler_vars['update_handler_called'] = True + handler_vars["update_handler_called"] = True def delete_handler(sensor_db): - handler_vars['delete_handler_called'] = True + handler_vars["delete_handler_called"] = True sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler) - message = Message(None, delivery_info={'routing_key': 'create'}) + message = Message(None, delivery_info={"routing_key": "create"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['create_handler_called'], 'create handler should be called.') + self.assertTrue( + handler_vars["create_handler_called"], "create handler should be called." + ) - message = Message(None, delivery_info={'routing_key': 'update'}) + message = Message(None, delivery_info={"routing_key": "update"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['update_handler_called'], 'update handler should be called.') + self.assertTrue( + handler_vars["update_handler_called"], "update handler should be called." + ) - message = Message(None, delivery_info={'routing_key': 'delete'}) + message = Message(None, delivery_info={"routing_key": "delete"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['delete_handler_called'], 'delete handler should be called.') + self.assertTrue( + handler_vars["delete_handler_called"], "delete handler should be called." + ) diff --git a/st2common/tests/unit/test_service_setup.py b/st2common/tests/unit/test_service_setup.py index 4000f6ce81..b1358f295d 100644 --- a/st2common/tests/unit/test_service_setup.py +++ b/st2common/tests/unit/test_service_setup.py @@ -31,9 +31,7 @@ from st2tests.base import CleanFilesTestCase from st2tests import config -__all__ = [ - 'ServiceSetupTestCase' -] +__all__ = ["ServiceSetupTestCase"] MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL = """ [loggers] @@ -61,11 +59,11 @@ datefmt= """.strip() -MOCK_DEFAULT_CONFIG_FILE_PATH = '/etc/st2/st2.conf-test-patched' +MOCK_DEFAULT_CONFIG_FILE_PATH = "/etc/st2/st2.conf-test-patched" def mock_get_logging_config_path(): - return '' + return "" class ServiceSetupTestCase(CleanFilesTestCase): @@ -78,19 +76,24 @@ def test_no_logging_config_found(self): else: expected_msg = "No section: .*" - self.assertRaisesRegexp(Exception, expected_msg, - service_setup.setup, service='api', - config=config, - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + self.assertRaisesRegexp( + Exception, + expected_msg, + service_setup.setup, + service="api", + config=config, + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) def test_invalid_log_level_friendly_error_message(self): _, mock_logging_config_path = tempfile.mkstemp() self.to_delete_files.append(mock_logging_config_path) - with open(mock_logging_config_path, 'w') as fp: + with open(mock_logging_config_path, "w") as fp: fp.write(MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL) def mock_get_logging_config_path(): @@ -99,21 +102,28 @@ def mock_get_logging_config_path(): config.get_logging_config_path = mock_get_logging_config_path if six.PY3: - expected_msg = 'ValueError: Unknown level: \'invalid_log_level\'' + expected_msg = "ValueError: Unknown level: 'invalid_log_level'" exc_type = ValueError else: - expected_msg = 'Invalid log level selected. Log level names need to be all uppercase' + expected_msg = ( + "Invalid log level selected. Log level names need to be all uppercase" + ) exc_type = KeyError - self.assertRaisesRegexp(exc_type, expected_msg, - service_setup.setup, service='api', - config=config, - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) - - @mock.patch('kombu.Queue.declare') + self.assertRaisesRegexp( + exc_type, + expected_msg, + service_setup.setup, + service="api", + config=config, + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) + + @mock.patch("kombu.Queue.declare") def test_register_exchanges_predeclare_queues(self, mock_declare): # Verify that queues are correctly pre-declared self.assertEqual(mock_declare.call_count, 0) @@ -121,34 +131,50 @@ def test_register_exchanges_predeclare_queues(self, mock_declare): register_exchanges() self.assertEqual(mock_declare.call_count, len(QUEUES)) - @mock.patch('st2common.constants.system.DEFAULT_CONFIG_FILE_PATH', - MOCK_DEFAULT_CONFIG_FILE_PATH) - @mock.patch('st2common.config.DEFAULT_CONFIG_FILE_PATH', MOCK_DEFAULT_CONFIG_FILE_PATH) + @mock.patch( + "st2common.constants.system.DEFAULT_CONFIG_FILE_PATH", + MOCK_DEFAULT_CONFIG_FILE_PATH, + ) + @mock.patch( + "st2common.config.DEFAULT_CONFIG_FILE_PATH", MOCK_DEFAULT_CONFIG_FILE_PATH + ) def test_service_setup_default_st2_conf_config_is_used(self): st2common_config.get_logging_config_path = mock_get_logging_config_path cfg.CONF.reset() # 1. DEFAULT_CONFIG_FILE_PATH config path should be used by default (/etc/st2/st2.conf) - expected_msg = 'Failed to find some config files: %s' % (MOCK_DEFAULT_CONFIG_FILE_PATH) - self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup, - service='api', - config=st2common_config, - config_args=['--debug'], - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + expected_msg = "Failed to find some config files: %s" % ( + MOCK_DEFAULT_CONFIG_FILE_PATH + ) + self.assertRaisesRegexp( + ConfigFilesNotFoundError, + expected_msg, + service_setup.setup, + service="api", + config=st2common_config, + config_args=["--debug"], + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) cfg.CONF.reset() # 2. --config-file should still override default config file path option - config_file_path = '/etc/st2/config.override.test' - expected_msg = 'Failed to find some config files: %s' % (config_file_path) - self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup, - service='api', - config=st2common_config, - config_args=['--config-file', config_file_path], - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + config_file_path = "/etc/st2/config.override.test" + expected_msg = "Failed to find some config files: %s" % (config_file_path) + self.assertRaisesRegexp( + ConfigFilesNotFoundError, + expected_msg, + service_setup.setup, + service="api", + config=st2common_config, + config_args=["--config-file", config_file_path], + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) diff --git a/st2common/tests/unit/test_shell_action_system_model.py b/st2common/tests/unit/test_shell_action_system_model.py index 6fdc7d1716..76609ab953 100644 --- a/st2common/tests/unit/test_shell_action_system_model.py +++ b/st2common/tests/unit/test_shell_action_system_model.py @@ -32,90 +32,87 @@ from local_runner.local_shell_script_runner import LocalShellScriptRunner CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures')) +FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures")) LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0] -__all__ = [ - 'ShellCommandActionTestCase', - 'ShellScriptActionTestCase' -] +__all__ = ["ShellCommandActionTestCase", "ShellScriptActionTestCase"] class ShellCommandActionTestCase(unittest2.TestCase): def setUp(self): self._base_kwargs = { - 'name': 'test action', - 'action_exec_id': '1', - 'command': 'ls -la', - 'env_vars': {}, - 'timeout': None + "name": "test action", + "action_exec_id": "1", + "command": "ls -la", + "env_vars": {}, + "timeout": None, } def test_user_argument(self): # User is the same as logged user, no sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'ls -la') + self.assertEqual(command, "ls -la") # User is different, sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c \'ls -la\'') + self.assertEqual(command, "sudo -E -H -u mauser -- bash -c 'ls -la'") # sudo with password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -H -u mauser -- bash -c \'ls -la\'' + expected_command = "sudo -S -E -H -u mauser -- bash -c 'ls -la'" self.assertEqual(command, expected_command) # sudo is used, it doesn't matter what user is specified since the # command should run as root kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' + kwargs["sudo"] = True + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -- bash -c \'ls -la\'') + self.assertEqual(command, "sudo -E -- bash -c 'ls -la'") # sudo with passwd kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = True + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -- bash -c \'ls -la\'' + expected_command = "sudo -S -E -- bash -c 'ls -la'" self.assertEqual(command, expected_command) class ShellScriptActionTestCase(unittest2.TestCase): def setUp(self): self._base_kwargs = { - 'name': 'test action', - 'action_exec_id': '1', - 'script_local_path_abs': '/tmp/foo.sh', - 'named_args': {}, - 'positional_args': [], - 'env_vars': {}, - 'timeout': None + "name": "test action", + "action_exec_id": "1", + "script_local_path_abs": "/tmp/foo.sh", + "named_args": {}, + "positional_args": [], + "env_vars": {}, + "timeout": None, } def _get_fixture(self, name): - path = os.path.join(FIXTURES_DIR, 'local_runner', name) + path = os.path.join(FIXTURES_DIR, "local_runner", name) - with open(path, 'r') as fp: + with open(path, "r") as fp: content = fp.read().strip() return content @@ -123,371 +120,374 @@ def _get_fixture(self, name): def test_user_argument(self): # User is the same as logged user, no sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh') + self.assertEqual(command, "/tmp/foo.sh") # User is different, sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["user"] = "mauser" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c /tmp/foo.sh') + self.assertEqual(command, "sudo -E -H -u mauser -- bash -c /tmp/foo.sh") # sudo with password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh' + expected_command = "sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) # complex sudo password which needs escaping kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = '$udo p\'as"sss' + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "$udo p'as\"sss" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = ('sudo -S -E -H ' - '-u mauser -- bash -c /tmp/foo.sh') + expected_command = "sudo -S -E -H " "-u mauser -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) command = action.get_sanitized_full_command_string() - expected_command = ('echo -e \'%s\n\' | sudo -S -E -H ' - '-u mauser -- bash -c /tmp/foo.sh' % (MASKED_ATTRIBUTE_VALUE)) + expected_command = ( + "echo -e '%s\n' | sudo -S -E -H " + "-u mauser -- bash -c /tmp/foo.sh" % (MASKED_ATTRIBUTE_VALUE) + ) self.assertEqual(command, expected_command) # sudo is used, it doesn't matter what user is specified since the # command should run as root kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = True + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -- bash -c /tmp/foo.sh' + expected_command = "sudo -S -E -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) def test_command_construction_with_parameters(self): # same user, named args, no positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh key1=value1 key2=value2') + self.assertEqual(command, "/tmp/foo.sh key1=value1 key2=value2") # same user, named args, no positional args, sudo password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = True + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -S -E -- bash -c ' - '\'/tmp/foo.sh key1=value1 key2=value2\'') + expected = "sudo -S -E -- bash -c " "'/tmp/foo.sh key1=value1 key2=value2'" self.assertEqual(command, expected) # different user, named args, no positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = 'sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2\'' + expected = ( + "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2'" + ) self.assertEqual(command, expected) # different user, named args, no positional args, sudo password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -S -E -H -u mauser -- bash -c ' - '\'/tmp/foo.sh key1=value1 key2=value2\'') + expected = ( + "sudo -S -E -H -u mauser -- bash -c " + "'/tmp/foo.sh key1=value1 key2=value2'" + ) self.assertEqual(command, expected) # same user, positional args, no named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {} - kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia', 'foo\nbar'] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {} + kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia", "foo\nbar"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh ein zwei drei \'mamma mia\' \'foo\nbar\'') + self.assertEqual(command, "/tmp/foo.sh ein zwei drei 'mamma mia' 'foo\nbar'") # different user, named args, positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = {} - kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia'] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = {} + kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - ex = ('sudo -E -H -u mauser -- ' - 'bash -c \'/tmp/foo.sh ein zwei drei \'"\'"\'mamma mia\'"\'"\'\'') + ex = ( + "sudo -E -H -u mauser -- " + "bash -c '/tmp/foo.sh ein zwei drei '\"'\"'mamma mia'\"'\"''" + ) self.assertEqual(command, ex) # same user, positional and named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2'), - ('key3', 'value 3') - ]) - - kwargs['positional_args'] = ['ein', 'zwei', 'drei'] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")] + ) + + kwargs["positional_args"] = ["ein", "zwei", "drei"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - exp = '/tmp/foo.sh key1=value1 key2=value2 key3=\'value 3\' ein zwei drei' + exp = "/tmp/foo.sh key1=value1 key2=value2 key3='value 3' ein zwei drei" self.assertEqual(command, exp) # different user, positional and named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2'), - ('key3', 'value 3') - ]) - kwargs['positional_args'] = ['ein', 'zwei', 'drei'] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict( + [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")] + ) + kwargs["positional_args"] = ["ein", "zwei", "drei"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2 ' - 'key3=\'"\'"\'value 3\'"\'"\' ein zwei drei\'') + expected = ( + "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2 " + "key3='\"'\"'value 3'\"'\"' ein zwei drei'" + ) self.assertEqual(command, expected) def test_named_parameter_escaping(self): # no sudo kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value foo bar'), - ('key2', 'value "bar" foo'), - ('key3', 'date ; whoami'), - ('key4', '"date ; whoami"'), - ]) + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [ + ("key1", "value foo bar"), + ("key2", 'value "bar" foo'), + ("key3", "date ; whoami"), + ("key4", '"date ; whoami"'), + ] + ) action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = self._get_fixture('escaping_test_command_1.txt') + expected = self._get_fixture("escaping_test_command_1.txt") self.assertEqual(command, expected) # sudo kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value foo bar'), - ('key2', 'value "bar" foo'), - ('key3', 'date ; whoami'), - ('key4', '"date ; whoami"'), - ]) + kwargs["sudo"] = True + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [ + ("key1", "value foo bar"), + ("key2", 'value "bar" foo'), + ("key3", "date ; whoami"), + ("key4", '"date ; whoami"'), + ] + ) action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = self._get_fixture('escaping_test_command_2.txt') + expected = self._get_fixture("escaping_test_command_2.txt") self.assertEqual(command, expected) def test_various_ascii_parameters(self): kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {'foo1': 'bar1', 'foo2': 'bar2'} - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {"foo1": "bar1", "foo2": "bar2"} + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, u"/tmp/foo.sh foo1=bar1 foo2=bar2") + self.assertEqual(command, "/tmp/foo.sh foo1=bar1 foo2=bar2") def test_unicode_parameter_specifing(self): kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {u'foo': u'bar'} - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {"foo": "bar"} + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, u"/tmp/foo.sh 'foo'='bar'") + self.assertEqual(command, "/tmp/foo.sh 'foo'='bar'") def test_command_construction_correct_default_parameter_values_are_used(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, - }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} - action_db = ActionDB(pack='dummy', name='action') + action_db = ActionDB(pack="dummy", name="action") - runner = LocalShellScriptRunner('id') + runner = LocalShellScriptRunner("id") runner.runner_parameters = {} runner.action = action_db # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo' + expected = "/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo" self.assertEqual(command_string, expected) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob' + expected = "/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob" self.assertEqual(command_string, expected) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc' + expected = "/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc" self.assertEqual(command_string, expected) diff --git a/st2common/tests/unit/test_state_publisher.py b/st2common/tests/unit/test_state_publisher.py index 1fa87b8487..99dbabda7f 100644 --- a/st2common/tests/unit/test_state_publisher.py +++ b/st2common/tests/unit/test_state_publisher.py @@ -27,7 +27,7 @@ from st2tests import DbTestCase -FAKE_STATE_MGMT_XCHG = kombu.Exchange('st2.fake.state', type='topic') +FAKE_STATE_MGMT_XCHG = kombu.Exchange("st2.fake.state", type="topic") class FakeModelPublisher(publishers.StatePublisherMixin): @@ -57,7 +57,7 @@ def _get_publisher(cls): def publish_state(cls, model_object): publisher = cls._get_publisher() if publisher: - publisher.publish_state(model_object, getattr(model_object, 'state', None)) + publisher.publish_state(model_object, getattr(model_object, "state", None)) @classmethod def _get_by_object(cls, object): @@ -65,7 +65,6 @@ def _get_by_object(cls, object): class StatePublisherTest(DbTestCase): - @classmethod def setUpClass(cls): super(StatePublisherTest, cls).setUpClass() @@ -75,13 +74,13 @@ def tearDown(self): FakeModelDB.drop_collection() super(StatePublisherTest, self).tearDown() - @mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) def test_publish(self): - instance = FakeModelDB(state='faked') + instance = FakeModelDB(state="faked") self.access.publish_state(instance) - publishers.PoolPublisher.publish.assert_called_with(instance, - FAKE_STATE_MGMT_XCHG, - instance.state) + publishers.PoolPublisher.publish.assert_called_with( + instance, FAKE_STATE_MGMT_XCHG, instance.state + ) def test_publish_unset(self): instance = FakeModelDB() @@ -92,5 +91,5 @@ def test_publish_none(self): self.assertRaises(Exception, self.access.publish_state, instance) def test_publish_empty_str(self): - instance = FakeModelDB(state='') + instance = FakeModelDB(state="") self.assertRaises(Exception, self.access.publish_state, instance) diff --git a/st2common/tests/unit/test_stream_generator.py b/st2common/tests/unit/test_stream_generator.py index 9c44db4657..a184220b80 100644 --- a/st2common/tests/unit/test_stream_generator.py +++ b/st2common/tests/unit/test_stream_generator.py @@ -20,7 +20,6 @@ class MockBody(object): - def __init__(self, id): self.id = id self.status = "succeeded" @@ -32,8 +31,7 @@ def __init__(self, id): EVENTS = [(INCLUDE, MockBody("notend")), (END_EVENT, MockBody(END_ID))] -class MockQueue(): - +class MockQueue: def __init__(self): self.items = EVENTS @@ -47,7 +45,6 @@ def put(self, event): class MockListener(listener.BaseListener): - def __init__(self, *args, **kwargs): super(MockListener, self).__init__(*args, **kwargs) @@ -56,19 +53,19 @@ def get_consumers(self, consumer, channel): class TestStream(unittest2.TestCase): - - @mock.patch('st2common.stream.listener.BaseListener._get_action_ref_for_body') - @mock.patch('eventlet.Queue') - def test_generator(self, mock_queue, - get_action_ref_for_body): + @mock.patch("st2common.stream.listener.BaseListener._get_action_ref_for_body") + @mock.patch("eventlet.Queue") + def test_generator(self, mock_queue, get_action_ref_for_body): get_action_ref_for_body.return_value = None mock_queue.return_value = MockQueue() mock_consumer = MockListener(connection=None) mock_consumer._stopped = False - app_iter = mock_consumer.generator(events=INCLUDE, + app_iter = mock_consumer.generator( + events=INCLUDE, end_event=END_EVENT, end_statuses=["succeeded"], - end_execution_id=END_ID) - events = EVENTS.append('') + end_execution_id=END_ID, + ) + events = EVENTS.append("") for index, val in enumerate(app_iter): self.assertEquals(val, events[index]) diff --git a/st2common/tests/unit/test_system_info.py b/st2common/tests/unit/test_system_info.py index e7ddb20bef..c840a7aa8b 100644 --- a/st2common/tests/unit/test_system_info.py +++ b/st2common/tests/unit/test_system_info.py @@ -23,8 +23,7 @@ class TestLogger(unittest.TestCase): - def test_process_info(self): process_info = system_info.get_process_info() - self.assertEqual(process_info['hostname'], socket.gethostname()) - self.assertEqual(process_info['pid'], os.getpid()) + self.assertEqual(process_info["hostname"], socket.gethostname()) + self.assertEqual(process_info["pid"], os.getpid()) diff --git a/st2common/tests/unit/test_tags.py b/st2common/tests/unit/test_tags.py index 3ffc59b50a..6230cedea6 100644 --- a/st2common/tests/unit/test_tags.py +++ b/st2common/tests/unit/test_tags.py @@ -28,53 +28,69 @@ class TaggedModel(stormbase.StormFoundationDB, stormbase.TagsMixin): class TestTags(DbTestCase): - def test_simple_count(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='tag1', value='v1'), - stormbase.TagField(name='tag2', value='v2')] + instance.tags = [ + stormbase.TagField(name="tag1", value="v1"), + stormbase.TagField(name="tag2", value="v2"), + ] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) def test_simple_value(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='tag1', value='v1')] + instance.tags = [stormbase.TagField(name="tag1", value="v1")] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) saved_tag = saved.tags[0] retrieved_tag = retrieved.tags[0] - self.assertEqual(saved_tag.name, retrieved_tag.name, 'Failed to retrieve tag.') - self.assertEqual(saved_tag.value, retrieved_tag.value, 'Failed to retrieve tag.') + self.assertEqual(saved_tag.name, retrieved_tag.name, "Failed to retrieve tag.") + self.assertEqual( + saved_tag.value, retrieved_tag.value, "Failed to retrieve tag." + ) def test_tag_max_size_restriction(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name=self._gen_random_string(), - value=self._gen_random_string())] + instance.tags = [ + stormbase.TagField( + name=self._gen_random_string(), value=self._gen_random_string() + ) + ] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) def test_name_exceeds_max_size(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name=self._gen_random_string(1025), - value='v1')] + instance.tags = [ + stormbase.TagField(name=self._gen_random_string(1025), value="v1") + ] try: instance.save() - self.assertTrue(False, 'Expected save to fail') + self.assertTrue(False, "Expected save to fail") except ValidationError: pass def test_value_exceeds_max_size(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='n1', - value=self._gen_random_string(1025))] + instance.tags = [ + stormbase.TagField(name="n1", value=self._gen_random_string(1025)) + ] try: instance.save() - self.assertTrue(False, 'Expected save to fail') + self.assertTrue(False, "Expected save to fail") except ValidationError: pass - def _gen_random_string(self, size=1024, chars=string.ascii_lowercase + string.digits): - return ''.join([random.choice(chars) for _ in range(size)]) + def _gen_random_string( + self, size=1024, chars=string.ascii_lowercase + string.digits + ): + return "".join([random.choice(chars) for _ in range(size)]) diff --git a/st2common/tests/unit/test_time_jinja_filters.py b/st2common/tests/unit/test_time_jinja_filters.py index c61473cdfd..5a343a5c29 100644 --- a/st2common/tests/unit/test_time_jinja_filters.py +++ b/st2common/tests/unit/test_time_jinja_filters.py @@ -20,14 +20,16 @@ class TestTimeJinjaFilters(TestCase): - def test_to_human_time_from_seconds(self): - self.assertEqual('0s', time.to_human_time_from_seconds(seconds=0)) - self.assertEqual('0.1\u03BCs', time.to_human_time_from_seconds(seconds=0.1)) - self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56)) - self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56.2)) - self.assertEqual('7m36s', time.to_human_time_from_seconds(seconds=456)) - self.assertEqual('1h16m0s', time.to_human_time_from_seconds(seconds=4560)) - self.assertEqual('1y12d16h36m37s', time.to_human_time_from_seconds(seconds=45678997)) - self.assertRaises(AssertionError, time.to_human_time_from_seconds, - seconds='stuff') + self.assertEqual("0s", time.to_human_time_from_seconds(seconds=0)) + self.assertEqual("0.1\u03BCs", time.to_human_time_from_seconds(seconds=0.1)) + self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56)) + self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56.2)) + self.assertEqual("7m36s", time.to_human_time_from_seconds(seconds=456)) + self.assertEqual("1h16m0s", time.to_human_time_from_seconds(seconds=4560)) + self.assertEqual( + "1y12d16h36m37s", time.to_human_time_from_seconds(seconds=45678997) + ) + self.assertRaises( + AssertionError, time.to_human_time_from_seconds, seconds="stuff" + ) diff --git a/st2common/tests/unit/test_transport.py b/st2common/tests/unit/test_transport.py index 9e4d4789b2..75e35ae2c9 100644 --- a/st2common/tests/unit/test_transport.py +++ b/st2common/tests/unit/test_transport.py @@ -19,9 +19,7 @@ from st2common.transport.utils import _get_ssl_kwargs -__all__ = [ - 'TransportUtilsTestCase' -] +__all__ = ["TransportUtilsTestCase"] class TransportUtilsTestCase(unittest2.TestCase): @@ -32,49 +30,39 @@ def test_get_ssl_kwargs(self): # 2. ssl kwarg provided ssl_kwargs = _get_ssl_kwargs(ssl=True) - self.assertEqual(ssl_kwargs, { - 'ssl': True - }) + self.assertEqual(ssl_kwargs, {"ssl": True}) # 3. ssl_keyfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'keyfile': '/tmp/keyfile' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile") + self.assertEqual(ssl_kwargs, {"ssl": True, "keyfile": "/tmp/keyfile"}) # 4. ssl_certfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'certfile': '/tmp/certfile' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile") + self.assertEqual(ssl_kwargs, {"ssl": True, "certfile": "/tmp/certfile"}) # 5. ssl_ca_certs provided - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs") + self.assertEqual(ssl_kwargs, {"ssl": True, "ca_certs": "/tmp/ca_certs"}) # 6. ssl_ca_certs and ssl_cert_reqs combinations - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_NONE - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_NONE}, + ) - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_OPTIONAL - }) + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional" + ) + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_OPTIONAL}, + ) - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_REQUIRED - }) + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required" + ) + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_REQUIRED}, + ) diff --git a/st2common/tests/unit/test_trigger_services.py b/st2common/tests/unit/test_trigger_services.py index 6f66a5f55b..b843526bc9 100644 --- a/st2common/tests/unit/test_trigger_services.py +++ b/st2common/tests/unit/test_trigger_services.py @@ -18,124 +18,147 @@ from st2common.models.api.rule import RuleAPI from st2common.models.system.common import ResourceReference from st2common.models.db.trigger import TriggerDB -from st2common.persistence.trigger import (Trigger, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerType import st2common.services.triggers as trigger_service from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', parameters={}, - type='dummy_pack_1.trigger-type-test.name') +MOCK_TRIGGER = TriggerDB( + pack="dummy_pack_1", + name="trigger-test.name", + parameters={}, + type="dummy_pack_1.trigger-type-test.name", +) class TriggerServiceTests(CleanDbTestCase): - def test_create_trigger_db_from_rule(self): - test_fixtures = { - 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_3.yaml'] - } + test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_3.yaml"]} loader = FixturesLoader() - fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures) - rules = fixtures['rules'] + fixtures = loader.load_fixtures( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + rules = fixtures["rules"] trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_1.yaml'])) + RuleAPI(**rules["cron_timer_rule_1.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_1) trigger_db = Trigger.get_by_id(trigger_db_ret_1.id) - self.assertDictEqual(trigger_db.parameters, - rules['cron_timer_rule_1.yaml']['trigger']['parameters']) + self.assertDictEqual( + trigger_db.parameters, + rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"], + ) trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_3.yaml'])) + RuleAPI(**rules["cron_timer_rule_3.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_2) self.assertTrue(trigger_db_ret_2.id != trigger_db_ret_1.id) def test_create_trigger_db_from_rule_duplicate(self): - test_fixtures = { - 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_2.yaml'] - } + test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_2.yaml"]} loader = FixturesLoader() - fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures) - rules = fixtures['rules'] + fixtures = loader.load_fixtures( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + rules = fixtures["rules"] trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_1.yaml'])) + RuleAPI(**rules["cron_timer_rule_1.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_1) trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_2.yaml'])) + RuleAPI(**rules["cron_timer_rule_2.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_2) - self.assertEqual(trigger_db_ret_1, trigger_db_ret_2, 'Should reuse same trigger.') + self.assertEqual( + trigger_db_ret_1, trigger_db_ret_2, "Should reuse same trigger." + ) trigger_db = Trigger.get_by_id(trigger_db_ret_1.id) - self.assertDictEqual(trigger_db.parameters, - rules['cron_timer_rule_1.yaml']['trigger']['parameters']) + self.assertDictEqual( + trigger_db.parameters, + rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"], + ) def test_create_or_update_trigger_db_simple_triggers(self): - test_fixtures = { - 'triggertypes': ['triggertype1.yaml'] - } + test_fixtures = {"triggertypes": ["triggertype1.yaml"]} loader = FixturesLoader() - fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures) - triggertypes = fixtures['triggertypes'] + fixtures = loader.save_fixtures_to_db( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + triggertypes = fixtures["triggertypes"] trigger_type_ref = ResourceReference.to_string_reference( - name=triggertypes['triggertype1.yaml']['name'], - pack=triggertypes['triggertype1.yaml']['pack']) + name=triggertypes["triggertype1.yaml"]["name"], + pack=triggertypes["triggertype1.yaml"]["pack"], + ) trigger = { - 'name': triggertypes['triggertype1.yaml']['name'], - 'pack': triggertypes['triggertype1.yaml']['pack'], - 'type': trigger_type_ref + "name": triggertypes["triggertype1.yaml"]["name"], + "pack": triggertypes["triggertype1.yaml"]["pack"], + "type": trigger_type_ref, } trigger_service.create_or_update_trigger_db(trigger) triggers = Trigger.get_all() - self.assertTrue(len(triggers) == 1, 'Only one trigger should be created.') - self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name']) + self.assertTrue(len(triggers) == 1, "Only one trigger should be created.") + self.assertTrue( + triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"] + ) # Try adding duplicate trigger_service.create_or_update_trigger_db(trigger) triggers = Trigger.get_all() - self.assertTrue(len(triggers) == 1, 'Only one trigger should be present.') - self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name']) + self.assertTrue(len(triggers) == 1, "Only one trigger should be present.") + self.assertTrue( + triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"] + ) def test_exception_thrown_when_rule_creation_no_trigger_yes_triggertype(self): - test_fixtures = { - 'triggertypes': ['triggertype1.yaml'] - } + test_fixtures = {"triggertypes": ["triggertype1.yaml"]} loader = FixturesLoader() - fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures) - triggertypes = fixtures['triggertypes'] + fixtures = loader.save_fixtures_to_db( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + triggertypes = fixtures["triggertypes"] trigger_type_ref = ResourceReference.to_string_reference( - name=triggertypes['triggertype1.yaml']['name'], - pack=triggertypes['triggertype1.yaml']['pack']) + name=triggertypes["triggertype1.yaml"]["name"], + pack=triggertypes["triggertype1.yaml"]["pack"], + ) rule = { - 'name': 'fancyrule', - 'trigger': { - 'type': trigger_type_ref - }, - 'criteria': { - - }, - 'action': { - 'ref': 'core.local', - 'parameters': { - 'cmd': 'date' - } - } + "name": "fancyrule", + "trigger": {"type": trigger_type_ref}, + "criteria": {}, + "action": {"ref": "core.local", "parameters": {"cmd": "date"}}, } rule_api = RuleAPI(**rule) - self.assertRaises(TriggerDoesNotExistException, - trigger_service.create_trigger_db_from_rule, rule_api) + self.assertRaises( + TriggerDoesNotExistException, + trigger_service.create_trigger_db_from_rule, + rule_api, + ) def test_get_trigger_db_given_type_and_params(self): # Add dummy triggers - trigger_1 = TriggerDB(pack='testpack', name='testtrigger1', type='testpack.testtrigger1') + trigger_1 = TriggerDB( + pack="testpack", name="testtrigger1", type="testpack.testtrigger1" + ) - trigger_2 = TriggerDB(pack='testpack', name='testtrigger2', type='testpack.testtrigger2') + trigger_2 = TriggerDB( + pack="testpack", name="testtrigger2", type="testpack.testtrigger2" + ) - trigger_3 = TriggerDB(pack='testpack', name='testtrigger3', type='testpack.testtrigger3') + trigger_3 = TriggerDB( + pack="testpack", name="testtrigger3", type="testpack.testtrigger3" + ) - trigger_4 = TriggerDB(pack='testpack', name='testtrigger4', type='testpack.testtrigger4', - parameters={'ponies': 'unicorn'}) + trigger_4 = TriggerDB( + pack="testpack", + name="testtrigger4", + type="testpack.testtrigger4", + parameters={"ponies": "unicorn"}, + ) Trigger.add_or_update(trigger_1) Trigger.add_or_update(trigger_2) @@ -143,64 +166,73 @@ def test_get_trigger_db_given_type_and_params(self): Trigger.add_or_update(trigger_4) # Trigger with no parameters, parameters={} in db - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters={} + ) self.assertEqual(trigger_db, trigger_1) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters=None + ) self.assertEqual(trigger_db, trigger_1) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters={'fo': 'bar'}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters={"fo": "bar"} + ) self.assertEqual(trigger_db, None) # Trigger with no parameters, no parameters attribute in the db - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters={} + ) self.assertEqual(trigger_db, trigger_2) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters=None + ) self.assertEqual(trigger_db, trigger_2) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters={'fo': 'bar'}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters={"fo": "bar"} + ) self.assertEqual(trigger_db, None) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_3.type, parameters={} + ) self.assertEqual(trigger_db, trigger_3) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_3.type, parameters=None + ) self.assertEqual(trigger_db, trigger_3) # Trigger with parameters trigger_db = trigger_service.get_trigger_db_given_type_and_params( - type=trigger_4.type, - parameters=trigger_4.parameters) + type=trigger_4.type, parameters=trigger_4.parameters + ) self.assertEqual(trigger_db, trigger_4) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_4.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_4.type, parameters=None + ) self.assertEqual(trigger_db, None) def test_add_trigger_type_no_params(self): # Trigger type with no params should create a trigger with same name as trigger type. trig_type = { - 'name': 'myawesometriggertype', - 'pack': 'dummy_pack_1', - 'description': 'Words cannot describe how awesome I am.', - 'parameters_schema': {}, - 'payload_schema': {} + "name": "myawesometriggertype", + "pack": "dummy_pack_1", + "description": "Words cannot describe how awesome I am.", + "parameters_schema": {}, + "payload_schema": {}, } trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type]) trigger_type, trigger = trigtype_dbs[0] trigtype_db = TriggerType.get_by_id(trigger_type.id) - self.assertEqual(trigtype_db.pack, 'dummy_pack_1') - self.assertEqual(trigtype_db.name, trig_type.get('name')) + self.assertEqual(trigtype_db.pack, "dummy_pack_1") + self.assertEqual(trigtype_db.name, trig_type.get("name")) self.assertIsNotNone(trigger) self.assertEqual(trigger.name, trigtype_db.name) @@ -210,35 +242,34 @@ def test_add_trigger_type_no_params(self): self.assertTrue(len(triggers) == 1) def test_add_trigger_type_with_params(self): - MOCK_TRIGGER.type = 'system.test' + MOCK_TRIGGER.type = "system.test" # Trigger type with params should not create a trigger. PARAMETERS_SCHEMA = { "type": "object", - "properties": { - "url": {"type": "string"} - }, - "required": ['url'], - "additionalProperties": False + "properties": {"url": {"type": "string"}}, + "required": ["url"], + "additionalProperties": False, } trig_type = { - 'name': 'myawesometriggertype2', - 'pack': 'my_pack_1', - 'description': 'Words cannot describe how awesome I am.', - 'parameters_schema': PARAMETERS_SCHEMA, - 'payload_schema': {} + "name": "myawesometriggertype2", + "pack": "my_pack_1", + "description": "Words cannot describe how awesome I am.", + "parameters_schema": PARAMETERS_SCHEMA, + "payload_schema": {}, } trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type]) trigger_type, trigger = trigtype_dbs[0] trigtype_db = TriggerType.get_by_id(trigger_type.id) - self.assertEqual(trigtype_db.pack, 'my_pack_1') - self.assertEqual(trigtype_db.name, trig_type.get('name')) + self.assertEqual(trigtype_db.pack, "my_pack_1") + self.assertEqual(trigtype_db.name, trig_type.get("name")) self.assertEqual(trigger, None) def test_add_trigger_type(self): """ This sensor has misconfigured trigger type. We shouldn't explode. """ + class FailTestSensor(object): started = False @@ -252,12 +283,12 @@ def stop(self): pass def get_trigger_types(self): - return [ - {'description': 'Ain\'t got no name'} - ] + return [{"description": "Ain't got no name"}] try: trigger_service.add_trigger_models(FailTestSensor().get_trigger_types()) - self.assertTrue(False, 'Trigger type doesn\'t have \'name\' field. Should have thrown.') + self.assertTrue( + False, "Trigger type doesn't have 'name' field. Should have thrown." + ) except Exception: self.assertTrue(True) diff --git a/st2common/tests/unit/test_triggers_registrar.py b/st2common/tests/unit/test_triggers_registrar.py index 53595d2867..5ceda4f851 100644 --- a/st2common/tests/unit/test_triggers_registrar.py +++ b/st2common/tests/unit/test_triggers_registrar.py @@ -22,9 +22,7 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'TriggersRegistrarTestCase' -] +__all__ = ["TriggersRegistrarTestCase"] class TriggersRegistrarTestCase(CleanDbTestCase): @@ -44,7 +42,7 @@ def test_register_all_triggers(self): def test_register_triggers_from_pack(self): base_path = get_fixtures_packs_base_path() - pack_dir = os.path.join(base_path, 'dummy_pack_1') + pack_dir = os.path.join(base_path, "dummy_pack_1") trigger_type_dbs = TriggerType.get_all() self.assertEqual(len(trigger_type_dbs), 0) @@ -58,12 +56,12 @@ def test_register_triggers_from_pack(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(trigger_type_dbs[0].name, 'event_handler') - self.assertEqual(trigger_type_dbs[0].pack, 'dummy_pack_1') - self.assertEqual(trigger_dbs[0].name, 'event_handler') - self.assertEqual(trigger_dbs[0].pack, 'dummy_pack_1') - self.assertEqual(trigger_dbs[0].type, 'dummy_pack_1.event_handler') + self.assertEqual(trigger_type_dbs[0].name, "event_handler") + self.assertEqual(trigger_type_dbs[0].pack, "dummy_pack_1") + self.assertEqual(trigger_dbs[0].name, "event_handler") + self.assertEqual(trigger_dbs[0].pack, "dummy_pack_1") + self.assertEqual(trigger_dbs[0].type, "dummy_pack_1.event_handler") - self.assertEqual(trigger_type_dbs[1].name, 'head_sha_monitor') - self.assertEqual(trigger_type_dbs[1].pack, 'dummy_pack_1') - self.assertEqual(trigger_type_dbs[1].payload_schema['type'], 'object') + self.assertEqual(trigger_type_dbs[1].name, "head_sha_monitor") + self.assertEqual(trigger_type_dbs[1].pack, "dummy_pack_1") + self.assertEqual(trigger_type_dbs[1].payload_schema["type"], "object") diff --git a/st2common/tests/unit/test_unit_testing_mocks.py b/st2common/tests/unit/test_unit_testing_mocks.py index ce63dd7834..742ca85da1 100644 --- a/st2common/tests/unit/test_unit_testing_mocks.py +++ b/st2common/tests/unit/test_unit_testing_mocks.py @@ -23,9 +23,9 @@ from st2tests.mocks.action import MockActionService __all__ = [ - 'BaseSensorTestCaseTestCase', - 'MockSensorServiceTestCase', - 'MockActionServiceTestCase' + "BaseSensorTestCaseTestCase", + "MockSensorServiceTestCase", + "MockActionServiceTestCase", ] @@ -37,36 +37,38 @@ class BaseMockResourceServiceTestCase(object): class TestCase(unittest2.TestCase): def test_get_user_info(self): result = self.mock_service.get_user_info() - self.assertEqual(result['username'], 'admin') - self.assertEqual(result['rbac']['roles'], ['admin']) + self.assertEqual(result["username"], "admin") + self.assertEqual(result["rbac"]["roles"], ["admin"]) def test_list_set_get_delete_values(self): # list_values, set_value result = self.mock_service.list_values() self.assertSequenceEqual(result, []) - self.mock_service.set_value(name='t1.local', value='test1', local=True) - self.mock_service.set_value(name='t1.global', value='test1', local=False) + self.mock_service.set_value(name="t1.local", value="test1", local=True) + self.mock_service.set_value(name="t1.global", value="test1", local=False) result = self.mock_service.list_values(local=True) self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, 'dummy.test:t1.local') + self.assertEqual(result[0].name, "dummy.test:t1.local") result = self.mock_service.list_values(local=False) - self.assertEqual(result[0].name, 'dummy.test:t1.local') - self.assertEqual(result[1].name, 't1.global') + self.assertEqual(result[0].name, "dummy.test:t1.local") + self.assertEqual(result[1].name, "t1.global") self.assertEqual(len(result), 2) # get_value - self.assertEqual(self.mock_service.get_value('inexistent'), None) - self.assertEqual(self.mock_service.get_value(name='t1.local', local=True), 'test1') + self.assertEqual(self.mock_service.get_value("inexistent"), None) + self.assertEqual( + self.mock_service.get_value(name="t1.local", local=True), "test1" + ) # delete_value self.assertEqual(len(self.mock_service.list_values(local=True)), 1) - self.assertEqual(self.mock_service.delete_value('inexistent'), False) + self.assertEqual(self.mock_service.delete_value("inexistent"), False) self.assertEqual(len(self.mock_service.list_values(local=True)), 1) - self.assertEqual(self.mock_service.delete_value('t1.local'), True) + self.assertEqual(self.mock_service.delete_value("t1.local"), True) self.assertEqual(len(self.mock_service.list_values(local=True)), 0) @@ -77,47 +79,50 @@ def test_dispatch_and_assertTriggerDispatched(self): sensor_service = self.sensor_service expected_msg = 'Trigger "nope" hasn\'t been dispatched' - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertTriggerDispatched, trigger='nope') + self.assertRaisesRegexp( + AssertionError, expected_msg, self.assertTriggerDispatched, trigger="nope" + ) - sensor_service.dispatch(trigger='test1', payload={'a': 'b'}) - result = self.assertTriggerDispatched(trigger='test1') + sensor_service.dispatch(trigger="test1", payload={"a": "b"}) + result = self.assertTriggerDispatched(trigger="test1") self.assertTrue(result) - result = self.assertTriggerDispatched(trigger='test1', payload={'a': 'b'}) + result = self.assertTriggerDispatched(trigger="test1", payload={"a": "b"}) self.assertTrue(result) expected_msg = 'Trigger "test1" hasn\'t been dispatched' - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertTriggerDispatched, - trigger='test1', - payload={'a': 'c'}) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertTriggerDispatched, + trigger="test1", + payload={"a": "c"}, + ) class MockSensorServiceTestCase(BaseMockResourceServiceTestCase.TestCase): - def setUp(self): - mock_sensor_wrapper = MockSensorWrapper(pack='dummy', class_name='test') + mock_sensor_wrapper = MockSensorWrapper(pack="dummy", class_name="test") self.mock_service = MockSensorService(sensor_wrapper=mock_sensor_wrapper) def test_get_logger(self): sensor_service = self.mock_service - logger = sensor_service.get_logger('test') - logger.info('test info') - logger.debug('test debug') + logger = sensor_service.get_logger("test") + logger.info("test info") + logger.debug("test debug") self.assertEqual(len(logger.method_calls), 2) method_name, method_args, method_kwargs = tuple(logger.method_calls[0]) - self.assertEqual(method_name, 'info') - self.assertEqual(method_args, ('test info',)) + self.assertEqual(method_name, "info") + self.assertEqual(method_args, ("test info",)) self.assertEqual(method_kwargs, {}) method_name, method_args, method_kwargs = tuple(logger.method_calls[1]) - self.assertEqual(method_name, 'debug') - self.assertEqual(method_args, ('test debug',)) + self.assertEqual(method_name, "debug") + self.assertEqual(method_args, ("test debug",)) self.assertEqual(method_kwargs, {}) class MockActionServiceTestCase(BaseMockResourceServiceTestCase.TestCase): def setUp(self): - mock_action_wrapper = MockActionWrapper(pack='dummy', class_name='test') + mock_action_wrapper = MockActionWrapper(pack="dummy", class_name="test") self.mock_service = MockActionService(action_wrapper=mock_action_wrapper) diff --git a/st2common/tests/unit/test_util_actionalias_helpstrings.py b/st2common/tests/unit/test_util_actionalias_helpstrings.py index a7726dd177..e543bd471a 100644 --- a/st2common/tests/unit/test_util_actionalias_helpstrings.py +++ b/st2common/tests/unit/test_util_actionalias_helpstrings.py @@ -25,62 +25,101 @@ ALIASES = [ - MemoryActionAliasDB(name="kyle_reese", ref="terminator.1", - pack="the80s", enabled=True, - formats=["Come with me if you want to live"] + MemoryActionAliasDB( + name="kyle_reese", + ref="terminator.1", + pack="the80s", + enabled=True, + formats=["Come with me if you want to live"], ), - MemoryActionAliasDB(name="terminator", ref="terminator.2", - pack="the80s", enabled=True, - formats=["I need your {{item}}, your {{item2}}" - " and your {{vehicle}}"] + MemoryActionAliasDB( + name="terminator", + ref="terminator.2", + pack="the80s", + enabled=True, + formats=["I need your {{item}}, your {{item2}}" " and your {{vehicle}}"], ), - MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.3", - pack="the80s", enabled=True, - formats=[{'display': 'Number 5 is {{status}}', - 'representation': ['Number 5 is {{status=alive}}']}, - 'Hey, laser lips, your mama was a snow blower.'] + MemoryActionAliasDB( + name="johnny_five_alive", + ref="short_circuit.3", + pack="the80s", + enabled=True, + formats=[ + { + "display": "Number 5 is {{status}}", + "representation": ["Number 5 is {{status=alive}}"], + }, + "Hey, laser lips, your mama was a snow blower.", + ], ), - MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.4", - pack="the80s", enabled=True, - formats=["How do I feel? I feel... {{status}}!"] + MemoryActionAliasDB( + name="i_feel_alive", + ref="short_circuit.4", + pack="the80s", + enabled=True, + formats=["How do I feel? I feel... {{status}}!"], ), - MemoryActionAliasDB(name='andy', ref='the_goonies.1', - pack="the80s", enabled=True, - formats=[{'display': 'Watch this.'}] + MemoryActionAliasDB( + name="andy", + ref="the_goonies.1", + pack="the80s", + enabled=True, + formats=[{"display": "Watch this."}], ), - MemoryActionAliasDB(name='andy', ref='the_goonies.5', - pack="the80s", enabled=True, - formats=[{'display': "He's just like his {{relation}}."}] + MemoryActionAliasDB( + name="andy", + ref="the_goonies.5", + pack="the80s", + enabled=True, + formats=[{"display": "He's just like his {{relation}}."}], ), - MemoryActionAliasDB(name='data', ref='the_goonies.6', - pack="the80s", enabled=True, - formats=[{'representation': "That's okay daddy. You can't hug a {{object}}."}] + MemoryActionAliasDB( + name="data", + ref="the_goonies.6", + pack="the80s", + enabled=True, + formats=[{"representation": "That's okay daddy. You can't hug a {{object}}."}], ), - MemoryActionAliasDB(name='mr_wang', ref='the_goonies.7', - pack="the80s", enabled=True, - formats=[{'representation': 'You are my greatest invention.'}] + MemoryActionAliasDB( + name="mr_wang", + ref="the_goonies.7", + pack="the80s", + enabled=True, + formats=[{"representation": "You are my greatest invention."}], ), - MemoryActionAliasDB(name="Ferris", ref="ferris_buellers_day_off.8", - pack="the80s", enabled=True, - formats=["Life moves pretty fast.", - "If you don't stop and look around once in a while, you could miss it."] + MemoryActionAliasDB( + name="Ferris", + ref="ferris_buellers_day_off.8", + pack="the80s", + enabled=True, + formats=[ + "Life moves pretty fast.", + "If you don't stop and look around once in a while, you could miss it.", + ], ), - MemoryActionAliasDB(name="economics.teacher", ref="ferris_buellers_day_off.10", - pack="the80s", enabled=False, - formats=["Bueller?... Bueller?... Bueller? "] + MemoryActionAliasDB( + name="economics.teacher", + ref="ferris_buellers_day_off.10", + pack="the80s", + enabled=False, + formats=["Bueller?... Bueller?... Bueller? "], + ), + MemoryActionAliasDB( + name="spengler", + ref="ghostbusters.10", + pack="the80s", + enabled=True, + formats=["{{choice}} cross the {{target}}"], ), - MemoryActionAliasDB(name="spengler", ref="ghostbusters.10", - pack="the80s", enabled=True, - formats=["{{choice}} cross the {{target}}"] - ) ] -@mock.patch.object(MemoryActionAliasDB, 'get_uid') +@mock.patch.object(MemoryActionAliasDB, "get_uid") class ActionAliasTestCase(unittest2.TestCase): - ''' + """ Test scenarios must consist of 80s movie quotes. - ''' + """ + def check_data_structure(self, result): tmp = list(result.keys()) tmp.sort() @@ -93,7 +132,9 @@ def test_filtering_no_arg(self, mock): result = generate_helpstring_result(ALIASES) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -115,7 +156,9 @@ def test_filtering_match(self, mock): result = generate_helpstring_result(ALIASES, "you") self.check_data_structure(result) self.check_available_count(result, 4) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 4) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -123,12 +166,16 @@ def test_pack_empty_string(self, mock): result = generate_helpstring_result(ALIASES, "", "") self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") def test_pack_no_match(self, mock): - result = generate_helpstring_result(ALIASES, "", "you_will_not_find_this_string") + result = generate_helpstring_result( + ALIASES, "", "you_will_not_find_this_string" + ) self.check_data_structure(result) self.check_available_count(result, 0) self.assertEqual(result.get("helpstrings"), []) @@ -137,7 +184,9 @@ def test_pack_match(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s") self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -153,7 +202,9 @@ def test_limit_neg_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", -3) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -161,7 +212,9 @@ def test_limit_pos_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 30) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -169,7 +222,9 @@ def test_limit_in_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 3) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 3) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -185,7 +240,9 @@ def test_offset_negative_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 0, -1) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -199,6 +256,8 @@ def test_offset_in_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 0, 6) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 4) self.assertEqual(the80s[0].get("display"), "He's just like his {{relation}}.") diff --git a/st2common/tests/unit/test_util_actionalias_matching.py b/st2common/tests/unit/test_util_actionalias_matching.py index c22ccab3e6..082fa40b98 100644 --- a/st2common/tests/unit/test_util_actionalias_matching.py +++ b/st2common/tests/unit/test_util_actionalias_matching.py @@ -24,89 +24,130 @@ MemoryActionAliasDB = ActionAliasDB -@mock.patch.object(MemoryActionAliasDB, 'get_uid') +@mock.patch.object(MemoryActionAliasDB, "get_uid") class ActionAliasTestCase(unittest2.TestCase): - ''' + """ Test scenarios must consist of 80s movie quotes. - ''' + """ + def test_list_format_strings_from_aliases(self, mock): ALIASES = [ - MemoryActionAliasDB(name="kyle_reese", ref="terminator.1", - formats=["Come with me if you want to live"]), - MemoryActionAliasDB(name="terminator", ref="terminator.2", - formats=["I need your {{item}}, your {{item2}}" - " and your {{vehicle}}"]) + MemoryActionAliasDB( + name="kyle_reese", + ref="terminator.1", + formats=["Come with me if you want to live"], + ), + MemoryActionAliasDB( + name="terminator", + ref="terminator.2", + formats=[ + "I need your {{item}}, your {{item2}}" " and your {{vehicle}}" + ], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], "Come with me if you want to live") - self.assertEqual(result[1]['display'], - "I need your {{item}}, your {{item2}} and" - " your {{vehicle}}") + self.assertEqual(result[0]["display"], "Come with me if you want to live") + self.assertEqual( + result[1]["display"], + "I need your {{item}}, your {{item2}} and" " your {{vehicle}}", + ) def test_list_format_strings_from_aliases_with_display(self, mock): ALIASES = [ - MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.1", formats=[ - {'display': 'Number 5 is {{status}}', - 'representation': ['Number 5 is {{status=alive}}']}, - 'Hey, laser lips, your mama was a snow blower.']), - MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.2", - formats=["How do I feel? I feel... {{status}}!"]) + MemoryActionAliasDB( + name="johnny_five_alive", + ref="short_circuit.1", + formats=[ + { + "display": "Number 5 is {{status}}", + "representation": ["Number 5 is {{status=alive}}"], + }, + "Hey, laser lips, your mama was a snow blower.", + ], + ), + MemoryActionAliasDB( + name="i_feel_alive", + ref="short_circuit.2", + formats=["How do I feel? I feel... {{status}}!"], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 3) - self.assertEqual(result[0]['display'], "Number 5 is {{status}}") - self.assertEqual(result[0]['representation'], "Number 5 is {{status=alive}}") - self.assertEqual(result[1]['display'], "Hey, laser lips, your mama was a snow blower.") - self.assertEqual(result[1]['representation'], - "Hey, laser lips, your mama was a snow blower.") - self.assertEqual(result[2]['display'], "How do I feel? I feel... {{status}}!") - self.assertEqual(result[2]['representation'], "How do I feel? I feel... {{status}}!") + self.assertEqual(result[0]["display"], "Number 5 is {{status}}") + self.assertEqual(result[0]["representation"], "Number 5 is {{status=alive}}") + self.assertEqual( + result[1]["display"], "Hey, laser lips, your mama was a snow blower." + ) + self.assertEqual( + result[1]["representation"], "Hey, laser lips, your mama was a snow blower." + ) + self.assertEqual(result[2]["display"], "How do I feel? I feel... {{status}}!") + self.assertEqual( + result[2]["representation"], "How do I feel? I feel... {{status}}!" + ) def test_list_format_strings_from_aliases_with_display_only(self, mock): ALIASES = [ - MemoryActionAliasDB(name='andy', - ref='the_goonies.1', formats=[{'display': 'Watch this.'}]), - MemoryActionAliasDB(name='andy', ref='the_goonies.2', - formats=[{'display': "He's just like his {{relation}}."}]) + MemoryActionAliasDB( + name="andy", ref="the_goonies.1", formats=[{"display": "Watch this."}] + ), + MemoryActionAliasDB( + name="andy", + ref="the_goonies.2", + formats=[{"display": "He's just like his {{relation}}."}], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], 'Watch this.') - self.assertEqual(result[0]['representation'], '') - self.assertEqual(result[1]['display'], "He's just like his {{relation}}.") - self.assertEqual(result[1]['representation'], '') + self.assertEqual(result[0]["display"], "Watch this.") + self.assertEqual(result[0]["representation"], "") + self.assertEqual(result[1]["display"], "He's just like his {{relation}}.") + self.assertEqual(result[1]["representation"], "") def test_list_format_strings_from_aliases_with_representation_only(self, mock): ALIASES = [ - MemoryActionAliasDB(name='data', ref='the_goonies.1', formats=[ - {'representation': "That's okay daddy. You can't hug a {{object}}."}]), - MemoryActionAliasDB(name='mr_wang', ref='the_goonies.2', formats=[ - {'representation': 'You are my greatest invention.'}]) + MemoryActionAliasDB( + name="data", + ref="the_goonies.1", + formats=[ + {"representation": "That's okay daddy. You can't hug a {{object}}."} + ], + ), + MemoryActionAliasDB( + name="mr_wang", + ref="the_goonies.2", + formats=[{"representation": "You are my greatest invention."}], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], None) - self.assertEqual(result[0]['representation'], - "That's okay daddy. You can't hug a {{object}}.") - self.assertEqual(result[1]['display'], None) - self.assertEqual(result[1]['representation'], 'You are my greatest invention.') + self.assertEqual(result[0]["display"], None) + self.assertEqual( + result[0]["representation"], + "That's okay daddy. You can't hug a {{object}}.", + ) + self.assertEqual(result[1]["display"], None) + self.assertEqual(result[1]["representation"], "You are my greatest invention.") def test_normalise_alias_format_string(self, mock): result = matching.normalise_alias_format_string( - 'Quite an experience to live in fear, isn\'t it?') + "Quite an experience to live in fear, isn't it?" + ) self.assertEqual([result[0]], result[1]) self.assertEqual(result[0], "Quite an experience to live in fear, isn't it?") def test_normalise_alias_format_string_error(self, mock): alias_list = ["Quite an experience to live in fear, isn't it?"] - expected_msg = ("alias_format '%s' is neither a dictionary or string type." - % repr(alias_list)) + expected_msg = ( + "alias_format '%s' is neither a dictionary or string type." + % repr(alias_list) + ) with self.assertRaises(TypeError) as cm: matching.normalise_alias_format_string(alias_list) @@ -115,13 +156,16 @@ def test_normalise_alias_format_string_error(self, mock): def test_matching(self, mock): ALIASES = [ - MemoryActionAliasDB(name="spengler", ref="ghostbusters.1", - formats=["{{choice}} cross the {{target}}"]), + MemoryActionAliasDB( + name="spengler", + ref="ghostbusters.1", + formats=["{{choice}} cross the {{target}}"], + ), ] COMMAND = "Don't cross the streams" match = matching.match_command_to_alias(COMMAND, ALIASES) self.assertEqual(len(match), 1) - self.assertEqual(match[0]['alias'].ref, "ghostbusters.1") - self.assertEqual(match[0]['representation'], "{{choice}} cross the {{target}}") + self.assertEqual(match[0]["alias"].ref, "ghostbusters.1") + self.assertEqual(match[0]["representation"], "{{choice}} cross the {{target}}") # we need some more complex scenarios in here. diff --git a/st2common/tests/unit/test_util_api.py b/st2common/tests/unit/test_util_api.py index bc0e385df1..2333939b13 100644 --- a/st2common/tests/unit/test_util_api.py +++ b/st2common/tests/unit/test_util_api.py @@ -23,24 +23,25 @@ from st2common.util.api import get_full_public_api_url from st2tests.config import parse_args from six.moves import zip + parse_args() class APIUtilsTestCase(unittest2.TestCase): def test_get_base_public_api_url(self): values = [ - 'http://foo.bar.com', - 'http://foo.bar.com/', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080/', - 'http://localhost:8080/', + "http://foo.bar.com", + "http://foo.bar.com/", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080/", + "http://localhost:8080/", ] expected = [ - 'http://foo.bar.com', - 'http://foo.bar.com', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080', - 'http://localhost:8080', + "http://foo.bar.com", + "http://foo.bar.com", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080", + "http://localhost:8080", ] for mock_value, expected_result in zip(values, expected): @@ -50,18 +51,18 @@ def test_get_base_public_api_url(self): def test_get_full_public_api_url(self): values = [ - 'http://foo.bar.com', - 'http://foo.bar.com/', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080/', - 'http://localhost:8080/', + "http://foo.bar.com", + "http://foo.bar.com/", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080/", + "http://localhost:8080/", ] expected = [ - 'http://foo.bar.com/' + DEFAULT_API_VERSION, - 'http://foo.bar.com/' + DEFAULT_API_VERSION, - 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION, - 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION, - 'http://localhost:8080/' + DEFAULT_API_VERSION, + "http://foo.bar.com/" + DEFAULT_API_VERSION, + "http://foo.bar.com/" + DEFAULT_API_VERSION, + "http://foo.bar.com:8080/" + DEFAULT_API_VERSION, + "http://foo.bar.com:8080/" + DEFAULT_API_VERSION, + "http://localhost:8080/" + DEFAULT_API_VERSION, ] for mock_value, expected_result in zip(values, expected): diff --git a/st2common/tests/unit/test_util_compat.py b/st2common/tests/unit/test_util_compat.py index 0e1ac9efe7..74face7ea6 100644 --- a/st2common/tests/unit/test_util_compat.py +++ b/st2common/tests/unit/test_util_compat.py @@ -19,18 +19,16 @@ from st2common.util.compat import to_ascii -__all__ = [ - 'CompatUtilsTestCase' -] +__all__ = ["CompatUtilsTestCase"] class CompatUtilsTestCase(unittest2.TestCase): def test_to_ascii(self): expected_values = [ - ('already ascii', 'already ascii'), - (u'foo', 'foo'), - ('٩(̾●̮̮̃̾•̃̾)۶', '()'), - ('\xd9\xa9', '') + ("already ascii", "already ascii"), + ("foo", "foo"), + ("٩(̾●̮̮̃̾•̃̾)۶", "()"), + ("\xd9\xa9", ""), ] for input_value, expected_value in expected_values: diff --git a/st2common/tests/unit/test_util_db.py b/st2common/tests/unit/test_util_db.py index dd230e6ae1..f94a2fe39a 100644 --- a/st2common/tests/unit/test_util_db.py +++ b/st2common/tests/unit/test_util_db.py @@ -22,88 +22,73 @@ class DatabaseUtilTestCase(unittest2.TestCase): - def test_noop_mongodb_to_python_types(self): - data = [ - 123, - 999.99, - True, - [10, 20, 30], - {'a': 1, 'b': 2}, - None - ] + data = [123, 999.99, True, [10, 20, 30], {"a": 1, "b": 2}, None] for item in data: self.assertEqual(db_util.mongodb_to_python_types(item), item) def test_mongodb_basedict_to_dict(self): - data = {'a': 1, 'b': 2} + data = {"a": 1, "b": 2} - obj = mongoengine.base.datastructures.BaseDict(data, None, 'foobar') + obj = mongoengine.base.datastructures.BaseDict(data, None, "foobar") self.assertDictEqual(db_util.mongodb_to_python_types(obj), data) def test_mongodb_baselist_to_list(self): data = [2, 4, 6] - obj = mongoengine.base.datastructures.BaseList(data, None, 'foobar') + obj = mongoengine.base.datastructures.BaseList(data, None, "foobar") self.assertListEqual(db_util.mongodb_to_python_types(obj), data) def test_nested_mongdb_to_python_types(self): data = { - 'a': mongoengine.base.datastructures.BaseList([1, 2, 3], None, 'a'), - 'b': mongoengine.base.datastructures.BaseDict({'a': 1, 'b': 2}, None, 'b'), - 'c': { - 'd': mongoengine.base.datastructures.BaseList([4, 5, 6], None, 'd'), - 'e': mongoengine.base.datastructures.BaseDict({'c': 3, 'd': 4}, None, 'e') + "a": mongoengine.base.datastructures.BaseList([1, 2, 3], None, "a"), + "b": mongoengine.base.datastructures.BaseDict({"a": 1, "b": 2}, None, "b"), + "c": { + "d": mongoengine.base.datastructures.BaseList([4, 5, 6], None, "d"), + "e": mongoengine.base.datastructures.BaseDict( + {"c": 3, "d": 4}, None, "e" + ), }, - 'f': mongoengine.base.datastructures.BaseList( + "f": mongoengine.base.datastructures.BaseList( [ - mongoengine.base.datastructures.BaseDict({'e': 5}, None, 'f1'), - mongoengine.base.datastructures.BaseDict({'f': 6}, None, 'f2') + mongoengine.base.datastructures.BaseDict({"e": 5}, None, "f1"), + mongoengine.base.datastructures.BaseDict({"f": 6}, None, "f2"), ], None, - 'f' + "f", ), - 'g': mongoengine.base.datastructures.BaseDict( + "g": mongoengine.base.datastructures.BaseDict( { - 'h': mongoengine.base.datastructures.BaseList( + "h": mongoengine.base.datastructures.BaseList( [ - mongoengine.base.datastructures.BaseDict({'g': 7}, None, 'h1'), - mongoengine.base.datastructures.BaseDict({'h': 8}, None, 'h2') + mongoengine.base.datastructures.BaseDict( + {"g": 7}, None, "h1" + ), + mongoengine.base.datastructures.BaseDict( + {"h": 8}, None, "h2" + ), ], None, - 'h' + "h", + ), + "i": mongoengine.base.datastructures.BaseDict( + {"j": 9, "k": 10}, None, "i" ), - 'i': mongoengine.base.datastructures.BaseDict({'j': 9, 'k': 10}, None, 'i') }, None, - 'g' + "g", ), } expected = { - 'a': [1, 2, 3], - 'b': {'a': 1, 'b': 2}, - 'c': { - 'd': [4, 5, 6], - 'e': {'c': 3, 'd': 4} - }, - 'f': [ - {'e': 5}, - {'f': 6} - ], - 'g': { - 'h': [ - {'g': 7}, - {'h': 8} - ], - 'i': { - 'j': 9, - 'k': 10 - } - } + "a": [1, 2, 3], + "b": {"a": 1, "b": 2}, + "c": {"d": [4, 5, 6], "e": {"c": 3, "d": 4}}, + "f": [{"e": 5}, {"f": 6}], + "g": {"h": [{"g": 7}, {"h": 8}], "i": {"j": 9, "k": 10}}, } self.assertDictEqual(db_util.mongodb_to_python_types(data), expected) diff --git a/st2common/tests/unit/test_util_file_system.py b/st2common/tests/unit/test_util_file_system.py index ea46a0b943..a1af0c957a 100644 --- a/st2common/tests/unit/test_util_file_system.py +++ b/st2common/tests/unit/test_util_file_system.py @@ -22,30 +22,32 @@ from st2common.util.file_system import get_file_list CURRENT_DIR = os.path.dirname(__file__) -ST2TESTS_DIR = os.path.join(CURRENT_DIR, '../../../st2tests/st2tests') +ST2TESTS_DIR = os.path.join(CURRENT_DIR, "../../../st2tests/st2tests") class FileSystemUtilsTestCase(unittest2.TestCase): def test_get_file_list(self): # Standard exclude pattern - directory = os.path.join(ST2TESTS_DIR, 'policies') + directory = os.path.join(ST2TESTS_DIR, "policies") expected = [ - 'mock_exception.py', - 'concurrency.py', - '__init__.py', - 'meta/mock_exception.yaml', - 'meta/concurrency.yaml', - 'meta/__init__.py' + "mock_exception.py", + "concurrency.py", + "__init__.py", + "meta/mock_exception.yaml", + "meta/concurrency.yaml", + "meta/__init__.py", ] - result = get_file_list(directory=directory, exclude_patterns=['*.pyc']) + result = get_file_list(directory=directory, exclude_patterns=["*.pyc"]) self.assertItemsEqual(expected, result) # Custom exclude pattern expected = [ - 'mock_exception.py', - 'concurrency.py', - '__init__.py', - 'meta/__init__.py' + "mock_exception.py", + "concurrency.py", + "__init__.py", + "meta/__init__.py", ] - result = get_file_list(directory=directory, exclude_patterns=['*.pyc', '*.yaml']) + result = get_file_list( + directory=directory, exclude_patterns=["*.pyc", "*.yaml"] + ) self.assertItemsEqual(expected, result) diff --git a/st2common/tests/unit/test_util_http.py b/st2common/tests/unit/test_util_http.py index 2bfbc22f04..a97aa8c7f1 100644 --- a/st2common/tests/unit/test_util_http.py +++ b/st2common/tests/unit/test_util_http.py @@ -19,24 +19,22 @@ from st2common.util.http import parse_content_type_header from six.moves import zip -__all__ = [ - 'HTTPUtilTestCase' -] +__all__ = ["HTTPUtilTestCase"] class HTTPUtilTestCase(unittest2.TestCase): def test_parse_content_type_header(self): values = [ - 'application/json', - 'foo/bar', - 'application/json; charset=utf-8', - 'application/json; charset=utf-8; foo=bar', + "application/json", + "foo/bar", + "application/json; charset=utf-8", + "application/json; charset=utf-8; foo=bar", ] expected_results = [ - ('application/json', {}), - ('foo/bar', {}), - ('application/json', {'charset': 'utf-8'}), - ('application/json', {'charset': 'utf-8', 'foo': 'bar'}) + ("application/json", {}), + ("foo/bar", {}), + ("application/json", {"charset": "utf-8"}), + ("application/json", {"charset": "utf-8", "foo": "bar"}), ] for value, expected_result in zip(values, expected_results): diff --git a/st2common/tests/unit/test_util_jinja.py b/st2common/tests/unit/test_util_jinja.py index 1b56adc0e9..127570f54b 100644 --- a/st2common/tests/unit/test_util_jinja.py +++ b/st2common/tests/unit/test_util_jinja.py @@ -21,97 +21,95 @@ class JinjaUtilsRenderTestCase(unittest2.TestCase): - def test_render_values(self): actual = jinja_utils.render_values( - mapping={'k1': '{{a}}', 'k2': '{{b}}'}, - context={'a': 'v1', 'b': 'v2'}) - expected = {'k2': 'v2', 'k1': 'v1'} + mapping={"k1": "{{a}}", "k2": "{{b}}"}, context={"a": "v1", "b": "v2"} + ) + expected = {"k2": "v2", "k1": "v1"} self.assertEqual(actual, expected) def test_render_values_skip_missing(self): actual = jinja_utils.render_values( - mapping={'k1': '{{a}}', 'k2': '{{b}}', 'k3': '{{c}}'}, - context={'a': 'v1', 'b': 'v2'}, - allow_undefined=True) - expected = {'k2': 'v2', 'k1': 'v1', 'k3': ''} + mapping={"k1": "{{a}}", "k2": "{{b}}", "k3": "{{c}}"}, + context={"a": "v1", "b": "v2"}, + allow_undefined=True, + ) + expected = {"k2": "v2", "k1": "v1", "k3": ""} self.assertEqual(actual, expected) def test_render_values_ascii_and_unicode_values(self): - mapping = { - u'k_ascii': '{{a}}', - u'k_unicode': '{{b}}', - u'k_ascii_unicode': '{{c}}'} + mapping = {"k_ascii": "{{a}}", "k_unicode": "{{b}}", "k_ascii_unicode": "{{c}}"} context = { - 'a': u'some ascii value', - 'b': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'c': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ' + "a": "some ascii value", + "b": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "c": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ", } expected = { - 'k_ascii': u'some ascii value', - 'k_unicode': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'k_ascii_unicode': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ' + "k_ascii": "some ascii value", + "k_unicode": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "k_ascii_unicode": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ", } actual = jinja_utils.render_values( - mapping=mapping, - context=context, - allow_undefined=True) + mapping=mapping, context=context, allow_undefined=True + ) self.assertEqual(actual, expected) def test_convert_str_to_raw(self): - jinja_expr = '{{foobar}}' - expected_raw_block = '{% raw %}{{foobar}}{% endraw %}' - self.assertEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + jinja_expr = "{{foobar}}" + expected_raw_block = "{% raw %}{{foobar}}{% endraw %}" + self.assertEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) - jinja_block_expr = '{% for item in items %}foobar{% end for %}' - expected_raw_block = '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}' + jinja_block_expr = "{% for item in items %}foobar{% end for %}" + expected_raw_block = ( + "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}" + ) self.assertEqual( - expected_raw_block, - jinja_utils.convert_jinja_to_raw_block(jinja_block_expr) + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_block_expr) ) def test_convert_list_to_raw(self): jinja_expr = [ - 'foobar', - '{{foo}}', - '{{bar}}', - '{% for item in items %}foobar{% end for %}', - {'foobar': '{{foobar}}'} + "foobar", + "{{foo}}", + "{{bar}}", + "{% for item in items %}foobar{% end for %}", + {"foobar": "{{foobar}}"}, ] expected_raw_block = [ - 'foobar', - '{% raw %}{{foo}}{% endraw %}', - '{% raw %}{{bar}}{% endraw %}', - '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}', - {'foobar': '{% raw %}{{foobar}}{% endraw %}'} + "foobar", + "{% raw %}{{foo}}{% endraw %}", + "{% raw %}{{bar}}{% endraw %}", + "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}", + {"foobar": "{% raw %}{{foobar}}{% endraw %}"}, ] - self.assertListEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + self.assertListEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) def test_convert_dict_to_raw(self): jinja_expr = { - 'var1': 'foobar', - 'var2': ['{{foo}}', '{{bar}}'], - 'var3': {'foobar': '{{foobar}}'}, - 'var4': {'foobar': '{% for item in items %}foobar{% end for %}'} + "var1": "foobar", + "var2": ["{{foo}}", "{{bar}}"], + "var3": {"foobar": "{{foobar}}"}, + "var4": {"foobar": "{% for item in items %}foobar{% end for %}"}, } expected_raw_block = { - 'var1': 'foobar', - 'var2': [ - '{% raw %}{{foo}}{% endraw %}', - '{% raw %}{{bar}}{% endraw %}' - ], - 'var3': { - 'foobar': '{% raw %}{{foobar}}{% endraw %}' + "var1": "foobar", + "var2": ["{% raw %}{{foo}}{% endraw %}", "{% raw %}{{bar}}{% endraw %}"], + "var3": {"foobar": "{% raw %}{{foobar}}{% endraw %}"}, + "var4": { + "foobar": "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}" }, - 'var4': { - 'foobar': '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}' - } } - self.assertDictEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + self.assertDictEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) diff --git a/st2common/tests/unit/test_util_keyvalue.py b/st2common/tests/unit/test_util_keyvalue.py index 07f061e60a..5a8c15a3f7 100644 --- a/st2common/tests/unit/test_util_keyvalue.py +++ b/st2common/tests/unit/test_util_keyvalue.py @@ -18,14 +18,19 @@ import unittest2 from st2common.util import keyvalue as kv_utl -from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE, - ALL_SCOPE, DATASTORE_PARENT_SCOPE, - DATASTORE_SCOPE_SEPARATOR) +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + USER_SCOPE, + ALL_SCOPE, + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, +) from st2common.exceptions.rbac import AccessDeniedError from st2common.models.db import auth as auth_db -USER = 'stanley' +USER = "stanley" class TestKeyValueUtil(unittest2.TestCase): @@ -38,48 +43,26 @@ def test_validate_scope(self): kv_utl._validate_scope(scope) def test_validate_scope_with_invalid_scope(self): - scope = 'INVALID_SCOPE' + scope = "INVALID_SCOPE" self.assertRaises(ValueError, kv_utl._validate_scope, scope) def test_validate_decrypt_query_parameter(self): test_params = [ - [ - False, - USER_SCOPE, - False, - {} - ], - [ - True, - USER_SCOPE, - False, - {} - ], - [ - True, - FULL_SYSTEM_SCOPE, - True, - {} - ], + [False, USER_SCOPE, False, {}], + [True, USER_SCOPE, False, {}], + [True, FULL_SYSTEM_SCOPE, True, {}], ] for params in test_params: kv_utl._validate_decrypt_query_parameter(*params) def test_validate_decrypt_query_parameter_access_denied(self): - test_params = [ - [ - True, - FULL_SYSTEM_SCOPE, - False, - {} - ] - ] + test_params = [[True, FULL_SYSTEM_SCOPE, False, {}]] for params in test_params: assert_params = [ AccessDeniedError, - kv_utl._validate_decrypt_query_parameter + kv_utl._validate_decrypt_query_parameter, ] assert_params.extend(params) @@ -88,81 +71,58 @@ def test_validate_decrypt_query_parameter_access_denied(self): def test_get_datastore_full_scope(self): self.assertEqual( kv_utl.get_datastore_full_scope(USER_SCOPE), - DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE]) + DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE]), ) def test_get_datastore_full_scope_all_scope(self): - self.assertEqual( - kv_utl.get_datastore_full_scope(ALL_SCOPE), - ALL_SCOPE - ) + self.assertEqual(kv_utl.get_datastore_full_scope(ALL_SCOPE), ALL_SCOPE) def test_get_datastore_full_scope_datastore_parent_scope(self): self.assertEqual( kv_utl.get_datastore_full_scope(DATASTORE_PARENT_SCOPE), - DATASTORE_PARENT_SCOPE + DATASTORE_PARENT_SCOPE, ) def test_derive_scope_and_key(self): - key = 'test' + key = "test" scope = USER_SCOPE result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_USER_SCOPE, 'user:%s' % key), - result - ) + self.assertEqual((FULL_USER_SCOPE, "user:%s" % key), result) def test_derive_scope_and_key_without_scope(self): - key = 'test' + key = "test" scope = None result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_USER_SCOPE, 'None:%s' % key), - result - ) + self.assertEqual((FULL_USER_SCOPE, "None:%s" % key), result) def test_derive_scope_and_key_system_key(self): - key = 'system.test' + key = "system.test" scope = None result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_SYSTEM_SCOPE, key.split('.')[1]), - result - ) + self.assertEqual((FULL_SYSTEM_SCOPE, key.split(".")[1]), result) - @mock.patch('st2common.util.keyvalue.KeyValuePair') - @mock.patch('st2common.util.keyvalue.deserialize_key_value') + @mock.patch("st2common.util.keyvalue.KeyValuePair") + @mock.patch("st2common.util.keyvalue.deserialize_key_value") def test_get_key(self, deseralize_key_value, KeyValuePair): - key, value = ('Lindsay', 'Lohan') + key, value = ("Lindsay", "Lohan") decrypt = False KeyValuePair.get_by_scope_and_name().value = value deseralize_key_value.return_value = value - result = kv_utl.get_key(key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt) + result = kv_utl.get_key( + key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt + ) self.assertEqual(result, value) KeyValuePair.get_by_scope_and_name.assert_called_with( - FULL_USER_SCOPE, - 'stanley:%s' % key - ) - deseralize_key_value.assert_called_once_with( - value, - decrypt + FULL_USER_SCOPE, "stanley:%s" % key ) + deseralize_key_value.assert_called_once_with(value, decrypt) def test_get_key_invalid_input(self): - self.assertRaises( - TypeError, - kv_utl.get_key, - key=1 - ) - self.assertRaises( - TypeError, - kv_utl.get_key, - key='test', - decrypt='yep' - ) + self.assertRaises(TypeError, kv_utl.get_key, key=1) + self.assertRaises(TypeError, kv_utl.get_key, key="test", decrypt="yep") diff --git a/st2common/tests/unit/test_util_output_schema.py b/st2common/tests/unit/test_util_output_schema.py index d3ef387a26..af9570d4fa 100644 --- a/st2common/tests/unit/test_util_output_schema.py +++ b/st2common/tests/unit/test_util_output_schema.py @@ -19,58 +19,46 @@ from st2common.constants.action import ( LIVEACTION_STATUS_SUCCEEDED, - LIVEACTION_STATUS_FAILED + LIVEACTION_STATUS_FAILED, ) ACTION_RESULT = { - 'output': { - 'output_1': 'Bobby', - 'output_2': 5, - 'deep_output': { - 'deep_item_1': 'Jindal', + "output": { + "output_1": "Bobby", + "output_2": 5, + "deep_output": { + "deep_item_1": "Jindal", }, } } RUNNER_SCHEMA = { - 'output': { - 'type': 'object' - }, - 'error': { - 'type': 'array' - }, + "output": {"type": "object"}, + "error": {"type": "array"}, } ACTION_SCHEMA = { - 'output_1': { - 'type': 'string' - }, - 'output_2': { - 'type': 'integer' - }, - 'deep_output': { - 'type': 'object', - 'parameters': { - 'deep_item_1': { - 'type': 'string', + "output_1": {"type": "string"}, + "output_2": {"type": "integer"}, + "deep_output": { + "type": "object", + "parameters": { + "deep_item_1": { + "type": "string", }, }, }, } RUNNER_SCHEMA_FAIL = { - 'not_a_key_you_have': { - 'type': 'string' - }, + "not_a_key_you_have": {"type": "string"}, } ACTION_SCHEMA_FAIL = { - 'not_a_key_you_have': { - 'type': 'string' - }, + "not_a_key_you_have": {"type": "string"}, } -OUTPUT_KEY = 'output' +OUTPUT_KEY = "output" class OutputSchemaTestCase(unittest2.TestCase): @@ -96,7 +84,7 @@ def test_invalid_runner_schema(self): ) expected_result = { - 'error': ( + "error": ( "Additional properties are not allowed ('output' was unexpected)" "\n\nFailed validating 'additionalProperties' in schema:\n {'addi" "tionalProperties': False,\n 'properties': {'not_a_key_you_have': " @@ -104,7 +92,7 @@ def test_invalid_runner_schema(self): "output': {'deep_output': {'deep_item_1': 'Jindal'},\n " "'output_1': 'Bobby',\n 'output_2': 5}}" ), - 'message': 'Error validating output. See error output for more details.' + "message": "Error validating output. See error output for more details.", } self.assertEqual(result, expected_result) @@ -120,12 +108,12 @@ def test_invalid_action_schema(self): ) expected_result = { - 'error': "Additional properties are not allowed", - 'message': u'Error validating output. See error output for more details.' + "error": "Additional properties are not allowed", + "message": "Error validating output. See error output for more details.", } # To avoid random failures (especially in python3) this assert cant be # exact since the parameters can be ordered differently per execution. - self.assertIn(expected_result['error'], result['error']) - self.assertEqual(result['message'], expected_result['message']) + self.assertIn(expected_result["error"], result["error"]) + self.assertEqual(result["message"], expected_result["message"]) self.assertEqual(status, LIVEACTION_STATUS_FAILED) diff --git a/st2common/tests/unit/test_util_pack.py b/st2common/tests/unit/test_util_pack.py index 20522b8b18..0b476b7336 100644 --- a/st2common/tests/unit/test_util_pack.py +++ b/st2common/tests/unit/test_util_pack.py @@ -22,59 +22,47 @@ class PackUtilsTestCase(unittest2.TestCase): - def test_get_pack_common_libs_path_for_pack_db(self): pack_model_args = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen', - 'path': '/opt/stackstorm/packs/yolo_ci/' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", + "path": "/opt/stackstorm/packs/yolo_ci/", } pack_db = PackDB(**pack_model_args) lib_path = get_pack_common_libs_path_for_pack_db(pack_db) - self.assertEqual('/opt/stackstorm/packs/yolo_ci/lib', lib_path) + self.assertEqual("/opt/stackstorm/packs/yolo_ci/lib", lib_path) def test_get_pack_common_libs_path_for_pack_db_no_path_in_pack_db(self): pack_model_args = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", } pack_db = PackDB(**pack_model_args) lib_path = get_pack_common_libs_path_for_pack_db(pack_db) self.assertEqual(None, lib_path) def test_get_pack_warnings_python2_only(self): - pack_metadata = { - 'python_versions': ['2'], - 'name': 'Pack2' - } + pack_metadata = {"python_versions": ["2"], "name": "Pack2"} warning = get_pack_warnings(pack_metadata) self.assertTrue("DEPRECATION WARNING" in warning) def test_get_pack_warnings_python3_only(self): - pack_metadata = { - 'python_versions': ['3'], - 'name': 'Pack3' - } + pack_metadata = {"python_versions": ["3"], "name": "Pack3"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) def test_get_pack_warnings_python2_and_3(self): - pack_metadata = { - 'python_versions': ['2', '3'], - 'name': 'Pack23' - } + pack_metadata = {"python_versions": ["2", "3"], "name": "Pack23"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) def test_get_pack_warnings_no_python(self): - pack_metadata = { - 'name': 'PackNone' - } + pack_metadata = {"name": "PackNone"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) diff --git a/st2common/tests/unit/test_util_payload.py b/st2common/tests/unit/test_util_payload.py index 207d4c1766..2621e3de91 100644 --- a/st2common/tests/unit/test_util_payload.py +++ b/st2common/tests/unit/test_util_payload.py @@ -19,27 +19,31 @@ from st2common.util.payload import PayloadLookup -__all__ = [ - 'PayloadLookupTestCase' -] +__all__ = ["PayloadLookupTestCase"] class PayloadLookupTestCase(unittest2.TestCase): @classmethod def setUpClass(cls): - cls.payload = PayloadLookup({ - 'pikachu': "Has no ears", - 'charmander': "Plays with fire", - }) + cls.payload = PayloadLookup( + { + "pikachu": "Has no ears", + "charmander": "Plays with fire", + } + ) super(PayloadLookupTestCase, cls).setUpClass() def test_get_key(self): - self.assertEqual(self.payload.get_value('trigger.pikachu'), ["Has no ears"]) - self.assertEqual(self.payload.get_value('trigger.charmander'), ["Plays with fire"]) + self.assertEqual(self.payload.get_value("trigger.pikachu"), ["Has no ears"]) + self.assertEqual( + self.payload.get_value("trigger.charmander"), ["Plays with fire"] + ) def test_explicitly_get_multiple_keys(self): - self.assertEqual(self.payload.get_value('trigger.pikachu[*]'), ["Has no ears"]) - self.assertEqual(self.payload.get_value('trigger.charmander[*]'), ["Plays with fire"]) + self.assertEqual(self.payload.get_value("trigger.pikachu[*]"), ["Has no ears"]) + self.assertEqual( + self.payload.get_value("trigger.charmander[*]"), ["Plays with fire"] + ) def test_get_nonexistent_key(self): - self.assertIsNone(self.payload.get_value('trigger.squirtle')) + self.assertIsNone(self.payload.get_value("trigger.squirtle")) diff --git a/st2common/tests/unit/test_util_sandboxing.py b/st2common/tests/unit/test_util_sandboxing.py index 5f387e0067..3926c9f74c 100644 --- a/st2common/tests/unit/test_util_sandboxing.py +++ b/st2common/tests/unit/test_util_sandboxing.py @@ -32,9 +32,7 @@ import st2tests.config as tests_config -__all__ = [ - 'SandboxingUtilsTestCase' -] +__all__ = ["SandboxingUtilsTestCase"] class SandboxingUtilsTestCase(unittest.TestCase): @@ -69,8 +67,10 @@ def assertEndsWith(self, string, ending_substr, msg=None): def test_get_sandbox_python_binary_path(self): # Non-system content pack, should use pack specific virtualenv binary - result = get_sandbox_python_binary_path(pack='mapack') - expected = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/mapack/bin/python') + result = get_sandbox_python_binary_path(pack="mapack") + expected = os.path.join( + cfg.CONF.system.base_path, "virtualenvs/mapack/bin/python" + ) self.assertEqual(result, expected) # System content pack, should use current process (system) python binary @@ -78,159 +78,190 @@ def test_get_sandbox_python_binary_path(self): self.assertEqual(result, sys.executable) def test_get_sandbox_path(self): - virtualenv_path = '/home/venv/test' + virtualenv_path = "/home/venv/test" # Mock the current PATH value - with mock.patch.dict(os.environ, {'PATH': '/home/path1:/home/path2:/home/path3:'}): + with mock.patch.dict( + os.environ, {"PATH": "/home/path1:/home/path2:/home/path3:"} + ): result = get_sandbox_path(virtualenv_path=virtualenv_path) - self.assertEqual(result, f'{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3') + self.assertEqual( + result, f"{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3" + ) - @mock.patch('st2common.util.sandboxing.get_python_lib') + @mock.patch("st2common.util.sandboxing.get_python_lib") def test_get_sandbox_python_path(self, mock_get_python_lib): # No inheritance - python_path = get_sandbox_python_path(inherit_from_parent=False, - inherit_parent_virtualenv=False) - self.assertEqual(python_path, ':') + python_path = get_sandbox_python_path( + inherit_from_parent=False, inherit_parent_virtualenv=False + ) + self.assertEqual(python_path, ":") # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") # Inherit from current process and from virtualenv (not running inside virtualenv) clear_virtualenv_prefix() - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") # Inherit from current process and from virtualenv (running inside virtualenv) - sys.real_prefix = '/usr' - mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest' + sys.real_prefix = "/usr" + mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest" - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=True) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=True + ) - self.assertEqual(python_path, f':/data/test1:/data/test2:{sys.prefix}/virtualenvtest') + self.assertEqual( + python_path, f":/data/test1:/data/test2:{sys.prefix}/virtualenvtest" + ) - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_no_inheritance(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_no_inheritance( + self, mock_get_python_lib + ): # No inheritance - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=False, - inherit_parent_virtualenv=False) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=False, + inherit_parent_virtualenv=False, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 3) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only( + self, mock_get_python_lib + ): # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=False) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=False, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 6) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # And the rest of the paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv( + self, mock_get_python_lib + ): # Inherit from current process and from virtualenv (not running inside virtualenv) clear_virtualenv_prefix() # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=True) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=True, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 6) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # And the rest of the paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") # Inherit from current process and from virtualenv (running inside virtualenv) - sys.real_prefix = '/usr' - mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest' + sys.real_prefix = "/usr" + mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest" # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=True) - - actual_path = python_path.strip(':').split(':') + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=True, + ) + + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 7) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # The paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") # And the parent virtualenv - self.assertEqual(actual_path[6], f'{sys.prefix}/virtualenvtest') + self.assertEqual(actual_path[6], f"{sys.prefix}/virtualenvtest") diff --git a/st2common/tests/unit/test_util_secrets.py b/st2common/tests/unit/test_util_secrets.py index f49f8f76a9..8c77c34f49 100644 --- a/st2common/tests/unit/test_util_secrets.py +++ b/st2common/tests/unit/test_util_secrets.py @@ -22,38 +22,30 @@ ################################################################################ TEST_FLAT_SCHEMA = { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string', - 'secret': False + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", + "secret": False, }, - 'arg_optional_no_type_secret': { - 'description': 'Bar', - 'secret': True - }, - 'arg_optional_type_array': { - 'description': 'Who''s the fairest?', - 'type': 'array' - }, - 'arg_optional_type_object': { - 'description': 'Who''s the fairest of them?', - 'type': 'object' + "arg_optional_no_type_secret": {"description": "Bar", "secret": True}, + "arg_optional_type_array": {"description": "Who" "s the fairest?", "type": "array"}, + "arg_optional_type_object": { + "description": "Who" "s the fairest of them?", + "type": "object", }, } -TEST_FLAT_SECRET_PARAMS = { - 'arg_optional_no_type_secret': None -} +TEST_FLAT_SECRET_PARAMS = {"arg_optional_no_type_secret": None} ################################################################################ TEST_NO_SECRETS_SCHEMA = { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string', - 'secret': False + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", + "secret": False, } } @@ -62,497 +54,397 @@ ################################################################################ TEST_NESTED_OBJECTS_SCHEMA = { - 'arg_string': { - 'description': 'Junk', - 'type': 'string', + "arg_string": { + "description": "Junk", + "type": "string", }, - 'arg_optional_object': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_optional_object": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, } TEST_NESTED_OBJECTS_SECRET_PARAMS = { - 'arg_optional_object': { - 'arg_nested_secret': 'string', - 'arg_nested_object': { - 'arg_double_nested_secret': 'string', - } + "arg_optional_object": { + "arg_nested_secret": "string", + "arg_nested_object": { + "arg_double_nested_secret": "string", + }, } } ################################################################################ TEST_ARRAY_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'down', - 'type': 'string', - 'secret': True - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "down", "type": "string", "secret": True}, } } -TEST_ARRAY_SECRET_PARAMS = { - 'arg_optional_array': [ - 'string' - ] -} +TEST_ARRAY_SECRET_PARAMS = {"arg_optional_array": ["string"]} ################################################################################ TEST_ROOT_ARRAY_SCHEMA = { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } - } + "description": "Mirror", + "type": "array", + "items": { + "description": "down", + "type": "object", + "properties": {"secret_field_in_object": {"type": "string", "secret": True}}, + }, } -TEST_ROOT_ARRAY_SECRET_PARAMS = [ - { - 'secret_field_in_object': 'string' - } -] +TEST_ROOT_ARRAY_SECRET_PARAMS = [{"secret_field_in_object": "string"}] ################################################################################ TEST_ROOT_OBJECT_SCHEMA = { - 'description': 'root', - 'type': 'object', - 'properties': { - 'arg_level_one': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } + "description": "root", + "type": "object", + "properties": { + "arg_level_one": { + "description": "down", + "type": "object", + "properties": { + "secret_field_in_object": {"type": "string", "secret": True} + }, } - } + }, } -TEST_ROOT_OBJECT_SECRET_PARAMS = { - 'arg_level_one': - { - 'secret_field_in_object': 'string' - } -} +TEST_ROOT_OBJECT_SECRET_PARAMS = {"arg_level_one": {"secret_field_in_object": "string"}} ################################################################################ TEST_NESTED_ARRAYS_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, }, - 'arg_optional_double_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_optional_double_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, + }, }, - 'arg_optional_tripple_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_optional_tripple_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, + }, + }, + }, + "arg_optional_quad_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, + }, }, - 'arg_optional_quad_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } - } - } } TEST_NESTED_ARRAYS_SECRET_PARAMS = { - 'arg_optional_array': [ - 'string' - ], - 'arg_optional_double_array': [ - [ - 'string' - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'string' - ] - ] - ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'string' - ] - ] - ] - ] + "arg_optional_array": ["string"], + "arg_optional_double_array": [["string"]], + "arg_optional_tripple_array": [[["string"]]], + "arg_optional_quad_array": [[[["string"]]]], } ################################################################################ TEST_NESTED_OBJECT_WITH_ARRAY_SCHEMA = { - 'arg_optional_object_with_array': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } + "arg_optional_object_with_array": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, } - } + }, } } TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ - 'string' - ] - } + "arg_optional_object_with_array": {"arg_nested_array": ["string"]} } ################################################################################ TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA = { - 'arg_optional_object_with_double_array': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_double_nested_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_optional_object_with_double_array": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_double_nested_array": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, } - } + }, } } TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ - [ - 'string' - ] - ] - } + "arg_optional_object_with_double_array": {"arg_double_nested_array": [["string"]]} } ################################################################################ TEST_NESTED_ARRAY_WITH_OBJECT_SCHEMA = { - 'arg_optional_array_with_object': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True + "arg_optional_array_with_object": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, } - } - } + }, + }, } } TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': 'string' - } - ] + "arg_optional_array_with_object": [{"arg_nested_secret": "string"}] } ################################################################################ TEST_SECRET_ARRAY_SCHEMA = { - 'arg_secret_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, + "arg_secret_array": { + "description": "Mirror", + "type": "array", + "secret": True, } } -TEST_SECRET_ARRAY_SECRET_PARAMS = { - 'arg_secret_array': 'array' -} +TEST_SECRET_ARRAY_SECRET_PARAMS = {"arg_secret_array": "array"} ################################################################################ TEST_SECRET_OBJECT_SCHEMA = { - 'arg_secret_object': { - 'type': 'object', - 'secret': True, + "arg_secret_object": { + "type": "object", + "secret": True, } } -TEST_SECRET_OBJECT_SECRET_PARAMS = { - 'arg_secret_object': 'object' -} +TEST_SECRET_OBJECT_SECRET_PARAMS = {"arg_secret_object": "object"} ################################################################################ TEST_SECRET_ROOT_ARRAY_SCHEMA = { - 'description': 'secret array', - 'type': 'array', - 'secret': True, - 'items': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } - } + "description": "secret array", + "type": "array", + "secret": True, + "items": { + "description": "down", + "type": "object", + "properties": {"secret_field_in_object": {"type": "string", "secret": True}}, + }, } -TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = 'array' +TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = "array" ################################################################################ TEST_SECRET_ROOT_OBJECT_SCHEMA = { - 'description': 'secret object', - 'type': 'object', - 'secret': True, - 'proeprteis': { - 'arg_level_one': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } + "description": "secret object", + "type": "object", + "secret": True, + "proeprteis": { + "arg_level_one": { + "description": "down", + "type": "object", + "properties": { + "secret_field_in_object": {"type": "string", "secret": True} + }, } - } + }, } -TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = 'object' +TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = "object" ################################################################################ TEST_SECRET_NESTED_OBJECTS_SCHEMA = { - 'arg_object': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_object": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "secret": True, + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, }, - 'arg_secret_object': { - 'description': 'Mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_secret_object": { + "description": "Mirror", + "type": "object", + "secret": True, + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "secret": True, + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, } TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS = { - 'arg_object': { - 'arg_nested_secret': 'string', - 'arg_nested_object': 'object' - }, - 'arg_secret_object': 'object' + "arg_object": {"arg_nested_secret": "string", "arg_nested_object": "object"}, + "arg_secret_object": "object", } ################################################################################ TEST_SECRET_NESTED_ARRAYS_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, - 'items': { - 'description': 'Deep down', - 'type': 'string' - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "secret": True, + "items": {"description": "Deep down", "type": "string"}, }, - 'arg_optional_double_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } + "arg_optional_double_array": { + "description": "Mirror", + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, }, - 'arg_optional_tripple_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } - } + "arg_optional_tripple_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, + }, + }, + "arg_optional_quad_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, + }, + }, }, - 'arg_optional_quad_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } - } - } - } } TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS = { - 'arg_optional_array': 'array', - 'arg_optional_double_array': 'array', - 'arg_optional_tripple_array': [ - 'array' - ], - 'arg_optional_quad_array': [ - [ - 'array' - ] - ] + "arg_optional_array": "array", + "arg_optional_double_array": "array", + "arg_optional_tripple_array": ["array"], + "arg_optional_quad_array": [["array"]], } ################################################################################ class SecretUtilsTestCase(unittest2.TestCase): - def test_get_secret_parameters_flat(self): result = secrets.get_secret_parameters(TEST_FLAT_SCHEMA) self.assertEqual(TEST_FLAT_SECRET_PARAMS, result) @@ -586,7 +478,9 @@ def test_get_secret_parameters_nested_object_with_array(self): self.assertEqual(TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS, result) def test_get_secret_parameters_nested_object_with_double_array(self): - result = secrets.get_secret_parameters(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA) + result = secrets.get_secret_parameters( + TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA + ) self.assertEqual(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS, result) def test_get_secret_parameters_nested_array_with_object(self): @@ -621,178 +515,128 @@ def test_get_secret_parameters_secret_nested_objects(self): def test_mask_secret_parameters_flat(self): parameters = { - 'arg_required_no_default': 'test', - 'arg_optional_no_type_secret': None + "arg_required_no_default": "test", + "arg_optional_no_type_secret": None, } - result = secrets.mask_secret_parameters(parameters, - TEST_FLAT_SECRET_PARAMS) + result = secrets.mask_secret_parameters(parameters, TEST_FLAT_SECRET_PARAMS) expected = { - 'arg_required_no_default': 'test', - 'arg_optional_no_type_secret': MASKED_ATTRIBUTE_VALUE + "arg_required_no_default": "test", + "arg_optional_no_type_secret": MASKED_ATTRIBUTE_VALUE, } self.assertEqual(expected, result) def test_mask_secret_parameters_no_secrets(self): - parameters = {'arg_required_no_default': 'junk'} - result = secrets.mask_secret_parameters(parameters, - TEST_NO_SECRETS_SECRET_PARAMS) - expected = { - 'arg_required_no_default': 'junk' - } + parameters = {"arg_required_no_default": "junk"} + result = secrets.mask_secret_parameters( + parameters, TEST_NO_SECRETS_SECRET_PARAMS + ) + expected = {"arg_required_no_default": "junk"} self.assertEqual(expected, result) def test_mask_secret_parameters_nested_objects(self): parameters = { - 'arg_optional_object': { - 'arg_nested_secret': 'nested Secret', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } + "arg_optional_object": { + "arg_nested_secret": "nested Secret", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECTS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECTS_SECRET_PARAMS + ) expected = { - 'arg_optional_object': { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE, - 'arg_nested_object': { - 'arg_double_nested_secret': MASKED_ATTRIBUTE_VALUE, - } + "arg_optional_object": { + "arg_nested_secret": MASKED_ATTRIBUTE_VALUE, + "arg_nested_object": { + "arg_double_nested_secret": MASKED_ATTRIBUTE_VALUE, + }, } } self.assertEqual(expected, result) def test_mask_secret_parameters_array(self): parameters = { - 'arg_optional_array': [ - '$ecret $tring 1', - '$ecret $tring 2', - '$ecret $tring 3' + "arg_optional_array": [ + "$ecret $tring 1", + "$ecret $tring 2", + "$ecret $tring 3", ] } - result = secrets.mask_secret_parameters(parameters, - TEST_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters(parameters, TEST_ARRAY_SECRET_PARAMS) expected = { - 'arg_optional_array': [ + "arg_optional_array": [ + MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE ] } self.assertEqual(expected, result) def test_mask_secret_parameters_root_array(self): parameters = [ - { - 'secret_field_in_object': 'Secret $tr!ng' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 2' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 3' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 4' - } + {"secret_field_in_object": "Secret $tr!ng"}, + {"secret_field_in_object": "Secret $tr!ng 2"}, + {"secret_field_in_object": "Secret $tr!ng 3"}, + {"secret_field_in_object": "Secret $tr!ng 4"}, ] - result = secrets.mask_secret_parameters(parameters, TEST_ROOT_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_ROOT_ARRAY_SECRET_PARAMS + ) expected = [ - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - } + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, ] self.assertEqual(expected, result) def test_mask_secret_parameters_root_object(self): - parameters = { - 'arg_level_one': - { - 'secret_field_in_object': 'Secret $tr!ng' - } - } + parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}} - result = secrets.mask_secret_parameters(parameters, TEST_ROOT_OBJECT_SECRET_PARAMS) - expected = { - 'arg_level_one': - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - } - } + result = secrets.mask_secret_parameters( + parameters, TEST_ROOT_OBJECT_SECRET_PARAMS + ) + expected = {"arg_level_one": {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}} self.assertEqual(expected, result) def test_mask_secret_parameters_nested_arrays(self): parameters = { - 'arg_optional_array': [ - 'secret 1', - 'secret 2', - 'secret 3', + "arg_optional_array": [ + "secret 1", + "secret 2", + "secret 3", ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 7', - 'secret 8', - 'secret 9', - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'secret 10', - 'secret 11' - ], - [ - 'secret 12', - 'secret 13', - 'secret 14' - ] + "secret 7", + "secret 8", + "secret 9", ], - [ - [ - 'secret 15', - 'secret 16' - ] - ] ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'secret 17', - 'secret 18' - ], - [ - 'secret 19' - ] - ] - ] - ] + "arg_optional_tripple_array": [ + [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]], + [["secret 15", "secret 16"]], + ], + "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]], } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_ARRAYS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_ARRAYS_SECRET_PARAMS + ) expected = { - 'arg_optional_array': [ + "arg_optional_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, @@ -802,58 +646,46 @@ def test_mask_secret_parameters_nested_arrays(self): MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - ] + ], ], - 'arg_optional_tripple_array': [ + "arg_optional_tripple_array": [ [ + [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE], [ MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ], - [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ] + ], ], - [ - [ - MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ] - ] + [[MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE]], ], - 'arg_optional_quad_array': [ + "arg_optional_quad_array": [ [ [ - [ - MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ], - [ - MASKED_ATTRIBUTE_VALUE - ] + [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE], + [MASKED_ATTRIBUTE_VALUE], ] ] - ] + ], } self.assertEqual(expected, result) def test_mask_secret_parameters_nested_object_with_array(self): parameters = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ - 'secret array value 1', - 'secret array value 2', - 'secret array value 3', + "arg_optional_object_with_array": { + "arg_nested_array": [ + "secret array value 1", + "secret array value 2", + "secret array value 3", ] } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS + ) expected = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ + "arg_optional_object_with_array": { + "arg_nested_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, @@ -864,36 +696,33 @@ def test_mask_secret_parameters_nested_object_with_array(self): def test_mask_secret_parameters_nested_object_with_double_array(self): parameters = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ + "arg_optional_object_with_double_array": { + "arg_double_nested_array": [ + ["secret 1", "secret 2", "secret 3"], [ - 'secret 1', - 'secret 2', - 'secret 3' + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 7", + "secret 8", + "secret 9", + "secret 10", ], - [ - 'secret 7', - 'secret 8', - 'secret 9', - 'secret 10', - ] ] } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS + ) expected = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ + "arg_optional_object_with_double_array": { + "arg_double_nested_array": [ [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE + MASKED_ATTRIBUTE_VALUE, ], [ MASKED_ATTRIBUTE_VALUE, @@ -905,7 +734,7 @@ def test_mask_secret_parameters_nested_object_with_double_array(self): MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - ] + ], ] } } @@ -913,187 +742,132 @@ def test_mask_secret_parameters_nested_object_with_double_array(self): def test_mask_secret_parameters_nested_array_with_object(self): parameters = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': 'secret 1' - }, - { - 'arg_nested_secret': 'secret 2' - }, - { - 'arg_nested_secret': 'secret 3' - } + "arg_optional_array_with_object": [ + {"arg_nested_secret": "secret 1"}, + {"arg_nested_secret": "secret 2"}, + {"arg_nested_secret": "secret 3"}, ] } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS + ) expected = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - }, - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - }, - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - } + "arg_optional_array_with_object": [ + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, ] } self.assertEqual(expected, result) def test_mask_secret_parameters_secret_array(self): - parameters = { - 'arg_secret_array': [ - "abc", - 123, - True - ] - } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ARRAY_SECRET_PARAMS) - expected = { - 'arg_secret_array': MASKED_ATTRIBUTE_VALUE - } + parameters = {"arg_secret_array": ["abc", 123, True]} + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ARRAY_SECRET_PARAMS + ) + expected = {"arg_secret_array": MASKED_ATTRIBUTE_VALUE} self.assertEqual(expected, result) def test_mask_secret_parameters_secret_object(self): parameters = { - 'arg_secret_object': - { + "arg_secret_object": { "abc": 123, "key": "value", "bool": True, "array": ["x", "y", "z"], - "obj": - { - "x": "deep" - } + "obj": {"x": "deep"}, } } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_OBJECT_SECRET_PARAMS) - expected = { - 'arg_secret_object': MASKED_ATTRIBUTE_VALUE - } + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_OBJECT_SECRET_PARAMS + ) + expected = {"arg_secret_object": MASKED_ATTRIBUTE_VALUE} self.assertEqual(expected, result) def test_mask_secret_parameters_secret_root_array(self): - parameters = [ - "abc", - 123, - True - ] - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS) + parameters = ["abc", 123, True] + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS + ) expected = MASKED_ATTRIBUTE_VALUE self.assertEqual(expected, result) def test_mask_secret_parameters_secret_root_object(self): - parameters = { - 'arg_level_one': - { - 'secret_field_in_object': 'Secret $tr!ng' - } - } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS) + parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}} + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS + ) expected = MASKED_ATTRIBUTE_VALUE self.assertEqual(expected, result) def test_mask_secret_parameters_secret_nested_arrays(self): parameters = { - 'arg_optional_array': [ - 'secret 1', - 'secret 2', - 'secret 3', + "arg_optional_array": [ + "secret 1", + "secret 2", + "secret 3", ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 7', - 'secret 8', - 'secret 9', - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'secret 10', - 'secret 11' - ], - [ - 'secret 12', - 'secret 13', - 'secret 14' - ] + "secret 7", + "secret 8", + "secret 9", ], - [ - [ - 'secret 15', - 'secret 16' - ] - ] ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'secret 17', - 'secret 18' - ], - [ - 'secret 19' - ] - ] - ] - ] + "arg_optional_tripple_array": [ + [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]], + [["secret 15", "secret 16"]], + ], + "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]], } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS + ) expected = { - 'arg_optional_array': MASKED_ATTRIBUTE_VALUE, - 'arg_optional_double_array': MASKED_ATTRIBUTE_VALUE, - 'arg_optional_tripple_array': [ + "arg_optional_array": MASKED_ATTRIBUTE_VALUE, + "arg_optional_double_array": MASKED_ATTRIBUTE_VALUE, + "arg_optional_tripple_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, ], - 'arg_optional_quad_array': [ + "arg_optional_quad_array": [ [ MASKED_ATTRIBUTE_VALUE, ] - ] + ], } self.assertEqual(expected, result) def test_mask_secret_parameters_secret_nested_objects(self): parameters = { - 'arg_object': { - 'arg_nested_secret': 'nested Secret', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } + "arg_object": { + "arg_nested_secret": "nested Secret", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, + }, + "arg_secret_object": { + "arg_nested_secret": "secret data", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, }, - 'arg_secret_object': { - 'arg_nested_secret': 'secret data', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } - } } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS + ) expected = { - 'arg_object': { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE, - 'arg_nested_object': MASKED_ATTRIBUTE_VALUE, + "arg_object": { + "arg_nested_secret": MASKED_ATTRIBUTE_VALUE, + "arg_nested_object": MASKED_ATTRIBUTE_VALUE, }, - 'arg_secret_object': MASKED_ATTRIBUTE_VALUE, + "arg_secret_object": MASKED_ATTRIBUTE_VALUE, } self.assertEqual(expected, result) diff --git a/st2common/tests/unit/test_util_shell.py b/st2common/tests/unit/test_util_shell.py index 86c37f2ad1..4a2a00e343 100644 --- a/st2common/tests/unit/test_util_shell.py +++ b/st2common/tests/unit/test_util_shell.py @@ -23,38 +23,26 @@ class ShellUtilsTestCase(unittest2.TestCase): def test_quote_unix(self): - arguments = [ - 'foo', - 'foo bar', - 'foo1 bar1', - '"foo"', - '"foo" "bar"', - "'foo bar'" - ] + arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"] expected_values = [ """ foo """, - """ 'foo bar' """, - """ 'foo1 bar1' """, - """ '"foo"' """, - """ '"foo" "bar"' """, - """ ''"'"'foo bar'"'"'' - """ + """, ] for argument, expected_value in zip(arguments, expected_values): @@ -63,38 +51,26 @@ def test_quote_unix(self): self.assertEqual(actual_value, expected_value.strip()) def test_quote_windows(self): - arguments = [ - 'foo', - 'foo bar', - 'foo1 bar1', - '"foo"', - '"foo" "bar"', - "'foo bar'" - ] + arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"] expected_values = [ """ foo """, - """ "foo bar" """, - """ "foo1 bar1" """, - """ \\"foo\\" """, - """ "\\"foo\\" \\"bar\\"" """, - """ "'foo bar'" - """ + """, ] for argument, expected_value in zip(arguments, expected_values): diff --git a/st2common/tests/unit/test_util_templating.py b/st2common/tests/unit/test_util_templating.py index 1756590bc1..c6cd539849 100644 --- a/st2common/tests/unit/test_util_templating.py +++ b/st2common/tests/unit/test_util_templating.py @@ -26,41 +26,45 @@ def setUp(self): super(TemplatingUtilsTestCase, self).setUp() # Insert mock DB objects - kvp_1_db = KeyValuePairDB(name='key1', value='valuea') + kvp_1_db = KeyValuePairDB(name="key1", value="valuea") kvp_1_db = KeyValuePair.add_or_update(kvp_1_db) - kvp_2_db = KeyValuePairDB(name='key2', value='valueb') + kvp_2_db = KeyValuePairDB(name="key2", value="valueb") kvp_2_db = KeyValuePair.add_or_update(kvp_2_db) - kvp_3_db = KeyValuePairDB(name='stanley:key1', value='valuestanley1', scope=FULL_USER_SCOPE) + kvp_3_db = KeyValuePairDB( + name="stanley:key1", value="valuestanley1", scope=FULL_USER_SCOPE + ) kvp_3_db = KeyValuePair.add_or_update(kvp_3_db) - kvp_4_db = KeyValuePairDB(name='joe:key1', value='valuejoe1', scope=FULL_USER_SCOPE) + kvp_4_db = KeyValuePairDB( + name="joe:key1", value="valuejoe1", scope=FULL_USER_SCOPE + ) kvp_4_db = KeyValuePair.add_or_update(kvp_4_db) def test_render_template_with_system_and_user_context(self): # 1. No reference to the user inside the template - template = '{{st2kv.system.key1}}' - user = 'stanley' + template = "{{st2kv.system.key1}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuea') + self.assertEqual(result, "valuea") - template = '{{st2kv.system.key2}}' - user = 'stanley' + template = "{{st2kv.system.key2}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valueb') + self.assertEqual(result, "valueb") # 2. Reference to the user inside the template - template = '{{st2kv.user.key1}}' - user = 'stanley' + template = "{{st2kv.user.key1}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuestanley1') + self.assertEqual(result, "valuestanley1") - template = '{{st2kv.user.key1}}' - user = 'joe' + template = "{{st2kv.user.key1}}" + user = "joe" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuejoe1') + self.assertEqual(result, "valuejoe1") diff --git a/st2common/tests/unit/test_util_types.py b/st2common/tests/unit/test_util_types.py index 1213eb69d1..8b7ef78864 100644 --- a/st2common/tests/unit/test_util_types.py +++ b/st2common/tests/unit/test_util_types.py @@ -17,9 +17,7 @@ from st2common.util.types import OrderedSet -__all__ = [ - 'OrderedTestTypeTestCase' -] +__all__ = ["OrderedTestTypeTestCase"] class OrderedTestTypeTestCase(unittest2.TestCase): diff --git a/st2common/tests/unit/test_util_url.py b/st2common/tests/unit/test_util_url.py index 551aed3e8c..8b23619593 100644 --- a/st2common/tests/unit/test_util_url.py +++ b/st2common/tests/unit/test_util_url.py @@ -23,16 +23,16 @@ class URLUtilsTestCase(unittest2.TestCase): def test_get_url_without_trailing_slash(self): values = [ - 'http://localhost:1818/foo/bar/', - 'http://localhost:1818/foo/bar', - 'http://localhost:1818/', - 'http://localhost:1818', + "http://localhost:1818/foo/bar/", + "http://localhost:1818/foo/bar", + "http://localhost:1818/", + "http://localhost:1818", ] expected = [ - 'http://localhost:1818/foo/bar', - 'http://localhost:1818/foo/bar', - 'http://localhost:1818', - 'http://localhost:1818', + "http://localhost:1818/foo/bar", + "http://localhost:1818/foo/bar", + "http://localhost:1818", + "http://localhost:1818", ] for value, expected_result in zip(values, expected): diff --git a/st2common/tests/unit/test_versioning_utils.py b/st2common/tests/unit/test_versioning_utils.py index 73d118aa89..de7bbbfeaf 100644 --- a/st2common/tests/unit/test_versioning_utils.py +++ b/st2common/tests/unit/test_versioning_utils.py @@ -23,40 +23,40 @@ class VersioningUtilsTestCase(unittest2.TestCase): def test_complex_semver_match(self): # Positive test case - self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.0.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.1.9', '>=1.6.0, <2.2.0')) + self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.0.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.1.9", ">=1.6.0, <2.2.0")) - self.assertTrue(complex_semver_match('1.6.0', 'all')) - self.assertTrue(complex_semver_match('1.6.1', 'all')) - self.assertTrue(complex_semver_match('2.0.0', 'all')) - self.assertTrue(complex_semver_match('2.1.0', 'all')) + self.assertTrue(complex_semver_match("1.6.0", "all")) + self.assertTrue(complex_semver_match("1.6.1", "all")) + self.assertTrue(complex_semver_match("2.0.0", "all")) + self.assertTrue(complex_semver_match("2.1.0", "all")) - self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0')) - self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0')) - self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0')) + self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0")) + self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0")) + self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0")) # Negative test case - self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('2.2.1', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('2.3.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('3.0.0', '>=1.6.0, <2.2.0')) + self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("2.2.1", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("2.3.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("3.0.0", ">=1.6.0, <2.2.0")) - self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0')) - self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0')) - self.assertFalse(complex_semver_match('1.5.9', '>=1.6.0')) + self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0")) + self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0")) + self.assertFalse(complex_semver_match("1.5.9", ">=1.6.0")) def test_normalize_pack_version(self): # Already a valid semver version string - self.assertEqual(normalize_pack_version('0.2.0'), '0.2.0') - self.assertEqual(normalize_pack_version('0.2.1'), '0.2.1') - self.assertEqual(normalize_pack_version('1.2.1'), '1.2.1') + self.assertEqual(normalize_pack_version("0.2.0"), "0.2.0") + self.assertEqual(normalize_pack_version("0.2.1"), "0.2.1") + self.assertEqual(normalize_pack_version("1.2.1"), "1.2.1") # Not a valid semver version string - self.assertEqual(normalize_pack_version('0.2'), '0.2.0') - self.assertEqual(normalize_pack_version('0.3'), '0.3.0') - self.assertEqual(normalize_pack_version('1.3'), '1.3.0') - self.assertEqual(normalize_pack_version('2.0'), '2.0.0') + self.assertEqual(normalize_pack_version("0.2"), "0.2.0") + self.assertEqual(normalize_pack_version("0.3"), "0.3.0") + self.assertEqual(normalize_pack_version("1.3"), "1.3.0") + self.assertEqual(normalize_pack_version("2.0"), "2.0.0") diff --git a/st2common/tests/unit/test_virtualenvs.py b/st2common/tests/unit/test_virtualenvs.py index 90c0f4e989..439801f67a 100644 --- a/st2common/tests/unit/test_virtualenvs.py +++ b/st2common/tests/unit/test_virtualenvs.py @@ -30,30 +30,28 @@ from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'VirtualenvUtilsTestCase' -] +__all__ = ["VirtualenvUtilsTestCase"] # Note: We set base requirements to an empty list to speed up the tests -@mock.patch('st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS', []) +@mock.patch("st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS", []) class VirtualenvUtilsTestCase(CleanFilesTestCase): def setUp(self): super(VirtualenvUtilsTestCase, self).setUp() config.parse_args() dir_path = tempfile.mkdtemp() - cfg.CONF.set_override(name='base_path', override=dir_path, group='system') + cfg.CONF.set_override(name="base_path", override=dir_path, group="system") self.base_path = dir_path - self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/') + self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/") # Make sure dir is deleted on tearDown self.to_delete_directories.append(self.base_path) def test_setup_pack_virtualenv_doesnt_exist_yet(self): # Test a fresh virtualenv creation - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist @@ -61,58 +59,81 @@ def test_setup_pack_virtualenv_doesnt_exist_yet(self): # Create virtualenv # Note: This pack has no requirements - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_pack_virtualenv_already_exists(self): # Test a scenario where virtualenv already exists - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist self.assertFalse(os.path.exists(pack_virtualenv_dir)) # Create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) # Re-create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify virtrualenv is still there self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_virtualenv_update(self): # Test a virtualenv update with pack which has requirements.txt - pack_name = 'dummy_pack_2' + pack_name = "dummy_pack_2" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist self.assertFalse(os.path.exists(pack_virtualenv_dir)) # Create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) # Update it - setup_pack_virtualenv(pack_name=pack_name, update=True, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=True, + include_setuptools=False, + include_wheel=False, + ) # Verify virtrualenv is still there self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_virtualenv_invalid_dependency_in_requirements_file(self): - pack_name = 'pack_invalid_requirements' + pack_name = "pack_invalid_requirements" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist @@ -120,182 +141,240 @@ def test_setup_virtualenv_invalid_dependency_in_requirements_file(self): # Try to create virtualenv, assert that it fails try: - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_setuptools=False, + include_wheel=False, + ) except Exception as e: - self.assertIn('Failed to install requirements from', six.text_type(e)) - self.assertTrue('No matching distribution found for someinvalidname' in - six.text_type(e)) + self.assertIn("Failed to install requirements from", six.text_type(e)) + self.assertTrue( + "No matching distribution found for someinvalidname" in six.text_type(e) + ) else: - self.fail('Exception not thrown') - - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + self.fail("Exception not thrown") + + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_without_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" install_requirement(pack_virtualenv_dir, requirement, proxy_config=None) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_http_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' - proxy_config = { - 'http_proxy': 'http://192.168.1.5:8080' - } + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" + proxy_config = {"http_proxy": "http://192.168.1.5:8080"} install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'http://192.168.1.5:8080', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "http://192.168.1.5:8080", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_https_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', - 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem' + "https_proxy": "https://192.168.1.5:8080", + "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem", } install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - '--cert', '/etc/ssl/certs/mitmproxy-ca.pem', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "--cert", + "/etc/ssl/certs/mitmproxy-ca.pem", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_https_proxy_no_cert(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', + "https_proxy": "https://192.168.1.5:8080", } install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_without_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' - install_requirements(pack_virtualenv_dir, requirements_file_path, proxy_config=None) + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=None + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_http_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' - proxy_config = { - 'http_proxy': 'http://192.168.1.5:8080' - } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) + proxy_config = {"http_proxy": "http://192.168.1.5:8080"} + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'http://192.168.1.5:8080', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "http://192.168.1.5:8080", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_https_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', - 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem' + "https_proxy": "https://192.168.1.5:8080", + "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem", } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - '--cert', '/etc/ssl/certs/mitmproxy-ca.pem', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "--cert", + "/etc/ssl/certs/mitmproxy-ca.pem", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_https_proxy_no_cert(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', + "https_proxy": "https://192.168.1.5:8080", } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) def assertVirtualenvExists(self, virtualenv_dir): self.assertTrue(os.path.exists(virtualenv_dir)) self.assertTrue(os.path.isdir(virtualenv_dir)) - self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, 'bin/'))) + self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, "bin/"))) return True diff --git a/st2exporter/dist_utils.py b/st2exporter/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2exporter/dist_utils.py +++ b/st2exporter/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2exporter/setup.py b/st2exporter/setup.py index bfd01f7061..afaae79cac 100644 --- a/st2exporter/setup.py +++ b/st2exporter/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2exporter import __version__ -ST2_COMPONENT = 'st2exporter' +ST2_COMPONENT = "st2exporter" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2exporter' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2exporter"], ) diff --git a/st2exporter/st2exporter/cmd/st2exporter_starter.py b/st2exporter/st2exporter/cmd/st2exporter_starter.py index c5ce157e24..2b86ef2707 100644 --- a/st2exporter/st2exporter/cmd/st2exporter_starter.py +++ b/st2exporter/st2exporter/cmd/st2exporter_starter.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -25,26 +26,29 @@ from st2exporter import config from st2exporter import worker -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup(): - common_setup(service='exporter', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True) + common_setup( + service="exporter", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + ) def _run_worker(): - LOG.info('(PID=%s) Exporter started.', os.getpid()) + LOG.info("(PID=%s) Exporter started.", os.getpid()) export_worker = worker.get_worker() try: export_worker.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Exporter stopped.', os.getpid()) + LOG.info("(PID=%s) Exporter stopped.", os.getpid()) export_worker.shutdown() except: return 1 @@ -62,7 +66,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Exporter quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Exporter quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2exporter/st2exporter/config.py b/st2exporter/st2exporter/config.py index 456b09e365..83f4f45d5d 100644 --- a/st2exporter/st2exporter/config.py +++ b/st2exporter/st2exporter/config.py @@ -31,8 +31,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def get_logging_config_path(): @@ -51,16 +54,20 @@ def _register_common_opts(): def _register_app_opts(): dump_opts = [ cfg.StrOpt( - 'dump_dir', default='/opt/stackstorm/exports/', - help='Directory to dump data to.') + "dump_dir", + default="/opt/stackstorm/exports/", + help="Directory to dump data to.", + ) ] - CONF.register_opts(dump_opts, group='exporter') + CONF.register_opts(dump_opts, group="exporter") logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.exporter.conf', - help='location of the logging.exporter.conf file') + "logging", + default="/etc/st2/logging.exporter.conf", + help="location of the logging.exporter.conf file", + ) ] - CONF.register_opts(logging_opts, group='exporter') + CONF.register_opts(logging_opts, group="exporter") diff --git a/st2exporter/st2exporter/exporter/dumper.py b/st2exporter/st2exporter/exporter/dumper.py index 2059557420..12fbeb4f83 100644 --- a/st2exporter/st2exporter/exporter/dumper.py +++ b/st2exporter/st2exporter/exporter/dumper.py @@ -26,40 +26,43 @@ from st2common.util import date as date_utils from st2common.util import isotime -__all__ = [ - 'Dumper' -] +__all__ = ["Dumper"] -ALLOWED_EXTENSIONS = ['json'] +ALLOWED_EXTENSIONS = ["json"] -CONVERTERS = { - 'json': JsonConverter -} +CONVERTERS = {"json": JsonConverter} LOG = logging.getLogger(__name__) class Dumper(object): - - def __init__(self, queue, export_dir, file_format='json', - file_prefix='st2-executions-', - batch_size=1000, sleep_interval=60, - max_files_per_sleep=5, - file_writer=None): + def __init__( + self, + queue, + export_dir, + file_format="json", + file_prefix="st2-executions-", + batch_size=1000, + sleep_interval=60, + max_files_per_sleep=5, + file_writer=None, + ): if not queue: - raise Exception('Need a queue to consume data from.') + raise Exception("Need a queue to consume data from.") if not export_dir: - raise Exception('Export dir needed to dump files to.') + raise Exception("Export dir needed to dump files to.") self._export_dir = export_dir if not os.path.exists(self._export_dir): - raise Exception('Dir path %s does not exist. Create one before using exporter.' % - self._export_dir) + raise Exception( + "Dir path %s does not exist. Create one before using exporter." + % self._export_dir + ) self._file_format = file_format.lower() if self._file_format not in ALLOWED_EXTENSIONS: - raise ValueError('Disallowed extension %s.' % file_format) + raise ValueError("Disallowed extension %s." % file_format) self._file_prefix = file_prefix self._batch_size = batch_size @@ -99,8 +102,8 @@ def _get_batch(self): else: executions_to_write.append(item) - LOG.debug('Returning %d items in batch.', len(executions_to_write)) - LOG.debug('Remaining items in queue: %d', self._queue.qsize()) + LOG.debug("Returning %d items in batch.", len(executions_to_write)) + LOG.debug("Remaining items in queue: %d", self._queue.qsize()) return executions_to_write def _flush(self): @@ -111,7 +114,7 @@ def _flush(self): try: self._write_to_disk() except: - LOG.error('Failed writing data to disk.') + LOG.error("Failed writing data to disk.") def _write_to_disk(self): count = 0 @@ -128,7 +131,7 @@ def _write_to_disk(self): self._update_marker(batch) count += 1 except: - LOG.exception('Writing batch to disk failed.') + LOG.exception("Writing batch to disk failed.") return count def _create_date_folder(self): @@ -139,7 +142,7 @@ def _create_date_folder(self): try: os.makedirs(folder_path) except: - LOG.exception('Unable to create sub-folder %s for export.', folder_name) + LOG.exception("Unable to create sub-folder %s for export.", folder_name) raise def _write_batch_to_disk(self, batch): @@ -147,42 +150,44 @@ def _write_batch_to_disk(self, batch): self._file_writer.write_text(doc_to_write, self._get_file_name()) def _get_file_name(self): - timestring = date_utils.get_datetime_utc_now().strftime('%Y-%m-%dT%H:%M:%S.%fZ') - file_name = self._file_prefix + timestring + '.' + self._file_format + timestring = date_utils.get_datetime_utc_now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") + file_name = self._file_prefix + timestring + "." + self._file_format file_name = os.path.join(self._export_dir, self._get_date_folder(), file_name) return file_name def _get_date_folder(self): - return date_utils.get_datetime_utc_now().strftime('%Y-%m-%d') + return date_utils.get_datetime_utc_now().strftime("%Y-%m-%d") def _update_marker(self, batch): timestamps = [isotime.parse(item.end_timestamp) for item in batch] new_marker = max(timestamps) if self._persisted_marker and self._persisted_marker > new_marker: - LOG.warn('Older executions are being exported. Perhaps out of order messages.') + LOG.warn( + "Older executions are being exported. Perhaps out of order messages." + ) try: self._write_marker_to_db(new_marker) except: - LOG.exception('Failed persisting dumper marker to db.') + LOG.exception("Failed persisting dumper marker to db.") else: self._persisted_marker = new_marker return self._persisted_marker def _write_marker_to_db(self, new_marker): - LOG.info('Updating marker in db to: %s', new_marker) + LOG.info("Updating marker in db to: %s", new_marker) markers = DumperMarker.get_all() if len(markers) > 1: - LOG.exception('More than one dumper marker found. Using first found one.') + LOG.exception("More than one dumper marker found. Using first found one.") marker = isotime.format(new_marker, offset=False) updated_at = date_utils.get_datetime_utc_now() if markers: - marker_id = markers[0]['id'] + marker_id = markers[0]["id"] else: marker_id = None diff --git a/st2exporter/st2exporter/exporter/file_writer.py b/st2exporter/st2exporter/exporter/file_writer.py index ec7e4d876c..49b5b4d63a 100644 --- a/st2exporter/st2exporter/exporter/file_writer.py +++ b/st2exporter/st2exporter/exporter/file_writer.py @@ -18,15 +18,11 @@ import abc import six -__all__ = [ - 'FileWriter', - 'TextFileWriter' -] +__all__ = ["FileWriter", "TextFileWriter"] @six.add_metaclass(abc.ABCMeta) class FileWriter(object): - @abc.abstractmethod def write(self, data, file_path, replace=False): """ @@ -40,13 +36,13 @@ class TextFileWriter(FileWriter): def write_text(self, text_data, file_path, replace=False, compressed=False): if compressed: - return Exception('Compression not supported.') + return Exception("Compression not supported.") self.write(text_data, file_path, replace=replace) def write(self, data, file_path, replace=False): if os.path.exists(file_path) and not replace: - raise Exception('File %s already exists.' % file_path) + raise Exception("File %s already exists." % file_path) - with open(file_path, 'w') as f: + with open(file_path, "w") as f: f.write(data) diff --git a/st2exporter/st2exporter/exporter/json_converter.py b/st2exporter/st2exporter/exporter/json_converter.py index a288197d41..ba7e95c0a5 100644 --- a/st2exporter/st2exporter/exporter/json_converter.py +++ b/st2exporter/st2exporter/exporter/json_converter.py @@ -15,15 +15,12 @@ from st2common.util.jsonify import json_encode -__all__ = [ - 'JsonConverter' -] +__all__ = ["JsonConverter"] class JsonConverter(object): - def convert(self, items_list): if not isinstance(items_list, list): - raise ValueError('Items to be converted should be a list.') + raise ValueError("Items to be converted should be a list.") json_doc = json_encode(items_list) return json_doc diff --git a/st2exporter/st2exporter/worker.py b/st2exporter/st2exporter/worker.py index 13273fd587..a5557ee41f 100644 --- a/st2exporter/st2exporter/worker.py +++ b/st2exporter/st2exporter/worker.py @@ -18,8 +18,11 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_CANCELED) +from st2common.constants.action import ( + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, + LIVEACTION_STATUS_CANCELED, +) from st2common.models.api.execution import ActionExecutionAPI from st2common.models.db.execution import ActionExecutionDB from st2common.persistence.execution import ActionExecution @@ -30,13 +33,13 @@ from st2exporter.exporter.dumper import Dumper from st2common.transport.queues import EXPORTER_WORK_QUEUE -__all__ = [ - 'ExecutionsExporter', - 'get_worker' -] +__all__ = ["ExecutionsExporter", "get_worker"] -COMPLETION_STATUSES = [LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_CANCELED] +COMPLETION_STATUSES = [ + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, + LIVEACTION_STATUS_CANCELED, +] LOG = logging.getLogger(__name__) @@ -46,18 +49,21 @@ class ExecutionsExporter(consumers.MessageHandler): def __init__(self, connection, queues): super(ExecutionsExporter, self).__init__(connection, queues) self.pending_executions = queue.Queue() - self._dumper = Dumper(queue=self.pending_executions, - export_dir=cfg.CONF.exporter.dump_dir) + self._dumper = Dumper( + queue=self.pending_executions, export_dir=cfg.CONF.exporter.dump_dir + ) self._consumer_thread = None def start(self, wait=False): - LOG.info('Bootstrapping executions from db...') + LOG.info("Bootstrapping executions from db...") try: self._bootstrap() except: - LOG.exception('Unable to bootstrap executions from db. Aborting.') + LOG.exception("Unable to bootstrap executions from db. Aborting.") raise - self._consumer_thread = eventlet.spawn(super(ExecutionsExporter, self).start, wait=True) + self._consumer_thread = eventlet.spawn( + super(ExecutionsExporter, self).start, wait=True + ) self._dumper.start() if wait: self.wait() @@ -71,7 +77,7 @@ def shutdown(self): super(ExecutionsExporter, self).shutdown() def process(self, execution): - LOG.debug('Got execution from queue: %s', execution) + LOG.debug("Got execution from queue: %s", execution) if execution.status not in COMPLETION_STATUSES: return execution_api = ActionExecutionAPI.from_model(execution, mask_secrets=True) @@ -80,21 +86,23 @@ def process(self, execution): def _bootstrap(self): marker = self._get_export_marker_from_db() - LOG.info('Using marker %s...' % marker) + LOG.info("Using marker %s..." % marker) missed_executions = self._get_missed_executions_from_db(export_marker=marker) - LOG.info('Found %d executions not exported yet...', len(missed_executions)) + LOG.info("Found %d executions not exported yet...", len(missed_executions)) for missed_execution in missed_executions: if missed_execution.status not in COMPLETION_STATUSES: continue - execution_api = ActionExecutionAPI.from_model(missed_execution, mask_secrets=True) + execution_api = ActionExecutionAPI.from_model( + missed_execution, mask_secrets=True + ) try: - LOG.debug('Missed execution %s', execution_api) + LOG.debug("Missed execution %s", execution_api) self.pending_executions.put_nowait(execution_api) except: - LOG.exception('Failed adding execution to in-memory queue.') + LOG.exception("Failed adding execution to in-memory queue.") continue - LOG.info('Bootstrapped executions...') + LOG.info("Bootstrapped executions...") def _get_export_marker_from_db(self): try: @@ -114,8 +122,8 @@ def _get_missed_executions_from_db(self, export_marker=None): # XXX: Should adapt this query to get only executions with status # in COMPLETION_STATUSES. - filters = {'end_timestamp__gt': export_marker} - LOG.info('Querying for executions with filters: %s', filters) + filters = {"end_timestamp__gt": export_marker} + LOG.info("Querying for executions with filters: %s", filters) return ActionExecution.query(**filters) def _get_all_executions_from_db(self): diff --git a/st2exporter/tests/integration/test_dumper_integration.py b/st2exporter/tests/integration/test_dumper_integration.py index bdb87b1249..0de7b91ed0 100644 --- a/st2exporter/tests/integration/test_dumper_integration.py +++ b/st2exporter/tests/integration/test_dumper_integration.py @@ -28,21 +28,30 @@ from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestDumper(DbTestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - loaded_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + loaded_executions = loaded_fixtures["executions"] execution_apis = [] for execution in loaded_executions.values(): execution_apis.append(ActionExecutionAPI(**execution)) @@ -54,31 +63,45 @@ def get_queue(self): executions_queue.put(execution) return executions_queue - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_marker_to_db(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) marker_db = dumper._write_marker_to_db(max_timestamp) persisted_marker = marker_db.marker self.assertIsInstance(persisted_marker, six.string_types) self.assertEqual(isotime.parse(persisted_marker), max_timestamp) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_marker_to_db_marker_exists(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) first_marker_db = dumper._write_marker_to_db(max_timestamp) - second_marker_db = dumper._write_marker_to_db(max_timestamp + datetime.timedelta(hours=1)) + second_marker_db = dumper._write_marker_to_db( + max_timestamp + datetime.timedelta(hours=1) + ) markers = DumperMarker.get_all() self.assertEqual(len(markers), 1) final_marker_id = markers[0].id diff --git a/st2exporter/tests/integration/test_export_worker.py b/st2exporter/tests/integration/test_export_worker.py index 8b0caf7d86..9237aab0e8 100644 --- a/st2exporter/tests/integration/test_export_worker.py +++ b/st2exporter/tests/integration/test_export_worker.py @@ -27,75 +27,92 @@ from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader import st2tests.config as tests_config + tests_config.parse_args() -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestExportWorker(DbTestCase): - @classmethod def setUpClass(cls): super(TestExportWorker, cls).setUpClass() fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - TestExportWorker.saved_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + TestExportWorker.saved_executions = loaded_fixtures["executions"] - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_marker_from_db(self): marker_dt = date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=5) - marker_db = DumperMarkerDB(marker=isotime.format(marker_dt, offset=False), - updated_at=date_utils.get_datetime_utc_now()) + marker_db = DumperMarkerDB( + marker=isotime.format(marker_dt, offset=False), + updated_at=date_utils.get_datetime_utc_now(), + ) DumperMarker.add_or_update(marker_db) exec_exporter = ExecutionsExporter(None, None) export_marker = exec_exporter._get_export_marker_from_db() self.assertEqual(export_marker, date_utils.add_utc_tz(marker_dt)) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_missed_executions_from_db_no_marker(self): exec_exporter = ExecutionsExporter(None, None) all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None) self.assertEqual(len(all_execs), len(self.saved_executions.values())) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_missed_executions_from_db_with_marker(self): exec_exporter = ExecutionsExporter(None, None) all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None) min_timestamp = min([item.end_timestamp for item in all_execs]) marker = min_timestamp + datetime.timedelta(seconds=1) - execs_greater_than_marker = [item for item in all_execs if item.end_timestamp > marker] + execs_greater_than_marker = [ + item for item in all_execs if item.end_timestamp > marker + ] all_execs = exec_exporter._get_missed_executions_from_db(export_marker=marker) self.assertTrue(len(all_execs) > 0) self.assertTrue(len(all_execs) == len(execs_greater_than_marker)) for item in all_execs: self.assertTrue(item.end_timestamp > marker) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_bootstrap(self): exec_exporter = ExecutionsExporter(None, None) exec_exporter._bootstrap() - self.assertEqual(exec_exporter.pending_executions.qsize(), len(self.saved_executions)) + self.assertEqual( + exec_exporter.pending_executions.qsize(), len(self.saved_executions) + ) count = 0 while count < exec_exporter.pending_executions.qsize(): - self.assertIsInstance(exec_exporter.pending_executions.get(), ActionExecutionAPI) + self.assertIsInstance( + exec_exporter.pending_executions.get(), ActionExecutionAPI + ) count += 1 - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_process(self): some_execution = list(self.saved_executions.values())[5] exec_exporter = ExecutionsExporter(None, None) self.assertEqual(exec_exporter.pending_executions.qsize(), 0) exec_exporter.process(some_execution) self.assertEqual(exec_exporter.pending_executions.qsize(), 1) - some_execution.status = 'scheduled' + some_execution.status = "scheduled" exec_exporter.process(some_execution) self.assertEqual(exec_exporter.pending_executions.qsize(), 1) diff --git a/st2exporter/tests/unit/test_dumper.py b/st2exporter/tests/unit/test_dumper.py index 98e42e60f1..0ddec72e3b 100644 --- a/st2exporter/tests/unit/test_dumper.py +++ b/st2exporter/tests/unit/test_dumper.py @@ -28,21 +28,30 @@ from st2tests.fixturesloader import FixturesLoader from st2common.util import date as date_utils -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestDumper(EventletTestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - loaded_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + loaded_executions = loaded_fixtures["executions"] execution_apis = [] for execution in loaded_executions.values(): execution_apis.append(ActionExecutionAPI(**execution)) @@ -54,81 +63,101 @@ def get_queue(self): executions_queue.put(execution) return executions_queue - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_batch_batch_size_greater_than_actual(self): executions_queue = self.get_queue() qsize = executions_queue.qsize() self.assertTrue(qsize > 0) - dumper = Dumper(queue=executions_queue, batch_size=2 * qsize, - export_dir='/tmp') + dumper = Dumper(queue=executions_queue, batch_size=2 * qsize, export_dir="/tmp") batch = dumper._get_batch() self.assertEqual(len(batch), qsize) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_batch_batch_size_lesser_than_actual(self): executions_queue = self.get_queue() qsize = executions_queue.qsize() self.assertTrue(qsize > 0) expected_batch_size = int(qsize / 2) - dumper = Dumper(queue=executions_queue, - batch_size=expected_batch_size, - export_dir='/tmp') + dumper = Dumper( + queue=executions_queue, batch_size=expected_batch_size, export_dir="/tmp" + ) batch = dumper._get_batch() self.assertEqual(len(batch), expected_batch_size) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_file_name(self): - dumper = Dumper(queue=self.get_queue(), - export_dir='/tmp', - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=self.get_queue(), + export_dir="/tmp", + file_prefix="st2-stuff-", + file_format="json", + ) file_name = dumper._get_file_name() - export_date = date_utils.get_datetime_utc_now().strftime('%Y-%m-%d') - self.assertTrue(file_name.startswith('/tmp/' + export_date + '/st2-stuff-')) - self.assertTrue(file_name.endswith('json')) + export_date = date_utils.get_datetime_utc_now().strftime("%Y-%m-%d") + self.assertTrue(file_name.startswith("/tmp/" + export_date + "/st2-stuff-")) + self.assertTrue(file_name.endswith("json")) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_to_disk_empty_queue(self): - dumper = Dumper(queue=queue.Queue(), - export_dir='/tmp', - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=queue.Queue(), + export_dir="/tmp", + file_prefix="st2-stuff-", + file_format="json", + ) # We just make sure this doesn't blow up. ret = dumper._write_to_disk() self.assertEqual(ret, 0) - @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_update_marker', mock.MagicMock(return_value=None)) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_update_marker", mock.MagicMock(return_value=None)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_to_disk(self): executions_queue = self.get_queue() max_files_per_sleep = 5 - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=1, max_files_per_sleep=max_files_per_sleep, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=1, + max_files_per_sleep=max_files_per_sleep, + file_prefix="st2-stuff-", + file_format="json", + ) # We just make sure this doesn't blow up. ret = dumper._write_to_disk() self.assertEqual(ret, max_files_per_sleep) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True)) def test_start_stop_dumper(self): executions_queue = self.get_queue() sleep_interval = 0.01 - dumper = Dumper(queue=executions_queue, sleep_interval=sleep_interval, - export_dir='/tmp', batch_size=1, max_files_per_sleep=5, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + sleep_interval=sleep_interval, + export_dir="/tmp", + batch_size=1, + max_files_per_sleep=5, + file_prefix="st2-stuff-", + file_format="json", + ) dumper.start() # Call stop after at least one batch was written to disk. eventlet.sleep(10 * sleep_interval) dumper.stop() - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True)) def test_update_marker(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) # Batch 1 batch = self.execution_apis[0:5] new_marker = dumper._update_marker(batch) @@ -145,15 +174,21 @@ def test_update_marker(self): self.assertEqual(new_marker, max_timestamp) dumper._write_marker_to_db.assert_called_with(new_marker) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True)) def test_update_marker_out_of_order_batch(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) # set dumper persisted timestamp to something less than min timestamp in the batch diff --git a/st2exporter/tests/unit/test_json_converter.py b/st2exporter/tests/unit/test_json_converter.py index ce2f484bca..07f82a8bf0 100644 --- a/st2exporter/tests/unit/test_json_converter.py +++ b/st2exporter/tests/unit/test_json_converter.py @@ -20,34 +20,43 @@ from st2tests.fixturesloader import FixturesLoader from st2exporter.exporter.json_converter import JsonConverter -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestJsonConverter(unittest2.TestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_convert(self): - executions_list = list(self.loaded_fixtures['executions'].values()) + executions_list = list(self.loaded_fixtures["executions"].values()) converter = JsonConverter() converted_doc = converter.convert(executions_list) - self.assertTrue(type(converted_doc), 'string') + self.assertTrue(type(converted_doc), "string") reversed_doc = json.loads(converted_doc) self.assertListEqual(executions_list, reversed_doc) def test_convert_non_list(self): - executions_dict = self.loaded_fixtures['executions'] + executions_dict = self.loaded_fixtures["executions"] converter = JsonConverter() try: converter.convert(executions_dict) - self.fail('Should have thrown exception.') + self.fail("Should have thrown exception.") except ValueError: pass diff --git a/st2reactor/dist_utils.py b/st2reactor/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2reactor/dist_utils.py +++ b/st2reactor/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2reactor/setup.py b/st2reactor/setup.py index 0379240b8f..adb3e7accc 100644 --- a/st2reactor/setup.py +++ b/st2reactor/setup.py @@ -23,9 +23,9 @@ from dist_utils import apply_vagrant_workaround from st2reactor import __version__ -ST2_COMPONENT = 'st2reactor' +ST2_COMPONENT = "st2reactor" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,23 +33,25 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2-rule-tester', - 'bin/st2-trigger-refire', - 'bin/st2rulesengine', - 'bin/st2sensorcontainer', - 'bin/st2garbagecollector', - 'bin/st2timersengine', - ] + "bin/st2-rule-tester", + "bin/st2-trigger-refire", + "bin/st2rulesengine", + "bin/st2sensorcontainer", + "bin/st2garbagecollector", + "bin/st2timersengine", + ], ) diff --git a/st2reactor/st2reactor/__init__.py b/st2reactor/st2reactor/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2reactor/st2reactor/__init__.py +++ b/st2reactor/st2reactor/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2reactor/st2reactor/cmd/garbagecollector.py b/st2reactor/st2reactor/cmd/garbagecollector.py index ab3c64409b..b4be4dfa8b 100644 --- a/st2reactor/st2reactor/cmd/garbagecollector.py +++ b/st2reactor/st2reactor/cmd/garbagecollector.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -31,9 +32,7 @@ from st2reactor.garbage_collector import config from st2reactor.garbage_collector.base import GarbageCollectorService -__all__ = [ - 'main' -] +__all__ = ["main"] LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__]) @@ -41,14 +40,17 @@ def _setup(): - capabilities = { - 'name': 'garbagecollector', - 'type': 'passive' - } - common_setup(service='garbagecollector', config=config, setup_db=True, - register_mq_exchanges=True, register_signal_handlers=True, - register_runners=False, service_registry=True, - capabilities=capabilities) + capabilities = {"name": "garbagecollector", "type": "passive"} + common_setup( + service="garbagecollector", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -61,13 +63,14 @@ def main(): collection_interval = cfg.CONF.garbagecollector.collection_interval sleep_delay = cfg.CONF.garbagecollector.sleep_delay - garbage_collector = GarbageCollectorService(collection_interval=collection_interval, - sleep_delay=sleep_delay) + garbage_collector = GarbageCollectorService( + collection_interval=collection_interval, sleep_delay=sleep_delay + ) exit_code = garbage_collector.run() except SystemExit as exit_code: return exit_code except: - LOG.exception('(PID:%s) GarbageCollector quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) GarbageCollector quit due to exception.", os.getpid()) return FAILURE_EXIT_CODE finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/rule_tester.py b/st2reactor/st2reactor/cmd/rule_tester.py index 926a27a4ff..b346168cb5 100644 --- a/st2reactor/st2reactor/cmd/rule_tester.py +++ b/st2reactor/st2reactor/cmd/rule_tester.py @@ -25,23 +25,27 @@ from st2common.script_setup import teardown as common_teardown from st2reactor.rules.tester import RuleTester -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('rule', default=None, - help='Path to the file containing rule definition.'), - cfg.StrOpt('rule-ref', default=None, - help='Ref of the rule.'), - cfg.StrOpt('trigger-instance', default=None, - help='Path to the file containing trigger instance definition'), - cfg.StrOpt('trigger-instance-id', default=None, - help='Id of the Trigger Instance to use for validation.') + cfg.StrOpt( + "rule", default=None, help="Path to the file containing rule definition." + ), + cfg.StrOpt("rule-ref", default=None, help="Ref of the rule."), + cfg.StrOpt( + "trigger-instance", + default=None, + help="Path to the file containing trigger instance definition", + ), + cfg.StrOpt( + "trigger-instance-id", + default=None, + help="Id of the Trigger Instance to use for validation.", + ), ] do_register_cli_opts(cli_opts) @@ -51,17 +55,19 @@ def main(): common_setup(config=config, setup_db=True, register_mq_exchanges=False) try: - tester = RuleTester(rule_file_path=cfg.CONF.rule, - rule_ref=cfg.CONF.rule_ref, - trigger_instance_file_path=cfg.CONF.trigger_instance, - trigger_instance_id=cfg.CONF.trigger_instance_id) + tester = RuleTester( + rule_file_path=cfg.CONF.rule, + rule_ref=cfg.CONF.rule_ref, + trigger_instance_file_path=cfg.CONF.trigger_instance, + trigger_instance_id=cfg.CONF.trigger_instance_id, + ) matches = tester.evaluate() finally: common_teardown() if matches: - LOG.info('=== RULE MATCHES ===') + LOG.info("=== RULE MATCHES ===") sys.exit(0) else: - LOG.info('=== RULE DOES NOT MATCH ===') + LOG.info("=== RULE DOES NOT MATCH ===") sys.exit(1) diff --git a/st2reactor/st2reactor/cmd/rulesengine.py b/st2reactor/st2reactor/cmd/rulesengine.py index f372cc252e..895fbe42d9 100644 --- a/st2reactor/st2reactor/cmd/rulesengine.py +++ b/st2reactor/st2reactor/cmd/rulesengine.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -34,13 +35,18 @@ def _setup(): - capabilities = { - 'name': 'rulesengine', - 'type': 'passive' - } - common_setup(service='rulesengine', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=True, - register_runners=False, service_registry=True, capabilities=capabilities) + capabilities = {"name": "rulesengine", "type": "passive"} + common_setup( + service="rulesengine", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -48,7 +54,7 @@ def _teardown(): def _run_worker(): - LOG.info('(PID=%s) RulesEngine started.', os.getpid()) + LOG.info("(PID=%s) RulesEngine started.", os.getpid()) rules_engine_worker = worker.get_worker() @@ -56,10 +62,10 @@ def _run_worker(): rules_engine_worker.start() return rules_engine_worker.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) RulesEngine stopped.', os.getpid()) + LOG.info("(PID=%s) RulesEngine stopped.", os.getpid()) rules_engine_worker.shutdown() except: - LOG.exception('(PID:%s) RulesEngine quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) RulesEngine quit due to exception.", os.getpid()) return 1 return 0 @@ -72,7 +78,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) RulesEngine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) RulesEngine quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/sensormanager.py b/st2reactor/st2reactor/cmd/sensormanager.py index df2be8e7ac..f3d27afb5b 100644 --- a/st2reactor/st2reactor/cmd/sensormanager.py +++ b/st2reactor/st2reactor/cmd/sensormanager.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -33,9 +34,7 @@ from st2reactor.container.manager import SensorContainerManager from st2reactor.container.partitioner_lookup import get_sensors_partitioner -__all__ = [ - 'main' -] +__all__ = ["main"] LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__]) @@ -43,13 +42,17 @@ def _setup(): - capabilities = { - 'name': 'sensorcontainer', - 'type': 'passive' - } - common_setup(service='sensorcontainer', config=config, setup_db=True, - register_mq_exchanges=True, register_signal_handlers=True, - register_runners=False, service_registry=True, capabilities=capabilities) + capabilities = {"name": "sensorcontainer", "type": "passive"} + common_setup( + service="sensorcontainer", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -60,16 +63,21 @@ def main(): try: _setup() - single_sensor_mode = (cfg.CONF.single_sensor_mode or - cfg.CONF.sensorcontainer.single_sensor_mode) + single_sensor_mode = ( + cfg.CONF.single_sensor_mode or cfg.CONF.sensorcontainer.single_sensor_mode + ) if single_sensor_mode and not cfg.CONF.sensor_ref: - raise ValueError('--sensor-ref argument must be provided when running in single ' - 'sensor mode') + raise ValueError( + "--sensor-ref argument must be provided when running in single " + "sensor mode" + ) sensors_partitioner = get_sensors_partitioner() - container_manager = SensorContainerManager(sensors_partitioner=sensors_partitioner, - single_sensor_mode=single_sensor_mode) + container_manager = SensorContainerManager( + sensors_partitioner=sensors_partitioner, + single_sensor_mode=single_sensor_mode, + ) return container_manager.run_sensors() except SystemExit as exit_code: return exit_code @@ -77,7 +85,7 @@ def main(): LOG.exception(e) return 1 except: - LOG.exception('(PID:%s) SensorContainer quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) SensorContainer quit due to exception.", os.getpid()) return FAILURE_EXIT_CODE finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/timersengine.py b/st2reactor/st2reactor/cmd/timersengine.py index 0b0cc4b5dd..9b4edd52b5 100644 --- a/st2reactor/st2reactor/cmd/timersengine.py +++ b/st2reactor/st2reactor/cmd/timersengine.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -38,12 +39,16 @@ def _setup(): - capabilities = { - 'name': 'timerengine', - 'type': 'passive' - } - common_setup(service='timer_engine', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "timerengine", "type": "passive"} + common_setup( + service="timer_engine", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -55,14 +60,16 @@ def _kickoff_timer(timer): def _run_worker(): - LOG.info('(PID=%s) TimerEngine started.', os.getpid()) + LOG.info("(PID=%s) TimerEngine started.", os.getpid()) timer = None try: timer_thread = None if cfg.CONF.timer.enable or cfg.CONF.timersengine.enable: - local_tz = cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone + local_tz = ( + cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone + ) timer = St2Timer(local_timezone=local_tz) timer_thread = concurrency.spawn(_kickoff_timer, timer) LOG.info(TIMER_ENABLED_LOG_LINE) @@ -70,9 +77,9 @@ def _run_worker(): else: LOG.info(TIMER_DISABLED_LOG_LINE) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) TimerEngine stopped.', os.getpid()) + LOG.info("(PID=%s) TimerEngine stopped.", os.getpid()) except: - LOG.exception('(PID:%s) TimerEngine quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) TimerEngine quit due to exception.", os.getpid()) return 1 finally: if timer: @@ -88,7 +95,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) TimerEngine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) TimerEngine quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/trigger_re_fire.py b/st2reactor/st2reactor/cmd/trigger_re_fire.py index 4f2c8f9ca1..8282a5decf 100644 --- a/st2reactor/st2reactor/cmd/trigger_re_fire.py +++ b/st2reactor/st2reactor/cmd/trigger_re_fire.py @@ -27,24 +27,23 @@ from st2common.persistence.trigger import TriggerInstance from st2common.transport.reactor import TriggerDispatcher -__all__ = [ - 'main' -] +__all__ = ["main"] CONF = cfg.CONF def _parse_config(): cli_opts = [ - cfg.BoolOpt('verbose', - short='v', - default=False, - help='Print more verbose output'), - cfg.StrOpt('trigger-instance-id', - short='t', - required=True, - dest='trigger_instance_id', - help='Id of trigger instance'), + cfg.BoolOpt( + "verbose", short="v", default=False, help="Print more verbose output" + ), + cfg.StrOpt( + "trigger-instance-id", + short="t", + required=True, + dest="trigger_instance_id", + help="Id of trigger instance", + ), ] CONF.register_cli_opts(cli_opts) st2cfg.register_opts(ignore_errors=False) @@ -54,22 +53,17 @@ def _parse_config(): def _setup_logging(): logging_config = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'default': { - 'format': '%(asctime)s %(levelname)s %(name)s %(message)s' - }, + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(asctime)s %(levelname)s %(name)s %(message)s"}, }, - 'handlers': { - 'console': { - '()': std_logging.StreamHandler, - 'formatter': 'default' - } + "handlers": { + "console": {"()": std_logging.StreamHandler, "formatter": "default"} }, - 'root': { - 'handlers': ['console'], - 'level': 'DEBUG', + "root": { + "handlers": ["console"], + "level": "DEBUG", }, } std_logging.config.dictConfig(logging_config) @@ -82,8 +76,9 @@ def _setup_db(): def _refire_trigger_instance(trigger_instance_id, log_): trigger_instance = TriggerInstance.get_by_id(trigger_instance_id) trigger_dispatcher = TriggerDispatcher(log_) - trigger_dispatcher.dispatch(trigger=trigger_instance.trigger, - payload=trigger_instance.payload) + trigger_dispatcher.dispatch( + trigger=trigger_instance.trigger, payload=trigger_instance.payload + ) def main(): @@ -94,7 +89,8 @@ def main(): else: output = pprint.pprint _setup_db() - _refire_trigger_instance(trigger_instance_id=CONF.trigger_instance_id, - log_=logging.getLogger(__name__)) - output('Trigger re-fired') + _refire_trigger_instance( + trigger_instance_id=CONF.trigger_instance_id, log_=logging.getLogger(__name__) + ) + output("Trigger re-fired") db_teardown() diff --git a/st2reactor/st2reactor/container/hash_partitioner.py b/st2reactor/st2reactor/container/hash_partitioner.py index 9ed0cb78be..b9e5e46658 100644 --- a/st2reactor/st2reactor/container/hash_partitioner.py +++ b/st2reactor/st2reactor/container/hash_partitioner.py @@ -17,25 +17,25 @@ import ctypes import hashlib -from st2reactor.container.partitioners import DefaultPartitioner, get_all_enabled_sensors +from st2reactor.container.partitioners import ( + DefaultPartitioner, + get_all_enabled_sensors, +) -__all__ = [ - 'HashPartitioner', - 'Range' -] +__all__ = ["HashPartitioner", "Range"] # The range expression serialized is of the form `RANGE_START..RANGE_END|RANGE_START..RANGE_END ...` -SUB_RANGE_SEPARATOR = '|' -RANGE_BOUNDARY_SEPARATOR = '..' +SUB_RANGE_SEPARATOR = "|" +RANGE_BOUNDARY_SEPARATOR = ".." class Range(object): - RANGE_MIN_ENUM = 'min' + RANGE_MIN_ENUM = "min" RANGE_MIN_VALUE = 0 - RANGE_MAX_ENUM = 'max' - RANGE_MAX_VALUE = 2**32 + RANGE_MAX_ENUM = "max" + RANGE_MAX_VALUE = 2 ** 32 def __init__(self, range_repr): self.range_start, self.range_end = self._get_range_boundaries(range_repr) @@ -44,15 +44,17 @@ def __contains__(self, item): return item >= self.range_start and item < self.range_end def _get_range_boundaries(self, range_repr): - range_repr = [value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR)] + range_repr = [ + value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR) + ] if len(range_repr) != 2: - raise ValueError('Unsupported sub-range format %s.' % range_repr) + raise ValueError("Unsupported sub-range format %s." % range_repr) range_start = self._get_valid_range_boundary(range_repr[0]) range_end = self._get_valid_range_boundary(range_repr[1]) if range_start > range_end: - raise ValueError('Misconfigured range [%d..%d]' % (range_start, range_end)) + raise ValueError("Misconfigured range [%d..%d]" % (range_start, range_end)) return (range_start, range_end) def _get_valid_range_boundary(self, boundary_value): @@ -73,7 +75,6 @@ def _get_valid_range_boundary(self, boundary_value): class HashPartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name, hash_ranges): super(HashPartitioner, self).__init__(sensor_node_name=sensor_node_name) self._hash_ranges = self._create_hash_ranges(hash_ranges) @@ -112,7 +113,7 @@ def _hash_sensor_ref(self, sensor_ref): h = ctypes.c_uint(0) for d in reversed(str(md5_hash_int_repr)): d = ctypes.c_uint(int(d)) - higherorder = ctypes.c_uint(h.value & 0xf8000000) + higherorder = ctypes.c_uint(h.value & 0xF8000000) h = ctypes.c_uint(h.value << 5) h = ctypes.c_uint(h.value ^ (higherorder.value >> 27)) h = ctypes.c_uint(h.value ^ d.value) diff --git a/st2reactor/st2reactor/container/manager.py b/st2reactor/st2reactor/container/manager.py index 694d3ce337..e9f251aebc 100644 --- a/st2reactor/st2reactor/container/manager.py +++ b/st2reactor/st2reactor/container/manager.py @@ -27,16 +27,13 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'SensorContainerManager' -] +__all__ = ["SensorContainerManager"] class SensorContainerManager(object): - def __init__(self, sensors_partitioner, single_sensor_mode=False): if not sensors_partitioner: - raise ValueError('sensors_partitioner should be non-None.') + raise ValueError("sensors_partitioner should be non-None.") self._sensors_partitioner = sensors_partitioner self._single_sensor_mode = single_sensor_mode @@ -44,10 +41,12 @@ def __init__(self, sensors_partitioner, single_sensor_mode=False): self._sensor_container = None self._container_thread = None - self._sensors_watcher = SensorWatcher(create_handler=self._handle_create_sensor, - update_handler=self._handle_update_sensor, - delete_handler=self._handle_delete_sensor, - queue_suffix='sensor_container') + self._sensors_watcher = SensorWatcher( + create_handler=self._handle_create_sensor, + update_handler=self._handle_update_sensor, + delete_handler=self._handle_delete_sensor, + queue_suffix="sensor_container", + ) def run_sensors(self): """ @@ -55,15 +54,18 @@ def run_sensors(self): """ sensors = self._sensors_partitioner.get_sensors() if sensors: - LOG.info('Setting up container to run %d sensors.', len(sensors)) - LOG.info('\tSensors list - %s.', [self._get_sensor_ref(sensor) for sensor in sensors]) + LOG.info("Setting up container to run %d sensors.", len(sensors)) + LOG.info( + "\tSensors list - %s.", + [self._get_sensor_ref(sensor) for sensor in sensors], + ) sensors_to_run = [] for sensor in sensors: # TODO: Directly pass DB object to the ProcessContainer sensors_to_run.append(self._to_sensor_object(sensor)) - LOG.info('(PID:%s) SensorContainer started.', os.getpid()) + LOG.info("(PID:%s) SensorContainer started.", os.getpid()) self._setup_sigterm_handler() exit_code = self._spin_container_and_wait(sensors_to_run) @@ -74,22 +76,25 @@ def _spin_container_and_wait(self, sensors): try: self._sensor_container = ProcessSensorContainer( - sensors=sensors, - single_sensor_mode=self._single_sensor_mode) + sensors=sensors, single_sensor_mode=self._single_sensor_mode + ) self._container_thread = concurrency.spawn(self._sensor_container.run) - LOG.debug('Starting sensor CUD watcher...') + LOG.debug("Starting sensor CUD watcher...") self._sensors_watcher.start() exit_code = self._container_thread.wait() - LOG.error('Process container quit with exit_code %d.', exit_code) - LOG.error('(PID:%s) SensorContainer stopped.', os.getpid()) + LOG.error("Process container quit with exit_code %d.", exit_code) + LOG.error("(PID:%s) SensorContainer stopped.", os.getpid()) except (KeyboardInterrupt, SystemExit): self._sensor_container.shutdown() self._sensors_watcher.stop() - LOG.info('(PID:%s) SensorContainer stopped. Reason - %s', os.getpid(), - sys.exc_info()[0].__name__) + LOG.info( + "(PID:%s) SensorContainer stopped. Reason - %s", + os.getpid(), + sys.exc_info()[0].__name__, + ) concurrency.kill(self._container_thread) self._container_thread = None @@ -99,7 +104,6 @@ def _spin_container_and_wait(self, sensors): return exit_code def _setup_sigterm_handler(self): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and we call sensor_container.shutdown() # there which cleans things up. @@ -110,16 +114,16 @@ def sigterm_handler(signum=None, frame=None): signal.signal(signal.SIGTERM, sigterm_handler) def _to_sensor_object(self, sensor_db): - file_path = sensor_db.artifact_uri.replace('file://', '') - class_name = sensor_db.entry_point.split('.')[-1] + file_path = sensor_db.artifact_uri.replace("file://", "") + class_name = sensor_db.entry_point.split(".")[-1] sensor_obj = { - 'pack': sensor_db.pack, - 'file_path': file_path, - 'class_name': class_name, - 'trigger_types': sensor_db.trigger_types, - 'poll_interval': sensor_db.poll_interval, - 'ref': self._get_sensor_ref(sensor_db) + "pack": sensor_db.pack, + "file_path": file_path, + "class_name": class_name, + "trigger_types": sensor_db.trigger_types, + "poll_interval": sensor_db.poll_interval, + "ref": self._get_sensor_ref(sensor_db), } return sensor_obj @@ -130,42 +134,50 @@ def _to_sensor_object(self, sensor_db): def _handle_create_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not supported. Ignoring create.', self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not supported. Ignoring create.", + self._get_sensor_ref(sensor), + ) return if not sensor.enabled: - LOG.info('sensor %s is not enabled.', self._get_sensor_ref(sensor)) + LOG.info("sensor %s is not enabled.", self._get_sensor_ref(sensor)) return - LOG.info('Adding sensor %s.', self._get_sensor_ref(sensor)) + LOG.info("Adding sensor %s.", self._get_sensor_ref(sensor)) self._sensor_container.add_sensor(sensor=self._to_sensor_object(sensor)) def _handle_update_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not assigned to this partition. Ignoring update. ', - self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not assigned to this partition. Ignoring update. ", + self._get_sensor_ref(sensor), + ) return sensor_ref = self._get_sensor_ref(sensor) sensor_obj = self._to_sensor_object(sensor) # Handle disabling sensor if not sensor.enabled: - LOG.info('Sensor %s disabled. Unloading sensor.', sensor_ref) + LOG.info("Sensor %s disabled. Unloading sensor.", sensor_ref) self._sensor_container.remove_sensor(sensor=sensor_obj) return - LOG.info('Sensor %s updated. Reloading sensor.', sensor_ref) + LOG.info("Sensor %s updated. Reloading sensor.", sensor_ref) try: self._sensor_container.remove_sensor(sensor=sensor_obj) except: - LOG.exception('Failed to reload sensor %s', sensor_ref) + LOG.exception("Failed to reload sensor %s", sensor_ref) else: self._sensor_container.add_sensor(sensor=sensor_obj) - LOG.info('Sensor %s reloaded.', sensor_ref) + LOG.info("Sensor %s reloaded.", sensor_ref) def _handle_delete_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not supported. Ignoring delete.', self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not supported. Ignoring delete.", + self._get_sensor_ref(sensor), + ) return - LOG.info('Unloading sensor %s.', self._get_sensor_ref(sensor)) + LOG.info("Unloading sensor %s.", self._get_sensor_ref(sensor)) self._sensor_container.remove_sensor(sensor=self._to_sensor_object(sensor)) def _get_sensor_ref(self, sensor): diff --git a/st2reactor/st2reactor/container/partitioner_lookup.py b/st2reactor/st2reactor/container/partitioner_lookup.py index c4f43db6da..1469b3c63c 100644 --- a/st2reactor/st2reactor/container/partitioner_lookup.py +++ b/st2reactor/st2reactor/container/partitioner_lookup.py @@ -18,16 +18,22 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.sensors import DEFAULT_PARTITION_LOADER, KVSTORE_PARTITION_LOADER, \ - FILE_PARTITION_LOADER, HASH_PARTITION_LOADER +from st2common.constants.sensors import ( + DEFAULT_PARTITION_LOADER, + KVSTORE_PARTITION_LOADER, + FILE_PARTITION_LOADER, + HASH_PARTITION_LOADER, +) from st2common.exceptions.sensors import SensorPartitionerNotSupportedException -from st2reactor.container.partitioners import DefaultPartitioner, KVStorePartitioner, \ - FileBasedPartitioner, SingleSensorPartitioner +from st2reactor.container.partitioners import ( + DefaultPartitioner, + KVStorePartitioner, + FileBasedPartitioner, + SingleSensorPartitioner, +) from st2reactor.container.hash_partitioner import HashPartitioner -__all__ = [ - 'get_sensors_partitioner' -] +__all__ = ["get_sensors_partitioner"] LOG = logging.getLogger(__name__) @@ -35,25 +41,28 @@ DEFAULT_PARTITION_LOADER: DefaultPartitioner, KVSTORE_PARTITION_LOADER: KVStorePartitioner, FILE_PARTITION_LOADER: FileBasedPartitioner, - HASH_PARTITION_LOADER: HashPartitioner + HASH_PARTITION_LOADER: HashPartitioner, } def get_sensors_partitioner(): if cfg.CONF.sensor_ref: - LOG.info('Running in single sensor mode, using a single sensor partitioner...') + LOG.info("Running in single sensor mode, using a single sensor partitioner...") return SingleSensorPartitioner(sensor_ref=cfg.CONF.sensor_ref) partition_provider_config = copy.copy(cfg.CONF.sensorcontainer.partition_provider) - partition_provider = partition_provider_config.pop('name') + partition_provider = partition_provider_config.pop("name") sensor_node_name = cfg.CONF.sensorcontainer.sensor_node_name provider = PROVIDERS.get(partition_provider.lower(), None) if not provider: - raise SensorPartitionerNotSupportedException('Partition provider %s not found.' % - (partition_provider)) + raise SensorPartitionerNotSupportedException( + "Partition provider %s not found." % (partition_provider) + ) - LOG.info('Using partitioner %s with sensornode %s.', partition_provider, sensor_node_name) + LOG.info( + "Using partitioner %s with sensornode %s.", partition_provider, sensor_node_name + ) # pass in extra config with no analysis return provider(sensor_node_name=sensor_node_name, **partition_provider_config) diff --git a/st2reactor/st2reactor/container/partitioners.py b/st2reactor/st2reactor/container/partitioners.py index 12a17f9081..02a6d6137b 100644 --- a/st2reactor/st2reactor/container/partitioners.py +++ b/st2reactor/st2reactor/container/partitioners.py @@ -18,18 +18,20 @@ import yaml from st2common import log as logging -from st2common.exceptions.sensors import SensorNotFoundException, \ - SensorPartitionMapMissingException +from st2common.exceptions.sensors import ( + SensorNotFoundException, + SensorPartitionMapMissingException, +) from st2common.persistence.keyvalue import KeyValuePair from st2common.persistence.sensor import SensorType __all__ = [ - 'get_all_enabled_sensors', - 'DefaultPartitioner', - 'KVStorePartitioner', - 'FileBasedPartitioner', - 'SingleSensorPartitioner' + "get_all_enabled_sensors", + "DefaultPartitioner", + "KVStorePartitioner", + "FileBasedPartitioner", + "SingleSensorPartitioner", ] LOG = logging.getLogger(__name__) @@ -38,12 +40,11 @@ def get_all_enabled_sensors(): # only query for enabled sensors. sensors = SensorType.query(enabled=True) - LOG.info('Found %d registered sensors in db scan.', len(sensors)) + LOG.info("Found %d registered sensors in db scan.", len(sensors)) return sensors class DefaultPartitioner(object): - def __init__(self, sensor_node_name): self.sensor_node_name = sensor_node_name @@ -78,7 +79,6 @@ def get_required_sensor_refs(self): class KVStorePartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name): super(KVStorePartitioner, self).__init__(sensor_node_name=sensor_node_name) self._supported_sensor_refs = None @@ -90,46 +90,51 @@ def get_required_sensor_refs(self): partition_lookup_key = self._get_partition_lookup_key(self.sensor_node_name) kvp = KeyValuePair.get_by_name(partition_lookup_key) - sensor_refs_str = kvp.value if kvp.value else '' - self._supported_sensor_refs = set([ - sensor_ref.strip() for sensor_ref in sensor_refs_str.split(',')]) + sensor_refs_str = kvp.value if kvp.value else "" + self._supported_sensor_refs = set( + [sensor_ref.strip() for sensor_ref in sensor_refs_str.split(",")] + ) return list(self._supported_sensor_refs) def _get_partition_lookup_key(self, sensor_node_name): - return '{}.sensor_partition'.format(sensor_node_name) + return "{}.sensor_partition".format(sensor_node_name) class FileBasedPartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name, partition_file): super(FileBasedPartitioner, self).__init__(sensor_node_name=sensor_node_name) self.partition_file = partition_file self._supported_sensor_refs = None def is_sensor_owner(self, sensor_db): - return sensor_db.get_reference().ref in self._supported_sensor_refs and sensor_db.enabled + return ( + sensor_db.get_reference().ref in self._supported_sensor_refs + and sensor_db.enabled + ) def get_required_sensor_refs(self): - with open(self.partition_file, 'r') as f: + with open(self.partition_file, "r") as f: partition_map = yaml.safe_load(f) sensor_refs = partition_map.get(self.sensor_node_name, None) if sensor_refs is None: - raise SensorPartitionMapMissingException('Sensor partition not found for %s in %s.' - % (self.sensor_node_name, - self.partition_file)) + raise SensorPartitionMapMissingException( + "Sensor partition not found for %s in %s." + % (self.sensor_node_name, self.partition_file) + ) self._supported_sensor_refs = set(sensor_refs) return list(self._supported_sensor_refs) class SingleSensorPartitioner(object): - def __init__(self, sensor_ref): self._sensor_ref = sensor_ref def get_sensors(self): sensor = SensorType.get_by_ref(self._sensor_ref) if not sensor: - raise SensorNotFoundException('Sensor %s not found in db.' % self._sensor_ref) + raise SensorNotFoundException( + "Sensor %s not found in db." % self._sensor_ref + ) return [sensor] def is_sensor_owner(self, sensor_db): diff --git a/st2reactor/st2reactor/container/process_container.py b/st2reactor/st2reactor/container/process_container.py index f8f1638d71..890bcccbb9 100644 --- a/st2reactor/st2reactor/container/process_container.py +++ b/st2reactor/st2reactor/container/process_container.py @@ -31,7 +31,7 @@ from st2common.constants.error_messages import PACK_VIRTUALENV_DOESNT_EXIST from st2common.constants.system import API_URL_ENV_VARIABLE_NAME from st2common.constants.system import AUTH_TOKEN_ENV_VARIABLE_NAME -from st2common.constants.triggers import (SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER) +from st2common.constants.triggers import SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER from st2common.constants.exit_codes import SUCCESS_EXIT_CODE from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.models.system.common import ResourceReference @@ -44,14 +44,12 @@ from st2common.util.sandboxing import get_sandbox_python_binary_path from st2common.util.sandboxing import get_sandbox_virtualenv_path -__all__ = [ - 'ProcessSensorContainer' -] +__all__ = ["ProcessSensorContainer"] -LOG = logging.getLogger('st2reactor.process_sensor_container') +LOG = logging.getLogger("st2reactor.process_sensor_container") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_NAME = 'sensor_wrapper.py' +WRAPPER_SCRIPT_NAME = "sensor_wrapper.py" WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME) # How many times to try to subsequently respawn a sensor after a non-zero exit before giving up @@ -78,8 +76,15 @@ class ProcessSensorContainer(object): Sensor container which runs sensors in a separate process. """ - def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatcher=None, - wrapper_script_path=WRAPPER_SCRIPT_PATH, create_token=True): + def __init__( + self, + sensors, + poll_interval=5, + single_sensor_mode=False, + dispatcher=None, + wrapper_script_path=WRAPPER_SCRIPT_PATH, + create_token=True, + ): """ :param sensors: A list of sensor dicts. :type sensors: ``list`` of ``dict`` @@ -119,7 +124,9 @@ def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatche # Stores information needed for respawning dead sensors self._sensor_start_times = {} # maps sensor_id -> sensor start time - self._sensor_respawn_counts = defaultdict(int) # maps sensor_id -> number of respawns + self._sensor_respawn_counts = defaultdict( + int + ) # maps sensor_id -> number of respawns # A list of all the instance variables which hold internal state information about a # particular_sensor @@ -144,10 +151,10 @@ def run(self): sensor_ids = list(self._sensors.keys()) if len(sensor_ids) >= 1: - LOG.debug('%d active sensor(s)' % (len(sensor_ids))) + LOG.debug("%d active sensor(s)" % (len(sensor_ids))) self._poll_sensors_for_results(sensor_ids) else: - LOG.debug('No active sensors') + LOG.debug("No active sensors") concurrency.sleep(self._poll_interval) except success_exception_cls: @@ -157,12 +164,12 @@ def run(self): self._stopped = True return SUCCESS_EXIT_CODE except: - LOG.exception('Container failed to run sensors.') + LOG.exception("Container failed to run sensors.") self._stopped = True return FAILURE_EXIT_CODE self._stopped = True - LOG.error('Process container stopped.') + LOG.error("Process container stopped.") exit_code = self._exit_code or SUCCESS_EXIT_CODE return exit_code @@ -179,23 +186,29 @@ def _poll_sensors_for_results(self, sensor_ids): if status is not None: # Dead process detected - LOG.info('Process for sensor %s has exited with code %s', sensor_id, status) + LOG.info( + "Process for sensor %s has exited with code %s", sensor_id, status + ) sensor = self._sensors[sensor_id] self._delete_sensor(sensor_id) - self._dispatch_trigger_for_sensor_exit(sensor=sensor, - exit_code=status) + self._dispatch_trigger_for_sensor_exit(sensor=sensor, exit_code=status) # Try to respawn a dead process (maybe it was a simple failure which can be # resolved with a restart) - concurrency.spawn(self._respawn_sensor, sensor_id=sensor_id, sensor=sensor, - exit_code=status) + concurrency.spawn( + self._respawn_sensor, + sensor_id=sensor_id, + sensor=sensor, + exit_code=status, + ) else: sensor_start_time = self._sensor_start_times[sensor_id] sensor_respawn_count = self._sensor_respawn_counts[sensor_id] - successfully_started = ((now - sensor_start_time) >= - SENSOR_SUCCESSFUL_START_THRESHOLD) + successfully_started = ( + now - sensor_start_time + ) >= SENSOR_SUCCESSFUL_START_THRESHOLD if successfully_started and sensor_respawn_count >= 1: # Sensor has been successfully running more than threshold seconds, clear the @@ -209,7 +222,7 @@ def stopped(self): return self._stopped def shutdown(self, force=False): - LOG.info('Container shutting down. Invoking cleanup on sensors.') + LOG.info("Container shutting down. Invoking cleanup on sensors.") self._stopped = True if force: @@ -221,7 +234,7 @@ def shutdown(self, force=False): for sensor_id in sensor_ids: self._stop_sensor_process(sensor_id=sensor_id, exit_timeout=exit_timeout) - LOG.info('All sensors are shut down.') + LOG.info("All sensors are shut down.") self._sensors = {} self._processes = {} @@ -235,11 +248,11 @@ def add_sensor(self, sensor): sensor_id = self._get_sensor_id(sensor=sensor) if sensor_id in self._sensors: - LOG.warning('Sensor %s already exists and running.', sensor_id) + LOG.warning("Sensor %s already exists and running.", sensor_id) return False self._spawn_sensor_process(sensor=sensor) - LOG.debug('Sensor %s started.', sensor_id) + LOG.debug("Sensor %s started.", sensor_id) self._sensors[sensor_id] = sensor return True @@ -252,11 +265,11 @@ def remove_sensor(self, sensor): sensor_id = self._get_sensor_id(sensor=sensor) if sensor_id not in self._sensors: - LOG.warning('Sensor %s isn\'t running in this container.', sensor_id) + LOG.warning("Sensor %s isn't running in this container.", sensor_id) return False self._stop_sensor_process(sensor_id=sensor_id) - LOG.debug('Sensor %s stopped.', sensor_id) + LOG.debug("Sensor %s stopped.", sensor_id) return True def _run_all_sensors(self): @@ -264,7 +277,7 @@ def _run_all_sensors(self): for sensor_id in sensor_ids: sensor_obj = self._sensors[sensor_id] - LOG.info('Running sensor %s', sensor_id) + LOG.info("Running sensor %s", sensor_id) try: self._spawn_sensor_process(sensor=sensor_obj) @@ -275,7 +288,7 @@ def _run_all_sensors(self): del self._sensors[sensor_id] continue - LOG.info('Sensor %s started' % sensor_id) + LOG.info("Sensor %s started" % sensor_id) def _spawn_sensor_process(self, sensor): """ @@ -285,45 +298,53 @@ def _spawn_sensor_process(self, sensor): belonging to the sensor pack. """ sensor_id = self._get_sensor_id(sensor=sensor) - pack_ref = sensor['pack'] + pack_ref = sensor["pack"] virtualenv_path = get_sandbox_virtualenv_path(pack=pack_ref) python_path = get_sandbox_python_binary_path(pack=pack_ref) if virtualenv_path and not os.path.isdir(virtualenv_path): - format_values = {'pack': sensor['pack'], 'virtualenv_path': virtualenv_path} + format_values = {"pack": sensor["pack"], "virtualenv_path": virtualenv_path} msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values raise Exception(msg) - args = self._get_args_for_wrapper_script(python_binary=python_path, sensor=sensor) + args = self._get_args_for_wrapper_script( + python_binary=python_path, sensor=sensor + ) if self._enable_common_pack_libs: - pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref) + pack_common_libs_path = get_pack_common_libs_path_for_pack_ref( + pack_ref=pack_ref + ) else: pack_common_libs_path = None env = os.environ.copy() - sandbox_python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=True) + sandbox_python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=True + ) if self._enable_common_pack_libs and pack_common_libs_path: - env['PYTHONPATH'] = pack_common_libs_path + ':' + sandbox_python_path + env["PYTHONPATH"] = pack_common_libs_path + ":" + sandbox_python_path else: - env['PYTHONPATH'] = sandbox_python_path + env["PYTHONPATH"] = sandbox_python_path if self._create_token: # Include full api URL and API token specific to that sensor - LOG.debug('Creating temporary auth token for sensor %s' % (sensor['class_name'])) + LOG.debug( + "Creating temporary auth token for sensor %s" % (sensor["class_name"]) + ) ttl = cfg.CONF.auth.service_token_ttl metadata = { - 'service': 'sensors_container', - 'sensor_path': sensor['file_path'], - 'sensor_class': sensor['class_name'] + "service": "sensors_container", + "sensor_path": sensor["file_path"], + "sensor_class": sensor["class_name"], } - temporary_token = create_token(username='sensors_container', ttl=ttl, metadata=metadata, - service=True) + temporary_token = create_token( + username="sensors_container", ttl=ttl, metadata=metadata, service=True + ) env[API_URL_ENV_VARIABLE_NAME] = get_full_public_api_url() env[AUTH_TOKEN_ENV_VARIABLE_NAME] = temporary_token.token @@ -332,18 +353,27 @@ def _spawn_sensor_process(self, sensor): # TODO 2: Store metadata (wrapper process id) with the token and delete # tokens for old, dead processes on startup - cmd = ' '.join(args) + cmd = " ".join(args) LOG.debug('Running sensor subprocess (cmd="%s")', cmd) # TODO: Intercept stdout and stderr for aggregated logging purposes try: - process = subprocess.Popen(args=args, stdin=None, stdout=None, - stderr=None, shell=False, env=env, - preexec_fn=on_parent_exit('SIGTERM')) + process = subprocess.Popen( + args=args, + stdin=None, + stdout=None, + stderr=None, + shell=False, + env=env, + preexec_fn=on_parent_exit("SIGTERM"), + ) except Exception as e: - cmd = ' '.join(args) - message = ('Failed to spawn process for sensor %s ("%s"): %s' % - (sensor_id, cmd, six.text_type(e))) + cmd = " ".join(args) + message = 'Failed to spawn process for sensor %s ("%s"): %s' % ( + sensor_id, + cmd, + six.text_type(e), + ) raise Exception(message) self._processes[sensor_id] = process @@ -397,32 +427,35 @@ def _respawn_sensor(self, sensor_id, sensor, exit_code): """ Method for respawning a sensor which died with a non-zero exit code. """ - extra = {'sensor_id': sensor_id, 'sensor': sensor} + extra = {"sensor_id": sensor_id, "sensor": sensor} if self._single_sensor_mode: # In single sensor mode we want to exit immediately on failure - LOG.info('Not respawning a sensor since running in single sensor mode', - extra=extra) + LOG.info( + "Not respawning a sensor since running in single sensor mode", + extra=extra, + ) self._stopped = True self._exit_code = exit_code return if self._stopped: - LOG.debug('Stopped, not respawning a dead sensor', extra=extra) + LOG.debug("Stopped, not respawning a dead sensor", extra=extra) return - should_respawn = self._should_respawn_sensor(sensor_id=sensor_id, sensor=sensor, - exit_code=exit_code) + should_respawn = self._should_respawn_sensor( + sensor_id=sensor_id, sensor=sensor, exit_code=exit_code + ) if not should_respawn: - LOG.debug('Not respawning a dead sensor', extra=extra) + LOG.debug("Not respawning a dead sensor", extra=extra) return - LOG.debug('Respawning dead sensor', extra=extra) + LOG.debug("Respawning dead sensor", extra=extra) self._sensor_respawn_counts[sensor_id] += 1 - sleep_delay = (SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id]) + sleep_delay = SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id] concurrency.sleep(sleep_delay) try: @@ -443,7 +476,7 @@ def _should_respawn_sensor(self, sensor_id, sensor, exit_code): respawn_count = self._sensor_respawn_counts[sensor_id] if respawn_count >= SENSOR_MAX_RESPAWN_COUNTS: - LOG.debug('Sensor has already been respawned max times, giving up') + LOG.debug("Sensor has already been respawned max times, giving up") return False return True @@ -460,23 +493,23 @@ def _get_args_for_wrapper_script(self, python_binary, sensor): :rtype: ``list`` """ - trigger_type_refs = sensor['trigger_types'] or [] - trigger_type_refs = ','.join(trigger_type_refs) + trigger_type_refs = sensor["trigger_types"] or [] + trigger_type_refs = ",".join(trigger_type_refs) parent_args = json.dumps(sys.argv[1:]) args = [ python_binary, self._wrapper_script_path, - '--pack=%s' % (sensor['pack']), - '--file-path=%s' % (sensor['file_path']), - '--class-name=%s' % (sensor['class_name']), - '--trigger-type-refs=%s' % (trigger_type_refs), - '--parent-args=%s' % (parent_args) + "--pack=%s" % (sensor["pack"]), + "--file-path=%s" % (sensor["file_path"]), + "--class-name=%s" % (sensor["class_name"]), + "--trigger-type-refs=%s" % (trigger_type_refs), + "--parent-args=%s" % (parent_args), ] - if sensor['poll_interval']: - args.append('--poll-interval=%s' % (sensor['poll_interval'])) + if sensor["poll_interval"]: + args.append("--poll-interval=%s" % (sensor["poll_interval"])) return args @@ -486,32 +519,28 @@ def _get_sensor_id(self, sensor): :type sensor: ``dict`` """ - sensor_id = sensor['ref'] + sensor_id = sensor["ref"] return sensor_id def _dispatch_trigger_for_sensor_spawn(self, sensor, process, cmd): trigger = ResourceReference.to_string_reference( - name=SENSOR_SPAWN_TRIGGER['name'], - pack=SENSOR_SPAWN_TRIGGER['pack']) + name=SENSOR_SPAWN_TRIGGER["name"], pack=SENSOR_SPAWN_TRIGGER["pack"] + ) now = int(time.time()) payload = { - 'id': sensor['class_name'], - 'timestamp': now, - 'pid': process.pid, - 'cmd': cmd + "id": sensor["class_name"], + "timestamp": now, + "pid": process.pid, + "cmd": cmd, } self._dispatcher.dispatch(trigger, payload=payload) def _dispatch_trigger_for_sensor_exit(self, sensor, exit_code): trigger = ResourceReference.to_string_reference( - name=SENSOR_EXIT_TRIGGER['name'], - pack=SENSOR_EXIT_TRIGGER['pack']) + name=SENSOR_EXIT_TRIGGER["name"], pack=SENSOR_EXIT_TRIGGER["pack"] + ) now = int(time.time()) - payload = { - 'id': sensor['class_name'], - 'timestamp': now, - 'exit_code': exit_code - } + payload = {"id": sensor["class_name"], "timestamp": now, "exit_code": exit_code} self._dispatcher.dispatch(trigger, payload=payload) def _delete_sensor(self, sensor_id): diff --git a/st2reactor/st2reactor/container/sensor_wrapper.py b/st2reactor/st2reactor/container/sensor_wrapper.py index 56a37707d2..c605b47291 100644 --- a/st2reactor/st2reactor/container/sensor_wrapper.py +++ b/st2reactor/st2reactor/container/sensor_wrapper.py @@ -25,6 +25,7 @@ # for details. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -51,10 +52,7 @@ from st2common.services.datastore import SensorDatastoreService from st2common.util.monkey_patch import use_select_poll_workaround -__all__ = [ - 'SensorWrapper', - 'SensorService' -] +__all__ = ["SensorWrapper", "SensorService"] use_select_poll_workaround(nose_only=False) @@ -69,12 +67,15 @@ def __init__(self, sensor_wrapper): self._sensor_wrapper = sensor_wrapper self._logger = self._sensor_wrapper._logger - self._trigger_dispatcher_service = TriggerDispatcherService(logger=sensor_wrapper._logger) + self._trigger_dispatcher_service = TriggerDispatcherService( + logger=sensor_wrapper._logger + ) self._datastore_service = SensorDatastoreService( logger=self._logger, pack_name=self._sensor_wrapper._pack, class_name=self._sensor_wrapper._class_name, - api_username='sensor_service') + api_username="sensor_service", + ) self._client = None @@ -86,7 +87,7 @@ def get_logger(self, name): """ Retrieve an instance of a logger to be used by the sensor class. """ - logger_name = '%s.%s' % (self._sensor_wrapper._logger.name, name) + logger_name = "%s.%s" % (self._sensor_wrapper._logger.name, name) logger = logging.getLogger(logger_name) logger.propagate = True @@ -105,9 +106,12 @@ def get_user_info(self): def dispatch(self, trigger, payload=None, trace_tag=None): # Provided by the parent BaseTriggerDispatcherService class - return self._trigger_dispatcher_service.dispatch(trigger=trigger, payload=payload, - trace_tag=trace_tag, - throw_on_validation_error=False) + return self._trigger_dispatcher_service.dispatch( + trigger=trigger, + payload=payload, + trace_tag=trace_tag, + throw_on_validation_error=False, + ) def dispatch_with_context(self, trigger, payload=None, trace_context=None): """ @@ -123,10 +127,12 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None): :type trace_context: ``st2common.api.models.api.trace.TraceContext`` """ # Provided by the parent BaseTriggerDispatcherService class - return self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger, + return self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger, payload=payload, trace_context=trace_context, - throw_on_validation_error=False) + throw_on_validation_error=False, + ) ################################## # Methods for datastore management @@ -136,20 +142,31 @@ def list_values(self, local=True, prefix=None): return self.datastore_service.list_values(local=local, prefix=prefix) def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): - return self.datastore_service.get_value(name=name, local=local, scope=scope, - decrypt=decrypt) + return self.datastore_service.get_value( + name=name, local=local, scope=scope, decrypt=decrypt + ) - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): - return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local, - scope=scope, encrypt=encrypt) + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): + return self.datastore_service.set_value( + name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt + ) def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): return self.datastore_service.delete_value(name=name, local=local, scope=scope) class SensorWrapper(object): - def __init__(self, pack, file_path, class_name, trigger_types, - poll_interval=None, parent_args=None): + def __init__( + self, + pack, + file_path, + class_name, + trigger_types, + poll_interval=None, + parent_args=None, + ): """ :param pack: Name of the pack this sensor belongs to. :type pack: ``str`` @@ -185,32 +202,48 @@ def __init__(self, pack, file_path, class_name, trigger_types, pass # 2. Establish DB connection - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None - db_setup_with_retry(cfg.CONF.database.db_name, cfg.CONF.database.host, - cfg.CONF.database.port, username=username, password=password, - ssl=cfg.CONF.database.ssl, ssl_keyfile=cfg.CONF.database.ssl_keyfile, - ssl_certfile=cfg.CONF.database.ssl_certfile, - ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, - authentication_mechanism=cfg.CONF.database.authentication_mechanism, - ssl_match_hostname=cfg.CONF.database.ssl_match_hostname) + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) + db_setup_with_retry( + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ssl=cfg.CONF.database.ssl, + ssl_keyfile=cfg.CONF.database.ssl_keyfile, + ssl_certfile=cfg.CONF.database.ssl_certfile, + ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, + authentication_mechanism=cfg.CONF.database.authentication_mechanism, + ssl_match_hostname=cfg.CONF.database.ssl_match_hostname, + ) # 3. Instantiate the watcher - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix='sensorwrapper_%s_%s' % - (self._pack, self._class_name), - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix="sensorwrapper_%s_%s" % (self._pack, self._class_name), + exclusive=True, + ) # 4. Set up logging - self._logger = logging.getLogger('SensorWrapper.%s.%s' % - (self._pack, self._class_name)) + self._logger = logging.getLogger( + "SensorWrapper.%s.%s" % (self._pack, self._class_name) + ) logging.setup(cfg.CONF.sensorcontainer.logging) - if '--debug' in parent_args: + if "--debug" in parent_args: set_log_level_for_all_loggers() else: # NOTE: statsd logger logs everything by default under INFO so we ignore those log @@ -223,16 +256,17 @@ def run(self): atexit.register(self.stop) self._trigger_watcher.start() - self._logger.info('Watcher started') + self._logger.info("Watcher started") - self._logger.info('Running sensor initialization code') + self._logger.info("Running sensor initialization code") self._sensor_instance.setup() if self._poll_interval: - message = ('Running sensor in active mode (poll interval=%ss)' % - (self._poll_interval)) + message = "Running sensor in active mode (poll interval=%ss)" % ( + self._poll_interval + ) else: - message = 'Running sensor in passive mode' + message = "Running sensor in passive mode" self._logger.info(message) @@ -240,18 +274,20 @@ def run(self): self._sensor_instance.run() except Exception as e: # Include traceback - msg = ('Sensor "%s" run method raised an exception: %s.' % - (self._class_name, six.text_type(e))) + msg = 'Sensor "%s" run method raised an exception: %s.' % ( + self._class_name, + six.text_type(e), + ) self._logger.warn(msg, exc_info=True) raise Exception(msg) def stop(self): # Stop watcher - self._logger.info('Stopping trigger watcher') + self._logger.info("Stopping trigger watcher") self._trigger_watcher.stop() # Run sensor cleanup code - self._logger.info('Invoking cleanup on sensor') + self._logger.info("Invoking cleanup on sensor") self._sensor_instance.cleanup() ############################################## @@ -259,16 +295,18 @@ def stop(self): ############################################## def _handle_create_trigger(self, trigger): - self._logger.debug('Calling sensor "add_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "add_trigger" method (trigger.type=%s)' % (trigger.type) + ) self._trigger_names[str(trigger.id)] = trigger trigger = self._sanitize_trigger(trigger=trigger) self._sensor_instance.add_trigger(trigger=trigger) def _handle_update_trigger(self, trigger): - self._logger.debug('Calling sensor "update_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "update_trigger" method (trigger.type=%s)' % (trigger.type) + ) self._trigger_names[str(trigger.id)] = trigger trigger = self._sanitize_trigger(trigger=trigger) @@ -279,8 +317,9 @@ def _handle_delete_trigger(self, trigger): if trigger_id not in self._trigger_names: return - self._logger.debug('Calling sensor "remove_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "remove_trigger" method (trigger.type=%s)' % (trigger.type) + ) del self._trigger_names[trigger_id] trigger = self._sanitize_trigger(trigger=trigger) @@ -294,35 +333,45 @@ def _get_sensor_instance(self): module_name, _ = os.path.splitext(filename) try: - sensor_class = loader.register_plugin_class(base_class=Sensor, - file_path=self._file_path, - class_name=self._class_name) + sensor_class = loader.register_plugin_class( + base_class=Sensor, + file_path=self._file_path, + class_name=self._class_name, + ) except Exception as e: tb_msg = traceback.format_exc() - msg = ('Failed to load sensor class from file "%s" (sensor file most likely doesn\'t ' - 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to load sensor class from file "%s" (sensor file most likely doesn\'t ' + "exist or contains invalid syntax): %s" + % (self._file_path, six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) if not sensor_class: - raise ValueError('Sensor module is missing a class with name "%s"' % - (self._class_name)) + raise ValueError( + 'Sensor module is missing a class with name "%s"' % (self._class_name) + ) sensor_class_kwargs = {} - sensor_class_kwargs['sensor_service'] = SensorService(sensor_wrapper=self) + sensor_class_kwargs["sensor_service"] = SensorService(sensor_wrapper=self) sensor_config = self._get_sensor_config() - sensor_class_kwargs['config'] = sensor_config + sensor_class_kwargs["config"] = sensor_config if self._poll_interval and issubclass(sensor_class, PollingSensor): - sensor_class_kwargs['poll_interval'] = self._poll_interval + sensor_class_kwargs["poll_interval"] = self._poll_interval try: sensor_instance = sensor_class(**sensor_class_kwargs) except Exception: - self._logger.exception('Failed to instantiate "%s" sensor class' % (self._class_name)) - raise Exception('Failed to instantiate "%s" sensor class' % (self._class_name)) + self._logger.exception( + 'Failed to instantiate "%s" sensor class' % (self._class_name) + ) + raise Exception( + 'Failed to instantiate "%s" sensor class' % (self._class_name) + ) return sensor_instance @@ -342,31 +391,43 @@ def _sanitize_trigger(self, trigger): return sanitized -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Sensor runner wrapper') - parser.add_argument('--pack', required=True, - help='Name of the pack this sensor belongs to') - parser.add_argument('--file-path', required=True, - help='Path to the sensor module') - parser.add_argument('--class-name', required=True, - help='Name of the sensor class') - parser.add_argument('--trigger-type-refs', required=False, - help='Comma delimited string of trigger type references') - parser.add_argument('--poll-interval', type=int, default=None, required=False, - help='Sensor poll interval') - parser.add_argument('--parent-args', required=False, - help='Command line arguments passed to the parent process') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sensor runner wrapper") + parser.add_argument( + "--pack", required=True, help="Name of the pack this sensor belongs to" + ) + parser.add_argument("--file-path", required=True, help="Path to the sensor module") + parser.add_argument("--class-name", required=True, help="Name of the sensor class") + parser.add_argument( + "--trigger-type-refs", + required=False, + help="Comma delimited string of trigger type references", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=None, + required=False, + help="Sensor poll interval", + ) + parser.add_argument( + "--parent-args", + required=False, + help="Command line arguments passed to the parent process", + ) args = parser.parse_args() trigger_types = args.trigger_type_refs - trigger_types = trigger_types.split(',') if trigger_types else [] + trigger_types = trigger_types.split(",") if trigger_types else [] parent_args = json.loads(args.parent_args) if args.parent_args else [] assert isinstance(parent_args, list) - obj = SensorWrapper(pack=args.pack, - file_path=args.file_path, - class_name=args.class_name, - trigger_types=trigger_types, - poll_interval=args.poll_interval, - parent_args=parent_args) + obj = SensorWrapper( + pack=args.pack, + file_path=args.file_path, + class_name=args.class_name, + trigger_types=trigger_types, + poll_interval=args.poll_interval, + parent_args=parent_args, + ) obj.run() diff --git a/st2reactor/st2reactor/container/utils.py b/st2reactor/st2reactor/container/utils.py index a156d209b0..6b05904627 100644 --- a/st2reactor/st2reactor/container/utils.py +++ b/st2reactor/st2reactor/container/utils.py @@ -22,10 +22,12 @@ from st2common.persistence.trigger import TriggerInstance from st2common.services.triggers import get_trigger_db_by_ref_or_dict -LOG = logging.getLogger('st2reactor.sensor.container_utils') +LOG = logging.getLogger("st2reactor.sensor.container_utils") -def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigger=False): +def create_trigger_instance( + trigger, payload, occurrence_time, raise_on_no_trigger=False +): """ This creates a trigger instance object given trigger and payload. Trigger can be just a string reference (pack.name) or a ``dict`` containing 'id' or @@ -40,9 +42,9 @@ def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigg trigger_db = get_trigger_db_by_ref_or_dict(trigger=trigger) if not trigger_db: - LOG.debug('No trigger in db for %s', trigger) + LOG.debug("No trigger in db for %s", trigger) if raise_on_no_trigger: - raise StackStormDBObjectNotFoundError('Trigger not found for %s' % trigger) + raise StackStormDBObjectNotFoundError("Trigger not found for %s" % trigger) return None trigger_ref = trigger_db.get_reference().ref diff --git a/st2reactor/st2reactor/garbage_collector/base.py b/st2reactor/st2reactor/garbage_collector/base.py index 3261458677..bb963e8e51 100644 --- a/st2reactor/st2reactor/garbage_collector/base.py +++ b/st2reactor/st2reactor/garbage_collector/base.py @@ -42,16 +42,17 @@ from st2common.garbage_collection.inquiries import purge_inquiries from st2common.garbage_collection.trigger_instances import purge_trigger_instances -__all__ = [ - 'GarbageCollectorService' -] +__all__ = ["GarbageCollectorService"] LOG = logging.getLogger(__name__) class GarbageCollectorService(object): - def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL, - sleep_delay=DEFAULT_SLEEP_DELAY): + def __init__( + self, + collection_interval=DEFAULT_COLLECTION_INTERVAL, + sleep_delay=DEFAULT_SLEEP_DELAY, + ): """ :param collection_interval: How often to check database for old data and perform garbage collection. @@ -64,7 +65,9 @@ def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL, self._collection_interval = collection_interval self._action_executions_ttl = cfg.CONF.garbagecollector.action_executions_ttl - self._action_executions_output_ttl = cfg.CONF.garbagecollector.action_executions_output_ttl + self._action_executions_output_ttl = ( + cfg.CONF.garbagecollector.action_executions_output_ttl + ) self._trigger_instances_ttl = cfg.CONF.garbagecollector.trigger_instances_ttl self._purge_inquiries = cfg.CONF.garbagecollector.purge_inquiries self._workflow_execution_max_idle = cfg.CONF.workflow_engine.gc_max_idle_sec @@ -91,7 +94,7 @@ def run(self): self._running = False return SUCCESS_EXIT_CODE except Exception as e: - LOG.exception('Exception in the garbage collector: %s' % (six.text_type(e))) + LOG.exception("Exception in the garbage collector: %s" % (six.text_type(e))) self._running = False return FAILURE_EXIT_CODE @@ -101,7 +104,7 @@ def _register_signal_handlers(self): signal.signal(signal.SIGUSR2, self.handle_sigusr2) def handle_sigusr2(self, signal_number, stack_frame): - LOG.info('Forcing garbage collection...') + LOG.info("Forcing garbage collection...") self._perform_garbage_collection() def shutdown(self): @@ -111,61 +114,88 @@ def _main_loop(self): while self._running: self._perform_garbage_collection() - LOG.info('Sleeping for %s seconds before next garbage collection...' % - (self._collection_interval)) + LOG.info( + "Sleeping for %s seconds before next garbage collection..." + % (self._collection_interval) + ) concurrency.sleep(self._collection_interval) def _validate_ttl_values(self): """ Validate that a user has supplied reasonable TTL values. """ - if self._action_executions_ttl and self._action_executions_ttl < MINIMUM_TTL_DAYS: - raise ValueError('Minimum possible TTL for action_executions_ttl in days is %s' % - (MINIMUM_TTL_DAYS)) - - if self._trigger_instances_ttl and self._trigger_instances_ttl < MINIMUM_TTL_DAYS: - raise ValueError('Minimum possible TTL for trigger_instances_ttl in days is %s' % - (MINIMUM_TTL_DAYS)) - - if self._action_executions_output_ttl and \ - self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT: - raise ValueError(('Minimum possible TTL for action_executions_output_ttl in days ' - 'is %s') % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)) + if ( + self._action_executions_ttl + and self._action_executions_ttl < MINIMUM_TTL_DAYS + ): + raise ValueError( + "Minimum possible TTL for action_executions_ttl in days is %s" + % (MINIMUM_TTL_DAYS) + ) + + if ( + self._trigger_instances_ttl + and self._trigger_instances_ttl < MINIMUM_TTL_DAYS + ): + raise ValueError( + "Minimum possible TTL for trigger_instances_ttl in days is %s" + % (MINIMUM_TTL_DAYS) + ) + + if ( + self._action_executions_output_ttl + and self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT + ): + raise ValueError( + ( + "Minimum possible TTL for action_executions_output_ttl in days " + "is %s" + ) + % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT) + ) def _perform_garbage_collection(self): - LOG.info('Performing garbage collection...') + LOG.info("Performing garbage collection...") proc_message = "Performing garbage collection for %s." skip_message = "Skipping garbage collection for %s since it's not configured." # Note: We sleep for a bit between garbage collection of each object type to prevent busy # waiting - obj_type = 'action executions' - if self._action_executions_ttl and self._action_executions_ttl >= MINIMUM_TTL_DAYS: + obj_type = "action executions" + if ( + self._action_executions_ttl + and self._action_executions_ttl >= MINIMUM_TTL_DAYS + ): LOG.info(proc_message, obj_type) self._purge_action_executions() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'action executions output' - if self._action_executions_output_ttl and \ - self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT: + obj_type = "action executions output" + if ( + self._action_executions_output_ttl + and self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT + ): LOG.info(proc_message, obj_type) self._purge_action_executions_output() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'trigger instances' - if self._trigger_instances_ttl and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS: + obj_type = "trigger instances" + if ( + self._trigger_instances_ttl + and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS + ): LOG.info(proc_message, obj_type) self._purge_trigger_instances() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'inquiries' + obj_type = "inquiries" if self._purge_inquiries: LOG.info(proc_message, obj_type) self._timeout_inquiries() @@ -173,7 +203,7 @@ def _perform_garbage_collection(self): else: LOG.debug(skip_message, obj_type) - obj_type = 'orphaned workflow executions' + obj_type = "orphaned workflow executions" if self._workflow_execution_max_idle > 0: LOG.info(proc_message, obj_type) self._purge_orphaned_workflow_executions() @@ -187,41 +217,53 @@ def _purge_action_executions(self): the criteria defined in the config. """ utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._action_executions_ttl)) + timestamp = utc_now - datetime.timedelta(days=self._action_executions_ttl) # Another sanity check to make sure we don't delete new executions if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting action executions older than: %s' % (timestamp_str)) + LOG.info("Deleting action executions older than: %s" % (timestamp_str)) assert timestamp < utc_now try: purge_executions(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to delete executions: %s' % (six.text_type(e))) + LOG.exception("Failed to delete executions: %s" % (six.text_type(e))) return True def _purge_action_executions_output(self): utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._action_executions_output_ttl)) + timestamp = utc_now - datetime.timedelta( + days=self._action_executions_output_ttl + ) # Another sanity check to make sure we don't delete new objects - if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + if timestamp > ( + utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT) + ): + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting action executions output objects older than: %s' % (timestamp_str)) + LOG.info( + "Deleting action executions output objects older than: %s" % (timestamp_str) + ) assert timestamp < utc_now try: purge_execution_output_objects(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to delete execution output objects: %s' % (six.text_type(e))) + LOG.exception( + "Failed to delete execution output objects: %s" % (six.text_type(e)) + ) return True @@ -230,31 +272,32 @@ def _purge_trigger_instances(self): Purge trigger instances which match the criteria defined in the config. """ utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._trigger_instances_ttl)) + timestamp = utc_now - datetime.timedelta(days=self._trigger_instances_ttl) # Another sanity check to make sure we don't delete new executions if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting trigger instances older than: %s' % (timestamp_str)) + LOG.info("Deleting trigger instances older than: %s" % (timestamp_str)) assert timestamp < utc_now try: purge_trigger_instances(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to trigger instances: %s' % (six.text_type(e))) + LOG.exception("Failed to trigger instances: %s" % (six.text_type(e))) return True def _timeout_inquiries(self): - """Mark Inquiries as "timeout" that have exceeded their TTL - """ + """Mark Inquiries as "timeout" that have exceeded their TTL""" try: purge_inquiries(logger=LOG) except Exception as e: - LOG.exception('Failed to purge inquiries: %s' % (six.text_type(e))) + LOG.exception("Failed to purge inquiries: %s" % (six.text_type(e))) return True @@ -265,6 +308,8 @@ def _purge_orphaned_workflow_executions(self): try: purge_orphaned_workflow_executions(logger=LOG) except Exception as e: - LOG.exception('Failed to purge orphaned workflow executions: %s' % (six.text_type(e))) + LOG.exception( + "Failed to purge orphaned workflow executions: %s" % (six.text_type(e)) + ) return True diff --git a/st2reactor/st2reactor/garbage_collector/config.py b/st2reactor/st2reactor/garbage_collector/config.py index 19cf53362e..9a0faf0dec 100644 --- a/st2reactor/st2reactor/garbage_collector/config.py +++ b/st2reactor/st2reactor/garbage_collector/config.py @@ -29,8 +29,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -49,48 +52,62 @@ def _register_common_opts(): def _register_garbage_collector_opts(): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.garbagecollector.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.garbagecollector.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(logging_opts, group='garbagecollector') + CONF.register_opts(logging_opts, group="garbagecollector") common_opts = [ cfg.IntOpt( - 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL, - help='How often to check database for old data and perform garbage collection.'), + "collection_interval", + default=DEFAULT_COLLECTION_INTERVAL, + help="How often to check database for old data and perform garbage collection.", + ), cfg.FloatOpt( - 'sleep_delay', default=DEFAULT_SLEEP_DELAY, - help='How long to wait / sleep (in seconds) between ' - 'collection of different object types.') + "sleep_delay", + default=DEFAULT_SLEEP_DELAY, + help="How long to wait / sleep (in seconds) between " + "collection of different object types.", + ), ] - CONF.register_opts(common_opts, group='garbagecollector') + CONF.register_opts(common_opts, group="garbagecollector") ttl_opts = [ cfg.IntOpt( - 'action_executions_ttl', default=None, - help='Action executions and related objects (live actions, action output ' - 'objects) older than this value (days) will be automatically deleted.'), + "action_executions_ttl", + default=None, + help="Action executions and related objects (live actions, action output " + "objects) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'action_executions_output_ttl', default=7, - help='Action execution output objects (ones generated by action output ' - 'streaming) older than this value (days) will be automatically deleted.'), + "action_executions_output_ttl", + default=7, + help="Action execution output objects (ones generated by action output " + "streaming) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'trigger_instances_ttl', default=None, - help='Trigger instances older than this value (days) will be automatically deleted.') + "trigger_instances_ttl", + default=None, + help="Trigger instances older than this value (days) will be automatically deleted.", + ), ] - CONF.register_opts(ttl_opts, group='garbagecollector') + CONF.register_opts(ttl_opts, group="garbagecollector") inquiry_opts = [ cfg.BoolOpt( - 'purge_inquiries', default=False, - help='Set to True to perform garbage collection on Inquiries (based on ' - 'the TTL value per Inquiry)') + "purge_inquiries", + default=False, + help="Set to True to perform garbage collection on Inquiries (based on " + "the TTL value per Inquiry)", + ) ] - CONF.register_opts(inquiry_opts, group='garbagecollector') + CONF.register_opts(inquiry_opts, group="garbagecollector") register_opts() diff --git a/st2reactor/st2reactor/rules/config.py b/st2reactor/st2reactor/rules/config.py index 004c45b870..637ef4e457 100644 --- a/st2reactor/st2reactor/rules/config.py +++ b/st2reactor/st2reactor/rules/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,11 +50,13 @@ def _register_common_opts(): def _register_rules_engine_opts(): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.rulesengine.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.rulesengine.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(logging_opts, group='rulesengine') + CONF.register_opts(logging_opts, group="rulesengine") register_opts() diff --git a/st2reactor/st2reactor/rules/enforcer.py b/st2reactor/st2reactor/rules/enforcer.py index 594f157482..4d34b86ce2 100644 --- a/st2reactor/st2reactor/rules/enforcer.py +++ b/st2reactor/st2reactor/rules/enforcer.py @@ -40,15 +40,15 @@ from st2common.exceptions import param as param_exc from st2common.exceptions import apivalidation as validation_exc -__all__ = [ - 'RuleEnforcer' -] +__all__ = ["RuleEnforcer"] -LOG = logging.getLogger('st2reactor.ruleenforcement.enforce') +LOG = logging.getLogger("st2reactor.ruleenforcement.enforce") -EXEC_KICKED_OFF_STATES = [action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_REQUESTED] +EXEC_KICKED_OFF_STATES = [ + action_constants.LIVEACTION_STATUS_SCHEDULED, + action_constants.LIVEACTION_STATUS_REQUESTED, +] class RuleEnforcer(object): @@ -58,95 +58,117 @@ def __init__(self, trigger_instance, rule): def get_action_execution_context(self, action_db, trace_context=None): context = { - 'trigger_instance': reference.get_ref_from_model(self.trigger_instance), - 'rule': reference.get_ref_from_model(self.rule), - 'user': get_system_username(), - 'pack': action_db.pack, + "trigger_instance": reference.get_ref_from_model(self.trigger_instance), + "rule": reference.get_ref_from_model(self.rule), + "user": get_system_username(), + "pack": action_db.pack, } if trace_context is not None: context[TRACE_CONTEXT] = trace_context # Additional non-action / global context - additional_context = { - TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload - } + additional_context = {TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload} return context, additional_context - def get_resolved_parameters(self, action_db, runnertype_db, params, context=None, - additional_contexts=None): + def get_resolved_parameters( + self, action_db, runnertype_db, params, context=None, additional_contexts=None + ): resolved_params = param_utils.render_live_params( runner_parameters=runnertype_db.runner_parameters, action_parameters=action_db.parameters, params=params, action_context=context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) return resolved_params def enforce(self): - rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid} - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id), - rule=rule_spec) - extra = { - 'trigger_instance_db': self.trigger_instance, - 'rule_db': self.rule + rule_spec = { + "ref": self.rule.ref, + "id": str(self.rule.id), + "uid": self.rule.uid, } + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(self.trigger_instance.id), rule=rule_spec + ) + extra = {"trigger_instance_db": self.trigger_instance, "rule_db": self.rule} execution_db = None try: execution_db = self._do_enforce() # pylint: disable=no-member enforcement_db.execution_id = str(execution_db.id) enforcement_db.status = RULE_ENFORCEMENT_STATUS_SUCCEEDED - extra['execution_db'] = execution_db + extra["execution_db"] = execution_db except Exception as e: # Record the failure reason in the RuleEnforcement. enforcement_db.status = RULE_ENFORCEMENT_STATUS_FAILED enforcement_db.failure_reason = six.text_type(e) - LOG.exception('Failed kicking off execution for rule %s.', self.rule, extra=extra) + LOG.exception( + "Failed kicking off execution for rule %s.", self.rule, extra=extra + ) finally: self._update_enforcement(enforcement_db) # pylint: disable=no-member if not execution_db or execution_db.status not in EXEC_KICKED_OFF_STATES: - LOG.audit('Rule enforcement failed. Execution of Action %s failed. ' - 'TriggerInstance: %s and Rule: %s', - self.rule.action.ref, self.trigger_instance, self.rule, - extra=extra) + LOG.audit( + "Rule enforcement failed. Execution of Action %s failed. " + "TriggerInstance: %s and Rule: %s", + self.rule.action.ref, + self.trigger_instance, + self.rule, + extra=extra, + ) else: - LOG.audit('Rule enforced. Execution %s, TriggerInstance %s and Rule %s.', - execution_db, self.trigger_instance, self.rule, extra=extra) + LOG.audit( + "Rule enforced. Execution %s, TriggerInstance %s and Rule %s.", + execution_db, + self.trigger_instance, + self.rule, + extra=extra, + ) return execution_db def _do_enforce(self): # TODO: Refactor this to avoid additional lookup in cast_params - action_ref = self.rule.action['ref'] + action_ref = self.rule.action["ref"] # Verify action referenced in the rule exists in the database action_db = action_utils.get_action_by_ref(action_ref) if not action_db: raise ValueError('Action "%s" doesn\'t exist' % (action_ref)) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) params = self.rule.action.parameters - LOG.info('Invoking action %s for trigger_instance %s with params %s.', - self.rule.action.ref, self.trigger_instance.id, - json.dumps(params)) + LOG.info( + "Invoking action %s for trigger_instance %s with params %s.", + self.rule.action.ref, + self.trigger_instance.id, + json.dumps(params), + ) # update trace before invoking the action. trace_context = self._update_trace() - LOG.debug('Updated trace %s with rule %s.', trace_context, self.rule.id) + LOG.debug("Updated trace %s with rule %s.", trace_context, self.rule.id) context, additional_contexts = self.get_action_execution_context( - action_db=action_db, - trace_context=trace_context) + action_db=action_db, trace_context=trace_context + ) - return self._invoke_action(action_db=action_db, runnertype_db=runnertype_db, params=params, - context=context, - additional_contexts=additional_contexts) + return self._invoke_action( + action_db=action_db, + runnertype_db=runnertype_db, + params=params, + context=context, + additional_contexts=additional_contexts, + ) def _update_trace(self): """ @@ -154,9 +176,13 @@ def _update_trace(self): """ trace_db = None try: - trace_db = trace_service.get_trace_db_by_trigger_instance(self.trigger_instance) + trace_db = trace_service.get_trace_db_by_trigger_instance( + self.trigger_instance + ) except: - LOG.exception('No Trace found for TriggerInstance %s.', self.trigger_instance.id) + LOG.exception( + "No Trace found for TriggerInstance %s.", self.trigger_instance.id + ) return None # This would signify some sort of coding error so assert. @@ -165,19 +191,23 @@ def _update_trace(self): trace_db = trace_service.add_or_update_given_trace_db( trace_db=trace_db, rules=[ - trace_service.get_trace_component_for_rule(self.rule, self.trigger_instance) - ]) + trace_service.get_trace_component_for_rule( + self.rule, self.trigger_instance + ) + ], + ) return vars(TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag)) def _update_enforcement(self, enforcement_db): try: RuleEnforcement.add_or_update(enforcement_db) except: - extra = {'enforcement_db': enforcement_db} - LOG.exception('Failed writing enforcement model to db.', extra=extra) + extra = {"enforcement_db": enforcement_db} + LOG.exception("Failed writing enforcement model to db.", extra=extra) - def _invoke_action(self, action_db, runnertype_db, params, context=None, - additional_contexts=None): + def _invoke_action( + self, action_db, runnertype_db, params, context=None, additional_contexts=None + ): """ Schedule an action execution. @@ -189,9 +219,13 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, :rtype: :class:`LiveActionDB` on successful scheduling, None otherwise. """ action_ref = action_db.ref - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) - liveaction_db = LiveActionDB(action=action_ref, context=context, parameters=params) + liveaction_db = LiveActionDB( + action=action_ref, context=context, parameters=params + ) try: liveaction_db.parameters = self.get_resolved_parameters( @@ -199,7 +233,8 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, action_db=action_db, params=liveaction_db.parameters, context=liveaction_db.context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) except param_exc.ParamException as e: # We still need to create a request, so liveaction_db is assigned an ID liveaction_db, execution_db = action_service.create_request(liveaction_db) @@ -209,8 +244,11 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, action_service.update_status( liveaction=liveaction_db, new_status=action_constants.LIVEACTION_STATUS_FAILED, - result={'error': six.text_type(e), - 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": six.text_type(e), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) # Might be a good idea to return the actual ActionExecution rather than bubble up # the exception. diff --git a/st2reactor/st2reactor/rules/engine.py b/st2reactor/st2reactor/rules/engine.py index 1d50d01c9e..453a0457da 100644 --- a/st2reactor/st2reactor/rules/engine.py +++ b/st2reactor/st2reactor/rules/engine.py @@ -21,11 +21,9 @@ from st2reactor.rules.matcher import RulesMatcher from st2common.metrics.base import get_driver -LOG = logging.getLogger('st2reactor.rules.RulesEngine') +LOG = logging.getLogger("st2reactor.rules.RulesEngine") -__all__ = [ - 'RulesEngine' -] +__all__ = ["RulesEngine"] class RulesEngine(object): @@ -40,7 +38,10 @@ def handle_trigger_instance(self, trigger_instance): # Enforce the rules. self.enforce_rules(enforcers) else: - LOG.info('No matching rules found for trigger instance %s.', trigger_instance['id']) + LOG.info( + "No matching rules found for trigger instance %s.", + trigger_instance["id"], + ) def get_matching_rules_for_trigger(self, trigger_instance): trigger = trigger_instance.trigger @@ -48,23 +49,34 @@ def get_matching_rules_for_trigger(self, trigger_instance): trigger_db = get_trigger_db_by_ref(trigger_instance.trigger) if not trigger_db: - LOG.error('No matching trigger found in db for trigger instance %s.', trigger_instance) + LOG.error( + "No matching trigger found in db for trigger instance %s.", + trigger_instance, + ) return None rules = get_rules_given_trigger(trigger=trigger) - LOG.info('Found %d rules defined for trigger %s', len(rules), - trigger_db.get_reference().ref) + LOG.info( + "Found %d rules defined for trigger %s", + len(rules), + trigger_db.get_reference().ref, + ) if len(rules) < 1: return rules - matcher = RulesMatcher(trigger_instance=trigger_instance, - trigger=trigger_db, rules=rules) + matcher = RulesMatcher( + trigger_instance=trigger_instance, trigger=trigger_db, rules=rules + ) matching_rules = matcher.get_matching_rules() - LOG.info('Matched %s rule(s) for trigger_instance %s (trigger=%s)', len(matching_rules), - trigger_instance['id'], trigger_db.ref) + LOG.info( + "Matched %s rule(s) for trigger_instance %s (trigger=%s)", + len(matching_rules), + trigger_instance["id"], + trigger_db.ref, + ) return matching_rules def create_rule_enforcers(self, trigger_instance, matching_rules): @@ -78,8 +90,8 @@ def create_rule_enforcers(self, trigger_instance, matching_rules): enforcers = [] for matching_rule in matching_rules: - metrics_driver.inc_counter('rule.matched') - metrics_driver.inc_counter('rule.%s.matched' % (matching_rule.ref)) + metrics_driver.inc_counter("rule.matched") + metrics_driver.inc_counter("rule.%s.matched" % (matching_rule.ref)) enforcers.append(RuleEnforcer(trigger_instance, matching_rule)) return enforcers @@ -89,4 +101,4 @@ def enforce_rules(self, enforcers): try: enforcer.enforce() # Should this happen in an eventlet pool? except: - LOG.exception('Exception enforcing rule %s.', enforcer.rule) + LOG.exception("Exception enforcing rule %s.", enforcer.rule) diff --git a/st2reactor/st2reactor/rules/filter.py b/st2reactor/st2reactor/rules/filter.py index 1c67538198..700d072c31 100644 --- a/st2reactor/st2reactor/rules/filter.py +++ b/st2reactor/st2reactor/rules/filter.py @@ -31,12 +31,10 @@ from st2common.util.payload import PayloadLookup from st2common.util.templating import render_template_with_system_context -__all__ = [ - 'RuleFilter' -] +__all__ = ["RuleFilter"] -LOG = logging.getLogger('st2reactor.ruleenforcement.filter') +LOG = logging.getLogger("st2reactor.ruleenforcement.filter") class RuleFilter(object): @@ -58,9 +56,9 @@ def __init__(self, trigger_instance, trigger, rule, extra_info=False): # Base context used with a logger self._base_logger_context = { - 'rule': self.rule, - 'trigger': self.trigger, - 'trigger_instance': self.trigger_instance + "rule": self.rule, + "trigger": self.trigger, + "trigger_instance": self.trigger_instance, } def filter(self): @@ -69,12 +67,18 @@ def filter(self): :rtype: ``bool`` """ - LOG.info('Validating rule %s for %s.', self.rule.ref, self.trigger['name'], - extra=self._base_logger_context) + LOG.info( + "Validating rule %s for %s.", + self.rule.ref, + self.trigger["name"], + extra=self._base_logger_context, + ) if not self.rule.enabled: if self.extra_info: - LOG.info('Validation failed for rule %s as it is disabled.', self.rule.ref) + LOG.info( + "Validation failed for rule %s as it is disabled.", self.rule.ref + ) return False criteria = self.rule.criteria @@ -85,52 +89,66 @@ def filter(self): payload_lookup = PayloadLookup(self.trigger_instance.payload) - LOG.debug('Trigger payload: %s', self.trigger_instance.payload, - extra=self._base_logger_context) + LOG.debug( + "Trigger payload: %s", + self.trigger_instance.payload, + extra=self._base_logger_context, + ) for (criterion_k, criterion_v) in six.iteritems(criteria): - is_rule_applicable, payload_value, criterion_pattern = self._check_criterion( - criterion_k, - criterion_v, - payload_lookup - ) + ( + is_rule_applicable, + payload_value, + criterion_pattern, + ) = self._check_criterion(criterion_k, criterion_v, payload_lookup) if not is_rule_applicable: if self.extra_info: - criteria_extra_info = '\n'.join([ - ' key: %s' % criterion_k, - ' pattern: %s' % criterion_pattern, - ' type: %s' % criterion_v['type'], - ' payload: %s' % payload_value - ]) - LOG.info('Validation for rule %s failed on criteria -\n%s', self.rule.ref, - criteria_extra_info, - extra=self._base_logger_context) + criteria_extra_info = "\n".join( + [ + " key: %s" % criterion_k, + " pattern: %s" % criterion_pattern, + " type: %s" % criterion_v["type"], + " payload: %s" % payload_value, + ] + ) + LOG.info( + "Validation for rule %s failed on criteria -\n%s", + self.rule.ref, + criteria_extra_info, + extra=self._base_logger_context, + ) break if not is_rule_applicable: - LOG.debug('Rule %s not applicable for %s.', self.rule.id, self.trigger['name'], - extra=self._base_logger_context) + LOG.debug( + "Rule %s not applicable for %s.", + self.rule.id, + self.trigger["name"], + extra=self._base_logger_context, + ) return is_rule_applicable def _check_criterion(self, criterion_k, criterion_v, payload_lookup): - if 'type' not in criterion_v: + if "type" not in criterion_v: # Comparison operator type not specified, can't perform a comparison return (False, None, None) - criteria_operator = criterion_v['type'] - criteria_condition = criterion_v.get('condition', None) - criteria_pattern = criterion_v.get('pattern', None) + criteria_operator = criterion_v["type"] + criteria_condition = criterion_v.get("condition", None) + criteria_pattern = criterion_v.get("pattern", None) # Render the pattern (it can contain a jinja expressions) try: criteria_pattern = self._render_criteria_pattern( criteria_pattern=criteria_pattern, - criteria_context=payload_lookup.context + criteria_context=payload_lookup.context, ) except Exception as e: - msg = ('Failed to render pattern value "%s" for key "%s"' % (criteria_pattern, - criterion_k)) + msg = 'Failed to render pattern value "%s" for key "%s"' % ( + criteria_pattern, + criterion_k, + ) LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -144,7 +162,7 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup): else: payload_value = None except Exception as e: - msg = ('Failed transforming criteria key %s' % criterion_k) + msg = "Failed transforming criteria key %s" % criterion_k LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -154,13 +172,18 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup): try: if criteria_operator == criteria_operators.SEARCH: - result = op_func(value=payload_value, criteria_pattern=criteria_pattern, - criteria_condition=criteria_condition, - check_function=self._bool_criterion) + result = op_func( + value=payload_value, + criteria_pattern=criteria_pattern, + criteria_condition=criteria_condition, + check_function=self._bool_criterion, + ) else: result = op_func(value=payload_value, criteria_pattern=criteria_pattern) except Exception as e: - msg = ('There might be a problem with the criteria in rule %s' % (self.rule.ref)) + msg = "There might be a problem with the criteria in rule %s" % ( + self.rule.ref + ) LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -185,9 +208,9 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context): return criteria_pattern LOG.debug( - 'Rendering criteria pattern (%s) with context: %s', + "Rendering criteria pattern (%s) with context: %s", criteria_pattern, - criteria_context + criteria_context, ) to_complex = False @@ -197,30 +220,24 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context): if len(re.findall(MATCH_CRITERIA, criteria_pattern)) > 0: LOG.debug("Rendering Complex") complex_criteria_pattern = re.sub( - MATCH_CRITERIA, r'\1\2 | to_complex\3', - criteria_pattern + MATCH_CRITERIA, r"\1\2 | to_complex\3", criteria_pattern ) try: criteria_rendered = render_template_with_system_context( - value=complex_criteria_pattern, - context=criteria_context + value=complex_criteria_pattern, context=criteria_context ) criteria_rendered = json.loads(criteria_rendered) to_complex = True except ValueError as error: - LOG.debug('Criteria pattern not valid JSON: %s', error) + LOG.debug("Criteria pattern not valid JSON: %s", error) if not to_complex: criteria_rendered = render_template_with_system_context( - value=criteria_pattern, - context=criteria_context + value=criteria_pattern, context=criteria_context ) - LOG.debug( - 'Rendered criteria pattern: %s', - criteria_rendered - ) + LOG.debug("Rendered criteria pattern: %s", criteria_rendered) return criteria_rendered @@ -231,19 +248,32 @@ def _create_rule_enforcement(self, failure_reason, exc): Without that, only way for users to find out about those failes matches is by inspecting the logs. """ - failure_reason = ('Failed to match rule "%s" against trigger instance "%s": %s: %s' % - (self.rule.ref, str(self.trigger_instance.id), failure_reason, str(exc))) - rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid} - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id), - rule=rule_spec, - failure_reason=failure_reason, - status=RULE_ENFORCEMENT_STATUS_FAILED) + failure_reason = ( + 'Failed to match rule "%s" against trigger instance "%s": %s: %s' + % ( + self.rule.ref, + str(self.trigger_instance.id), + failure_reason, + str(exc), + ) + ) + rule_spec = { + "ref": self.rule.ref, + "id": str(self.rule.id), + "uid": self.rule.uid, + } + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(self.trigger_instance.id), + rule=rule_spec, + failure_reason=failure_reason, + status=RULE_ENFORCEMENT_STATUS_FAILED, + ) try: RuleEnforcement.add_or_update(enforcement_db) except: - extra = {'enforcement_db': enforcement_db} - LOG.exception('Failed writing enforcement model to db.', extra=extra) + extra = {"enforcement_db": enforcement_db} + LOG.exception("Failed writing enforcement model to db.", extra=extra) return enforcement_db @@ -253,6 +283,7 @@ class SecondPassRuleFilter(RuleFilter): Special filter that handles all second pass rules. For not these are only backstop rules i.e. those that can match when no other rule has matched. """ + def __init__(self, trigger_instance, trigger, rule, first_pass_matched): """ :param trigger_instance: TriggerInstance DB object. @@ -277,4 +308,4 @@ def filter(self): return super(SecondPassRuleFilter, self).filter() def _is_backstop_rule(self): - return self.rule.type['ref'] == RULE_TYPE_BACKSTOP + return self.rule.type["ref"] == RULE_TYPE_BACKSTOP diff --git a/st2reactor/st2reactor/rules/matcher.py b/st2reactor/st2reactor/rules/matcher.py index b2ed198945..4b3a8a2483 100644 --- a/st2reactor/st2reactor/rules/matcher.py +++ b/st2reactor/st2reactor/rules/matcher.py @@ -18,7 +18,7 @@ from st2common.constants.rules import RULE_TYPE_BACKSTOP from st2reactor.rules.filter import RuleFilter, SecondPassRuleFilter -LOG = logging.getLogger('st2reactor.rules.RulesMatcher') +LOG = logging.getLogger("st2reactor.rules.RulesMatcher") class RulesMatcher(object): @@ -31,25 +31,44 @@ def __init__(self, trigger_instance, trigger, rules, extra_info=False): def get_matching_rules(self): first_pass, second_pass = self._split_rules_into_passes() # first pass - rule_filters = [RuleFilter(trigger_instance=self.trigger_instance, - trigger=self.trigger, - rule=rule, - extra_info=self.extra_info) - for rule in first_pass] - matched_rules = [rule_filter.rule for rule_filter in rule_filters if rule_filter.filter()] - LOG.debug('[1st_pass] %d rule(s) found to enforce for %s.', len(matched_rules), - self.trigger['name']) + rule_filters = [ + RuleFilter( + trigger_instance=self.trigger_instance, + trigger=self.trigger, + rule=rule, + extra_info=self.extra_info, + ) + for rule in first_pass + ] + matched_rules = [ + rule_filter.rule for rule_filter in rule_filters if rule_filter.filter() + ] + LOG.debug( + "[1st_pass] %d rule(s) found to enforce for %s.", + len(matched_rules), + self.trigger["name"], + ) # second pass - rule_filters = [SecondPassRuleFilter(self.trigger_instance, self.trigger, rule, - matched_rules) - for rule in second_pass] - matched_in_second_pass = [rule_filter.rule for rule_filter in rule_filters - if rule_filter.filter()] - LOG.debug('[2nd_pass] %d rule(s) found to enforce for %s.', len(matched_in_second_pass), - self.trigger['name']) + rule_filters = [ + SecondPassRuleFilter( + self.trigger_instance, self.trigger, rule, matched_rules + ) + for rule in second_pass + ] + matched_in_second_pass = [ + rule_filter.rule for rule_filter in rule_filters if rule_filter.filter() + ] + LOG.debug( + "[2nd_pass] %d rule(s) found to enforce for %s.", + len(matched_in_second_pass), + self.trigger["name"], + ) matched_rules.extend(matched_in_second_pass) - LOG.info('%d rule(s) found to enforce for %s.', len(matched_rules), - self.trigger['name']) + LOG.info( + "%d rule(s) found to enforce for %s.", + len(matched_rules), + self.trigger["name"], + ) return matched_rules def _split_rules_into_passes(self): @@ -68,4 +87,4 @@ def _split_rules_into_passes(self): return first_pass, second_pass def _is_first_pass_rule(self, rule): - return rule.type['ref'] != RULE_TYPE_BACKSTOP + return rule.type["ref"] != RULE_TYPE_BACKSTOP diff --git a/st2reactor/st2reactor/rules/tester.py b/st2reactor/st2reactor/rules/tester.py index 790148d82d..da4e3572c5 100644 --- a/st2reactor/st2reactor/rules/tester.py +++ b/st2reactor/st2reactor/rules/tester.py @@ -32,16 +32,19 @@ from st2reactor.rules.enforcer import RuleEnforcer from st2reactor.rules.matcher import RulesMatcher -__all__ = [ - 'RuleTester' -] +__all__ = ["RuleTester"] LOG = logging.getLogger(__name__) class RuleTester(object): - def __init__(self, rule_file_path=None, rule_ref=None, trigger_instance_file_path=None, - trigger_instance_id=None): + def __init__( + self, + rule_file_path=None, + rule_ref=None, + trigger_instance_file_path=None, + trigger_instance_id=None, + ): """ :param rule_file_path: Path to the file containing rule definition. :type rule_file_path: ``str`` @@ -69,13 +72,20 @@ def evaluate(self): # The trigger check needs to be performed here as that is not performed # by RulesMatcher. if rule_db.trigger != trigger_db.ref: - LOG.info('rule.trigger "%s" and trigger.ref "%s" do not match.', - rule_db.trigger, trigger_db.ref) + LOG.info( + 'rule.trigger "%s" and trigger.ref "%s" do not match.', + rule_db.trigger, + trigger_db.ref, + ) return False # Check if rule matches criteria. - matcher = RulesMatcher(trigger_instance=trigger_instance_db, trigger=trigger_db, - rules=[rule_db], extra_info=True) + matcher = RulesMatcher( + trigger_instance=trigger_instance_db, + trigger=trigger_db, + rules=[rule_db], + extra_info=True, + ) matching_rules = matcher.get_matching_rules() # Rule does not match so early exit. @@ -91,69 +101,86 @@ def evaluate(self): action_db.parameters = {} params = rule_db.action.parameters # pylint: disable=no-member - context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db, - trace_context=None) + context, additional_contexts = enforcer.get_action_execution_context( + action_db=action_db, trace_context=None + ) # Note: We only return partially resolved parameters. # To be able to return all parameters we would need access to corresponding ActionDB, # RunnerTypeDB and ConfigDB object, but this would add a dependency on the database and the # tool is meant to be used standalone. try: - params = enforcer.get_resolved_parameters(action_db=action_db, - runnertype_db=runner_type_db, - params=params, - context=context, - additional_contexts=additional_contexts) - - LOG.info('Action parameters resolved to:') + params = enforcer.get_resolved_parameters( + action_db=action_db, + runnertype_db=runner_type_db, + params=params, + context=context, + additional_contexts=additional_contexts, + ) + + LOG.info("Action parameters resolved to:") for param in six.iteritems(params): - LOG.info('\t%s: %s', param[0], param[1]) + LOG.info("\t%s: %s", param[0], param[1]) return True except (UndefinedError, ValueError) as e: - LOG.error('Failed to resolve parameters\n\tOriginal error : %s', six.text_type(e)) + LOG.error( + "Failed to resolve parameters\n\tOriginal error : %s", six.text_type(e) + ) return False except: - LOG.exception('Failed to resolve parameters.') + LOG.exception("Failed to resolve parameters.") return False def _get_rule_db(self): if self._rule_file_path: return self._get_rule_db_from_file( - file_path=os.path.realpath(self._rule_file_path)) + file_path=os.path.realpath(self._rule_file_path) + ) elif self._rule_ref: return Rule.get_by_ref(self._rule_ref) - raise ValueError('One of _rule_file_path or _rule_ref should be specified.') + raise ValueError("One of _rule_file_path or _rule_ref should be specified.") def _get_trigger_instance_db(self): if self._trigger_instance_file_path: return self._get_trigger_instance_db_from_file( - file_path=os.path.realpath(self._trigger_instance_file_path)) + file_path=os.path.realpath(self._trigger_instance_file_path) + ) elif self._trigger_instance_id: trigger_instance_db = TriggerInstance.get_by_id(self._trigger_instance_id) trigger_db = Trigger.get_by_ref(trigger_instance_db.trigger) return trigger_instance_db, trigger_db - raise ValueError('One of _trigger_instance_file_path or' - '_trigger_instance_id should be specified.') + raise ValueError( + "One of _trigger_instance_file_path or" + "_trigger_instance_id should be specified." + ) def _get_rule_db_from_file(self, file_path): data = self._meta_loader.load(file_path=file_path) - pack = data.get('pack', 'unknown') - name = data.get('name', 'unknown') - trigger = data['trigger']['type'] - criteria = data.get('criteria', None) - action = data.get('action', {}) - - rule_db = RuleDB(pack=pack, name=name, trigger=trigger, criteria=criteria, action=action, - enabled=True) - rule_db.id = 'rule_tester_rule' + pack = data.get("pack", "unknown") + name = data.get("name", "unknown") + trigger = data["trigger"]["type"] + criteria = data.get("criteria", None) + action = data.get("action", {}) + + rule_db = RuleDB( + pack=pack, + name=name, + trigger=trigger, + criteria=criteria, + action=action, + enabled=True, + ) + rule_db.id = "rule_tester_rule" return rule_db def _get_trigger_instance_db_from_file(self, file_path): data = self._meta_loader.load(file_path=file_path) instance = TriggerInstanceDB(**data) - instance.id = 'rule_tester_instance' + instance.id = "rule_tester_instance" - trigger_ref = ResourceReference.from_string_reference(instance['trigger']) - trigger_db = TriggerDB(pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref) + trigger_ref = ResourceReference.from_string_reference(instance["trigger"]) + trigger_db = TriggerDB( + pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref + ) return instance, trigger_db diff --git a/st2reactor/st2reactor/rules/worker.py b/st2reactor/st2reactor/rules/worker.py index 7dbe4a59e1..53e636a346 100644 --- a/st2reactor/st2reactor/rules/worker.py +++ b/st2reactor/st2reactor/rules/worker.py @@ -41,12 +41,12 @@ def __init__(self, connection, queues): self.rules_engine = RulesEngine() def pre_ack_process(self, message): - ''' + """ TriggerInstance from message is create prior to acknowledging the message. This gets us a way to not acknowledge messages. - ''' - trigger = message['trigger'] - payload = message['payload'] + """ + trigger = message["trigger"] + payload = message["payload"] # Accomodate for not being able to create a TrigegrInstance if a TriggerDB # is not found. @@ -54,16 +54,19 @@ def pre_ack_process(self, message): trigger, payload or {}, date_utils.get_datetime_utc_now(), - raise_on_no_trigger=True) + raise_on_no_trigger=True, + ) return self._compose_pre_ack_process_response(trigger_instance, message) def process(self, pre_ack_response): - trigger_instance, message = self._decompose_pre_ack_process_response(pre_ack_response) + trigger_instance, message = self._decompose_pre_ack_process_response( + pre_ack_response + ) if not trigger_instance: - raise ValueError('No trigger_instance provided for processing.') + raise ValueError("No trigger_instance provided for processing.") - get_driver().inc_counter('trigger.%s.processed' % (trigger_instance.trigger)) + get_driver().inc_counter("trigger.%s.processed" % (trigger_instance.trigger)) try: # Use trace_context from the message and if not found create a new context @@ -71,34 +74,39 @@ def process(self, pre_ack_response): trace_context = message.get(TRACE_CONTEXT, None) if not trace_context: trace_context = { - TRACE_ID: 'trigger_instance-%s' % str(trigger_instance.id) + TRACE_ID: "trigger_instance-%s" % str(trigger_instance.id) } # add a trace or update an existing trace with trigger_instance trace_service.add_or_update_given_trace_context( trace_context=trace_context, trigger_instances=[ - trace_service.get_trace_component_for_trigger_instance(trigger_instance) - ] + trace_service.get_trace_component_for_trigger_instance( + trigger_instance + ) + ], ) container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING + ) - with CounterWithTimer(key='rule.processed'): - with Timer(key='trigger.%s.processed' % (trigger_instance.trigger)): + with CounterWithTimer(key="rule.processed"): + with Timer(key="trigger.%s.processed" % (trigger_instance.trigger)): self.rules_engine.handle_trigger_instance(trigger_instance) container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED + ) except: # TODO : Capture the reason for failure. container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED + ) # This could be a large message but at least in case of an exception # we get to see more context. # Beyond this point code cannot really handle the exception anyway so # eating up the exception. - LOG.exception('Failed to handle trigger_instance %s.', trigger_instance) + LOG.exception("Failed to handle trigger_instance %s.", trigger_instance) return @staticmethod @@ -106,14 +114,14 @@ def _compose_pre_ack_process_response(trigger_instance, message): """ Codify response of the pre_ack_process method. """ - return {'trigger_instance': trigger_instance, 'message': message} + return {"trigger_instance": trigger_instance, "message": message} @staticmethod def _decompose_pre_ack_process_response(response): """ Break-down response of pre_ack_process into constituents for simpler consumption. """ - return response.get('trigger_instance', None), response.get('message', None) + return response.get("trigger_instance", None), response.get("message", None) def get_worker(): diff --git a/st2reactor/st2reactor/sensor/base.py b/st2reactor/st2reactor/sensor/base.py index f7fce2460b..a8309ba292 100644 --- a/st2reactor/st2reactor/sensor/base.py +++ b/st2reactor/st2reactor/sensor/base.py @@ -21,10 +21,7 @@ from st2common.util import concurrency -__all__ = [ - 'Sensor', - 'PollingSensor' -] +__all__ = ["Sensor", "PollingSensor"] @six.add_metaclass(abc.ABCMeta) @@ -107,7 +104,9 @@ class PollingSensor(BaseSensor): """ def __init__(self, sensor_service, config=None, poll_interval=5): - super(PollingSensor, self).__init__(sensor_service=sensor_service, config=config) + super(PollingSensor, self).__init__( + sensor_service=sensor_service, config=config + ) self._poll_interval = poll_interval @abc.abstractmethod diff --git a/st2reactor/st2reactor/sensor/config.py b/st2reactor/st2reactor/sensor/config.py index 981ddd9b8f..8126bdbc9f 100644 --- a/st2reactor/st2reactor/sensor/config.py +++ b/st2reactor/st2reactor/sensor/config.py @@ -26,8 +26,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(ignore_errors=False): @@ -46,48 +49,62 @@ def _register_common_opts(ignore_errors=False): def _register_sensor_container_opts(ignore_errors=False): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.sensorcontainer.conf', - help='location of the logging.conf file') + "logging", + default="/etc/st2/logging.sensorcontainer.conf", + help="location of the logging.conf file", + ) ] - st2cfg.do_register_opts(logging_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + logging_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # Partitioning options partition_opts = [ cfg.StrOpt( - 'sensor_node_name', default='sensornode1', - help='name of the sensor node.'), + "sensor_node_name", default="sensornode1", help="name of the sensor node." + ), cfg.Opt( - 'partition_provider', + "partition_provider", type=types.Dict(value_type=types.String()), - default={'name': DEFAULT_PARTITION_LOADER}, - help='Provider of sensor node partition config.') + default={"name": DEFAULT_PARTITION_LOADER}, + help="Provider of sensor node partition config.", + ), ] - st2cfg.do_register_opts(partition_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + partition_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # Other options other_opts = [ cfg.BoolOpt( - 'single_sensor_mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single_sensor_mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ) ] - st2cfg.do_register_opts(other_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + other_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # CLI options cli_opts = [ cfg.StrOpt( - 'sensor-ref', - help='Only run sensor with the provided reference. Value is of the form ' - '. (e.g. linux.FileWatchSensor).'), + "sensor-ref", + help="Only run sensor with the provided reference. Value is of the form " + ". (e.g. linux.FileWatchSensor).", + ), cfg.BoolOpt( - 'single-sensor-mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single-sensor-mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ), ] st2cfg.do_register_cli_opts(cli_opts, ignore_errors=ignore_errors) diff --git a/st2reactor/st2reactor/timer/base.py b/st2reactor/st2reactor/timer/base.py index ed99d90e77..723d362066 100644 --- a/st2reactor/st2reactor/timer/base.py +++ b/st2reactor/st2reactor/timer/base.py @@ -41,17 +41,20 @@ class St2Timer(object): """ A timer interface that uses APScheduler 3.0. """ + def __init__(self, local_timezone=None): self._timezone = local_timezone self._scheduler = BlockingScheduler(timezone=self._timezone) self._jobs = {} self._trigger_types = list(TIMER_TRIGGER_TYPES.keys()) - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=self.__class__.__name__, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=self.__class__.__name__, + exclusive=True, + ) self._trigger_dispatcher = TriggerDispatcher(LOG) def start(self): @@ -70,89 +73,109 @@ def update_trigger(self, trigger): self.add_trigger(trigger) def remove_trigger(self, trigger): - trigger_id = trigger['id'] + trigger_id = trigger["id"] try: job_id = self._jobs[trigger_id] except KeyError: - LOG.info('Job not found: %s', trigger_id) + LOG.info("Job not found: %s", trigger_id) return self._scheduler.remove_job(job_id) del self._jobs[trigger_id] def _add_job_to_scheduler(self, trigger): - trigger_type_ref = trigger['type'] + trigger_type_ref = trigger["type"] trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref] try: - util_schema.validate(instance=trigger['parameters'], - schema=trigger_type['parameters_schema'], - cls=util_schema.CustomValidator, - use_default=True, - allow_default_none=True) + util_schema.validate( + instance=trigger["parameters"], + schema=trigger_type["parameters_schema"], + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) except jsonschema.ValidationError as e: - LOG.error('Exception scheduling timer: %s, %s', - trigger['parameters'], e, exc_info=True) + LOG.error( + "Exception scheduling timer: %s, %s", + trigger["parameters"], + e, + exc_info=True, + ) raise # Or should we just return? - time_spec = trigger['parameters'] - time_zone = aps_utils.astimezone(trigger['parameters'].get('timezone')) + time_spec = trigger["parameters"] + time_zone = aps_utils.astimezone(trigger["parameters"].get("timezone")) time_type = None - if trigger_type['name'] == 'st2.IntervalTimer': - unit = time_spec.get('unit', None) - value = time_spec.get('delta', None) - time_type = IntervalTrigger(**{unit: value, 'timezone': time_zone}) - elif trigger_type['name'] == 'st2.DateTimer': + if trigger_type["name"] == "st2.IntervalTimer": + unit = time_spec.get("unit", None) + value = time_spec.get("delta", None) + time_type = IntervalTrigger(**{unit: value, "timezone": time_zone}) + elif trigger_type["name"] == "st2.DateTimer": # Raises an exception if date string isn't a valid one. - dat = date_parser.parse(time_spec.get('date', None)) + dat = date_parser.parse(time_spec.get("date", None)) time_type = DateTrigger(dat, timezone=time_zone) - elif trigger_type['name'] == 'st2.CronTimer': + elif trigger_type["name"] == "st2.CronTimer": cron = time_spec.copy() - cron['timezone'] = time_zone + cron["timezone"] = time_zone time_type = CronTrigger(**cron) utc_now = date_utils.get_datetime_utc_now() - if hasattr(time_type, 'run_date') and utc_now > time_type.run_date: - LOG.warning('Not scheduling expired timer: %s : %s', - trigger['parameters'], time_type.run_date) + if hasattr(time_type, "run_date") and utc_now > time_type.run_date: + LOG.warning( + "Not scheduling expired timer: %s : %s", + trigger["parameters"], + time_type.run_date, + ) else: self._add_job(trigger, time_type) return time_type def _add_job(self, trigger, time_type, replace=True): try: - job = self._scheduler.add_job(self._emit_trigger_instance, - trigger=time_type, - args=[trigger], - replace_existing=replace) - LOG.info('Job %s scheduled.', job.id) - self._jobs[trigger['id']] = job.id + job = self._scheduler.add_job( + self._emit_trigger_instance, + trigger=time_type, + args=[trigger], + replace_existing=replace, + ) + LOG.info("Job %s scheduled.", job.id) + self._jobs[trigger["id"]] = job.id except Exception as e: - LOG.error('Exception scheduling timer: %s, %s', - trigger['parameters'], e, exc_info=True) + LOG.error( + "Exception scheduling timer: %s, %s", + trigger["parameters"], + e, + exc_info=True, + ) def _emit_trigger_instance(self, trigger): utc_now = date_utils.get_datetime_utc_now() # debug logging is reasonable for this one. A high resolution timer will end up # trashing standard logs. - LOG.debug('Timer fired at: %s. Trigger: %s', str(utc_now), trigger) + LOG.debug("Timer fired at: %s. Trigger: %s", str(utc_now), trigger) payload = { - 'executed_at': str(utc_now), - 'schedule': trigger['parameters'].get('time') + "executed_at": str(utc_now), + "schedule": trigger["parameters"].get("time"), } - trace_context = TraceContext(trace_tag='%s-%s' % (self._get_trigger_type_name(trigger), - trigger.get('name', uuid.uuid4().hex))) + trace_context = TraceContext( + trace_tag="%s-%s" + % ( + self._get_trigger_type_name(trigger), + trigger.get("name", uuid.uuid4().hex), + ) + ) self._trigger_dispatcher.dispatch(trigger, payload, trace_context=trace_context) def _get_trigger_type_name(self, trigger): - trigger_type_ref = trigger['type'] + trigger_type_ref = trigger["type"] trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref] - return trigger_type['name'] + return trigger_type["name"] def _register_timer_trigger_types(self): return trigger_services.add_trigger_models(list(TIMER_TRIGGER_TYPES.values())) diff --git a/st2reactor/st2reactor/timer/config.py b/st2reactor/st2reactor/timer/config.py index db180f85dd..bbc1020cb9 100644 --- a/st2reactor/st2reactor/timer/config.py +++ b/st2reactor/st2reactor/timer/config.py @@ -25,8 +25,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): diff --git a/st2reactor/tests/integration/test_garbage_collector.py b/st2reactor/tests/integration/test_garbage_collector.py index 1de1e9c529..5b0f890ac3 100644 --- a/st2reactor/tests/integration/test_garbage_collector.py +++ b/st2reactor/tests/integration/test_garbage_collector.py @@ -37,33 +37,28 @@ from st2tests.fixturesloader import FixturesLoader from six.moves import range -__all__ = [ - 'GarbageCollectorServiceTestCase' -] +__all__ = ["GarbageCollectorServiceTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) -INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests2.conf') +INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests2.conf") INQUIRY_CONFIG_PATH = os.path.abspath(INQUIRY_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2garbagecollector') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2garbagecollector") BINARY = os.path.abspath(BINARY) -CMD = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH] -CMD_INQUIRY = [PYTHON_BINARY, BINARY, '--config-file', INQUIRY_CONFIG_PATH] +CMD = [PYTHON_BINARY, BINARY, "--config-file", ST2_CONFIG_PATH] +CMD_INQUIRY = [PYTHON_BINARY, BINARY, "--config-file", INQUIRY_CONFIG_PATH] -TEST_FIXTURES = { - 'runners': ['inquirer.yaml'], - 'actions': ['ask.yaml'] -} +TEST_FIXTURES = {"runners": ["inquirer.yaml"], "actions": ["ask.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" class GarbageCollectorServiceTestCase(IntegrationTestCase, CleanDbTestCase): @@ -75,7 +70,8 @@ def setUp(self): super(GarbageCollectorServiceTestCase, self).setUp() self.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) def test_garbage_collection(self): now = date_utils.get_datetime_utc_now() @@ -85,102 +81,125 @@ def test_garbage_collection(self): # config old_executions_count = 15 ttl_days = 30 # > 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, old_executions_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) # Insert come mock ActionExecutionDB objects with start_timestamp > TTL defined in the # config new_executions_count = 5 ttl_days = 2 # < 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, new_executions_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) # Insert some mock output objects where start_timestamp > action_executions_output_ttl new_output_count = 5 ttl_days = 15 # > 10 and < 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, new_output_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) execs = ActionExecution.get_all() - self.assertEqual(len(execs), - (old_executions_count + new_executions_count + new_output_count)) - - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') - self.assertEqual(len(stdout_dbs), - (old_executions_count + new_executions_count + new_output_count)) - - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') - self.assertEqual(len(stderr_dbs), - (old_executions_count + new_executions_count + new_output_count)) + self.assertEqual( + len(execs), (old_executions_count + new_executions_count + new_output_count) + ) + + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") + self.assertEqual( + len(stdout_dbs), + (old_executions_count + new_executions_count + new_output_count), + ) + + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") + self.assertEqual( + len(stderr_dbs), + (old_executions_count + new_executions_count + new_output_count), + ) # Start garbage collector process = self._start_garbage_collector() @@ -196,10 +215,10 @@ def test_garbage_collection(self): # Collection for output objects older than 10 days is also enabled, so those objects # should be deleted as well - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), (new_executions_count)) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), (new_executions_count)) def test_inquiry_garbage_collection(self): @@ -207,28 +226,28 @@ def test_inquiry_garbage_collection(self): # Insert some mock Inquiries with start_timestamp > TTL old_inquiry_count = 15 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, old_inquiry_count): self._create_inquiry(ttl=2, timestamp=timestamp) # Insert some mock Inquiries with TTL set to a "disabled" value disabled_inquiry_count = 3 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, disabled_inquiry_count): self._create_inquiry(ttl=0, timestamp=timestamp) # Insert some mock Inquiries with start_timestamp < TTL new_inquiry_count = 5 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, new_inquiry_count): self._create_inquiry(ttl=15, timestamp=timestamp) - filters = { - 'status': action_constants.LIVEACTION_STATUS_PENDING - } + filters = {"status": action_constants.LIVEACTION_STATUS_PENDING} inquiries = list(ActionExecution.query(**filters)) - self.assertEqual(len(inquiries), - (old_inquiry_count + new_inquiry_count + disabled_inquiry_count)) + self.assertEqual( + len(inquiries), + (old_inquiry_count + new_inquiry_count + disabled_inquiry_count), + ) # Start garbage collector process = self._start_garbage_collector() @@ -243,18 +262,25 @@ def test_inquiry_garbage_collection(self): self.assertEqual(len(inquiries), new_inquiry_count + disabled_inquiry_count) def _create_inquiry(self, ttl, timestamp): - action_db = self.models['actions']['ask.yaml'] + action_db = self.models["actions"]["ask.yaml"] liveaction_db = LiveActionDB() liveaction_db.status = action_constants.LIVEACTION_STATUS_PENDING liveaction_db.start_timestamp = timestamp - liveaction_db.action = ResourceReference(name=action_db.name, pack=action_db.pack).ref - liveaction_db.result = {'ttl': ttl} + liveaction_db.action = ResourceReference( + name=action_db.name, pack=action_db.pack + ).ref + liveaction_db.result = {"ttl": ttl} liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) def _start_garbage_collector(self): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(CMD_INQUIRY, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + CMD_INQUIRY, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process diff --git a/st2reactor/tests/integration/test_rules_engine.py b/st2reactor/tests/integration/test_rules_engine.py index 669a88797f..05ebce5e9e 100644 --- a/st2reactor/tests/integration/test_rules_engine.py +++ b/st2reactor/tests/integration/test_rules_engine.py @@ -26,18 +26,16 @@ from st2tests.base import IntegrationTestCase from st2tests.base import CleanDbTestCase -__all__ = [ - 'TimersEngineServiceEnableDisableTestCase' -] +__all__ = ["TimersEngineServiceEnableDisableTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2timersengine') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2timersengine") BINARY = os.path.abspath(BINARY) -CMD = [PYTHON_BINARY, BINARY, '--config-file'] +CMD = [PYTHON_BINARY, BINARY, "--config-file"] class TimersEngineServiceEnableDisableTestCase(IntegrationTestCase, CleanDbTestCase): @@ -46,7 +44,7 @@ def setUp(self): config_text = open(ST2_CONFIG_PATH).read() self.cfg_fd, self.cfg_path = tempfile.mkstemp() - with open(self.cfg_path, 'w') as f: + with open(self.cfg_path, "w") as f: f.write(config_text) self.cmd = [] self.cmd.extend(CMD) @@ -65,7 +63,7 @@ def test_timer_enable_implicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -78,12 +76,15 @@ def test_timer_enable_implicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_ENABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE) + ) def test_timer_enable_explicit(self): - self._append_to_cfg_file(cfg_path=self.cfg_path, - content='\n[timersengine]\nenable = True\n[timer]\nenable = True') + self._append_to_cfg_file( + cfg_path=self.cfg_path, + content="\n[timersengine]\nenable = True\n[timer]\nenable = True", + ) process = None seen_line = False @@ -91,7 +92,7 @@ def test_timer_enable_explicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -104,12 +105,15 @@ def test_timer_enable_explicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_ENABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE) + ) def test_timer_disable_explicit(self): - self._append_to_cfg_file(cfg_path=self.cfg_path, - content='\n[timersengine]\nenable = False\n[timer]\nenable = False') + self._append_to_cfg_file( + cfg_path=self.cfg_path, + content="\n[timersengine]\nenable = False\n[timer]\nenable = False", + ) process = None seen_line = False @@ -117,7 +121,7 @@ def test_timer_disable_explicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -130,18 +134,24 @@ def test_timer_disable_explicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_DISABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_DISABLED_LOG_LINE) + ) def _start_times_engine(self, cmd): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process def _append_to_cfg_file(self, cfg_path, content): - with open(cfg_path, 'a') as f: + with open(cfg_path, "a") as f: f.write(content) def _remove_tempfile(self, fd, path): diff --git a/st2reactor/tests/integration/test_sensor_container.py b/st2reactor/tests/integration/test_sensor_container.py index 7971e36106..41eb3307bc 100644 --- a/st2reactor/tests/integration/test_sensor_container.py +++ b/st2reactor/tests/integration/test_sensor_container.py @@ -30,28 +30,26 @@ from st2common.bootstrap.sensorsregistrar import register_sensors from st2tests.base import IntegrationTestCase -__all__ = [ - 'SensorContainerTestCase' -] +__all__ = ["SensorContainerTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2sensorcontainer') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2sensorcontainer") BINARY = os.path.abspath(BINARY) -PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, '../../../contrib')) +PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, "../../../contrib")) DEFAULT_CMD = [ PYTHON_BINARY, BINARY, - '--config-file', + "--config-file", ST2_CONFIG_PATH, - '--sensor-ref=examples.SamplePollingSensor' + "--sensor-ref=examples.SamplePollingSensor", ] @@ -69,11 +67,24 @@ def setUpClass(cls): st2tests.config.parse_args() - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) cls.db_connection = db_setup( - cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, ensure_indexes=False) + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ensure_indexes=False, + ) # NOTE: We need to perform this patching because test fixtures are located outside of the # packs base paths directory. This will never happen outside the context of test fixtures. @@ -83,11 +94,17 @@ def setUpClass(cls): register_sensors(packs_base_paths=[PACKS_BASE_PATH], use_pack_cache=False) # Create virtualenv for examples pack - virtualenv_path = '/tmp/virtualenvs/examples' + virtualenv_path = "/tmp/virtualenvs/examples" - run_command(cmd=['rm', '-rf', virtualenv_path]) + run_command(cmd=["rm", "-rf", virtualenv_path]) - cmd = ['virtualenv', '--system-site-packages', '--python', PYTHON_BINARY, virtualenv_path] + cmd = [ + "virtualenv", + "--system-site-packages", + "--python", + PYTHON_BINARY, + virtualenv_path, + ] run_command(cmd=cmd) def test_child_processes_are_killed_on_sigint(self): @@ -169,7 +186,13 @@ def test_child_processes_are_killed_on_sigkill(self): def test_single_sensor_mode(self): # 1. --sensor-ref not provided - cmd = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode'] + cmd = [ + PYTHON_BINARY, + BINARY, + "--config-file", + ST2_CONFIG_PATH, + "--single-sensor-mode", + ] process = self._start_sensor_container(cmd=cmd) pp = psutil.Process(process.pid) @@ -178,14 +201,24 @@ def test_single_sensor_mode(self): concurrency.sleep(4) stdout = process.stdout.read() - self.assertTrue((b'--sensor-ref argument must be provided when running in single sensor ' - b'mode') in stdout) + self.assertTrue( + ( + b"--sensor-ref argument must be provided when running in single sensor " + b"mode" + ) + in stdout + ) self.assertProcessExited(proc=pp) self.remove_process(process=process) # 2. sensor ref provided - cmd = [BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode', - '--sensor-ref=examples.SampleSensorExit'] + cmd = [ + BINARY, + "--config-file", + ST2_CONFIG_PATH, + "--single-sensor-mode", + "--sensor-ref=examples.SampleSensorExit", + ] process = self._start_sensor_container(cmd=cmd) pp = psutil.Process(process.pid) @@ -196,9 +229,11 @@ def test_single_sensor_mode(self): # Container should exit and not respawn a sensor in single sensor mode stdout = process.stdout.read() - self.assertTrue(b'Process for sensor examples.SampleSensorExit has exited with code 110') - self.assertTrue(b'Not respawning a sensor since running in single sensor mode') - self.assertTrue(b'Process container quit with exit_code 110.') + self.assertTrue( + b"Process for sensor examples.SampleSensorExit has exited with code 110" + ) + self.assertTrue(b"Not respawning a sensor since running in single sensor mode") + self.assertTrue(b"Process container quit with exit_code 110.") concurrency.sleep(2) self.assertProcessExited(proc=pp) @@ -207,7 +242,12 @@ def test_single_sensor_mode(self): def _start_sensor_container(self, cmd=DEFAULT_CMD): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process diff --git a/st2reactor/tests/integration/test_sensor_watcher.py b/st2reactor/tests/integration/test_sensor_watcher.py index 9727da92a6..6caee09c7f 100644 --- a/st2reactor/tests/integration/test_sensor_watcher.py +++ b/st2reactor/tests/integration/test_sensor_watcher.py @@ -22,19 +22,15 @@ from st2common.services.sensor_watcher import SensorWatcher from st2tests.base import IntegrationTestCase -__all__ = [ - 'SensorWatcherTestCase' -] +__all__ = ["SensorWatcherTestCase"] class SensorWatcherTestCase(IntegrationTestCase): - @classmethod def setUpClass(cls): super(SensorWatcherTestCase, cls).setUpClass() def test_sensor_watch_queue_gets_deleted_on_stop(self): - def create_handler(sensor_db): pass @@ -44,25 +40,32 @@ def update_handler(sensor_db): def delete_handler(sensor_db): pass - sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler, - queue_suffix='covfefe') + sensor_watcher = SensorWatcher( + create_handler, update_handler, delete_handler, queue_suffix="covfefe" + ) sensor_watcher.start() - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) start = monotonic() done = False while not done: concurrency.sleep(0.01) - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) done = len(sw_queues) > 0 or ((monotonic() - start) < 5) sensor_watcher.stop() - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) self.assertTrue(len(sw_queues) == 0) def _list_amqp_queues(self): - rabbit_client = Client('localhost:15672', 'guest', 'guest') - queues = [q['name'] for q in rabbit_client.get_queues()] + rabbit_client = Client("localhost:15672", "guest", "guest") + queues = [q["name"] for q in rabbit_client.get_queues()] return queues def _get_sensor_watcher_amqp_queues(self, queue_name): diff --git a/st2reactor/tests/unit/test_container_utils.py b/st2reactor/tests/unit/test_container_utils.py index d8c14bf1d5..24d297ba7d 100644 --- a/st2reactor/tests/unit/test_container_utils.py +++ b/st2reactor/tests/unit/test_container_utils.py @@ -23,20 +23,25 @@ from st2tests.base import CleanDbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ContainerUtilsTest(CleanDbTestCase): def setUp(self): super(ContainerUtilsTest, self).setUp() # Insert mock TriggerDB - trigger_db = TriggerDB(name='name1', pack='pack1', type='type1', - parameters={'a': 1, 'b': '2', 'c': 'foo'}) + trigger_db = TriggerDB( + name="name1", + pack="pack1", + type="type1", + parameters={"a": 1, "b": "2", "c": "foo"}, + ) self.trigger_db = Trigger.add_or_update(trigger_db) def test_create_trigger_instance_invalid_trigger(self): - trigger_instance = 'dummy_pack.footrigger' - instance = create_trigger_instance(trigger=trigger_instance, payload={}, - occurrence_time=None) + trigger_instance = "dummy_pack.footrigger" + instance = create_trigger_instance( + trigger=trigger_instance, payload={}, occurrence_time=None + ) self.assertIsNone(instance) def test_create_trigger_instance_success(self): @@ -46,34 +51,40 @@ def test_create_trigger_instance_success(self): occurrence_time = None # TriggerDB look up by id - trigger = {'id': self.trigger_db.id} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) - self.assertEqual(trigger_instance_db.trigger, 'pack1.name1') + trigger = {"id": self.trigger_db.id} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) + self.assertEqual(trigger_instance_db.trigger, "pack1.name1") # Object doesn't exist (invalid id) - trigger = {'id': '5776aa2b0640fd2991b15987'} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"id": "5776aa2b0640fd2991b15987"} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) # TriggerDB look up by uid - trigger = {'uid': self.trigger_db.uid} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) - self.assertEqual(trigger_instance_db.trigger, 'pack1.name1') + trigger = {"uid": self.trigger_db.uid} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) + self.assertEqual(trigger_instance_db.trigger, "pack1.name1") - trigger = {'uid': 'invaliduid'} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"uid": "invaliduid"} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) # TriggerDB look up by type and parameters (last resort) - trigger = {'type': 'pack1.name1', 'parameters': self.trigger_db.parameters} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"type": "pack1.name1", "parameters": self.trigger_db.parameters} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) - trigger = {'type': 'pack1.name1', 'parameters': {}} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"type": "pack1.name1", "parameters": {}} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) diff --git a/st2reactor/tests/unit/test_enforce.py b/st2reactor/tests/unit/test_enforce.py index 174216dbb4..4b282305bd 100644 --- a/st2reactor/tests/unit/test_enforce.py +++ b/st2reactor/tests/unit/test_enforce.py @@ -38,62 +38,68 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RuleEnforcerTestCase', - 'RuleEnforcerDataTransformationTestCase' -] +__all__ = ["RuleEnforcerTestCase", "RuleEnforcerDataTransformationTestCase"] -PACK = 'generic' +PACK = "generic" FIXTURES_1 = { - 'runners': ['testrunner1.yaml', 'testrunner2.yaml'], - 'actions': ['action1.yaml', 'a2.yaml', 'a2_default_value.yaml'], - 'triggertypes': ['triggertype1.yaml'], - 'triggers': ['trigger1.yaml'], - 'traces': ['trace_for_test_enforce.yaml', 'trace_for_test_enforce_2.yaml', - 'trace_for_test_enforce_3.yaml'] + "runners": ["testrunner1.yaml", "testrunner2.yaml"], + "actions": ["action1.yaml", "a2.yaml", "a2_default_value.yaml"], + "triggertypes": ["triggertype1.yaml"], + "triggers": ["trigger1.yaml"], + "traces": [ + "trace_for_test_enforce.yaml", + "trace_for_test_enforce_2.yaml", + "trace_for_test_enforce_3.yaml", + ], } FIXTURES_2 = { - 'rules': [ - 'rule1.yaml', - 'rule2.yaml', - 'rule_use_none_filter.yaml', - 'rule_none_no_use_none_filter.yaml', - 'rule_action_default_value.yaml', - 'rule_action_default_value_overridden.yaml', - 'rule_action_default_value_render_fail.yaml' + "rules": [ + "rule1.yaml", + "rule2.yaml", + "rule_use_none_filter.yaml", + "rule_none_no_use_none_filter.yaml", + "rule_action_default_value.yaml", + "rule_action_default_value_overridden.yaml", + "rule_action_default_value_render_fail.yaml", ] } MOCK_TRIGGER_INSTANCE = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE.id = 'triggerinstance-test' -MOCK_TRIGGER_INSTANCE.payload = {'t1_p': 't1_p_v'} +MOCK_TRIGGER_INSTANCE.id = "triggerinstance-test" +MOCK_TRIGGER_INSTANCE.payload = {"t1_p": "t1_p_v"} MOCK_TRIGGER_INSTANCE.occurrence_time = date_utils.get_datetime_utc_now() MOCK_TRIGGER_INSTANCE_2 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_2.id = 'triggerinstance-test2' -MOCK_TRIGGER_INSTANCE_2.payload = {'t1_p': None} +MOCK_TRIGGER_INSTANCE_2.id = "triggerinstance-test2" +MOCK_TRIGGER_INSTANCE_2.payload = {"t1_p": None} MOCK_TRIGGER_INSTANCE_2.occurrence_time = date_utils.get_datetime_utc_now() MOCK_TRIGGER_INSTANCE_3 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_3.id = 'triggerinstance-test3' -MOCK_TRIGGER_INSTANCE_3.payload = {'t1_p': None, 't2_p': 'value2'} +MOCK_TRIGGER_INSTANCE_3.id = "triggerinstance-test3" +MOCK_TRIGGER_INSTANCE_3.payload = {"t1_p": None, "t2_p": "value2"} MOCK_TRIGGER_INSTANCE_3.occurrence_time = date_utils.get_datetime_utc_now() -MOCK_TRIGGER_INSTANCE_PAYLOAD = {'k1': 'v1', 'k2': 'v2', 'k3': 3, 'k4': True, - 'k5': {'foo': 'bar'}, 'k6': [1, 3]} +MOCK_TRIGGER_INSTANCE_PAYLOAD = { + "k1": "v1", + "k2": "v2", + "k3": 3, + "k4": True, + "k5": {"foo": "bar"}, + "k6": [1, 3], +} MOCK_TRIGGER_INSTANCE_4 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_4.id = 'triggerinstance-test4' +MOCK_TRIGGER_INSTANCE_4.id = "triggerinstance-test4" MOCK_TRIGGER_INSTANCE_4.payload = MOCK_TRIGGER_INSTANCE_PAYLOAD MOCK_TRIGGER_INSTANCE_4.occurrence_time = date_utils.get_datetime_utc_now() MOCK_LIVEACTION = LiveActionDB() -MOCK_LIVEACTION.id = 'liveaction-test-1.id' -MOCK_LIVEACTION.status = 'requested' +MOCK_LIVEACTION.id = "liveaction-test-1.id" +MOCK_LIVEACTION.status = "requested" MOCK_EXECUTION = ActionExecutionDB() -MOCK_EXECUTION.id = 'exec-test-1.id' -MOCK_EXECUTION.status = 'requested' +MOCK_EXECUTION.id = "exec-test-1.id" +MOCK_EXECUTION.status = "requested" FAILURE_REASON = "fail!" @@ -111,11 +117,16 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) - cls.models.update(FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_2)) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) + cls.models.update( + FixturesLoader().save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=FIXTURES_2 + ) + ) MOCK_TRIGGER_INSTANCE.trigger = reference.get_ref_from_model( - cls.models['triggers']['trigger1.yaml']) + cls.models["triggers"]["trigger1.yaml"] + ) def setUp(self): super(BaseRuleEnforcerTestCase, self).setUp() @@ -124,335 +135,445 @@ def setUp(self): class RuleEnforcerTestCase(BaseRuleEnforcerTestCase): - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) def test_ruleenforcement_occurs(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) def test_ruleenforcement_casts(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(action_service.request.called) - self.assertIsInstance(action_service.request.call_args[0][0].parameters['objtype'], dict) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertIsInstance( + action_service.request.call_args[0][0].parameters["objtype"], dict + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_ruleenforcement_create_on_success(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule2.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule2.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_rule_enforcement_create_rule_none_param_casting(self): mock_trigger_instance = MOCK_TRIGGER_INSTANCE_2 # 1. Non None value, should be serialized as regular string - mock_trigger_instance.payload = {'t1_p': 'somevalue'} + mock_trigger_instance.payload = {"t1_p": "somevalue"} def mock_cast_string(x): - assert x == 'somevalue' + assert x == "somevalue" return casts._cast_string(x) - casts.CASTS['string'] = mock_cast_string - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_use_none_filter.yaml']) + casts.CASTS["string"] = mock_cast_string + + enforcer = RuleEnforcer( + mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"] + ) execution_db = enforcer.enforce() # Verify value has been serialized correctly call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], 'somevalue') + self.assertEqual(live_action_db.parameters["actionstr"], "somevalue") self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) # 2. Verify that None type from trigger instance is correctly serialized to # None when using "use_none" Jinja filter when invoking an action - mock_trigger_instance.payload = {'t1_p': None} + mock_trigger_instance.payload = {"t1_p": None} def mock_cast_string(x): assert x == data.NONE_MAGIC_VALUE return casts._cast_string(x) - casts.CASTS['string'] = mock_cast_string - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_use_none_filter.yaml']) + casts.CASTS["string"] = mock_cast_string + + enforcer = RuleEnforcer( + mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"] + ) execution_db = enforcer.enforce() # Verify None has been correctly serialized to None call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], None) + self.assertEqual(live_action_db.parameters["actionstr"], None) self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) - casts.CASTS['string'] = casts._cast_string + casts.CASTS["string"] = casts._cast_string # 3. Parameter value is a compound string one of which values is None, but "use_none" # filter is not used mock_trigger_instance = MOCK_TRIGGER_INSTANCE_3 - mock_trigger_instance.payload = {'t1_p': None, 't2_p': 'value2'} + mock_trigger_instance.payload = {"t1_p": None, "t2_p": "value2"} - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_none_no_use_none_filter.yaml']) + enforcer = RuleEnforcer( + mock_trigger_instance, + self.models["rules"]["rule_none_no_use_none_filter.yaml"], + ) execution_db = enforcer.enforce() # Verify None has been correctly serialized to None call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], 'None-value2') + self.assertEqual(live_action_db.parameters["actionstr"], "None-value2") self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_none_no_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - casts.CASTS['string'] = casts._cast_string - - @mock.patch.object(action_service, 'request', mock.MagicMock( - side_effect=ValueError(FAILURE_REASON))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_none_no_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) + + casts.CASTS["string"] = casts._cast_string + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(side_effect=ValueError(FAILURE_REASON)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_ruleenforcement_create_on_fail(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason, - FAILURE_REASON) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_FAILED) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) - @mock.patch('st2common.util.param.get_config', - mock.Mock(return_value={'arrtype_value': ['one 1', 'two 2', 'three 3']})) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].failure_reason, FAILURE_REASON + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_FAILED, + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) + @mock.patch( + "st2common.util.param.get_config", + mock.Mock(return_value={"arrtype_value": ["one 1", "two 2", "three 3"]}), + ) def test_action_default_jinja_parameter_value_is_rendered(self): # Verify that a default action parameter which is a Jinja variable is correctly rendered - rule = self.models['rules']['rule_action_default_value.yaml'] + rule = self.models["rules"]["rule_action_default_value.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) call_parameters = action_service.request.call_args[0][0].parameters - self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'}) - self.assertEqual(call_parameters['strtype'], 't1_p_v') - self.assertEqual(call_parameters['arrtype'], ['one 1', 'two 2', 'three 3']) + self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"}) + self.assertEqual(call_parameters["strtype"], "t1_p_v") + self.assertEqual(call_parameters["arrtype"], ["one 1", "two 2", "three 3"]) - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_action_default_jinja_parameter_value_overridden_in_rule(self): # Verify that it works correctly if default parameter value is overridden in rule - rule = self.models['rules']['rule_action_default_value_overridden.yaml'] + rule = self.models["rules"]["rule_action_default_value_overridden.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) call_parameters = action_service.request.call_args[0][0].parameters - self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'}) - self.assertEqual(call_parameters['strtype'], 't1_p_v') - self.assertEqual(call_parameters['arrtype'], ['override 1', 'override 2']) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(action_service, 'create_request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(action_service, 'update_status', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"}) + self.assertEqual(call_parameters["strtype"], "t1_p_v") + self.assertEqual(call_parameters["arrtype"], ["override 1", "override 2"]) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object( + action_service, + "create_request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object( + action_service, + "update_status", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_action_default_jinja_parameter_value_render_fail(self): # Action parameter render failure should result in a failed execution - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_FAILED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_FAILED, + ) self.assertFalse(action_service.request.called) self.assertTrue(action_service.create_request.called) - self.assertEqual(action_service.create_request.call_args[0][0].action, - 'wolfpack.a2_default_value') + self.assertEqual( + action_service.create_request.call_args[0][0].action, + "wolfpack.a2_default_value", + ) self.assertTrue(action_service.update_status.called) - self.assertEqual(action_service.update_status.call_args[1]['new_status'], - action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + action_service.update_status.call_args[1]["new_status"], + action_constants.LIVEACTION_STATUS_FAILED, + ) - expected_msg = ('Failed to render parameter "arrtype": \'dict object\' has no ' - 'attribute \'arrtype_value\'') + expected_msg = ( + "Failed to render parameter \"arrtype\": 'dict object' has no " + "attribute 'arrtype_value'" + ) - result = action_service.update_status.call_args[1]['result'] - self.assertEqual(result['error'], expected_msg) + result = action_service.update_status.call_args[1]["result"] + self.assertEqual(result["error"], expected_msg) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason, - expected_msg) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].failure_reason, expected_msg + ) class RuleEnforcerDataTransformationTestCase(BaseRuleEnforcerTestCase): - def test_payload_data_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] - params = {'ip1': '{{trigger.k1}}-static', - 'ip2': '{{trigger.k2}} static'} + params = {"ip1": "{{trigger.k1}}-static", "ip2": "{{trigger.k2}} static"} - expected_params = {'ip1': 'v1-static', 'ip2': 'v2 static'} + expected_params = {"ip1": "v1-static", "ip2": "v2 static"} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_int_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] - params = {'int': 666} - expected_params = {'int': 666} + params = {"int": 666} + expected_params = {"int": 666} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_bool_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - params = {'bool': True} - expected_params = {'bool': True} + params = {"bool": True} + expected_params = {"bool": True} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_complex_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - params = {'complex_dict': {'bool': True, 'int': 666, 'str': '{{trigger.k1}}-string'}} - expected_params = {'complex_dict': {'bool': True, 'int': 666, 'str': 'v1-string'}} + params = { + "complex_dict": {"bool": True, "int": 666, "str": "{{trigger.k1}}-string"} + } + expected_params = { + "complex_dict": {"bool": True, "int": 666, "str": "v1-string"} + } - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) - params = {'simple_list': [1, 2, 3]} - expected_params = {'simple_list': [1, 2, 3]} + params = {"simple_list": [1, 2, 3]} + expected_params = {"simple_list": [1, 2, 3]} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_hypenated_payload_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] - payload = {'headers': {'hypenated-header': 'dont-care'}, 'k2': 'v2'} + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] + payload = {"headers": {"hypenated-header": "dont-care"}, "k2": "v2"} MOCK_TRIGGER_INSTANCE_4.payload = payload - params = {'ip1': '{{trigger.headers[\'hypenated-header\']}}-static', - 'ip2': '{{trigger.k2}} static'} - expected_params = {'ip1': 'dont-care-static', 'ip2': 'v2 static'} - - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + params = { + "ip1": "{{trigger.headers['hypenated-header']}}-static", + "ip2": "{{trigger.k2}} static", + } + expected_params = {"ip1": "dont-care-static", "ip2": "v2 static"} + + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_system_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - k5 = KeyValuePair.add_or_update(KeyValuePairDB(name='k5', value='v5')) - k6 = KeyValuePair.add_or_update(KeyValuePairDB(name='k6', value='v6')) - k7 = KeyValuePair.add_or_update(KeyValuePairDB(name='k7', value='v7')) - k8 = KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value='v8', - scope=FULL_SYSTEM_SCOPE)) + k5 = KeyValuePair.add_or_update(KeyValuePairDB(name="k5", value="v5")) + k6 = KeyValuePair.add_or_update(KeyValuePairDB(name="k6", value="v6")) + k7 = KeyValuePair.add_or_update(KeyValuePairDB(name="k7", value="v7")) + k8 = KeyValuePair.add_or_update( + KeyValuePairDB(name="k8", value="v8", scope=FULL_SYSTEM_SCOPE) + ) - params = {'ip5': '{{trigger.k2}}-static', - 'ip6': '{{st2kv.system.k6}}-static', - 'ip7': '{{st2kv.system.k7}}-static'} - expected_params = {'ip5': 'v2-static', - 'ip6': 'v6-static', - 'ip7': 'v7-static'} + params = { + "ip5": "{{trigger.k2}}-static", + "ip6": "{{st2kv.system.k6}}-static", + "ip7": "{{st2kv.system.k7}}-static", + } + expected_params = {"ip5": "v2-static", "ip6": "v6-static", "ip7": "v7-static"} try: - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) finally: KeyValuePair.delete(k5) KeyValuePair.delete(k6) KeyValuePair.delete(k7) KeyValuePair.delete(k8) - def assertResolvedParamsMatchExpected(self, rule, trigger_instance, params, expected_params): + def assertResolvedParamsMatchExpected( + self, rule, trigger_instance, params, expected_params + ): runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} enforcer = RuleEnforcer(trigger_instance, rule) - context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db) + context, additional_contexts = enforcer.get_action_execution_context( + action_db=action_db + ) - resolved_params = enforcer.get_resolved_parameters(action_db=action_db, + resolved_params = enforcer.get_resolved_parameters( + action_db=action_db, runnertype_db=runner_type_db, params=params, context=context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) self.assertEqual(resolved_params, expected_params) diff --git a/st2reactor/tests/unit/test_filter.py b/st2reactor/tests/unit/test_filter.py index 4df7ef2360..d1e42eaece 100644 --- a/st2reactor/tests/unit/test_filter.py +++ b/st2reactor/tests/unit/test_filter.py @@ -27,57 +27,71 @@ from st2tests import DbTestCase -MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', type='system.test') +MOCK_TRIGGER = TriggerDB( + pack="dummy_pack_1", name="trigger-test.name", type="system.test" +) MOCK_TRIGGER_INSTANCE = TriggerInstanceDB( trigger=MOCK_TRIGGER.get_reference().ref, occurrence_time=date_utils.get_datetime_utc_now(), payload={ - 'p1': 'v1', - 'p2': 'preYYYpost', - 'bool': True, - 'int': 1, - 'float': 0.8, - 'list': ['v1', True, 1], - 'recursive_list': [ + "p1": "v1", + "p2": "preYYYpost", + "bool": True, + "int": 1, + "float": 0.8, + "list": ["v1", True, 1], + "recursive_list": [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Signed off by", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Signed off by", + "to_value": "Stanley", + }, ], - } + }, ) -MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack='wolfpack', name='action-test-1.name') +MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack="wolfpack", name="action-test-1.name") -MOCK_RULE_1 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some1', - trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), - criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction")) +MOCK_RULE_1 = RuleDB( + id=bson.ObjectId(), + pack="wolfpack", + name="some1", + trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), + criteria={}, + action=ActionExecutionSpecDB(ref="somepack.someaction"), +) -MOCK_RULE_2 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some2', - trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), - criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction")) +MOCK_RULE_2 = RuleDB( + id=bson.ObjectId(), + pack="wolfpack", + name="some2", + trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), + criteria={}, + action=ActionExecutionSpecDB(ref="somepack.someaction"), +) -@mock.patch.object(reference, 'get_model_by_resource_ref', - mock.MagicMock(return_value=MOCK_TRIGGER)) +@mock.patch.object( + reference, "get_model_by_resource_ref", mock.MagicMock(return_value=MOCK_TRIGGER) +) class FilterTest(DbTestCase): def test_empty_criteria(self): rule = MOCK_RULE_1 rule.criteria = {} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have failed.') + self.assertTrue(f.filter(), "equals check should have failed.") def test_empty_payload(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}} trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE) trigger_instance.payload = None f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") def test_empty_criteria_and_empty_payload(self): rule = MOCK_RULE_1 @@ -85,234 +99,247 @@ def test_empty_criteria_and_empty_payload(self): trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE) trigger_instance.payload = None f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have failed.') + self.assertTrue(f.filter(), "equals check should have failed.") def test_search_operator_pass_any_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'any', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "any", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Approved' - } - } + "item.to_value": {"type": "equals", "pattern": "Approved"}, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed evaluation') + self.assertTrue(f.filter(), "Failed evaluation") def test_search_operator_fail_any_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'any', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "any", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Denied', - } - } + "item.to_value": { + "type": "equals", + "pattern": "Denied", + }, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'Passed evaluation') + self.assertFalse(f.filter(), "Passed evaluation") def test_search_operator_pass_all_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'all', - 'pattern': { - 'item.field_name': { - 'type': 'startswith', - 'pattern': 'S', + "trigger.recursive_list": { + "type": "search", + "condition": "all", + "pattern": { + "item.field_name": { + "type": "startswith", + "pattern": "S", } - } + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed evaluation') + self.assertTrue(f.filter(), "Failed evaluation") def test_search_operator_fail_all_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'all', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "all", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Denied', - } - } + "item.to_value": { + "type": "equals", + "pattern": "Denied", + }, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'Passed evaluation') + self.assertFalse(f.filter(), "Passed evaluation") def test_matchregex_operator_pass_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v1$'}} + rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v1$"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed to pass evaluation.') + self.assertTrue(f.filter(), "Failed to pass evaluation.") def test_matchregex_operator_fail_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v$'}} + rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v$"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'regex check should have failed.') + self.assertFalse(f.filter(), "regex check should have failed.") def test_equals_operator_pass_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p1}}'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p1}}"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': "{{'%s' % trigger.p1 if trigger.int}}" + "trigger.p1": { + "type": "equals", + "pattern": "{{'%s' % trigger.p1 if trigger.int}}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") # Test our filter works if proper JSON is returned from user pattern rule = MOCK_RULE_1 rule.criteria = { - 'trigger.list': { - 'type': 'equals', - 'pattern': """ + "trigger.list": { + "type": "equals", + "pattern": """ [ {% for item in trigger.list %} {{item}}{% if not loop.last %},{% endif %} {% endfor %} ] - """ + """, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_operator_fail_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p2}}'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p2}}"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") rule = MOCK_RULE_1 rule.criteria = { - 'trigger.list': { - 'type': 'equals', - 'pattern': """ + "trigger.list": { + "type": "equals", + "pattern": """ [ {% for item in trigger.list %} {{item}} {% endfor %} ] - """ + """, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") def test_equals_bool_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': True}} + rule.criteria = {"trigger.bool": {"type": "equals", "pattern": True}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{trigger.bool}}'}} + rule.criteria = { + "trigger.bool": {"type": "equals", "pattern": "{{trigger.bool}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{ trigger.bool }}'}} + rule.criteria = { + "trigger.bool": {"type": "equals", "pattern": "{{ trigger.bool }}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_int_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': 1}} + rule.criteria = {"trigger.int": {"type": "equals", "pattern": 1}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': '{{trigger.int}}'}} + rule.criteria = { + "trigger.int": {"type": "equals", "pattern": "{{trigger.int}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_float_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': 0.8}} + rule.criteria = {"trigger.float": {"type": "equals", "pattern": 0.8}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': '{{trigger.float}}'}} + rule.criteria = { + "trigger.float": {"type": "equals", "pattern": "{{trigger.float}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_exists(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'exists'}} + rule.criteria = {"trigger.float": {"type": "exists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), '"float" key exists in trigger. Should return true.') - rule.criteria = {'trigger.floattt': {'type': 'exists'}} + self.assertTrue( + f.filter(), '"float" key exists in trigger. Should return true.' + ) + rule.criteria = {"trigger.floattt": {"type": "exists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.') + self.assertFalse( + f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.' + ) def test_nexists(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'nexists'}} + rule.criteria = {"trigger.float": {"type": "nexists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), '"float" key exists in trigger. Should return false.') - rule.criteria = {'trigger.floattt': {'type': 'nexists'}} + self.assertFalse( + f.filter(), '"float" key exists in trigger. Should return false.' + ) + rule.criteria = {"trigger.floattt": {"type": "nexists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.') + self.assertTrue( + f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.' + ) def test_gt_lt_falsy_pattern(self): # Make sure that the falsy value (number 0) is handled correctly rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'gt', 'pattern': 0}} + rule.criteria = {"trigger.int": {"type": "gt", "pattern": 0}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'trigger value is gt than 0 but didn\'t match') + self.assertTrue(f.filter(), "trigger value is gt than 0 but didn't match") - rule.criteria = {'trigger.int': {'type': 'lt', 'pattern': 0}} + rule.criteria = {"trigger.int": {"type": "lt", "pattern": 0}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'trigger value is gt than 0 but didn\'t fail') + self.assertFalse(f.filter(), "trigger value is gt than 0 but didn't fail") - @mock.patch('st2common.util.templating.KeyValueLookup') + @mock.patch("st2common.util.templating.KeyValueLookup") def test_criteria_pattern_references_a_datastore_item(self, mock_KeyValueLookup): class MockResultLookup(object): pass @@ -323,22 +350,24 @@ class MockSystemLookup(object): rule = MOCK_RULE_2 # Using a variable in pattern, referencing an inexistent datastore value - rule.criteria = {'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.inexistent_value }}'} + rule.criteria = { + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.inexistent_value }}", + } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) self.assertFalse(f.filter()) # Using a variable in pattern, referencing an existing value which doesn't match mock_result = MockSystemLookup() - mock_result.test_value_1 = 'non matching' + mock_result.test_value_1 = "non matching" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_1 }}' + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_1 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -346,13 +375,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which does match mock_result = MockSystemLookup() - mock_result.test_value_2 = 'v1' + mock_result.test_value_2 = "v1" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_2 }}' + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_2 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -360,13 +389,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which matches partially mock_result = MockSystemLookup() - mock_result.test_value_3 = 'YYY' + mock_result.test_value_3 = "YYY" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p2': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_3 }}' + "trigger.p2": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_3 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -374,13 +403,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which matches partially mock_result = MockSystemLookup() - mock_result.test_value_3 = 'YYY' + mock_result.test_value_3 = "YYY" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p2': { - 'type': 'equals', - 'pattern': 'pre{{ st2kv.system.test_value_3 }}post' + "trigger.p2": { + "type": "equals", + "pattern": "pre{{ st2kv.system.test_value_3 }}post", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) diff --git a/st2reactor/tests/unit/test_garbage_collector.py b/st2reactor/tests/unit/test_garbage_collector.py index 31442e8eb3..93de6b25d0 100644 --- a/st2reactor/tests/unit/test_garbage_collector.py +++ b/st2reactor/tests/unit/test_garbage_collector.py @@ -21,43 +21,48 @@ from oslo_config import cfg import st2tests.config as tests_config + tests_config.parse_args() from st2reactor.garbage_collector import base as garbage_collector class GarbageCollectorServiceTest(unittest.TestCase): - def tearDown(self): # Reset gc_max_idle_sec with a value of 1 to reenable for other tests. - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") super(GarbageCollectorServiceTest, self).tearDown() @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions', - mock.MagicMock(return_value=None)) + "_purge_action_executions", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions_output', - mock.MagicMock(return_value=None)) + "_purge_action_executions_output", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_trigger_instances', - mock.MagicMock(return_value=None)) + "_purge_trigger_instances", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_timeout_inquiries', - mock.MagicMock(return_value=None)) + "_timeout_inquiries", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_orphaned_workflow_executions', - mock.MagicMock(return_value=None)) + "_purge_orphaned_workflow_executions", + mock.MagicMock(return_value=None), + ) def test_orphaned_workflow_executions_gc_enabled(self): # Mock the default value of gc_max_idle_sec with a value >= 1 to enable. The config # gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see # whether to run the routine. - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") # Run the garbage collection. gc = garbage_collector.GarbageCollectorService(sleep_delay=0) @@ -70,29 +75,34 @@ def test_orphaned_workflow_executions_gc_enabled(self): @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions', - mock.MagicMock(return_value=None)) + "_purge_action_executions", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions_output', - mock.MagicMock(return_value=None)) + "_purge_action_executions_output", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_trigger_instances', - mock.MagicMock(return_value=None)) + "_purge_trigger_instances", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_timeout_inquiries', - mock.MagicMock(return_value=None)) + "_timeout_inquiries", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_orphaned_workflow_executions', - mock.MagicMock(return_value=None)) + "_purge_orphaned_workflow_executions", + mock.MagicMock(return_value=None), + ) def test_orphaned_workflow_executions_gc_disabled(self): # Mock the default value of gc_max_idle_sec with a value of 0 to disable. The config # gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see # whether to run the routine. - cfg.CONF.set_override('gc_max_idle_sec', 0, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 0, group="workflow_engine") # Run the garbage collection. gc = garbage_collector.GarbageCollectorService(sleep_delay=0) diff --git a/st2reactor/tests/unit/test_hash_partitioner.py b/st2reactor/tests/unit/test_hash_partitioner.py index 4412c07b97..12e522a10c 100644 --- a/st2reactor/tests/unit/test_hash_partitioner.py +++ b/st2reactor/tests/unit/test_hash_partitioner.py @@ -22,10 +22,8 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -PACK = 'generic' -FIXTURES_1 = { - 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml'] -} +PACK = "generic" +FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]} class HashPartitionerTest(DbTestCase): @@ -38,39 +36,42 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) config.parse_args() def test_full_range_hash_partitioner(self): - partitioner = HashPartitioner('node1', 'MIN..MAX') + partitioner = HashPartitioner("node1", "MIN..MAX") sensors = partitioner.get_sensors() - self.assertEqual(len(sensors), 3, 'Expected all sensors') + self.assertEqual(len(sensors), 3, "Expected all sensors") def test_multi_range_hash_partitioner(self): range_third = int(Range.RANGE_MAX_VALUE / 3) range_two_third = range_third * 2 - hash_ranges = \ - 'MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX'.format( - range_third=range_third, range_two_third=range_two_third) - partitioner = HashPartitioner('node1', hash_ranges) + hash_ranges = "MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX".format( + range_third=range_third, range_two_third=range_two_third + ) + partitioner = HashPartitioner("node1", hash_ranges) sensors = partitioner.get_sensors() - self.assertEqual(len(sensors), 3, 'Expected all sensors') + self.assertEqual(len(sensors), 3, "Expected all sensors") def test_split_range_hash_partitioner(self): range_mid = int(Range.RANGE_MAX_VALUE / 2) - partitioner = HashPartitioner('node1', 'MIN..%s' % range_mid) + partitioner = HashPartitioner("node1", "MIN..%s" % range_mid) sensors1 = partitioner.get_sensors() - partitioner = HashPartitioner('node2', '%s..MAX' % range_mid) + partitioner = HashPartitioner("node2", "%s..MAX" % range_mid) sensors2 = partitioner.get_sensors() - self.assertEqual(len(sensors1) + len(sensors2), 3, 'Expected all sensors') + self.assertEqual(len(sensors1) + len(sensors2), 3, "Expected all sensors") def test_hash_effectiveness(self): range_third = int(Range.RANGE_MAX_VALUE / 3) - partitioner1 = HashPartitioner('node1', 'MIN..%s' % range_third) - partitioner2 = HashPartitioner('node2', '%s..%s' % (range_third, range_third + range_third)) - partitioner3 = HashPartitioner('node2', '%s..MAX' % (range_third + range_third)) + partitioner1 = HashPartitioner("node1", "MIN..%s" % range_third) + partitioner2 = HashPartitioner( + "node2", "%s..%s" % (range_third, range_third + range_third) + ) + partitioner3 = HashPartitioner("node2", "%s..MAX" % (range_third + range_third)) refs_count = 1000 @@ -89,15 +90,21 @@ def test_hash_effectiveness(self): if partitioner3._is_in_hash_range(ref): p3_count += 1 - self.assertEqual(p1_count + p2_count + p3_count, refs_count, - 'Sum should equal all sensors.') + self.assertEqual( + p1_count + p2_count + p3_count, refs_count, "Sum should equal all sensors." + ) # Test effectiveness by checking if the sd is within 20% of mean mean = refs_count / 3 - variance = float((p1_count - mean)**2 + (p1_count - mean)**2 + (p3_count - mean)**2) / 3 + variance = ( + float( + (p1_count - mean) ** 2 + (p1_count - mean) ** 2 + (p3_count - mean) ** 2 + ) + / 3 + ) sd = math.sqrt(variance) - self.assertTrue(sd / mean <= 0.2, 'Some values deviate too much from the mean.') + self.assertTrue(sd / mean <= 0.2, "Some values deviate too much from the mean.") def _generate_refs(self, count=10): random_word_count = int(math.sqrt(count)) + 1 @@ -105,7 +112,7 @@ def _generate_refs(self, count=10): x_index = 0 y_index = 0 while count > 0: - yield '%s.%s' % (words[x_index], words[y_index]) + yield "%s.%s" % (words[x_index], words[y_index]) if y_index < len(words) - 1: y_index += 1 else: diff --git a/st2reactor/tests/unit/test_partitioners.py b/st2reactor/tests/unit/test_partitioners.py index 00e7681cc9..8c4213ec5b 100644 --- a/st2reactor/tests/unit/test_partitioners.py +++ b/st2reactor/tests/unit/test_partitioners.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from oslo_config import cfg -from st2common.constants.sensors import KVSTORE_PARTITION_LOADER, FILE_PARTITION_LOADER, \ - HASH_PARTITION_LOADER +from st2common.constants.sensors import ( + KVSTORE_PARTITION_LOADER, + FILE_PARTITION_LOADER, + HASH_PARTITION_LOADER, +) from st2common.models.db.keyvalue import KeyValuePairDB from st2common.persistence.keyvalue import KeyValuePair from st2reactor.container.partitioner_lookup import get_sensors_partitioner @@ -26,10 +29,8 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -PACK = 'generic' -FIXTURES_1 = { - 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml'] -} +PACK = "generic" +FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]} class PartitionerTest(DbTestCase): @@ -42,76 +43,91 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) config.parse_args() def test_default_partitioner(self): provider = get_sensors_partitioner() sensors = provider.get_sensors() - self.assertEqual(len(sensors), len(FIXTURES_1['sensors']), - 'Failed to provider all sensors') + self.assertEqual( + len(sensors), len(FIXTURES_1["sensors"]), "Failed to provider all sensors" + ) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) def test_kvstore_partitioner(self): - cfg.CONF.set_override(name='partition_provider', - override={'name': KVSTORE_PARTITION_LOADER}, - group='sensorcontainer') - kvp = KeyValuePairDB(**{'name': 'sensornode1.sensor_partition', - 'value': 'generic.Sensor1, generic.Sensor2'}) + cfg.CONF.set_override( + name="partition_provider", + override={"name": KVSTORE_PARTITION_LOADER}, + group="sensorcontainer", + ) + kvp = KeyValuePairDB( + **{ + "name": "sensornode1.sensor_partition", + "value": "generic.Sensor1, generic.Sensor2", + } + ) KeyValuePair.add_or_update(kvp, publish=False, dispatch_trigger=False) provider = get_sensors_partitioner() sensors = provider.get_sensors() - self.assertEqual(len(sensors), len(kvp.value.split(','))) + self.assertEqual(len(sensors), len(kvp.value.split(","))) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertFalse(provider.is_sensor_owner(sensor3)) def test_file_partitioner(self): partition_file = FixturesLoader().get_fixture_file_path_abs( - fixtures_pack=PACK, fixtures_type='sensors', fixture_name='partition_file.yaml') - cfg.CONF.set_override(name='partition_provider', - override={'name': FILE_PARTITION_LOADER, - 'partition_file': partition_file}, - group='sensorcontainer') + fixtures_pack=PACK, + fixtures_type="sensors", + fixture_name="partition_file.yaml", + ) + cfg.CONF.set_override( + name="partition_provider", + override={"name": FILE_PARTITION_LOADER, "partition_file": partition_file}, + group="sensorcontainer", + ) provider = get_sensors_partitioner() sensors = provider.get_sensors() self.assertEqual(len(sensors), 2) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertFalse(provider.is_sensor_owner(sensor3)) def test_hash_partitioner(self): # no specific partitioner testing here for that see test_hash_partitioner.py # This test is to make sure the wiring and some basics work - cfg.CONF.set_override(name='partition_provider', - override={'name': HASH_PARTITION_LOADER, - 'hash_ranges': '%s..%s' % (Range.RANGE_MIN_ENUM, - Range.RANGE_MAX_ENUM)}, - group='sensorcontainer') + cfg.CONF.set_override( + name="partition_provider", + override={ + "name": HASH_PARTITION_LOADER, + "hash_ranges": "%s..%s" % (Range.RANGE_MIN_ENUM, Range.RANGE_MAX_ENUM), + }, + group="sensorcontainer", + ) provider = get_sensors_partitioner() sensors = provider.get_sensors() self.assertEqual(len(sensors), 3) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor2 = self.models['sensors']['sensor2.yaml'] + sensor2 = self.models["sensors"]["sensor2.yaml"] self.assertTrue(provider.is_sensor_owner(sensor2)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertTrue(provider.is_sensor_owner(sensor3)) diff --git a/st2reactor/tests/unit/test_process_container.py b/st2reactor/tests/unit/test_process_container.py index 10ad700b8a..d1bcfdfe64 100644 --- a/st2reactor/tests/unit/test_process_container.py +++ b/st2reactor/tests/unit/test_process_container.py @@ -17,7 +17,7 @@ import os import time -from mock import (MagicMock, Mock, patch) +from mock import MagicMock, Mock, patch import unittest2 from st2reactor.container.process_container import ProcessSensorContainer @@ -26,14 +26,18 @@ from st2common.persistence.pack import Pack import st2tests.config as tests_config + tests_config.parse_args() -MOCK_PACK_DB = PackDB(ref='wolfpack', name='wolf pack', description='', - path='/opt/stackstorm/packs/wolfpack/') +MOCK_PACK_DB = PackDB( + ref="wolfpack", + name="wolf pack", + description="", + path="/opt/stackstorm/packs/wolfpack/", +) class ProcessContainerTests(unittest2.TestCase): - def test_no_sensors_dont_quit(self): process_container = ProcessSensorContainer(None, poll_interval=0.1) process_container_thread = concurrency.spawn(process_container.run) @@ -43,113 +47,133 @@ def test_no_sensors_dont_quit(self): process_container.shutdown() process_container_thread.kill() - @patch.object(ProcessSensorContainer, '_get_sensor_id', - MagicMock(return_value='wolfpack.StupidSensor')) - @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn', - MagicMock(return_value=None)) - @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB)) - @patch.object(os.path, 'isdir', MagicMock(return_value=True)) - @patch('subprocess.Popen') - @patch('st2reactor.container.process_container.create_token') - def test_common_lib_path_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen): + @patch.object( + ProcessSensorContainer, + "_get_sensor_id", + MagicMock(return_value="wolfpack.StupidSensor"), + ) + @patch.object( + ProcessSensorContainer, + "_dispatch_trigger_for_sensor_spawn", + MagicMock(return_value=None), + ) + @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB)) + @patch.object(os.path, "isdir", MagicMock(return_value=True)) + @patch("subprocess.Popen") + @patch("st2reactor.container.process_container.create_token") + def test_common_lib_path_in_pythonpath_env_var( + self, mock_create_token, mock_subproc_popen + ): process_mock = Mock() - attrs = {'communicate.return_value': ('output', 'error')} + attrs = {"communicate.return_value": ("output", "error")} process_mock.configure_mock(**attrs) mock_subproc_popen.return_value = process_mock mock_create_token = Mock() - mock_create_token.return_value = 'WHOLETTHEDOGSOUT' + mock_create_token.return_value = "WHOLETTHEDOGSOUT" mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) sensor = { - 'class_name': 'wolfpack.StupidSensor', - 'ref': 'wolfpack.StupidSensor', - 'id': '567890', - 'trigger_types': ['some_trigga'], - 'pack': 'wolfpack', - 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py', - 'poll_interval': 5 + "class_name": "wolfpack.StupidSensor", + "ref": "wolfpack.StupidSensor", + "id": "567890", + "trigger_types": ["some_trigga"], + "pack": "wolfpack", + "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py", + "poll_interval": 5, } process_container._enable_common_pack_libs = True - process_container._sensors = {'pack.StupidSensor': sensor} + process_container._sensors = {"pack.StupidSensor": sensor} process_container._spawn_sensor_process(sensor) _, call_kwargs = mock_subproc_popen.call_args - actual_env = call_kwargs['env'] - self.assertIn('PYTHONPATH', actual_env) - pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib' - self.assertIn(pack_common_lib_path, actual_env['PYTHONPATH']) - - @patch.object(ProcessSensorContainer, '_get_sensor_id', - MagicMock(return_value='wolfpack.StupidSensor')) - @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn', - MagicMock(return_value=None)) - @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB)) - @patch.object(os.path, 'isdir', MagicMock(return_value=True)) - @patch('subprocess.Popen') - @patch('st2reactor.container.process_container.create_token') - def test_common_lib_path_not_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen): + actual_env = call_kwargs["env"] + self.assertIn("PYTHONPATH", actual_env) + pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib" + self.assertIn(pack_common_lib_path, actual_env["PYTHONPATH"]) + + @patch.object( + ProcessSensorContainer, + "_get_sensor_id", + MagicMock(return_value="wolfpack.StupidSensor"), + ) + @patch.object( + ProcessSensorContainer, + "_dispatch_trigger_for_sensor_spawn", + MagicMock(return_value=None), + ) + @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB)) + @patch.object(os.path, "isdir", MagicMock(return_value=True)) + @patch("subprocess.Popen") + @patch("st2reactor.container.process_container.create_token") + def test_common_lib_path_not_in_pythonpath_env_var( + self, mock_create_token, mock_subproc_popen + ): process_mock = Mock() - attrs = {'communicate.return_value': ('output', 'error')} + attrs = {"communicate.return_value": ("output", "error")} process_mock.configure_mock(**attrs) mock_subproc_popen.return_value = process_mock mock_create_token = Mock() - mock_create_token.return_value = 'WHOLETTHEDOGSOUT' + mock_create_token.return_value = "WHOLETTHEDOGSOUT" mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) sensor = { - 'class_name': 'wolfpack.StupidSensor', - 'ref': 'wolfpack.StupidSensor', - 'id': '567890', - 'trigger_types': ['some_trigga'], - 'pack': 'wolfpack', - 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py', - 'poll_interval': 5 + "class_name": "wolfpack.StupidSensor", + "ref": "wolfpack.StupidSensor", + "id": "567890", + "trigger_types": ["some_trigga"], + "pack": "wolfpack", + "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py", + "poll_interval": 5, } process_container._enable_common_pack_libs = False - process_container._sensors = {'pack.StupidSensor': sensor} + process_container._sensors = {"pack.StupidSensor": sensor} process_container._spawn_sensor_process(sensor) _, call_kwargs = mock_subproc_popen.call_args - actual_env = call_kwargs['env'] - self.assertIn('PYTHONPATH', actual_env) - pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib' - self.assertNotIn(pack_common_lib_path, actual_env['PYTHONPATH']) + actual_env = call_kwargs["env"] + self.assertIn("PYTHONPATH", actual_env) + pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib" + self.assertNotIn(pack_common_lib_path, actual_env["PYTHONPATH"]) - @patch.object(time, 'time', MagicMock(return_value=1439441533)) + @patch.object(time, "time", MagicMock(return_value=1439441533)) def test_dispatch_triggers_on_spawn_exit(self): mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) - sensor = { - 'class_name': 'pack.StupidSensor' - } + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) + sensor = {"class_name": "pack.StupidSensor"} process = Mock() - process_attrs = {'pid': 1234} + process_attrs = {"pid": 1234} process.configure_mock(**process_attrs) - cmd = 'sensor_wrapper.py --class-name pack.StupidSensor' + cmd = "sensor_wrapper.py --class-name pack.StupidSensor" process_container._dispatch_trigger_for_sensor_spawn(sensor, process, cmd) mock_dispatcher.dispatch.assert_called_with( - 'core.st2.sensor.process_spawn', + "core.st2.sensor.process_spawn", payload={ - 'timestamp': 1439441533, - 'cmd': 'sensor_wrapper.py --class-name pack.StupidSensor', - 'pid': 1234, - 'id': 'pack.StupidSensor'}) + "timestamp": 1439441533, + "cmd": "sensor_wrapper.py --class-name pack.StupidSensor", + "pid": 1234, + "id": "pack.StupidSensor", + }, + ) process_container._dispatch_trigger_for_sensor_exit(sensor, 1) mock_dispatcher.dispatch.assert_called_with( - 'core.st2.sensor.process_exit', + "core.st2.sensor.process_exit", payload={ - 'id': 'pack.StupidSensor', - 'timestamp': 1439441533, - 'exit_code': 1 - }) + "id": "pack.StupidSensor", + "timestamp": 1439441533, + "exit_code": 1, + }, + ) diff --git a/st2reactor/tests/unit/test_rule_engine.py b/st2reactor/tests/unit/test_rule_engine.py index 39b1627268..2f70a2a9d7 100644 --- a/st2reactor/tests/unit/test_rule_engine.py +++ b/st2reactor/tests/unit/test_rule_engine.py @@ -18,9 +18,9 @@ from mongoengine import NotUniqueError from st2common.models.api.rule import RuleAPI -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB from st2common.persistence.rule import Rule -from st2common.persistence.trigger import (TriggerType, Trigger) +from st2common.persistence.trigger import TriggerType, Trigger from st2common.util import date as date_utils import st2reactor.container.utils as container_utils from st2reactor.rules.enforcer import RuleEnforcer @@ -29,30 +29,29 @@ class RuleEngineTest(DbTestCase): - @classmethod def setUpClass(cls): super(RuleEngineTest, cls).setUpClass() RuleEngineTest._setup_test_models() - @mock.patch.object(RuleEnforcer, 'enforce', mock.MagicMock(return_value=True)) + @mock.patch.object(RuleEnforcer, "enforce", mock.MagicMock(return_value=True)) def test_handle_trigger_instances(self): trigger_instance_1 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger_instance_2 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2", "k3": "v3"}, + date_utils.get_datetime_utc_now(), ) trigger_instance_3 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger2', - {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger2", + {"k1": "t1_p_v", "k2": "v2", "k3": "v3"}, + date_utils.get_datetime_utc_now(), ) instances = [trigger_instance_1, trigger_instance_2, trigger_instance_3] rules_engine = RulesEngine() @@ -60,32 +59,36 @@ def test_handle_trigger_instances(self): rules_engine.handle_trigger_instance(instance) def test_create_trigger_instance_for_trigger_with_params(self): - trigger = {'type': 'dummy_pack_1.st2.test.trigger4', 'parameters': {'url': 'sample'}} - payload = {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'} + trigger = { + "type": "dummy_pack_1.st2.test.trigger4", + "parameters": {"url": "sample"}, + } + payload = {"k1": "t1_p_v", "k2": "v2", "k3": "v3"} occurrence_time = date_utils.get_datetime_utc_now() - trigger_instance = container_utils.create_trigger_instance(trigger=trigger, - payload=payload, - occurrence_time=occurrence_time) + trigger_instance = container_utils.create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertTrue(trigger_instance) - self.assertEqual(trigger_instance.trigger, trigger['type']) + self.assertEqual(trigger_instance.trigger, trigger["type"]) self.assertEqual(trigger_instance.payload, payload) def test_get_matching_rules_filters_disabled_rules(self): trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) rules_engine = RulesEngine() matching_rules = rules_engine.get_matching_rules_for_trigger(trigger_instance) - expected_rules = ['st2.test.rule2'] + expected_rules = ["st2.test.rule2"] for rule in matching_rules: self.assertIn(rule.name, expected_rules) def test_handle_trigger_instance_no_rules(self): trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger3', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger3", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) rules_engine = RulesEngine() rules_engine.handle_trigger_instance(trigger_instance) # should not throw. @@ -96,14 +99,26 @@ def _setup_test_models(cls): RuleEngineTest._setup_sample_rules() @classmethod - def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2', - 'st2.test.trigger3', 'st2.test.trigger4']): + def _setup_sample_triggers( + self, + names=[ + "st2.test.trigger1", + "st2.test.trigger2", + "st2.test.trigger3", + "st2.test.trigger4", + ], + ): trigger_dbs = [] for name in names: trigtype = None try: - trigtype = TriggerTypeDB(pack='dummy_pack_1', name=name, description='', - payload_schema={}, parameters_schema={}) + trigtype = TriggerTypeDB( + pack="dummy_pack_1", + name=name, + description="", + payload_schema={}, + parameters_schema={}, + ) try: trigtype = TriggerType.get_by_name(name) except: @@ -111,11 +126,15 @@ def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2' except NotUniqueError: pass - created = TriggerDB(pack='dummy_pack_1', name=name, description='', - type=trigtype.get_reference().ref) + created = TriggerDB( + pack="dummy_pack_1", + name=name, + description="", + type=trigtype.get_reference().ref, + ) - if name in ['st2.test.trigger4']: - created.parameters = {'url': 'sample'} + if name in ["st2.test.trigger4"]: + created.parameters = {"url": "sample"} else: created.parameters = {} @@ -130,55 +149,40 @@ def _setup_sample_rules(self): # Rules for st2.test.trigger1 RULE_1 = { - 'enabled': True, - 'name': 'st2.test.rule1', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'k1': { # Missing prefix 'trigger'. This rule won't match. - 'pattern': 't1_p_v', - 'type': 'equals' + "enabled": True, + "name": "st2.test.rule1", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": { + "k1": { # Missing prefix 'trigger'. This rule won't match. + "pattern": "t1_p_v", + "type": "equals", } }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_1) rule_db = RuleAPI.to_model(rule_api) rule_db = Rule.add_or_update(rule_db) rules.append(rule_db) - RULE_2 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule2', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + RULE_2 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule2", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_2) rule_db = RuleAPI.to_model(rule_api) @@ -186,27 +190,17 @@ def _setup_sample_rules(self): rules.append(rule_db) RULE_3 = { - 'enabled': False, # Disabled rule shouldn't match. - 'name': 'st2.test.rule3', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "enabled": False, # Disabled rule shouldn't match. + "name": "st2.test.rule3", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_3) rule_db = RuleAPI.to_model(rule_api) @@ -215,27 +209,17 @@ def _setup_sample_rules(self): # Rules for st2.test.trigger2 RULE_4 = { - 'enabled': True, - 'name': 'st2.test.rule4', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger2' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "enabled": True, + "name": "st2.test.rule4", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger2"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_4) rule_db = RuleAPI.to_model(rule_api) diff --git a/st2reactor/tests/unit/test_rule_matcher.py b/st2reactor/tests/unit/test_rule_matcher.py index a5680fa094..46cc084662 100644 --- a/st2reactor/tests/unit/test_rule_matcher.py +++ b/st2reactor/tests/unit/test_rule_matcher.py @@ -19,9 +19,9 @@ import mock from st2common.models.api.rule import RuleAPI -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB from st2common.persistence.rule import Rule -from st2common.persistence.trigger import (TriggerType, Trigger) +from st2common.persistence.trigger import TriggerType, Trigger from st2common.services.triggers import get_trigger_db_by_ref from st2common.util import date as date_utils import st2reactor.container.utils as container_utils @@ -33,106 +33,68 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RuleMatcherTestCase', - 'BackstopRuleMatcherTestCase' -] +__all__ = ["RuleMatcherTestCase", "BackstopRuleMatcherTestCase"] # Mock rules RULE_1 = { - 'enabled': True, - 'name': 'st2.test.rule1', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'k1': { # Missing prefix 'trigger'. This rule won't match. - 'pattern': 't1_p_v', - 'type': 'equals' + "enabled": True, + "name": "st2.test.rule1", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": { + "k1": { # Missing prefix 'trigger'. This rule won't match. + "pattern": "t1_p_v", + "type": "equals", } }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } -RULE_2 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule2', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } +RULE_2 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule2", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } RULE_3 = { - 'enabled': False, # Disabled rule shouldn't match. - 'name': 'st2.test.rule3', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' + "enabled": False, # Disabled rule shouldn't match. + "name": "st2.test.rule3", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } -RULE_4 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule4', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger4' +RULE_4 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule4", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger4"}, + "criteria": {"trigger.k1": {"pattern": "t2_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't2_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } @@ -140,15 +102,15 @@ class RuleMatcherTestCase(CleanDbTestCase): rules = [] def test_get_matching_rules(self): - self._setup_sample_trigger('st2.test.trigger1') + self._setup_sample_trigger("st2.test.trigger1") rule_db_1 = self._setup_sample_rule(RULE_1) rule_db_2 = self._setup_sample_rule(RULE_2) rule_db_3 = self._setup_sample_rule(RULE_3) rules = [rule_db_1, rule_db_2, rule_db_3] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -159,17 +121,22 @@ def test_get_matching_rules(self): def test_trigger_instance_payload_with_special_values(self): # Test a rule where TriggerInstance payload contains a dot (".") and $ - self._setup_sample_trigger('st2.test.trigger1') - self._setup_sample_trigger('st2.test.trigger2') + self._setup_sample_trigger("st2.test.trigger1") + self._setup_sample_trigger("st2.test.trigger2") rule_db_1 = self._setup_sample_rule(RULE_1) rule_db_2 = self._setup_sample_rule(RULE_2) rule_db_3 = self._setup_sample_rule(RULE_3) rules = [rule_db_1, rule_db_2, rule_db_3] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger2', - {'k1': 't1_p_v', 'k2.k2': 'v2', 'k3.more.nested.deep': 'some.value', - 'k4.even.more.nested$': 'foo', 'yep$aaa': 'b'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger2", + { + "k1": "t1_p_v", + "k2.k2": "v2", + "k3.more.nested.deep": "some.value", + "k4.even.more.nested$": "foo", + "yep$aaa": "b", + }, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -178,20 +145,22 @@ def test_trigger_instance_payload_with_special_values(self): self.assertIsNotNone(matching_rules) self.assertEqual(len(matching_rules), 1) - @mock.patch('st2reactor.rules.matcher.RuleFilter._render_criteria_pattern', - mock.Mock(side_effect=Exception('exception in _render_criteria_pattern'))) + @mock.patch( + "st2reactor.rules.matcher.RuleFilter._render_criteria_pattern", + mock.Mock(side_effect=Exception("exception in _render_criteria_pattern")), + ) def test_rule_enforcement_is_created_on_exception_1(self): # 1. Exception in _render_criteria_pattern rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -203,29 +172,35 @@ def test_rule_enforcement_is_created_on_exception_1(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": Failed to render pattern value "t2_p_v" for key ' - '"trigger.k1": exception in _render_criteria_pattern' % - (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": Failed to render pattern value "t2_p_v" for key ' + '"trigger.k1": exception in _render_criteria_pattern' + % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) - @mock.patch('st2reactor.rules.filter.PayloadLookup.get_value', - mock.Mock(side_effect=Exception('exception in get_value'))) + @mock.patch( + "st2reactor.rules.filter.PayloadLookup.get_value", + mock.Mock(side_effect=Exception("exception in get_value")), + ) def test_rule_enforcement_is_created_on_exception_2(self): # 1. Exception in payload_lookup.get_value rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -237,28 +212,34 @@ def test_rule_enforcement_is_created_on_exception_2(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": Failed transforming criteria key trigger.k1: ' - 'exception in get_value' % (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": Failed transforming criteria key trigger.k1: ' + "exception in get_value" % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) - @mock.patch('st2common.operators.get_operator', - mock.Mock(return_value=mock.Mock(side_effect=Exception('exception in equals')))) + @mock.patch( + "st2common.operators.get_operator", + mock.Mock(return_value=mock.Mock(side_effect=Exception("exception in equals"))), + ) def test_rule_enforcement_is_created_on_exception_3(self): # 1. Exception in payload_lookup.get_value rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -270,22 +251,31 @@ def test_rule_enforcement_is_created_on_exception_3(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": There might be a problem with the criteria in rule ' - 'yoyohoneysingh.st2.test.rule4: exception in equals' % - (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": There might be a problem with the criteria in rule ' + "yoyohoneysingh.st2.test.rule4: exception in equals" + % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) def _setup_sample_trigger(self, name): - trigtype = TriggerTypeDB(name=name, pack='dummy_pack_1', payload_schema={}, - parameters_schema={}) + trigtype = TriggerTypeDB( + name=name, pack="dummy_pack_1", payload_schema={}, parameters_schema={} + ) TriggerType.add_or_update(trigtype) - created = TriggerDB(name=name, pack='dummy_pack_1', type=trigtype.get_reference().ref, - parameters={}) + created = TriggerDB( + name=name, + pack="dummy_pack_1", + type=trigtype.get_reference().ref, + parameters={}, + ) Trigger.add_or_update(created) def _setup_sample_rule(self, rule): @@ -295,14 +285,12 @@ def _setup_sample_rule(self, rule): return rule_db -PACK = 'backstop' +PACK = "backstop" FIXTURES_TRIGGERS = { - 'triggertypes': ['triggertype1.yaml'], - 'triggers': ['trigger1.yaml'] -} -FIXTURES_RULES = { - 'rules': ['backstop.yaml', 'success.yaml', 'fail.yaml'] + "triggertypes": ["triggertype1.yaml"], + "triggers": ["trigger1.yaml"], } +FIXTURES_RULES = {"rules": ["backstop.yaml", "success.yaml", "fail.yaml"]} class BackstopRuleMatcherTestCase(DbTestCase): @@ -315,33 +303,41 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = fixturesloader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS) - cls.models.update(fixturesloader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES)) + fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS + ) + cls.models.update( + fixturesloader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES + ) + ) def test_backstop_ignore(self): trigger_instance = container_utils.create_trigger_instance( - self.models['triggers']['trigger1.yaml'].ref, - {'k1': 'v1'}, - date_utils.get_datetime_utc_now() + self.models["triggers"]["trigger1.yaml"].ref, + {"k1": "v1"}, + date_utils.get_datetime_utc_now(), ) - trigger = self.models['triggers']['trigger1.yaml'] - rules = [rule for rule in six.itervalues(self.models['rules'])] + trigger = self.models["triggers"]["trigger1.yaml"] + rules = [rule for rule in six.itervalues(self.models["rules"])] rules_matcher = RulesMatcher(trigger_instance, trigger, rules) matching_rules = rules_matcher.get_matching_rules() self.assertEqual(len(matching_rules), 1) - self.assertEqual(matching_rules[0].id, self.models['rules']['success.yaml'].id) + self.assertEqual(matching_rules[0].id, self.models["rules"]["success.yaml"].id) def test_backstop_apply(self): trigger_instance = container_utils.create_trigger_instance( - self.models['triggers']['trigger1.yaml'].ref, - {'k1': 'v1'}, - date_utils.get_datetime_utc_now() + self.models["triggers"]["trigger1.yaml"].ref, + {"k1": "v1"}, + date_utils.get_datetime_utc_now(), ) - trigger = self.models['triggers']['trigger1.yaml'] - success_rule = self.models['rules']['success.yaml'] - rules = [rule for rule in six.itervalues(self.models['rules']) if rule != success_rule] + trigger = self.models["triggers"]["trigger1.yaml"] + success_rule = self.models["rules"]["success.yaml"] + rules = [ + rule + for rule in six.itervalues(self.models["rules"]) + if rule != success_rule + ] rules_matcher = RulesMatcher(trigger_instance, trigger, rules) matching_rules = rules_matcher.get_matching_rules() self.assertEqual(len(matching_rules), 1) - self.assertEqual(matching_rules[0].id, self.models['rules']['backstop.yaml'].id) + self.assertEqual(matching_rules[0].id, self.models["rules"]["backstop.yaml"].id) diff --git a/st2reactor/tests/unit/test_sensor_and_rule_registration.py b/st2reactor/tests/unit/test_sensor_and_rule_registration.py index 3f54e97c73..50075690e9 100644 --- a/st2reactor/tests/unit/test_sensor_and_rule_registration.py +++ b/st2reactor/tests/unit/test_sensor_and_rule_registration.py @@ -27,22 +27,20 @@ from st2common.bootstrap.sensorsregistrar import SensorsRegistrar from st2common.bootstrap.rulesregistrar import RulesRegistrar -__all__ = [ - 'SensorRegistrationTestCase', - 'RuleRegistrationTestCase' -] +__all__ = ["SensorRegistrationTestCase", "RuleRegistrationTestCase"] CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures/packs')) +PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures/packs")) # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_sensor'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_sensor")), +) class SensorRegistrationTestCase(DbTestCase): - - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def test_register_sensors(self): # Verify DB is empty at the beginning self.assertEqual(len(SensorType.get_all()), 0) @@ -61,29 +59,33 @@ def test_register_sensors(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 10) self.assertTrue(sensor_dbs[0].enabled) - self.assertEqual(sensor_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml') + self.assertEqual(sensor_dbs[0].metadata_file, "sensors/test_sensor_1.yaml") - self.assertEqual(sensor_dbs[1].name, 'TestSensorDisabled') + self.assertEqual(sensor_dbs[1].name, "TestSensorDisabled") self.assertEqual(sensor_dbs[1].poll_interval, 10) self.assertFalse(sensor_dbs[1].enabled) - self.assertEqual(sensor_dbs[1].metadata_file, 'sensors/test_sensor_2.yaml') + self.assertEqual(sensor_dbs[1].metadata_file, "sensors/test_sensor_2.yaml") - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") self.assertEqual(len(trigger_type_dbs[0].tags), 0) - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") self.assertEqual(len(trigger_type_dbs[1].tags), 2) - self.assertEqual(trigger_type_dbs[1].tags[0].name, 'tag1name') - self.assertEqual(trigger_type_dbs[1].tags[0].value, 'tag1 value') + self.assertEqual(trigger_type_dbs[1].tags[0].name, "tag1name") + self.assertEqual(trigger_type_dbs[1].tags[0].value, "tag1 value") # Triggered which are registered via sensors have metadata_file pointing to the sensor # definition file - self.assertEqual(trigger_type_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml') - self.assertEqual(trigger_type_dbs[1].metadata_file, 'sensors/test_sensor_1.yaml') + self.assertEqual( + trigger_type_dbs[0].metadata_file, "sensors/test_sensor_1.yaml" + ) + self.assertEqual( + trigger_type_dbs[1].metadata_file, "sensors/test_sensor_1.yaml" + ) # Verify second call to registration doesn't create a duplicate objects registrar.register_from_packs(base_dirs=[PACKS_DIR]) @@ -96,13 +98,13 @@ def test_register_sensors(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 10) - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") # Verify sensor and trigger data is updated on registration original_load = registrar._meta_loader.load @@ -110,9 +112,10 @@ def test_register_sensors(self): def mock_load(*args, **kwargs): # Update poll_interval and trigger_type_2 description data = original_load(*args, **kwargs) - data['poll_interval'] = 50 - data['trigger_types'][1]['description'] = 'test 2' + data["poll_interval"] = 50 + data["trigger_types"][1]["description"] = "test 2" return data + registrar._meta_loader.load = mock_load registrar.register_from_packs(base_dirs=[PACKS_DIR]) @@ -125,20 +128,22 @@ def mock_load(*args, **kwargs): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 50) - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].description, 'test 2') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].description, "test 2") # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_rules'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_rules")), +) class RuleRegistrationTestCase(DbTestCase): def test_register_rules(self): # Verify DB is empty at the beginning @@ -154,8 +159,8 @@ def test_register_rules(self): self.assertEqual(len(rule_dbs), 2) self.assertEqual(len(trigger_dbs), 1) - self.assertEqual(rule_dbs[0].name, 'sample.with_the_same_timer') - self.assertEqual(rule_dbs[1].name, 'sample.with_timer') + self.assertEqual(rule_dbs[0].name, "sample.with_the_same_timer") + self.assertEqual(rule_dbs[1].name, "sample.with_timer") self.assertIsNotNone(trigger_dbs[0].name) # Verify second register call updates existing models diff --git a/st2reactor/tests/unit/test_sensor_service.py b/st2reactor/tests/unit/test_sensor_service.py index 2064c25ee3..9d1e245e10 100644 --- a/st2reactor/tests/unit/test_sensor_service.py +++ b/st2reactor/tests/unit/test_sensor_service.py @@ -23,22 +23,20 @@ from st2common.constants.keyvalue import SYSTEM_SCOPE from st2common.constants.keyvalue import USER_SCOPE -__all__ = [ - 'SensorServiceTestCase' -] +__all__ = ["SensorServiceTestCase"] # This trigger has schema that uses all property types TEST_SCHEMA = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'age': {'type': 'integer'}, - 'name': {'type': 'string', 'required': True}, - 'address': {'type': 'string', 'default': '-'}, - 'career': {'type': 'array'}, - 'married': {'type': 'boolean'}, - 'awards': {'type': 'object'}, - 'income': {'anyOf': [{'type': 'integer'}, {'type': 'string'}]}, + "type": "object", + "additionalProperties": False, + "properties": { + "age": {"type": "integer"}, + "name": {"type": "string", "required": True}, + "address": {"type": "string", "default": "-"}, + "career": {"type": "array"}, + "married": {"type": "boolean"}, + "awards": {"type": "object"}, + "income": {"anyOf": [{"type": "integer"}, {"type": "string"}]}, }, } @@ -60,8 +58,9 @@ def side_effect(trigger, payload, trace_context): self.sensor_service = SensorService(mock.MagicMock()) self.sensor_service._trigger_dispatcher_service._dispatcher = mock.Mock() - self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = \ + self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = ( mock.MagicMock(side_effect=side_effect) + ) self._dispatched_count = 0 # Previously, cfg.CONF.system.validate_trigger_payload was set to False explicitly @@ -73,55 +72,65 @@ def tearDown(self): # Replace original configured value for payload validation cfg.CONF.system.validate_trigger_payload = self.validate_trigger_payload - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_valid_payload_validation_enabled(self): cfg.CONF.system.validate_trigger_payload = True # define a valid payload payload = { - 'name': 'John Doe', - 'age': 25, - 'career': ['foo, Inc.', 'bar, Inc.'], - 'married': True, - 'awards': {'2016': ['hoge prize', 'fuga prize']}, - 'income': 50000 + "name": "John Doe", + "age": 25, + "career": ["foo, Inc.", "bar, Inc."], + "married": True, + "awards": {"2016": ["hoge prize", "fuga prize"]}, + "income": 50000, } # dispatching a trigger - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # This assumed that the target tirgger dispatched self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) - @mock.patch('st2common.services.triggers.get_trigger_db_by_ref', - mock.MagicMock(return_value=TriggerDBMock(type='trigger-type-ref'))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_db_by_ref", + mock.MagicMock(return_value=TriggerDBMock(type="trigger-type-ref")), + ) def test_dispatch_success_with_validation_enabled_trigger_reference(self): # Test a scenario where a Trigger ref and not TriggerType ref is provided cfg.CONF.system.validate_trigger_payload = True # define a valid payload payload = { - 'name': 'John Doe', - 'age': 25, - 'career': ['foo, Inc.', 'bar, Inc.'], - 'married': True, - 'awards': {'2016': ['hoge prize', 'fuga prize']}, - 'income': 50000 + "name": "John Doe", + "age": 25, + "career": ["foo, Inc.", "bar, Inc."], + "married": True, + "awards": {"2016": ["hoge prize", "fuga prize"]}, + "income": 50000, } self.assertEqual(self._dispatched_count, 0) # dispatching a trigger - self.sensor_service.dispatch('pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b', payload) + self.sensor_service.dispatch( + "pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b", payload + ) # This assumed that the target tirgger dispatched self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_validation_disabled_and_invalid_payload(self): """ Tests that an invalid payload still results in dispatch success with default config @@ -143,29 +152,31 @@ def test_dispatch_success_with_validation_disabled_and_invalid_payload(self): # define a invalid payload (the type of 'age' is incorrect) payload = { - 'name': 'John Doe', - 'age': '25', + "name": "John Doe", + "age": "25", } - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # The default config is to disable validation. So, we want to make sure # the dispatch actually went through. self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_incorrect_type(self): # define a invalid payload (the type of 'age' is incorrect) payload = { - 'name': 'John Doe', - 'age': '25', + "name": "John Doe", + "age": "25", } # set config to stop dispatching when the payload comply with target trigger_type cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # This assumed that the target trigger isn't dispatched self.assertEqual(self._dispatched_count, 0) @@ -173,120 +184,130 @@ def test_dispatch_failure_caused_by_incorrect_type(self): # reset config to permit force dispatching cfg.CONF.system.validate_trigger_payload = False - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_lack_of_required_parameter(self): # define a invalid payload (lack of required property) payload = { - 'age': 25, + "age": 25, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 0) # reset config to permit force dispatching cfg.CONF.system.validate_trigger_payload = False - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_extra_parameter(self): # define a invalid payload ('hobby' is extra) payload = { - 'name': 'John Doe', - 'hobby': 'programming', + "name": "John Doe", + "hobby": "programming", } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 0) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_multiple_type_value(self): payload = { - 'name': 'John Doe', - 'income': 1234, + "name": "John Doe", + "income": 1234, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # reset payload which can have different type - payload['income'] = 'secret' + payload["income"] = "secret" - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 2) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_null(self): payload = { - 'name': 'John Doe', - 'age': None, + "name": "John Doe", + "age": None, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock())) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock()), + ) def test_dispatch_success_without_payload_schema(self): # the case trigger has no property - self.sensor_service.dispatch('trigger-name', {}) + self.sensor_service.dispatch("trigger-name", {}) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=None)) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=None), + ) def test_dispatch_trigger_type_not_in_db_should_not_dispatch(self): cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('not-in-database-ref', {}) + self.sensor_service.dispatch("not-in-database-ref", {}) self.assertEqual(self._dispatched_count, 0) def test_datastore_methods(self): self.sensor_service._datastore_service = mock.Mock() # Verify methods take encrypt, decrypt and scope arguments - self.sensor_service.get_value(name='foo1', scope=SYSTEM_SCOPE, decrypt=True) + self.sensor_service.get_value(name="foo1", scope=SYSTEM_SCOPE, decrypt=True) call_kwargs = self.sensor_service.datastore_service.get_value.call_args[1] expected_kwargs = { - 'name': 'foo1', - 'local': True, - 'scope': SYSTEM_SCOPE, - 'decrypt': True + "name": "foo1", + "local": True, + "scope": SYSTEM_SCOPE, + "decrypt": True, } self.assertEqual(call_kwargs, expected_kwargs) - self.sensor_service.set_value(name='foo2', value='bar', scope=USER_SCOPE, encrypt=True) + self.sensor_service.set_value( + name="foo2", value="bar", scope=USER_SCOPE, encrypt=True + ) call_kwargs = self.sensor_service.datastore_service.set_value.call_args[1] expected_kwargs = { - 'name': 'foo2', - 'value': 'bar', - 'ttl': None, - 'local': True, - 'scope': USER_SCOPE, - 'encrypt': True + "name": "foo2", + "value": "bar", + "ttl": None, + "local": True, + "scope": USER_SCOPE, + "encrypt": True, } self.assertEqual(call_kwargs, expected_kwargs) - self.sensor_service.delete_value(name='foo3', scope=USER_SCOPE) + self.sensor_service.delete_value(name="foo3", scope=USER_SCOPE) call_kwargs = self.sensor_service.datastore_service.delete_value.call_args[1] - expected_kwargs = { - 'name': 'foo3', - 'local': True, - 'scope': USER_SCOPE - } + expected_kwargs = {"name": "foo3", "local": True, "scope": USER_SCOPE} self.assertEqual(call_kwargs, expected_kwargs) diff --git a/st2reactor/tests/unit/test_sensor_wrapper.py b/st2reactor/tests/unit/test_sensor_wrapper.py index 735e0e545b..b2d637812d 100644 --- a/st2reactor/tests/unit/test_sensor_wrapper.py +++ b/st2reactor/tests/unit/test_sensor_wrapper.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -33,11 +34,9 @@ from st2reactor.sensor.base import Sensor, PollingSensor CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) -__all__ = [ - 'SensorWrapperTestCase' -] +__all__ = ["SensorWrapperTestCase"] class SensorWrapperTestCase(unittest2.TestCase): @@ -47,27 +46,33 @@ def setUpClass(cls): tests_config.parse_args() def test_sensor_instance_has_sensor_service(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) - self.assertIsNotNone(getattr(wrapper._sensor_instance, 'sensor_service', None)) - self.assertIsNotNone(getattr(wrapper._sensor_instance, 'config', None)) + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) + self.assertIsNotNone(getattr(wrapper._sensor_instance, "sensor_service", None)) + self.assertIsNotNone(getattr(wrapper._sensor_instance, "config", None)) def test_trigger_cud_event_handlers(self): - trigger_id = '57861fcb0640fd1524e577c0' - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + trigger_id = "57861fcb0640fd1524e577c0" + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) self.assertEqual(wrapper._trigger_names, {}) @@ -78,7 +83,9 @@ def test_trigger_cud_event_handlers(self): # Call create handler with a trigger which refers to this sensor self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_create_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {trigger_id: trigger}) self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 1) @@ -86,7 +93,9 @@ def test_trigger_cud_event_handlers(self): # Validate that update handler updates the trigger_names self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_update_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {trigger_id: trigger}) self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 1) @@ -94,70 +103,97 @@ def test_trigger_cud_event_handlers(self): # Validate that delete handler deletes the trigger from trigger_names self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_delete_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {}) self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 1) def test_sensor_creation_passive(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) self.assertIsInstance(wrapper._sensor_instance, Sensor) self.assertIsNotNone(wrapper._sensor_instance) def test_sensor_creation_active(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] poll_interval = 10 - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestPollingSensor', - trigger_types=trigger_types, - parent_args=parent_args, - poll_interval=poll_interval) + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestPollingSensor", + trigger_types=trigger_types, + parent_args=parent_args, + poll_interval=poll_interval, + ) self.assertIsNotNone(wrapper._sensor_instance) self.assertIsInstance(wrapper._sensor_instance, PollingSensor) self.assertEqual(wrapper._sensor_instance._poll_interval, poll_interval) def test_sensor_init_fails_file_doesnt_exist(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor_doesnt_exist.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - expected_msg = 'Failed to load sensor class from file.*? No such file or directory' - self.assertRaisesRegexp(IOError, expected_msg, SensorWrapper, - pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor_doesnt_exist.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + expected_msg = ( + "Failed to load sensor class from file.*? No such file or directory" + ) + self.assertRaisesRegexp( + IOError, + expected_msg, + SensorWrapper, + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) def test_sensor_init_fails_sensor_code_contains_typo(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor_with_typo.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - expected_msg = 'Failed to load sensor class from file.*? \'typobar\' is not defined' - self.assertRaisesRegexp(NameError, expected_msg, SensorWrapper, - pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor_with_typo.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + expected_msg = ( + "Failed to load sensor class from file.*? 'typobar' is not defined" + ) + self.assertRaisesRegexp( + NameError, + expected_msg, + SensorWrapper, + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) # Verify error message also contains traceback try: - SensorWrapper(pack='core', file_path=file_path, class_name='TestSensor', - trigger_types=trigger_types, parent_args=parent_args) + SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) except NameError as e: - self.assertIn('Traceback (most recent call last)', six.text_type(e)) - self.assertIn('line 20, in ', six.text_type(e)) + self.assertIn("Traceback (most recent call last)", six.text_type(e)) + self.assertIn("line 20, in ", six.text_type(e)) else: - self.fail('NameError not thrown') + self.fail("NameError not thrown") def test_sensor_wrapper_poll_method_still_works(self): # Verify that sensor wrapper correctly applied select.poll() eventlet workaround so code @@ -167,5 +203,5 @@ def test_sensor_wrapper_poll_method_still_works(self): import select self.assertTrue(eventlet.patcher.is_monkey_patched(select)) - self.assertTrue(select != eventlet.patcher.original('select')) + self.assertTrue(select != eventlet.patcher.original("select")) self.assertTrue(select.poll()) diff --git a/st2reactor/tests/unit/test_tester.py b/st2reactor/tests/unit/test_tester.py index f1f1b01886..60cd6919b8 100644 --- a/st2reactor/tests/unit/test_tester.py +++ b/st2reactor/tests/unit/test_tester.py @@ -25,65 +25,77 @@ BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS_TRIGGERS = { - 'triggertypes': ['triggertype1.yaml', 'triggertype2.yaml'], - 'triggers': ['trigger1.yaml', 'trigger2.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml', 'trigger_instance_2.yaml'] + "triggertypes": ["triggertype1.yaml", "triggertype2.yaml"], + "triggers": ["trigger1.yaml", "trigger2.yaml"], + "triggerinstances": ["trigger_instance_1.yaml", "trigger_instance_2.yaml"], } -TEST_MODELS_RULES = { - 'rules': ['rule1.yaml'] -} +TEST_MODELS_RULES = {"rules": ["rule1.yaml"]} -TEST_MODELS_ACTIONS = { - 'actions': ['action1.yaml'] -} +TEST_MODELS_ACTIONS = {"actions": ["action1.yaml"]} -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RuleTesterTestCase(CleanDbTestCase): def test_matching_trigger_from_file(self): - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_ACTIONS) - rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml') - trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_1.yaml') - tester = RuleTester(rule_file_path=rule_file_path, - trigger_instance_file_path=trigger_instance_file_path) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS + ) + rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml") + trigger_instance_file_path = os.path.join( + BASE_PATH, "../fixtures/trigger_instance_1.yaml" + ) + tester = RuleTester( + rule_file_path=rule_file_path, + trigger_instance_file_path=trigger_instance_file_path, + ) matching = tester.evaluate() self.assertTrue(matching) def test_non_matching_trigger_from_file(self): - rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml') - trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_2.yaml') - tester = RuleTester(rule_file_path=rule_file_path, - trigger_instance_file_path=trigger_instance_file_path) + rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml") + trigger_instance_file_path = os.path.join( + BASE_PATH, "../fixtures/trigger_instance_2.yaml" + ) + tester = RuleTester( + rule_file_path=rule_file_path, + trigger_instance_file_path=trigger_instance_file_path, + ) matching = tester.evaluate() self.assertFalse(matching) def test_matching_trigger_from_db(self): - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_ACTIONS) - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_TRIGGERS) - trigger_instance_db = models['triggerinstances']['trigger_instance_2.yaml'] - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_RULES) - rule_db = models['rules']['rule1.yaml'] - tester = RuleTester(rule_ref=rule_db.ref, - trigger_instance_id=str(trigger_instance_db.id)) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS + ) + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS + ) + trigger_instance_db = models["triggerinstances"]["trigger_instance_2.yaml"] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES + ) + rule_db = models["rules"]["rule1.yaml"] + tester = RuleTester( + rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id) + ) matching = tester.evaluate() self.assertTrue(matching) def test_non_matching_trigger_from_db(self): - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_TRIGGERS) - trigger_instance_db = models['triggerinstances']['trigger_instance_1.yaml'] - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_RULES) - rule_db = models['rules']['rule1.yaml'] - tester = RuleTester(rule_ref=rule_db.ref, - trigger_instance_id=str(trigger_instance_db.id)) + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS + ) + trigger_instance_db = models["triggerinstances"]["trigger_instance_1.yaml"] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES + ) + rule_db = models["rules"]["rule1.yaml"] + tester = RuleTester( + rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id) + ) matching = tester.evaluate() self.assertFalse(matching) diff --git a/st2reactor/tests/unit/test_timer.py b/st2reactor/tests/unit/test_timer.py index 861d74349e..f4311d18d8 100644 --- a/st2reactor/tests/unit/test_timer.py +++ b/st2reactor/tests/unit/test_timer.py @@ -60,9 +60,14 @@ def test_existing_rules_are_loaded_on_start(self): # Add a dummy timer Trigger object type_ = list(TIMER_TRIGGER_TYPES.keys())[0] - parameters = {'unit': 'seconds', 'delta': 1000} - trigger_db = TriggerDB(id=bson.ObjectId(), name='test_trigger_1', pack='dummy', - type=type_, parameters=parameters) + parameters = {"unit": "seconds", "delta": 1000} + trigger_db = TriggerDB( + id=bson.ObjectId(), + name="test_trigger_1", + pack="dummy", + type=type_, + parameters=parameters, + ) trigger_db = Trigger.add_or_update(trigger_db) # Verify object has been added @@ -74,7 +79,7 @@ def test_existing_rules_are_loaded_on_start(self): # Verify handlers are called timer._handle_create_trigger.assert_called_with(trigger_db) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_timer_trace_tag_creation(self, dispatch_mock): timer = St2Timer() timer._scheduler = mock.Mock() @@ -82,11 +87,14 @@ def test_timer_trace_tag_creation(self, dispatch_mock): # Add a dummy timer Trigger object type_ = list(TIMER_TRIGGER_TYPES.keys())[0] - parameters = {'unit': 'seconds', 'delta': 1} - trigger_db = TriggerDB(name='test_trigger_1', pack='dummy', type=type_, - parameters=parameters) + parameters = {"unit": "seconds", "delta": 1} + trigger_db = TriggerDB( + name="test_trigger_1", pack="dummy", type=type_, parameters=parameters + ) timer.add_trigger(trigger_db) timer._emit_trigger_instance(trigger=trigger_db.to_serializable_dict()) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, - '%s-%s' % (TIMER_TRIGGER_TYPES[type_]['name'], trigger_db.name)) + self.assertEqual( + dispatch_mock.call_args[1]["trace_context"].trace_tag, + "%s-%s" % (TIMER_TRIGGER_TYPES[type_]["name"], trigger_db.name), + ) diff --git a/st2stream/dist_utils.py b/st2stream/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2stream/dist_utils.py +++ b/st2stream/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2stream/setup.py b/st2stream/setup.py index af6b302f5d..f34692affc 100644 --- a/st2stream/setup.py +++ b/st2stream/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2stream import __version__ -ST2_COMPONENT = 'st2stream' +ST2_COMPONENT = "st2stream" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2stream' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2stream"], ) diff --git a/st2stream/st2stream/__init__.py b/st2stream/st2stream/__init__.py index bbe290db9a..e6d3f15e0b 100644 --- a/st2stream/st2stream/__init__.py +++ b/st2stream/st2stream/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2stream/st2stream/app.py b/st2stream/st2stream/app.py index 73d32eb4cf..0932dfcc27 100644 --- a/st2stream/st2stream/app.py +++ b/st2stream/st2stream/app.py @@ -43,9 +43,9 @@ def setup_app(config={}): - LOG.info('Creating st2stream: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2stream: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things @@ -54,30 +54,33 @@ def setup_app(config={}): st2stream_config.register_opts() capabilities = { - 'name': 'stream', - 'listen_host': cfg.CONF.stream.host, - 'listen_port': cfg.CONF.stream.port, - 'type': 'active' + "name": "stream", + "listen_host": cfg.CONF.stream.host, + "listen_port": cfg.CONF.stream.port, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='stream', config=st2stream_config, setup_db=True, - register_mq_exchanges=True, - register_signal_handlers=True, - register_internal_trigger_types=False, - run_migrations=False, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="stream", + config=st2stream_config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) - router = Router(debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable, - is_gunicorn=is_gunicorn) + router = Router( + debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn + ) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') - transforms = { - '^/stream/v1/': ['/', '/v1/'] - } + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") + transforms = {"^/stream/v1/": ["/", "/v1/"]} router.add_spec(spec, transforms=transforms) app = router.as_wsgi @@ -87,8 +90,8 @@ def setup_app(config={}): app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='stream') + app = ResponseInstrumentationMiddleware(app, router, service_name="stream") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='stream') + app = RequestInstrumentationMiddleware(app, router, service_name="stream") return app diff --git a/st2stream/st2stream/cmd/__init__.py b/st2stream/st2stream/cmd/__init__.py index 4d6cd0332d..85b1f07d71 100644 --- a/st2stream/st2stream/cmd/__init__.py +++ b/st2stream/st2stream/cmd/__init__.py @@ -15,4 +15,4 @@ from st2stream.cmd import api -__all__ = ['api'] +__all__ = ["api"] diff --git a/st2stream/st2stream/cmd/api.py b/st2stream/st2stream/cmd/api.py index cc1eec7d17..b4ce963ea5 100644 --- a/st2stream/st2stream/cmd/api.py +++ b/st2stream/st2stream/cmd/api.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -30,20 +31,20 @@ from st2common.util.wsgi import shutdown_server_kill_pending_requests from st2stream.signal_handlers import register_stream_signal_handlers from st2stream import config + config.register_opts() from st2stream import app -__all__ = [ - 'main' -] +__all__ = ["main"] eventlet.monkey_patch( os=True, select=True, socket=True, - thread=False if '--use-debugger' in sys.argv else True, - time=True) + thread=False if "--use-debugger" in sys.argv else True, + time=True, +) LOG = logging.getLogger(__name__) @@ -53,29 +54,43 @@ def _setup(): capabilities = { - 'name': 'stream', - 'listen_host': cfg.CONF.stream.host, - 'listen_port': cfg.CONF.stream.port, - 'type': 'active' + "name": "stream", + "listen_host": cfg.CONF.stream.host, + "listen_port": cfg.CONF.stream.port, + "type": "active", } - common_setup(service='stream', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=False, service_registry=True, capabilities=capabilities) + common_setup( + service="stream", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + ) def _run_server(): host = cfg.CONF.stream.host port = cfg.CONF.stream.port - LOG.info('(PID=%s) ST2 Stream API is serving on http://%s:%s.', os.getpid(), host, port) + LOG.info( + "(PID=%s) ST2 Stream API is serving on http://%s:%s.", os.getpid(), host, port + ) max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS worker_pool = eventlet.GreenPool(max_pool_size) sock = eventlet.listen((host, port)) def queue_shutdown(signal_number, stack_frame): - eventlet.spawn_n(shutdown_server_kill_pending_requests, sock=sock, - worker_pool=worker_pool, wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME) + eventlet.spawn_n( + shutdown_server_kill_pending_requests, + sock=sock, + worker_pool=worker_pool, + wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME, + ) # We register a custom SIGINT handler which allows us to kill long running active requests. # Note: Eventually we will support draining (waiting for short-running requests), but we @@ -97,12 +112,12 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except KeyboardInterrupt: - listener = get_listener_if_set(name='stream') + listener = get_listener_if_set(name="stream") if listener: listener.shutdown() except Exception: - LOG.exception('(PID=%s) ST2 Stream API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 Stream API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2stream/st2stream/config.py b/st2stream/st2stream/config.py index fe068dc0b2..bc117b556a 100644 --- a/st2stream/st2stream/config.py +++ b/st2stream/st2stream/config.py @@ -32,8 +32,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -54,17 +57,15 @@ def _register_app_opts(): # config since they are also used outside st2stream api_opts = [ cfg.StrOpt( - 'host', default='127.0.0.1', - help='StackStorm stream API server host'), - cfg.IntOpt( - 'port', default=9102, - help='StackStorm API stream, server port'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "host", default="127.0.0.1", help="StackStorm stream API server host" + ), + cfg.IntOpt("port", default=9102, help="StackStorm API stream, server port"), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), cfg.StrOpt( - 'logging', default='/etc/st2/logging.stream.conf', - help='location of the logging.conf file') + "logging", + default="/etc/st2/logging.stream.conf", + help="location of the logging.conf file", + ), ] - CONF.register_opts(api_opts, group='stream') + CONF.register_opts(api_opts, group="stream") diff --git a/st2stream/st2stream/controllers/v1/executions.py b/st2stream/st2stream/controllers/v1/executions.py index 379491e978..70023b8745 100644 --- a/st2stream/st2stream/controllers/v1/executions.py +++ b/st2stream/st2stream/controllers/v1/executions.py @@ -30,47 +30,46 @@ from st2common.rbac.types import PermissionType from st2common.stream.listener import get_listener -__all__ = [ - 'ActionExecutionOutputStreamController' -] +__all__ = ["ActionExecutionOutputStreamController"] LOG = logging.getLogger(__name__) # Event which is returned when no more data will be produced on this stream endpoint before closing # the connection. -NO_MORE_DATA_EVENT = 'event: EOF\ndata: \'\'\n\n' +NO_MORE_DATA_EVENT = "event: EOF\ndata: ''\n\n" class ActionExecutionOutputStreamController(ResourceController): model = ActionExecutionAPI access = ActionExecution - supported_filters = { - 'output_type': 'output_type' - } + supported_filters = {"output_type": "output_type"} CLOSE_STREAM_LIVEACTION_STATES = action_constants.LIVEACTION_COMPLETED_STATES + [ action_constants.LIVEACTION_STATUS_PAUSING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] - def get_one(self, id, output_type='all', requester_user=None): + def get_one(self, id, output_type="all", requester_user=None): # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).first() + if id == "last": + execution_db = ActionExecution.query().order_by("-id").limit(1).first() if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) execution_id = str(execution_db.id) query_filters = {} - if output_type and output_type != 'all': - query_filters['output_type'] = output_type + if output_type and output_type != "all": + query_filters["output_type"] = output_type def format_output_object(output_db_or_api): if isinstance(output_db_or_api, ActionExecutionOutputDB): @@ -78,25 +77,27 @@ def format_output_object(output_db_or_api): elif isinstance(output_db_or_api, ActionExecutionOutputAPI): data = output_db_or_api else: - raise ValueError('Unsupported format: %s' % (type(output_db_or_api))) + raise ValueError("Unsupported format: %s" % (type(output_db_or_api))) - event = 'st2.execution.output__create' - result = 'event: %s\ndata: %s\n\n' % (event, json_encode(data, indent=None)) + event = "st2.execution.output__create" + result = "event: %s\ndata: %s\n\n" % (event, json_encode(data, indent=None)) return result def existing_output_iter(): # Consume and return all of the existing lines - output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters) + output_dbs = ActionExecutionOutput.query( + execution_id=execution_id, **query_filters + ) # Note: We return all at once instead of yield line by line to avoid multiple socket # writes and to achieve better performance output = [format_output_object(output_db) for output_db in output_dbs] - output = ''.join(output) - yield six.binary_type(output.encode('utf-8')) + output = "".join(output) + yield six.binary_type(output.encode("utf-8")) def new_output_iter(): def noop_gen(): - yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8')) + yield six.binary_type(NO_MORE_DATA_EVENT.encode("utf-8")) # Bail out if execution has already completed / been paused if execution_db.status in self.CLOSE_STREAM_LIVEACTION_STATES: @@ -104,7 +105,9 @@ def noop_gen(): # Wait for and return any new line which may come in execution_ids = [execution_id] - listener = get_listener(name='execution_output') # pylint: disable=no-member + listener = get_listener( + name="execution_output" + ) # pylint: disable=no-member gen = listener.generator(execution_ids=execution_ids) def format(gen): @@ -117,28 +120,37 @@ def format(gen): # Note: gunicorn wsgi handler expect bytes, not unicode # pylint: disable=no-member if isinstance(model_api, ActionExecutionOutputAPI): - if output_type and output_type != 'all' and \ - model_api.output_type != output_type: + if ( + output_type + and output_type != "all" + and model_api.output_type != output_type + ): continue - output = format_output_object(model_api).encode('utf-8') + output = format_output_object(model_api).encode("utf-8") yield six.binary_type(output) elif isinstance(model_api, ActionExecutionAPI): if model_api.status in self.CLOSE_STREAM_LIVEACTION_STATES: - yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8')) + yield six.binary_type( + NO_MORE_DATA_EVENT.encode("utf-8") + ) break else: - LOG.debug('Unrecognized message type: %s' % (model_api)) + LOG.debug("Unrecognized message type: %s" % (model_api)) gen = format(gen) return gen def make_response(): app_iter = itertools.chain(existing_output_iter(), new_output_iter()) - res = Response(headerlist=[("X-Accel-Buffering", "no"), - ('Cache-Control', 'no-cache'), - ("Content-Type", "text/event-stream; charset=UTF-8")], - app_iter=app_iter) + res = Response( + headerlist=[ + ("X-Accel-Buffering", "no"), + ("Cache-Control", "no-cache"), + ("Content-Type", "text/event-stream; charset=UTF-8"), + ], + app_iter=app_iter, + ) return res res = make_response() diff --git a/st2stream/st2stream/controllers/v1/root.py b/st2stream/st2stream/controllers/v1/root.py index c9873127a6..2b9178f785 100644 --- a/st2stream/st2stream/controllers/v1/root.py +++ b/st2stream/st2stream/controllers/v1/root.py @@ -15,9 +15,7 @@ from st2stream.controllers.v1.stream import StreamController -__all__ = [ - 'RootController' -] +__all__ = ["RootController"] class RootController(object): diff --git a/st2stream/st2stream/controllers/v1/stream.py b/st2stream/st2stream/controllers/v1/stream.py index 19c7d71b1d..f6995c3300 100644 --- a/st2stream/st2stream/controllers/v1/stream.py +++ b/st2stream/st2stream/controllers/v1/stream.py @@ -21,58 +21,70 @@ from st2common.util.jsonify import json_encode from st2common.stream.listener import get_listener -__all__ = [ - 'StreamController' -] +__all__ = ["StreamController"] LOG = logging.getLogger(__name__) DEFAULT_EVENTS_WHITELIST = [ - 'st2.announcement__*', - - 'st2.execution__create', - 'st2.execution__update', - 'st2.execution__delete', - - 'st2.liveaction__create', - 'st2.liveaction__update', - 'st2.liveaction__delete', + "st2.announcement__*", + "st2.execution__create", + "st2.execution__update", + "st2.execution__delete", + "st2.liveaction__create", + "st2.liveaction__update", + "st2.liveaction__delete", ] def format(gen): - message = '''event: %s\ndata: %s\n\n''' + message = """event: %s\ndata: %s\n\n""" for pack in gen: if not pack: # Note: gunicorn wsgi handler expect bytes, not unicode - yield six.binary_type(b'\n') + yield six.binary_type(b"\n") else: (event, body) = pack # Note: gunicorn wsgi handler expect bytes, not unicode - yield six.binary_type((message % (event, json_encode(body, - indent=None))).encode('utf-8')) + yield six.binary_type( + (message % (event, json_encode(body, indent=None))).encode("utf-8") + ) class StreamController(object): - def get_all(self, end_execution_id=None, end_event=None, - events=None, action_refs=None, execution_ids=None, requester_user=None): + def get_all( + self, + end_execution_id=None, + end_event=None, + events=None, + action_refs=None, + execution_ids=None, + requester_user=None, + ): events = events if events else DEFAULT_EVENTS_WHITELIST action_refs = action_refs if action_refs else None execution_ids = execution_ids if execution_ids else None def make_response(): - listener = get_listener(name='stream') - app_iter = format(listener.generator(events=events, - action_refs=action_refs, - end_event=end_event, - end_statuses=action_constants.LIVEACTION_COMPLETED_STATES, - end_execution_id=end_execution_id, - execution_ids=execution_ids)) - res = Response(headerlist=[("X-Accel-Buffering", "no"), - ('Cache-Control', 'no-cache'), - ("Content-Type", "text/event-stream; charset=UTF-8")], - app_iter=app_iter) + listener = get_listener(name="stream") + app_iter = format( + listener.generator( + events=events, + action_refs=action_refs, + end_event=end_event, + end_statuses=action_constants.LIVEACTION_COMPLETED_STATES, + end_execution_id=end_execution_id, + execution_ids=execution_ids, + ) + ) + res = Response( + headerlist=[ + ("X-Accel-Buffering", "no"), + ("Cache-Control", "no-cache"), + ("Content-Type", "text/event-stream; charset=UTF-8"), + ], + app_iter=app_iter, + ) return res stream = make_response() diff --git a/st2stream/st2stream/signal_handlers.py b/st2stream/st2stream/signal_handlers.py index 56bc06450a..b292d8b67b 100644 --- a/st2stream/st2stream/signal_handlers.py +++ b/st2stream/st2stream/signal_handlers.py @@ -15,9 +15,7 @@ import signal -__all__ = [ - 'register_stream_signal_handlers' -] +__all__ = ["register_stream_signal_handlers"] def register_stream_signal_handlers(handler_func): diff --git a/st2stream/st2stream/wsgi.py b/st2stream/st2stream/wsgi.py index c177572ba1..14d847e2a1 100644 --- a/st2stream/st2stream/wsgi.py +++ b/st2stream/st2stream/wsgi.py @@ -18,8 +18,11 @@ from st2stream import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2stream/tests/unit/controllers/v1/base.py b/st2stream/tests/unit/controllers/v1/base.py index 24a59a5cd0..4f6e2ca336 100644 --- a/st2stream/tests/unit/controllers/v1/base.py +++ b/st2stream/tests/unit/controllers/v1/base.py @@ -16,9 +16,7 @@ from st2stream import app from st2tests.api import BaseFunctionalTest -__all__ = [ - 'FunctionalTest' -] +__all__ = ["FunctionalTest"] class FunctionalTest(BaseFunctionalTest): diff --git a/st2stream/tests/unit/controllers/v1/test_stream.py b/st2stream/tests/unit/controllers/v1/test_stream.py index 7ff7e62f3d..c67f3e2782 100644 --- a/st2stream/tests/unit/controllers/v1/test_stream.py +++ b/st2stream/tests/unit/controllers/v1/test_stream.py @@ -34,88 +34,72 @@ RUNNER_TYPE_1 = { - 'description': '', - 'enabled': True, - 'name': 'local-shell-cmd', - 'runner_module': 'local_runner', - 'runner_parameters': {} + "description": "", + "enabled": True, + "name": "local-shell-cmd", + "runner_module": "local_runner", + "runner_parameters": {}, } ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "local-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } EXECUTION_1 = { - 'id': '598dbf0c0640fd54bffc688b', - 'action': { - 'ref': 'sixpack.st2.dummy.action1' + "id": "598dbf0c0640fd54bffc688b", + "action": {"ref": "sixpack.st2.dummy.action1"}, + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, }, - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } } STDOUT_1 = { - 'execution_id': '598dbf0c0640fd54bffc688b', - 'action_ref': 'dummy.action1', - 'output_type': 'stdout' + "execution_id": "598dbf0c0640fd54bffc688b", + "action_ref": "dummy.action1", + "output_type": "stdout", } STDERR_1 = { - 'execution_id': '598dbf0c0640fd54bffc688b', - 'action_ref': 'dummy.action1', - 'output_type': 'stderr' + "execution_id": "598dbf0c0640fd54bffc688b", + "action_ref": "dummy.action1", + "output_type": "stderr", } class META(object): delivery_info = {} - def __init__(self, exchange='some', routing_key='thing'): - self.delivery_info['exchange'] = exchange - self.delivery_info['routing_key'] = routing_key + def __init__(self, exchange="some", routing_key="thing"): + self.delivery_info["exchange"] = exchange + self.delivery_info["routing_key"] = routing_key def ack(self): pass class TestStreamController(FunctionalTest): - @classmethod def setUpClass(cls): super(TestStreamController, cls).setUpClass() @@ -126,33 +110,35 @@ def setUpClass(cls): instance = ActionAPI(**ACTION_1) Action.add_or_update(ActionAPI.to_model(instance)) - @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock()) - @mock.patch('st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST', None) + @mock.patch.object(st2common.stream.listener, "listen", mock.Mock()) + @mock.patch("st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST", None) def test_get_all(self): resp = stream.StreamController().get_all() - self.assertEqual(resp._status, '200 OK') - self.assertIn(('Content-Type', 'text/event-stream; charset=UTF-8'), resp._headerlist) + self.assertEqual(resp._status, "200 OK") + self.assertIn( + ("Content-Type", "text/event-stream; charset=UTF-8"), resp._headerlist + ) - listener = st2common.stream.listener.get_listener(name='stream') + listener = st2common.stream.listener.get_listener(name="stream") process = listener.processor(LiveActionAPI) message = None for message in resp._app_iter: - message = message.decode('utf-8') - if message != '\n': + message = message.decode("utf-8") + if message != "\n": break process(LiveActionDB(**LIVE_ACTION_1), META()) - self.assertIn('event: some__thing', message) + self.assertIn("event: some__thing", message) self.assertIn('data: {"', message) self.assertNotIn(SUPER_SECRET_PARAMETER, message) - @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock()) + @mock.patch.object(st2common.stream.listener, "listen", mock.Mock()) def test_get_all_with_filters(self): - cfg.CONF.set_override(name='heartbeat', group='stream', override=0.1) + cfg.CONF.set_override(name="heartbeat", group="stream", override=0.1) - listener = st2common.stream.listener.get_listener(name='stream') + listener = st2common.stream.listener.get_listener(name="stream") process_execution = listener.processor(ActionExecutionAPI) process_liveaction = listener.processor(LiveActionAPI) process_output = listener.processor(ActionExecutionOutputAPI) @@ -164,50 +150,50 @@ def test_get_all_with_filters(self): output_api_stderr = ActionExecutionOutputDB(**STDERR_1) def dispatch_and_handle_mock_data(resp): - received_messages_data = '' + received_messages_data = "" for index, message in enumerate(resp._app_iter): if message.strip(): - received_messages_data += message.decode('utf-8') + received_messages_data += message.decode("utf-8") # Dispatch some mock events if index == 0: - meta = META('st2.execution', 'create') + meta = META("st2.execution", "create") process_execution(execution_api, meta) elif index == 1: - meta = META('st2.execution', 'update') + meta = META("st2.execution", "update") process_execution(execution_api, meta) elif index == 2: - meta = META('st2.execution', 'delete') + meta = META("st2.execution", "delete") process_execution(execution_api, meta) elif index == 3: - meta = META('st2.liveaction', 'create') + meta = META("st2.liveaction", "create") process_liveaction(liveaction_api, meta) elif index == 4: - meta = META('st2.liveaction', 'create') + meta = META("st2.liveaction", "create") process_liveaction(liveaction_api, meta) elif index == 5: - meta = META('st2.liveaction', 'delete') + meta = META("st2.liveaction", "delete") process_liveaction(liveaction_api, meta) elif index == 6: - meta = META('st2.liveaction', 'delete') + meta = META("st2.liveaction", "delete") process_liveaction(liveaction_api, meta) elif index == 7: - meta = META('st2.announcement', 'chatops') + meta = META("st2.announcement", "chatops") process_no_api_model({}, meta) elif index == 8: - meta = META('st2.execution.output', 'create') + meta = META("st2.execution.output", "create") process_output(output_api_stdout, meta) elif index == 9: - meta = META('st2.execution.output', 'create') + meta = META("st2.execution.output", "create") process_output(output_api_stderr, meta) elif index == 10: - meta = META('st2.announcement', 'errbot') + meta = META("st2.announcement", "errbot") process_no_api_model({}, meta) else: break - received_messages = received_messages_data.split('\n\n') + received_messages = received_messages_data.split("\n\n") received_messages = [message for message in received_messages if message] return received_messages @@ -217,10 +203,10 @@ def dispatch_and_handle_mock_data(resp): received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 9) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.liveaction__delete', received_messages[5]) - self.assertIn('st2.announcement__chatops', received_messages[7]) - self.assertIn('st2.announcement__errbot', received_messages[8]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.liveaction__delete", received_messages[5]) + self.assertIn("st2.announcement__chatops", received_messages[7]) + self.assertIn("st2.announcement__errbot", received_messages[8]) # 1. ?events= filter # No filter provided - all messages should be received @@ -229,79 +215,79 @@ def dispatch_and_handle_mock_data(resp): received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 11) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.announcement__chatops', received_messages[7]) - self.assertIn('st2.execution.output__create', received_messages[8]) - self.assertIn('st2.execution.output__create', received_messages[9]) - self.assertIn('st2.announcement__errbot', received_messages[10]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.announcement__chatops", received_messages[7]) + self.assertIn("st2.execution.output__create", received_messages[8]) + self.assertIn("st2.execution.output__create", received_messages[9]) + self.assertIn("st2.announcement__errbot", received_messages[10]) # Filter provided, only three messages should be received - events = ['st2.execution__create', 'st2.liveaction__delete'] + events = ["st2.execution__create", "st2.liveaction__delete"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 3) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.liveaction__delete', received_messages[1]) - self.assertIn('st2.liveaction__delete', received_messages[2]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.liveaction__delete", received_messages[1]) + self.assertIn("st2.liveaction__delete", received_messages[2]) # Filter provided, only three messages should be received - events = ['st2.liveaction__create', 'st2.liveaction__delete'] + events = ["st2.liveaction__create", "st2.liveaction__delete"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 4) - self.assertIn('st2.liveaction__create', received_messages[0]) - self.assertIn('st2.liveaction__create', received_messages[1]) - self.assertIn('st2.liveaction__delete', received_messages[2]) - self.assertIn('st2.liveaction__delete', received_messages[3]) + self.assertIn("st2.liveaction__create", received_messages[0]) + self.assertIn("st2.liveaction__create", received_messages[1]) + self.assertIn("st2.liveaction__delete", received_messages[2]) + self.assertIn("st2.liveaction__delete", received_messages[3]) # Glob filter - events = ['st2.announcement__*'] + events = ["st2.announcement__*"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) - self.assertIn('st2.announcement__chatops', received_messages[0]) - self.assertIn('st2.announcement__errbot', received_messages[1]) + self.assertIn("st2.announcement__chatops", received_messages[0]) + self.assertIn("st2.announcement__errbot", received_messages[1]) # Filter provided - events = ['st2.execution.output__create'] + events = ["st2.execution.output__create"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) - self.assertIn('st2.execution.output__create', received_messages[0]) - self.assertIn('st2.execution.output__create', received_messages[1]) + self.assertIn("st2.execution.output__create", received_messages[0]) + self.assertIn("st2.execution.output__create", received_messages[1]) # Filter provided, invalid , no message should be received - events = ['invalid1', 'invalid2'] + events = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) # 2. ?action_refs= filter - action_refs = ['invalid1', 'invalid2'] + action_refs = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(action_refs=action_refs) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) - action_refs = ['dummy.action1'] + action_refs = ["dummy.action1"] resp = stream.StreamController().get_all(action_refs=action_refs) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) # 3. ?execution_ids= filter - execution_ids = ['invalid1', 'invalid2'] + execution_ids = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(execution_ids=execution_ids) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) - execution_ids = [EXECUTION_1['id']] + execution_ids = [EXECUTION_1["id"]] resp = stream.StreamController().get_all(execution_ids=execution_ids) received_messages = dispatch_and_handle_mock_data(resp) diff --git a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py index deb76b4e97..d14dd029e8 100644 --- a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py +++ b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py @@ -30,50 +30,54 @@ from .base import FunctionalTest -__all__ = [ - 'ActionExecutionOutputStreamControllerTestCase' -] +__all__ = ["ActionExecutionOutputStreamControllerTestCase"] class ActionExecutionOutputStreamControllerTestCase(FunctionalTest): def test_get_one_id_last_no_executions_in_the_database(self): ActionExecution.query().delete() - resp = self.app.get('/v1/executions/last/output', expect_errors=True) + resp = self.app.get("/v1/executions/last/output", expect_errors=True) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(resp.json['faultstring'], 'No executions found in the database') + self.assertEqual( + resp.json["faultstring"], "No executions found in the database" + ) def test_get_output_running_execution(self): # Retrieve lister instance to avoid race with listener connection not being established # early enough for tests to pass. # NOTE: This only affects tests where listeners are not pre-initialized. - listener = get_listener(name='execution_output') + listener = get_listener(name="execution_output") eventlet.sleep(1.0) # Test the execution output API endpoint for execution which is running (blocking) status = action_constants.LIVEACTION_STATUS_RUNNING timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) - output_params = dict(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout before start\n') + output_params = dict( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout before start\n", + ) # Insert mock output object output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db, publish=False) def insert_mock_data(): - output_params['data'] = 'stdout mid 1\n' + output_params["data"] = "stdout mid 1\n" output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -81,7 +85,7 @@ def insert_mock_data(): # spawn an eventlet which eventually finishes the action. def publish_action_finished(action_execution_db): # Insert mock output object - output_params['data'] = 'stdout pre finish 1\n' + output_params["data"] = "stdout pre finish 1\n" output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -96,28 +100,32 @@ def publish_action_finished(action_execution_db): # Retrieve data while execution is running - endpoint return new data once it's available # and block until the execution finishes - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 4) - self.assertEqual(events[0][1]['data'], 'stdout before start\n') - self.assertEqual(events[1][1]['data'], 'stdout mid 1\n') - self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n') - self.assertEqual(events[3][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout before start\n") + self.assertEqual(events[1][1]["data"], "stdout mid 1\n") + self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n") + self.assertEqual(events[3][0], "EOF") # Once the execution is in completed state, existing output should be returned immediately - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 4) - self.assertEqual(events[0][1]['data'], 'stdout before start\n') - self.assertEqual(events[1][1]['data'], 'stdout mid 1\n') - self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n') - self.assertEqual(events[3][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout before start\n") + self.assertEqual(events[1][1]["data"], "stdout mid 1\n") + self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n") + self.assertEqual(events[3][0], "EOF") listener.shutdown() @@ -127,49 +135,57 @@ def test_get_output_finished_execution(self): # Insert mock execution and output objects status = action_constants.LIVEACTION_STATUS_SUCCEEDED timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) for i in range(1, 6): - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout %s\n' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) for i in range(10, 15): - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr %s\n' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 11) - self.assertEqual(events[0][1]['data'], 'stdout 1\n') - self.assertEqual(events[9][1]['data'], 'stderr 14\n') - self.assertEqual(events[10][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout 1\n") + self.assertEqual(events[9][1]["data"], "stderr 14\n") + self.assertEqual(events[10][0], "EOF") # Verify "last" short-hand id works - resp = self.app.get('/v1/executions/last/output', expect_errors=False) + resp = self.app.get("/v1/executions/last/output", expect_errors=False) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 11) - self.assertEqual(events[10][0], 'EOF') + self.assertEqual(events[10][0], "EOF") def _parse_response(self, response): """ @@ -177,12 +193,12 @@ def _parse_response(self, response): """ events = [] - lines = response.strip().split('\n') + lines = response.strip().split("\n") for index, line in enumerate(lines): - if 'data:' in line: + if "data:" in line: e_line = lines[index - 1] - event_name = e_line[e_line.find('event: ') + len('event:'):].strip() - event_data = line[line.find('data: ') + len('data :'):].strip() + event_name = e_line[e_line.find("event: ") + len("event:") :].strip() + event_data = line[line.find("data: ") + len("data :") :].strip() event_data = json.loads(event_data) if len(event_data) > 2 else {} events.append((event_name, event_data)) diff --git a/st2tests/dist_utils.py b/st2tests/dist_utils.py index a6f62c8cc2..2f2043cf29 100644 --- a/st2tests/dist_utils.py +++ b/st2tests/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2tests/integration/orquesta/base.py b/st2tests/integration/orquesta/base.py index 52e2277e4c..f5f13cce04 100644 --- a/st2tests/integration/orquesta/base.py +++ b/st2tests/integration/orquesta/base.py @@ -30,7 +30,7 @@ LIVEACTION_LAUNCHED_STATUSES = [ action_constants.LIVEACTION_STATUS_REQUESTED, action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_RUNNING + action_constants.LIVEACTION_STATUS_RUNNING, ] DEFAULT_WAIT_FIXED = 500 @@ -42,10 +42,9 @@ def retry_on_exceptions(exc): class WorkflowControlTestCaseMixin(object): - def _create_temp_file(self): _, temp_file_path = tempfile.mkstemp() - os.chmod(temp_file_path, 0o755) # nosec + os.chmod(temp_file_path, 0o755) # nosec return temp_file_path def _delete_temp_file(self, temp_file_path): @@ -57,18 +56,23 @@ def _delete_temp_file(self, temp_file_path): class TestWorkflowExecution(unittest2.TestCase): - @classmethod def setUpClass(cls): - cls.st2client = st2.Client(base_url='http://127.0.0.1') + cls.st2client = st2.Client(base_url="http://127.0.0.1") - def _execute_workflow(self, action, parameters=None, execute_async=True, - expected_status=None, expected_result=None): + def _execute_workflow( + self, + action, + parameters=None, + execute_async=True, + expected_status=None, + expected_result=None, + ): ex = models.LiveAction(action=action, parameters=(parameters or {})) ex = self.st2client.executions.create(ex) self.assertIsNotNone(ex.id) - self.assertEqual(ex.action['ref'], action) + self.assertEqual(ex.action["ref"], action) self.assertIn(ex.status, LIVEACTION_LAUNCHED_STATUSES) if execute_async: @@ -88,14 +92,16 @@ def _execute_workflow(self, action, parameters=None, execute_async=True, @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_state(self, ex, states): if isinstance(states, six.string_types): states = [states] for state in states: if state not in action_constants.LIVEACTION_STATUSES: - raise ValueError('Status %s is not valid.' % state) + raise ValueError("Status %s is not valid." % state) try: ex = self.st2client.executions.get_by_id(ex.id) @@ -104,8 +110,7 @@ def _wait_for_state(self, ex, states): if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( 'Execution is in completed state "%s" and ' - 'does not match expected state(s). %s' % - (ex.status, ex.result) + "does not match expected state(s). %s" % (ex.status, ex.result) ) else: raise @@ -117,13 +122,16 @@ def _get_children(self, ex): @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_task(self, ex, task, status=None, num_task_exs=1): ex = self.st2client.executions.get_by_id(ex.id) task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == task + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == task ] try: @@ -131,8 +139,9 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and does not match expected number of ' - 'tasks. Expected: %s Actual: %s' % (str(num_task_exs), str(len(task_exs))) + "Execution is in completed state and does not match expected number of " + "tasks. Expected: %s Actual: %s" + % (str(num_task_exs), str(len(task_exs))) ) else: raise @@ -143,7 +152,7 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and not all tasks ' + "Execution is in completed state and not all tasks " 'match expected status "%s".' % status ) else: @@ -153,17 +162,19 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_completion(self, ex): ex = self._wait_for_state(ex, action_constants.LIVEACTION_COMPLETED_STATES) try: - self.assertTrue(hasattr(ex, 'result')) + self.assertTrue(hasattr(ex, "result")) except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and does not ' - 'contain expected result.' + "Execution is in completed state and does not " + "contain expected result." ) else: raise diff --git a/st2tests/integration/orquesta/test_performance.py b/st2tests/integration/orquesta/test_performance.py index e68ecc7f5f..899b3090f9 100644 --- a/st2tests/integration/orquesta/test_performance.py +++ b/st2tests/integration/orquesta/test_performance.py @@ -27,34 +27,35 @@ class WiringTest(base.TestWorkflowExecution): - def test_concurrent_load(self): load_count = 3 delay_poll = load_count * 5 - wf_name = 'examples.orquesta-mock-create-vm' - wf_input = {'vm_name': 'demo1', 'meta': {'demo1.itests.org': '10.3.41.99'}} + wf_name = "examples.orquesta-mock-create-vm" + wf_input = {"vm_name": "demo1", "meta": {"demo1.itests.org": "10.3.41.99"}} exs = [self._execute_workflow(wf_name, wf_input) for i in range(load_count)] eventlet.sleep(delay_poll) for ex in exs: e = self._wait_for_completion(ex) - self.assertEqual(e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result)) - self.assertIn('output', e.result) - self.assertIn('vm_id', e.result['output']) + self.assertEqual( + e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result) + ) + self.assertIn("output", e.result) + self.assertIn("vm_id", e.result["output"]) def test_with_items_load(self): - wf_name = 'examples.orquesta-with-items-concurrency' + wf_name = "examples.orquesta-with-items-concurrency" num_items = 10 concurrency = 10 members = [str(i).zfill(5) for i in range(0, num_items)] - wf_input = {'members': members, 'concurrency': concurrency} + wf_input = {"members": members, "concurrency": concurrency} - message = '%s, resistance is futile!' - expected_output = {'items': [message % i for i in members]} - expected_result = {'output': expected_output} + message = "%s, resistance is futile!" + expected_output = {"items": [message % i for i in members]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) diff --git a/st2tests/integration/orquesta/test_wiring.py b/st2tests/integration/orquesta/test_wiring.py index f542c0d779..3e07d7b3fe 100644 --- a/st2tests/integration/orquesta/test_wiring.py +++ b/st2tests/integration/orquesta/test_wiring.py @@ -23,13 +23,12 @@ class WiringTest(base.TestWorkflowExecution): - def test_sequential(self): - wf_name = 'examples.orquesta-sequential' - wf_input = {'name': 'Thanos'} + wf_name = "examples.orquesta-sequential" + wf_input = {"name": "Thanos"} - expected_output = {'greeting': 'Thanos, All your base are belong to us!'} - expected_result = {'output': expected_output} + expected_output = {"greeting": "Thanos, All your base are belong to us!"} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -38,18 +37,18 @@ def test_sequential(self): self.assertDictEqual(ex.result, expected_result) def test_join(self): - wf_name = 'examples.orquesta-join' + wf_name = "examples.orquesta-join" expected_output = { - 'messages': [ - 'Fee fi fo fum', - 'I smell the blood of an English man', - 'Be alive, or be he dead', - 'I\'ll grind his bones to make my bread' + "messages": [ + "Fee fi fo fum", + "I smell the blood of an English man", + "Be alive, or be he dead", + "I'll grind his bones to make my bread", ] } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -58,10 +57,10 @@ def test_join(self): self.assertDictEqual(ex.result, expected_result) def test_cycle(self): - wf_name = 'examples.orquesta-rollback-retry' + wf_name = "examples.orquesta-rollback-retry" expected_output = None - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -70,12 +69,12 @@ def test_cycle(self): self.assertDictEqual(ex.result, expected_result) def test_action_less(self): - wf_name = 'examples.orquesta-test-action-less-tasks' - wf_input = {'name': 'Thanos'} + wf_name = "examples.orquesta-test-action-less-tasks" + wf_input = {"name": "Thanos"} - message = 'Thanos, All your base are belong to us!' - expected_output = {'greeting': message.upper()} - expected_result = {'output': expected_output} + message = "Thanos, All your base are belong to us!" + expected_output = {"greeting": message.upper()} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -84,73 +83,72 @@ def test_action_less(self): self.assertDictEqual(ex.result, expected_result) def test_st2_runtime_context(self): - wf_name = 'examples.orquesta-st2-ctx' + wf_name = "examples.orquesta-st2-ctx" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - expected_output = {'callback': 'http://127.0.0.1:9101/v1/executions/%s' % str(ex.id)} - expected_result = {'output': expected_output} + expected_output = { + "callback": "http://127.0.0.1:9101/v1/executions/%s" % str(ex.id) + } + expected_result = {"output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertDictEqual(ex.result, expected_result) def test_subworkflow(self): - wf_name = 'examples.orquesta-subworkflow' + wf_name = "examples.orquesta-subworkflow" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(ex, 'start', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "start", ac_const.LIVEACTION_STATUS_SUCCEEDED) - t2_ex = self._wait_for_task(ex, 'subworkflow', ac_const.LIVEACTION_STATUS_SUCCEEDED)[0] - self._wait_for_task(t2_ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(t2_ex, 'task2', ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(t2_ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t2_ex = self._wait_for_task( + ex, "subworkflow", ac_const.LIVEACTION_STATUS_SUCCEEDED + )[0] + self._wait_for_task(t2_ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(t2_ex, "task2", ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(t2_ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(ex, 'finish', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "finish", ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_output_on_error(self): - wf_name = 'examples.orquesta-output-on-error' + wf_name = "examples.orquesta-output-on-error" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - expected_output = { - 'progress': 25 - } + expected_output = {"progress": 25} expected_errors = [ { - 'type': 'error', - 'task_id': 'task2', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "type": "error", + "task_id": "task2", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_result = { - 'errors': expected_errors, - 'output': expected_output - } + expected_result = {"errors": expected_errors, "output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) self.assertDictEqual(ex.result, expected_result) def test_config_context_renders(self): config_value = "Testing" - wf_name = 'examples.render_config_context' + wf_name = "examples.render_config_context" - expected_output = {'context_value': config_value} - expected_result = {'output': expected_output} + expected_output = {"context_value": config_value} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -159,21 +157,21 @@ def test_config_context_renders(self): self.assertDictEqual(ex.result, expected_result) def test_field_escaping(self): - wf_name = 'examples.orquesta-test-field-escaping' + wf_name = "examples.orquesta-test-field-escaping" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) expected_output = { - 'wf.hostname.with.periods': { - 'hostname.domain.tld': 'vars.value.with.periods', - 'hostname2.domain.tld': { - 'stdout': 'vars.nested.value.with.periods', + "wf.hostname.with.periods": { + "hostname.domain.tld": "vars.value.with.periods", + "hostname2.domain.tld": { + "stdout": "vars.nested.value.with.periods", }, }, - 'wf.output.with.periods': 'vars.nested.value.with.periods', + "wf.output.with.periods": "vars.nested.value.with.periods", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertDictEqual(ex.result, expected_result) diff --git a/st2tests/integration/orquesta/test_wiring_cancel.py b/st2tests/integration/orquesta/test_wiring_cancel.py index ff9d0d378f..0e4edaf918 100644 --- a/st2tests/integration/orquesta/test_wiring_cancel.py +++ b/st2tests/integration/orquesta/test_wiring_cancel.py @@ -22,7 +22,9 @@ from st2common.constants import action as ac_const -class CancellationWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin): +class CancellationWiringTest( + base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin +): temp_file_path = None @@ -44,9 +46,9 @@ def test_cancellation(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - ex = self._execute_workflow('examples.orquesta-test-cancel', params) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path, "message": "foobar"} + ex = self._execute_workflow("examples.orquesta-test-cancel", params) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the workflow before the temp file is created. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -63,7 +65,7 @@ def test_cancellation(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) # Task is completed successfully for graceful exit. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) @@ -74,15 +76,15 @@ def test_task_cancellation(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - ex = self._execute_workflow('examples.orquesta-test-cancel', params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path, "message": "foobar"} + ex = self._execute_workflow("examples.orquesta-test-cancel", params) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the task execution. self.st2client.executions.delete(task_exs[0]) # Wait for the task and parent workflow to be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) @@ -93,10 +95,10 @@ def test_cancellation_cascade_down_to_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - action_ref = 'examples.orquesta-test-cancel-subworkflow' + params = {"tempfile": path, "message": "foobar"} + action_ref = "examples.orquesta-test-cancel-subworkflow" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -123,10 +125,10 @@ def test_cancellation_cascade_up_from_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - action_ref = 'examples.orquesta-test-cancel-subworkflow' + params = {"tempfile": path, "message": "foobar"} + action_ref = "examples.orquesta-test-cancel-subworkflow" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -155,12 +157,12 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path, 'file2': path} - action_ref = 'examples.orquesta-test-cancel-subworkflows' + params = {"file1": path, "file2": path} + action_ref = "examples.orquesta-test-cancel-subworkflows" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex_1 = task_exs[0] - task_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex_2 = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -168,19 +170,27 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self): self.st2client.executions.delete(subwf_ex_1) # Assert subworkflow is canceling. - subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING) + subwf_ex_1 = self._wait_for_state( + subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING + ) # Assert main workflow and the other subworkflow is canceling. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELING) - subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING) + subwf_ex_2 = self._wait_for_state( + subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file. os.remove(path) self.assertFalse(os.path.exists(path)) # Assert subworkflows are canceled. - subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED) - subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED) + subwf_ex_1 = self._wait_for_state( + subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED + ) + subwf_ex_2 = self._wait_for_state( + subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED + ) # Assert main workflow is canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) diff --git a/st2tests/integration/orquesta/test_wiring_data_flow.py b/st2tests/integration/orquesta/test_wiring_data_flow.py index a9569cf693..ed5fbfa23a 100644 --- a/st2tests/integration/orquesta/test_wiring_data_flow.py +++ b/st2tests/integration/orquesta/test_wiring_data_flow.py @@ -27,13 +27,12 @@ class WiringTest(base.TestWorkflowExecution): - def test_data_flow(self): - wf_name = 'examples.orquesta-data-flow' - wf_input = {'a1': 'fee fi fo fum'} + wf_name = "examples.orquesta-data-flow" + wf_input = {"a1": "fee fi fo fum"} - expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']} - expected_result = {'output': expected_output} + expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -42,15 +41,15 @@ def test_data_flow(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_unicode(self): - wf_name = 'examples.orquesta-data-flow' - wf_input = {'a1': '床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉'} + wf_name = "examples.orquesta-data-flow" + wf_input = {"a1": "床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉"} expected_output = { - 'a5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1'], - 'b5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1'] + "a5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"], + "b5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"], } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -59,16 +58,15 @@ def test_data_flow_unicode(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_unicode_concat_with_ascii(self): - wf_name = 'examples.orquesta-sequential' - wf_input = {'name': '薩諾斯'} + wf_name = "examples.orquesta-sequential" + wf_input = {"name": "薩諾斯"} expected_output = { - 'greeting': '%s, All your base are belong to us!' % ( - wf_input['name'].decode('utf-8') if six.PY2 else wf_input['name'] - ) + "greeting": "%s, All your base are belong to us!" + % (wf_input["name"].decode("utf-8") if six.PY2 else wf_input["name"]) } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -77,15 +75,17 @@ def test_data_flow_unicode_concat_with_ascii(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_big_data_size(self): - wf_name = 'examples.orquesta-data-flow' + wf_name = "examples.orquesta-data-flow" data_length = 100000 - data = ''.join(random.choice(string.ascii_lowercase) for _ in range(data_length)) + data = "".join( + random.choice(string.ascii_lowercase) for _ in range(data_length) + ) - wf_input = {'a1': data} + wf_input = {"a1": data} - expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']} - expected_result = {'output': expected_output} + expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) diff --git a/st2tests/integration/orquesta/test_wiring_delay.py b/st2tests/integration/orquesta/test_wiring_delay.py index f825475479..32b923b923 100644 --- a/st2tests/integration/orquesta/test_wiring_delay.py +++ b/st2tests/integration/orquesta/test_wiring_delay.py @@ -23,13 +23,12 @@ class TaskDelayWiringTest(base.TestWorkflowExecution): - def test_task_delay(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 1} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 1} - expected_output = {'greeting': 'Thanos, All your base are belong to us!'} - expected_result = {'output': expected_output} + expected_output = {"greeting": "Thanos, All your base are belong to us!"} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -38,12 +37,12 @@ def test_task_delay(self): self.assertDictEqual(ex.result, expected_result) def test_task_delay_workflow_cancellation(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 300} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 300} # Launch workflow and task1 should be delayed. ex = self._execute_workflow(wf_name, wf_input) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED) # Cancel the workflow before the temp file is created. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -53,24 +52,24 @@ def test_task_delay_workflow_cancellation(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) # Task execution should be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_task_delay_task_cancellation(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 300} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 300} # Launch workflow and task1 should be delayed. ex = self._execute_workflow(wf_name, wf_input) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED) # Cancel the task execution. self.st2client.executions.delete(task_exs[0]) # Wait for the task and parent workflow to be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) diff --git a/st2tests/integration/orquesta/test_wiring_error_handling.py b/st2tests/integration/orquesta/test_wiring_error_handling.py index f3c9f87fdd..130a68c7c5 100644 --- a/st2tests/integration/orquesta/test_wiring_error_handling.py +++ b/st2tests/integration/orquesta/test_wiring_error_handling.py @@ -22,236 +22,235 @@ class ErrorHandlingTest(base.TestWorkflowExecution): - def test_inspection_error(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] - ex = self._execute_workflow('examples.orquesta-fail-inspection') + ex = self._execute_workflow("examples.orquesta-fail-inspection") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_input_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-input-rendering') + ex = self._execute_workflow("examples.orquesta-fail-input-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_vars_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-vars-rendering') + ex = self._execute_workflow("examples.orquesta-fail-vars-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_start_task_error(self): self.maxDiff = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().name.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().name.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, }, { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'greeting\' ' - 'in expression \'<% ctx().greeting %>\' from context.' - ) - } + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'greeting' " + "in expression '<% ctx().greeting %>' from context." + ), + }, ] - ex = self._execute_workflow('examples.orquesta-fail-start-task') + ex = self._execute_workflow("examples.orquesta-fail-start-task") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_task_transition_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'value\' ' - 'in expression \'<% succeeded() and result().value %>\' from context.' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'value' " + "in expression '<% succeeded() and result().value %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_output = { - 'greeting': None - } + expected_output = {"greeting": None} - ex = self._execute_workflow('examples.orquesta-fail-task-transition') + ex = self._execute_workflow("examples.orquesta-fail-task-transition") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_task_publish_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'value\' ' - 'in expression \'<% result().value %>\' from context.' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'value' " + "in expression '<% result().value %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_output = { - 'greeting': None - } + expected_output = {"greeting": None} - ex = self._execute_workflow('examples.orquesta-fail-task-publish') + ex = self._execute_workflow("examples.orquesta-fail-task-publish") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_output_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-output-rendering') + ex = self._execute_workflow("examples.orquesta-fail-output-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_task_content_errors(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action reference "echo" is not formatted correctly.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task1.action' + "type": "content", + "message": 'The action reference "echo" is not formatted correctly.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task1.action", }, { - 'type': 'content', - 'message': 'The action "core.echoz" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task2.action' + "type": "content", + "message": 'The action "core.echoz" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task2.action", }, { - 'type': 'content', - 'message': 'Action "core.echo" is missing required input "message".', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task3.input' + "type": "content", + "message": 'Action "core.echo" is missing required input "message".', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task3.input", }, { - 'type': 'content', - 'message': 'Action "core.echo" has unexpected input "messages".', - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.properties.input.' - r'patternProperties.^\w+$' + "type": "content", + "message": 'Action "core.echo" has unexpected input "messages".', + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$.properties.input." + r"patternProperties.^\w+$" ), - 'spec_path': 'tasks.task3.input.messages' - } + "spec_path": "tasks.task3.input.messages", + }, ] - ex = self._execute_workflow('examples.orquesta-fail-inspection-task-contents') + ex = self._execute_workflow("examples.orquesta-fail-inspection-task-contents") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_remediate_then_fail(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' - } + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", + }, ] - ex = self._execute_workflow('examples.orquesta-remediate-then-fail') + ex = self._execute_workflow("examples.orquesta-remediate-then-fail") ex = self._wait_for_completion(ex) # Assert that the log task is executed. @@ -261,93 +260,95 @@ def test_remediate_then_fail(self): # tasks is reached (With some hard limit) before failing eventlet.sleep(2) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) - self._wait_for_task(ex, 'log', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "log", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_fail_manually(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' - } + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", + }, ] - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-fail-manually', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow( + "examples.orquesta-error-handling-fail-manually", wf_input + ) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) - self._wait_for_task(ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_fail_continue(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-continue', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow( + "examples.orquesta-error-handling-continue", wf_input + ) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_fail_noop(self): - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-noop', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow("examples.orquesta-error-handling-noop", wf_input) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertDictEqual(ex.result, {'output': expected_output}) + self.assertDictEqual(ex.result, {"output": expected_output}) diff --git a/st2tests/integration/orquesta/test_wiring_functions.py b/st2tests/integration/orquesta/test_wiring_functions.py index 91da108d39..538bf9ddd7 100644 --- a/st2tests/integration/orquesta/test_wiring_functions.py +++ b/st2tests/integration/orquesta/test_wiring_functions.py @@ -19,165 +19,174 @@ class FunctionsWiringTest(base.TestWorkflowExecution): - def test_data_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-data-functions' + wf_name = "examples.orquesta-test-yaql-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_none_str': '%*****__%NONE%__*****%', - 'data_str': 'foobar' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_none_str": "%*****__%NONE%__*****%", + "data_str": "foobar", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_data_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-data-functions' + wf_name = "examples.orquesta-test-jinja-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}', - 'data_none_str': '%*****__%NONE%__*****%', - 'data_str': 'foobar', - 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_pipe_str_1": '{"foo": {"bar": "foobar"}}', + "data_none_str": "%*****__%NONE%__*****%", + "data_str": "foobar", + "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_path_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-path-functions' + wf_name = "examples.orquesta-test-yaql-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_path_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-path-functions' + wf_name = "examples.orquesta-test-jinja-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_regex_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-regex-functions' + wf_name = "examples.orquesta-test-yaql-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_regex_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-regex-functions' + wf_name = "examples.orquesta-test-jinja-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_time_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-time-functions' + wf_name = "examples.orquesta-test-yaql-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_time_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-time-functions' + wf_name = "examples.orquesta-test-jinja-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_version_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-version-functions' + wf_name = "examples.orquesta-test-yaql-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_version_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-version-functions' + wf_name = "examples.orquesta-test-jinja-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) diff --git a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py index d02b8594c4..e4384c72cd 100644 --- a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py +++ b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py @@ -21,90 +21,76 @@ class DatastoreFunctionTest(base.TestWorkflowExecution): @classmethod - def set_kvp(cls, name, value, scope='system', secret=False): + def set_kvp(cls, name, value, scope="system", secret=False): kvp = models.KeyValuePair( - id=name, - name=name, - value=value, - scope=scope, - secret=secret + id=name, name=name, value=value, scope=scope, secret=secret ) cls.st2client.keys.update(kvp) @classmethod - def del_kvp(cls, name, scope='system'): - kvp = models.KeyValuePair( - id=name, - name=name, - scope=scope - ) + def del_kvp(cls, name, scope="system"): + kvp = models.KeyValuePair(id=name, name=name, scope=scope) cls.st2client.keys.delete(kvp) def test_st2kv_system_scope(self): - key = 'lakshmi' - value = 'kanahansnasnasdlsajks' + key = "lakshmi" + value = "kanahansnasnasdlsajks" self.set_kvp(key, value) - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': 'system.%s' % key} + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_user_scope(self): - key = 'winson' - value = 'SoDiamondEng' + key = "winson" + value = "SoDiamondEng" - self.set_kvp(key, value, 'user') - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': key} + self.set_kvp(key, value, "user") + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) # self.del_kvp(key) def test_st2kv_decrypt(self): - key = 'kami' - value = 'eggplant' + key = "kami" + value = "eggplant" self.set_kvp(key, value, secret=True) - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_nonexistent(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) @@ -112,69 +98,71 @@ def test_st2kv_nonexistent(self): self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_FAILED) - expected_error = 'The key "%s" does not exist in the StackStorm datastore.' % key + expected_error = ( + 'The key "%s" does not exist in the StackStorm datastore.' % key + ) - self.assertIn(expected_error, output.result['errors'][0]['message']) + self.assertIn(expected_error, output.result["errors"][0]["message"]) def test_st2kv_default_value(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': 'stone' - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": "stone"} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) def test_st2kv_default_value_with_empty_string(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': '' - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": ""} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) def test_st2kv_default_value_with_null(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': None - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": None} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) diff --git a/st2tests/integration/orquesta/test_wiring_functions_task.py b/st2tests/integration/orquesta/test_wiring_functions_task.py index 990b86752c..35d002c885 100644 --- a/st2tests/integration/orquesta/test_wiring_functions_task.py +++ b/st2tests/integration/orquesta/test_wiring_functions_task.py @@ -21,91 +21,94 @@ class FunctionsWiringTest(base.TestWorkflowExecution): - def test_task_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-task-functions' + wf_name = "examples.orquesta-test-yaql-task-functions" expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_task_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-task-functions' + wf_name = "examples.orquesta-test-jinja-task-functions" expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_task_nonexistent_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-task-nonexistent' + wf_name = "examples.orquesta-test-yaql-task-nonexistent" expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% task("task0") %>\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% task(\"task0\") %>'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': expected_output, 'errors': expected_errors} + expected_result = {"output": expected_output, "errors": expected_errors} self._execute_workflow( wf_name, execute_async=False, expected_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_result=expected_result + expected_result=expected_result, ) def test_task_nonexistent_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-task-nonexistent' + wf_name = "examples.orquesta-test-jinja-task-nonexistent" expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'JinjaEvaluationException: Unable to evaluate expression ' - '\'{{ task("task0") }}\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "JinjaEvaluationException: Unable to evaluate expression " + "'{{ task(\"task0\") }}'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': expected_output, 'errors': expected_errors} + expected_result = {"output": expected_output, "errors": expected_errors} self._execute_workflow( wf_name, execute_async=False, expected_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_result=expected_result + expected_result=expected_result, ) diff --git a/st2tests/integration/orquesta/test_wiring_inquiry.py b/st2tests/integration/orquesta/test_wiring_inquiry.py index 71d0ed9e96..688929c041 100644 --- a/st2tests/integration/orquesta/test_wiring_inquiry.py +++ b/st2tests/integration/orquesta/test_wiring_inquiry.py @@ -23,75 +23,88 @@ class InquiryWiringTest(base.TestWorkflowExecution): - def test_basic_inquiry(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-basic') + ex = self._execute_workflow("examples.orquesta-ask-basic") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the inquiry. - ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(ac_exs[0].id, {'approved': True}) + ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_consecutive_inquiries(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-consecutive') + ex = self._execute_workflow("examples.orquesta-ask-consecutive") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the first inquiry. - t1_ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True}) + t1_ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True}) # Wait for the workflow to pause again. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the second inquiry. - t2_ac_exs = self._wait_for_task(ex, 'get_confirmation', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True}) + t2_ac_exs = self._wait_for_task( + ex, "get_confirmation", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_parallel_inquiries(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-parallel') + ex = self._execute_workflow("examples.orquesta-ask-parallel") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the first inquiry. - t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True}) - t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_exs = self._wait_for_task( + ex, "ask_jack", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True}) + t1_ac_exs = self._wait_for_task( + ex, "ask_jack", ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Allow some time for the first inquiry to get processed. eventlet.sleep(1) # Respond to the second inquiry. - t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True}) - t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t2_ac_exs = self._wait_for_task( + ex, "ask_jill", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True}) + t2_ac_exs = self._wait_for_task( + ex, "ask_jill", ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_nested_inquiry(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-nested') + ex = self._execute_workflow("examples.orquesta-ask-nested") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Get the action execution of the subworkflow - ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PAUSED) + ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PAUSED + ) # Respond to the inquiry in the subworkflow. t2_t2_ac_exs = self._wait_for_task( - ac_exs[0], - 'get_approval', - ac_const.LIVEACTION_STATUS_PENDING + ac_exs[0], "get_approval", ac_const.LIVEACTION_STATUS_PENDING ) - self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {'approved': True}) + self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) diff --git a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py index 52eca1490f..9779ee26b8 100644 --- a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py +++ b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py @@ -22,7 +22,9 @@ from st2common.constants import action as ac_const -class PauseResumeWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin): +class PauseResumeWiringTest( + base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin +): temp_file_path_x = None temp_file_path_y = None @@ -47,9 +49,9 @@ def test_pause_and_resume(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause', params) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause", params) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the workflow before the temp file is deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -77,10 +79,10 @@ def test_pause_and_resume_cascade_to_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the workflow before the temp file is deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -113,11 +115,11 @@ def test_pause_and_resume_cascade_to_subworkflows(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the workflow before the temp files are deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -150,8 +152,12 @@ def test_pause_and_resume_cascade_to_subworkflows(self): ex = self.st2client.executions.resume(ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_and_resume_cascade_from_subworkflow(self): @@ -160,10 +166,10 @@ def test_pause_and_resume_cascade_from_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow will still be running. @@ -188,7 +194,9 @@ def test_pause_and_resume_cascade_from_subworkflow(self): tk_ac_ex = self._wait_for_state(tk_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused(self): + def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused( + self, + ): # Temp files are created during test setup. Ensure the temp files exist. path1 = self.temp_file_path_x self.assertTrue(os.path.exists(path1)) @@ -196,11 +204,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -228,17 +236,25 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau # The workflow will now be paused because no other task is running. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Resume the subworkflow. tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running(self): + def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running( + self, + ): # Temp files are created during test setup. Ensure the temp files exist. path1 = self.temp_file_path_x self.assertTrue(os.path.exists(path1)) @@ -246,11 +262,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -276,7 +292,9 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru # The subworkflow will succeed while the other subworkflow is still running. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_RUNNING) # Delete the temporary file for the other subworkflow. @@ -284,8 +302,12 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru self.assertFalse(os.path.exists(path2)) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): @@ -296,11 +318,11 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -336,7 +358,9 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id) # The subworkflow will succeed while the other subworkflow is still paused. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) @@ -344,8 +368,12 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): tk2_ac_ex = self.st2client.executions.resume(tk2_ac_ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): @@ -356,11 +384,11 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -396,6 +424,10 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): ex = self.st2client.executions.resume(ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) diff --git a/st2tests/integration/orquesta/test_wiring_rerun.py b/st2tests/integration/orquesta/test_wiring_rerun.py index b7a6de0efe..2fafee76e4 100644 --- a/st2tests/integration/orquesta/test_wiring_rerun.py +++ b/st2tests/integration/orquesta/test_wiring_rerun.py @@ -43,106 +43,104 @@ def tearDown(self): def test_rerun_workflow(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") ex = self.st2client.executions.re_run(orig_st2_ex_id) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertNotEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertNotEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_task_of_workflow_already_succeeded(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_and_reset_with_items_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task1']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task1"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) - children = self.st2client.executions.get_property(ex.id, 'children') + children = self.st2client.executions.get_property(ex.id, "children") self.assertEqual(len(children), 4) def test_rerun_and_resume_with_items_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") ex = self.st2client.executions.re_run( - orig_st2_ex_id, - tasks=['task1'], - no_reset=['task1'] + orig_st2_ex_id, tasks=["task1"], no_reset=["task1"] ) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) - children = self.st2client.executions.get_property(ex.id, 'children') + children = self.st2client.executions.get_property(ex.id, "children") self.assertEqual(len(children), 2) diff --git a/st2tests/integration/orquesta/test_wiring_task_retry.py b/st2tests/integration/orquesta/test_wiring_task_retry.py index c8d3bd1889..7bb7f3f258 100644 --- a/st2tests/integration/orquesta/test_wiring_task_retry.py +++ b/st2tests/integration/orquesta/test_wiring_task_retry.py @@ -23,9 +23,8 @@ class TaskRetryWiringTest(base.TestWorkflowExecution): - def test_task_retry(self): - wf_name = 'examples.orquesta-task-retry' + wf_name = "examples.orquesta-task-retry" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -34,14 +33,15 @@ def test_task_retry(self): # Assert there are retries for the task. task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "check" ] self.assertGreater(len(task_exs), 1) def test_task_retry_exhausted(self): - wf_name = 'examples.orquesta-task-retry-exhausted' + wf_name = "examples.orquesta-task-retry-exhausted" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -51,16 +51,18 @@ def test_task_retry_exhausted(self): # Assert the task has exhausted the number of retries task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "check" ] - self.assertListEqual(['failed'] * 3, [task_ex.status for task_ex in task_exs]) + self.assertListEqual(["failed"] * 3, [task_ex.status for task_ex in task_exs]) # Assert the task following the retry task is not run. task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'delete' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "delete" ] self.assertEqual(len(task_exs), 0) diff --git a/st2tests/integration/orquesta/test_wiring_with_items.py b/st2tests/integration/orquesta/test_wiring_with_items.py index b80e04e702..0bf83f1bf1 100644 --- a/st2tests/integration/orquesta/test_wiring_with_items.py +++ b/st2tests/integration/orquesta/test_wiring_with_items.py @@ -40,14 +40,14 @@ def tearDown(self): super(WithItemsWiringTest, self).tearDown() def test_with_items(self): - wf_name = 'examples.orquesta-with-items' + wf_name = "examples.orquesta-with-items" - members = ['Lakshmi', 'Lindsay', 'Tomaz', 'Matt', 'Drew'] - wf_input = {'members': members} + members = ["Lakshmi", "Lindsay", "Tomaz", "Matt", "Drew"] + wf_input = {"members": members} - message = '%s, resistance is futile!' - expected_output = {'items': [message % i for i in members]} - expected_result = {'output': expected_output} + message = "%s, resistance is futile!" + expected_output = {"items": [message % i for i in members]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -56,17 +56,17 @@ def test_with_items(self): self.assertDictEqual(ex.result, expected_result) def test_with_items_failure(self): - wf_name = 'examples.orquesta-test-with-items-failure' + wf_name = "examples.orquesta-test-with-items-failure" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - self._wait_for_task(ex, 'task1', num_task_exs=10) + self._wait_for_task(ex, "task1", num_task_exs=10) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) def test_with_items_concurrency(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 5 @@ -74,22 +74,22 @@ def test_with_items_concurrency(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) - self._wait_for_task(ex, 'task1', num_task_exs=2) + self._wait_for_task(ex, "task1", num_task_exs=2) os.remove(self.tempfiles[0]) os.remove(self.tempfiles[1]) - self._wait_for_task(ex, 'task1', num_task_exs=4) + self._wait_for_task(ex, "task1", num_task_exs=4) os.remove(self.tempfiles[2]) os.remove(self.tempfiles[3]) - self._wait_for_task(ex, 'task1', num_task_exs=5) + self._wait_for_task(ex, "task1", num_task_exs=5) os.remove(self.tempfiles[4]) ex = self._wait_for_completion(ex) @@ -97,7 +97,7 @@ def test_with_items_concurrency(self): self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_with_items_cancellation(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 2 @@ -105,19 +105,16 @@ def test_with_items_cancellation(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) # Wait for action executions to run. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_RUNNING, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency ) # Cancel the workflow execution. @@ -133,17 +130,14 @@ def test_with_items_cancellation(self): # Task is completed successfully for graceful exit. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the ex to be canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_with_items_concurrency_cancellation(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 4 @@ -151,19 +145,16 @@ def test_with_items_concurrency_cancellation(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) # Wait for action executions to run. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_RUNNING, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency ) # Cancel the workflow execution. @@ -180,27 +171,24 @@ def test_with_items_concurrency_cancellation(self): # Task is completed successfully for graceful exit. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the ex to be canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_with_items_pause_and_resume(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" num_items = 2 self.tempfiles = [] for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles} + wf_input = {"tempfiles": self.tempfiles} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) @@ -217,10 +205,7 @@ def test_with_items_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=num_items + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items ) # Wait for the workflow execution to pause. @@ -233,7 +218,7 @@ def test_with_items_pause_and_resume(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_with_items_concurrency_pause_and_resume(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 4 @@ -241,10 +226,10 @@ def test_with_items_concurrency_pause_and_resume(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) @@ -261,10 +246,7 @@ def test_with_items_concurrency_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the workflow execution to pause. @@ -280,17 +262,14 @@ def test_with_items_concurrency_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=num_items + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items ) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_subworkflow_empty_with_items(self): - wf_name = 'examples.orquesta-test-subworkflow-empty-with-items' + wf_name = "examples.orquesta-test-subworkflow-empty-with-items" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) diff --git a/st2tests/setup.py b/st2tests/setup.py index 3d5947be04..f5e17bb3a3 100644 --- a/st2tests/setup.py +++ b/st2tests/setup.py @@ -23,10 +23,10 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -ST2_COMPONENT = 'st2tests' +ST2_COMPONENT = "st2tests" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -INIT_FILE = os.path.join(BASE_DIR, 'st2tests/__init__.py') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +INIT_FILE = os.path.join(BASE_DIR, "st2tests/__init__.py") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -39,15 +39,17 @@ setup( name=ST2_COMPONENT, version=get_version_string(INIT_FILE), - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']) + packages=find_packages(exclude=["setuptools", "tests"]), ) diff --git a/st2tests/st2tests/__init__.py b/st2tests/st2tests/__init__.py index 594f0e2ae1..d087d05d2d 100644 --- a/st2tests/st2tests/__init__.py +++ b/st2tests/st2tests/__init__.py @@ -23,11 +23,11 @@ __all__ = [ - 'EventletTestCase', - 'DbTestCase', - 'ExecutionDbTestCase', - 'DbModelTestCase', - 'WorkflowTestCase' + "EventletTestCase", + "DbTestCase", + "ExecutionDbTestCase", + "DbModelTestCase", + "WorkflowTestCase", ] -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2tests/st2tests/action_aliases.py b/st2tests/st2tests/action_aliases.py index 301fd9a20f..88f02f9642 100644 --- a/st2tests/st2tests/action_aliases.py +++ b/st2tests/st2tests/action_aliases.py @@ -25,13 +25,13 @@ from st2common.util.pack import get_pack_ref_from_metadata from st2common.exceptions.content import ParseException from st2common.bootstrap.aliasesregistrar import AliasesRegistrar -from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db +from st2common.models.utils.action_alias_utils import ( + extract_parameters_for_action_alias_db, +) from st2common.models.utils.action_alias_utils import extract_parameters from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseActionAliasTestCase' -] +__all__ = ["BaseActionAliasTestCase"] class BaseActionAliasTestCase(BasePackResourceTestCase): @@ -48,7 +48,9 @@ def setUp(self): if not self.action_alias_name: raise ValueError('"action_alias_name" class attribute needs to be provided') - self.action_alias_db = self._get_action_alias_db_by_name(name=self.action_alias_name) + self.action_alias_db = self._get_action_alias_db_by_name( + name=self.action_alias_name + ) def assertCommandMatchesExactlyOneFormatString(self, format_strings, command): """ @@ -58,19 +60,22 @@ def assertCommandMatchesExactlyOneFormatString(self, format_strings, command): for format_string in format_strings: try: - extract_parameters(format_str=format_string, - param_stream=command) + extract_parameters(format_str=format_string, param_stream=command) except ParseException: continue matched_format_strings.append(format_string) if len(matched_format_strings) == 0: - msg = ('Command "%s" didn\'t match any of the provided format strings' % (command)) + msg = 'Command "%s" didn\'t match any of the provided format strings' % ( + command + ) raise AssertionError(msg) elif len(matched_format_strings) > 1: - msg = ('Command "%s" matched multiple format strings: %s' % - (command, ', '.join(matched_format_strings))) + msg = 'Command "%s" matched multiple format strings: %s' % ( + command, + ", ".join(matched_format_strings), + ) raise AssertionError(msg) def assertExtractedParametersMatch(self, format_string, command, parameters): @@ -83,11 +88,14 @@ def assertExtractedParametersMatch(self, format_string, command, parameters): extracted_params = extract_parameters_for_action_alias_db( action_alias_db=self.action_alias_db, format_str=format_string, - param_stream=command) + param_stream=command, + ) if extracted_params != parameters: - msg = ('Extracted parameters from command string "%s" against format string "%s"' - ' didn\'t match the provided parameters: ' % (command, format_string)) + msg = ( + 'Extracted parameters from command string "%s" against format string "%s"' + " didn't match the provided parameters: " % (command, format_string) + ) # Note: We intercept the exception so we can can include diff for the dictionaries try: @@ -117,13 +125,14 @@ def _get_action_alias_db_by_name(self, name): pack_loader = ContentPackLoader() registrar = AliasesRegistrar(use_pack_cache=False) - aliases_path = pack_loader.get_content_from_pack(pack_dir=base_pack_path, - content_type='aliases') + aliases_path = pack_loader.get_content_from_pack( + pack_dir=base_pack_path, content_type="aliases" + ) aliases = registrar._get_aliases_from_pack(aliases_dir=aliases_path) for alias_path in aliases: - action_alias_db = registrar._get_action_alias_db(pack=pack, - action_alias=alias_path, - ignore_metadata_file_error=True) + action_alias_db = registrar._get_action_alias_db( + pack=pack, action_alias=alias_path, ignore_metadata_file_error=True + ) if action_alias_db.name == name: return action_alias_db diff --git a/st2tests/st2tests/actions.py b/st2tests/st2tests/actions.py index f6026bc8bd..9caec9bca9 100644 --- a/st2tests/st2tests/actions.py +++ b/st2tests/st2tests/actions.py @@ -19,9 +19,7 @@ from st2tests.mocks.action import MockActionService from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseActionTestCase' -] +__all__ = ["BaseActionTestCase"] class BaseActionTestCase(BasePackResourceTestCase): @@ -35,7 +33,7 @@ def setUp(self): super(BaseActionTestCase, self).setUp() class_name = self.action_cls.__name__ - action_wrapper = MockActionWrapper(pack='tests', class_name=class_name) + action_wrapper = MockActionWrapper(pack="tests", class_name=class_name) self.action_service = MockActionService(action_wrapper=action_wrapper) def get_action_instance(self, config=None): @@ -43,7 +41,9 @@ def get_action_instance(self, config=None): Retrieve instance of the action class. """ # pylint: disable=not-callable - instance = get_action_class_instance(action_cls=self.action_cls, - config=config, - action_service=self.action_service) + instance = get_action_class_instance( + action_cls=self.action_cls, + config=config, + action_service=self.action_service, + ) return instance diff --git a/st2tests/st2tests/api.py b/st2tests/st2tests/api.py index 3b48df737a..7000ddd9a1 100644 --- a/st2tests/st2tests/api.py +++ b/st2tests/st2tests/api.py @@ -34,19 +34,19 @@ from st2tests import config as tests_config __all__ = [ - 'BaseFunctionalTest', - - 'FunctionalTest', - 'APIControllerWithIncludeAndExcludeFilterTestCase', - 'BaseInquiryControllerTestCase', - - 'FakeResponse', - 'TestApp' + "BaseFunctionalTest", + "FunctionalTest", + "APIControllerWithIncludeAndExcludeFilterTestCase", + "BaseInquiryControllerTestCase", + "FakeResponse", + "TestApp", ] -SUPER_SECRET_PARAMETER = 'SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS' -ANOTHER_SUPER_SECRET_PARAMETER = 'ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING' +SUPER_SECRET_PARAMETER = ( + "SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS" +) +ANOTHER_SUPER_SECRET_PARAMETER = "ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING" class ResponseValidationError(ValueError): @@ -61,32 +61,37 @@ class TestApp(webtest.TestApp): def do_request(self, req, **kwargs): self.cookiejar.clear() - if req.environ['REQUEST_METHOD'] != 'OPTIONS': + if req.environ["REQUEST_METHOD"] != "OPTIONS": # Making sure endpoint handles OPTIONS method properly - self.options(req.environ['PATH_INFO']) + self.options(req.environ["PATH_INFO"]) res = super(TestApp, self).do_request(req, **kwargs) - if res.headers.get('Warning', None): - raise ResponseValidationError('Endpoint produced invalid response. Make sure the ' - 'response matches OpenAPI scheme for the endpoint.') + if res.headers.get("Warning", None): + raise ResponseValidationError( + "Endpoint produced invalid response. Make sure the " + "response matches OpenAPI scheme for the endpoint." + ) - if not kwargs.get('expect_errors', None): + if not kwargs.get("expect_errors", None): try: body = res.body except AssertionError as e: - if 'Iterator read after closed' in six.text_type(e): - body = b'' + if "Iterator read after closed" in six.text_type(e): + body = b"" else: raise e - if six.b(SUPER_SECRET_PARAMETER) in body or \ - six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body: - raise ResponseLeakError('Endpoint response contains secret parameter. ' - 'Find the leak.') + if ( + six.b(SUPER_SECRET_PARAMETER) in body + or six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body + ): + raise ResponseLeakError( + "Endpoint response contains secret parameter. " "Find the leak." + ) - if 'Access-Control-Allow-Origin' not in res.headers: - raise ResponseValidationError('Response missing a required CORS header') + if "Access-Control-Allow-Origin" not in res.headers: + raise ResponseValidationError("Response missing a required CORS header") return res @@ -113,19 +118,19 @@ def tearDown(self): super(BaseFunctionalTest, self).tearDown() # Reset mock context for API requests - if getattr(self, 'request_context_mock', None): + if getattr(self, "request_context_mock", None): self.request_context_mock.stop() - if hasattr(Router, 'mock_context'): - del(Router.mock_context) + if hasattr(Router, "mock_context"): + del Router.mock_context @classmethod def _do_setUpClass(cls): tests_config.parse_args() - cfg.CONF.set_default('enable', cls.enable_auth, group='auth') + cfg.CONF.set_default("enable", cls.enable_auth, group="auth") - cfg.CONF.set_override(name='enable', override=False, group='rbac') + cfg.CONF.set_override(name="enable", override=False, group="rbac") # TODO(manas) : register action types here for now. RunnerType registration can be moved # to posting to /runnertypes but that implies implementing POST. @@ -142,11 +147,8 @@ def use_user(self, user_db): raise ValueError('"user_db" is mandatory') mock_context = { - 'user': user_db, - 'auth_info': { - 'method': 'authentication token', - 'location': 'header' - } + "user": user_db, + "auth_info": {"method": "authentication token", "location": "header"}, } self.request_context_mock = mock.PropertyMock(return_value=mock_context) Router.mock_context = self.request_context_mock @@ -184,40 +186,48 @@ class APIControllerWithIncludeAndExcludeFilterTestCase(object): # True if those tests are running with rbac enabled rbac_enabled = False - def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive(self): + def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive( + self, + ): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) - url = self.get_all_path + '?include_attributes=id&exclude_attributes=id' + url = self.get_all_path + "?include_attributes=id&exclude_attributes=id" resp = self.app.get(url, expect_errors=True) self.assertEqual(resp.status_int, 400) - expected_msg = ('exclude.*? and include.*? arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + expected_msg = ( + "exclude.*? and include.*? arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) def test_get_all_invalid_exclude_and_include_parameter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) # 1. Invalid exclude_attributes field - url = self.get_all_path + '?exclude_attributes=invalid_field' + url = self.get_all_path + "?exclude_attributes=invalid_field" resp = self.app.get(url, expect_errors=True) - expected_msg = ('Invalid or unsupported exclude attribute specified: .*invalid_field.*') + expected_msg = ( + "Invalid or unsupported exclude attribute specified: .*invalid_field.*" + ) self.assertEqual(resp.status_int, 400) - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) # 2. Invalid include_attributes field - url = self.get_all_path + '?include_attributes=invalid_field' + url = self.get_all_path + "?include_attributes=invalid_field" resp = self.app.get(url, expect_errors=True) - expected_msg = ('Invalid or unsupported include attribute specified: .*invalid_field.*') + expected_msg = ( + "Invalid or unsupported include attribute specified: .*invalid_field.*" + ) self.assertEqual(resp.status_int, 400) - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) def test_get_all_include_attributes_filter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) mandatory_include_fields = self.controller_cls.mandatory_include_fields_response @@ -226,8 +236,10 @@ def test_get_all_include_attributes_filter(self): object_ids = self._insert_mock_models() # Valid include attribute - mandatory field which should always be included - resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path, - mandatory_include_fields[0])) + resp = self.app.get( + "%s?include_attributes=%s" + % (self.get_all_path, mandatory_include_fields[0]) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -245,7 +257,9 @@ def test_get_all_include_attributes_filter(self): include_field = self.include_attribute_field_name assert include_field not in mandatory_include_fields - resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path, include_field)) + resp = self.app.get( + "%s?include_attributes=%s" % (self.get_all_path, include_field) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -263,7 +277,7 @@ def test_get_all_include_attributes_filter(self): def test_get_all_exclude_attributes_filter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) # Create any resources needed by those tests (if not already created inside setUp / # setUpClass) @@ -285,8 +299,9 @@ def test_get_all_exclude_attributes_filter(self): # 2. Verify attribute is excluded when filter is provided exclude_attribute = self.exclude_attribute_field_name - resp = self.app.get('%s?exclude_attributes=%s' % (self.get_all_path, - exclude_attribute)) + resp = self.app.get( + "%s?exclude_attributes=%s" % (self.get_all_path, exclude_attribute) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -300,8 +315,8 @@ def test_get_all_exclude_attributes_filter(self): def assertResponseObjectContainsField(self, resp_item, field): # Handle "." and nested fields - if '.' in field: - split = field.split('.') + if "." in field: + split = field.split(".") for index, field_part in enumerate(split): self.assertIn(field_part, resp_item) @@ -336,7 +351,6 @@ def _do_delete(self, object_id): class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code @@ -354,24 +368,27 @@ class BaseActionExecutionControllerTestCase(object): @staticmethod def _get_actionexecution_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def _get_liveaction_id(resp): - return resp.json['liveaction']['id'] + return resp.json["liveaction"]["id"] def _do_get_one(self, actionexecution_id, *args, **kwargs): - return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs) + return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs) def _do_post(self, liveaction, *args, **kwargs): - return self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + return self.app.post_json("/v1/executions", liveaction, *args, **kwargs) def _do_delete(self, actionexecution_id, expect_errors=False): - return self.app.delete('/v1/executions/%s' % actionexecution_id, - expect_errors=expect_errors) + return self.app.delete( + "/v1/executions/%s" % actionexecution_id, expect_errors=expect_errors + ) def _do_put(self, actionexecution_id, updates, *args, **kwargs): - return self.app.put_json('/v1/executions/%s' % actionexecution_id, updates, *args, **kwargs) + return self.app.put_json( + "/v1/executions/%s" % actionexecution_id, updates, *args, **kwargs + ) class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): @@ -380,6 +397,7 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): Inherits from CleanDbTestCase to preserve atomicity between tests """ + from st2api import app enable_auth = False @@ -387,26 +405,27 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): @staticmethod def _get_inquiry_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_execution(self, actionexecution_id, *args, **kwargs): - return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs) + return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs) def _do_get_one(self, inquiry_id, *args, **kwargs): - return self.app.get('/v1/inquiries/%s' % inquiry_id, *args, **kwargs) + return self.app.get("/v1/inquiries/%s" % inquiry_id, *args, **kwargs) def _do_get_all(self, limit=50, *args, **kwargs): - return self.app.get('/v1/inquiries/?limit=%s' % limit, *args, **kwargs) + return self.app.get("/v1/inquiries/?limit=%s" % limit, *args, **kwargs) def _do_respond(self, inquiry_id, response, *args, **kwargs): - payload = { - "id": inquiry_id, - "response": response - } - return self.app.put_json('/v1/inquiries/%s' % inquiry_id, payload, *args, **kwargs) + payload = {"id": inquiry_id, "response": response} + return self.app.put_json( + "/v1/inquiries/%s" % inquiry_id, payload, *args, **kwargs + ) - def _do_create_inquiry(self, liveaction, result, status='pending', *args, **kwargs): - post_resp = self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + def _do_create_inquiry(self, liveaction, result, status="pending", *args, **kwargs): + post_resp = self.app.post_json("/v1/executions", liveaction, *args, **kwargs) inquiry_id = self._get_inquiry_id(post_resp) - updates = {'status': status, 'result': result} - return self.app.put_json('/v1/executions/%s' % inquiry_id, updates, *args, **kwargs) + updates = {"status": status, "result": result} + return self.app.put_json( + "/v1/executions/%s" % inquiry_id, updates, *args, **kwargs + ) diff --git a/st2tests/st2tests/base.py b/st2tests/st2tests/base.py index 75a8f7ce02..4a4964763d 100644 --- a/st2tests/st2tests/base.py +++ b/st2tests/st2tests/base.py @@ -19,6 +19,7 @@ # NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. # See https://github.com/StackStorm/st2/pull/4834 for details from st2common.util.monkey_patch import monkey_patch + monkey_patch() try: @@ -50,6 +51,7 @@ # parse_args when BaseDbTestCase runs class setup. If that is removed, unit tests # will failed due to conflict with duplicate DB keys. import st2tests.config as tests_config + tests_config.parse_args() from st2common.util.api import get_full_public_api_url @@ -95,26 +97,23 @@ __all__ = [ - 'EventletTestCase', - 'DbTestCase', - 'DbModelTestCase', - 'CleanDbTestCase', - 'CleanFilesTestCase', - 'IntegrationTestCase', - 'RunnerTestCase', - 'ExecutionDbTestCase', - 'WorkflowTestCase', - + "EventletTestCase", + "DbTestCase", + "DbModelTestCase", + "CleanDbTestCase", + "CleanFilesTestCase", + "IntegrationTestCase", + "RunnerTestCase", + "ExecutionDbTestCase", + "WorkflowTestCase", # Pack test classes - 'BaseSensorTestCase', - 'BaseActionTestCase', - 'BaseActionAliasTestCase', - - 'get_fixtures_path', - 'get_resources_path', - - 'blocking_eventlet_spawn', - 'make_mock_stream_readline' + "BaseSensorTestCase", + "BaseActionTestCase", + "BaseActionAliasTestCase", + "get_fixtures_path", + "get_resources_path", + "blocking_eventlet_spawn", + "make_mock_stream_readline", ] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -135,7 +134,7 @@ ALL_MODELS.extend(rule_enforcement_model.MODELS) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -TESTS_CONFIG_PATH = os.path.join(BASE_DIR, '../conf/st2.conf') +TESTS_CONFIG_PATH = os.path.join(BASE_DIR, "../conf/st2.conf") class RunnerTestCase(unittest2.TestCase): @@ -148,17 +147,15 @@ def assertCommonSt2EnvVarsAvailableInEnv(self, env): """ for var_name in COMMON_ACTION_ENV_VARIABLES: self.assertIn(var_name, env) - self.assertEqual(env['ST2_ACTION_API_URL'], get_full_public_api_url()) + self.assertEqual(env["ST2_ACTION_API_URL"], get_full_public_api_url()) self.assertIsNotNone(env[AUTH_TOKEN_ENV_VARIABLE_NAME]) def loader(self, path): - """ Load the runner config - """ + """Load the runner config""" return self.meta_loader.load(path) class BaseTestCase(TestCase): - @classmethod def _register_packs(self): """ @@ -173,7 +170,9 @@ def _register_pack_configs(self, validate_configs=False): """ Register all the packs inside the fixtures directory. """ - registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=validate_configs) + registrar = ConfigsRegistrar( + use_pack_cache=False, validate_configs=validate_configs + ) registrar.register_from_packs(base_dirs=get_packs_base_paths()) @@ -189,18 +188,14 @@ def setUpClass(cls): os=True, select=True, socket=True, - thread=False if '--use-debugger' in sys.argv else True, - time=True + thread=False if "--use-debugger" in sys.argv else True, + time=True, ) @classmethod def tearDownClass(cls): eventlet.monkey_patch( - os=False, - select=False, - socket=False, - thread=False, - time=False + os=False, select=False, socket=False, thread=False, time=False ) @@ -222,17 +217,29 @@ def setUpClass(cls): tests_config.parse_args() if cls.DISPLAY_LOG_MESSAGES: - config_path = os.path.join(BASE_DIR, '../conf/logging.conf') - logging.config.fileConfig(config_path, - disable_existing_loggers=False) + config_path = os.path.join(BASE_DIR, "../conf/logging.conf") + logging.config.fileConfig(config_path, disable_existing_loggers=False) @classmethod def _establish_connection_and_re_create_db(cls): - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) cls.db_connection = db_setup( - cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, ensure_indexes=False) + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ensure_indexes=False, + ) cls._drop_collections() cls.db_connection.drop_database(cfg.CONF.database.db_name) @@ -242,12 +249,17 @@ def _establish_connection_and_re_create_db(cls): # NOTE: This is only needed in distributed scenarios (production deployments) where # multiple services can start up at the same time and race conditions are possible. if cls.ensure_indexes: - if len(cls.ensure_indexes_models) == 0 or len(cls.ensure_indexes_models) > 1: - msg = ('Ensuring indexes for all the models, this could significantly slow down ' - 'the tests') - print('#' * len(msg), file=sys.stderr) + if ( + len(cls.ensure_indexes_models) == 0 + or len(cls.ensure_indexes_models) > 1 + ): + msg = ( + "Ensuring indexes for all the models, this could significantly slow down " + "the tests" + ) + print("#" * len(msg), file=sys.stderr) print(msg, file=sys.stderr) - print('#' * len(msg), file=sys.stderr) + print("#" * len(msg), file=sys.stderr) db_ensure_indexes(cls.ensure_indexes_models) @@ -319,19 +331,19 @@ def run(self, result=None): class ExecutionDbTestCase(DbTestCase): - """" + """ " Base test class for tests which test various execution related code paths. This class offers some utility methods for waiting on execution status, etc. """ ensure_indexes = True - ensure_indexes_models = [ - ActionExecutionSchedulingQueueItemDB - ] + ensure_indexes_models = [ActionExecutionSchedulingQueueItemDB] - def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True): - assert isinstance(status, six.string_types), '%s is not of text type' % (status) + def _wait_on_status( + self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True + ): + assert isinstance(status, six.string_types), "%s is not of text type" % (status) for _ in range(0, retries): eventlet.sleep(delay) @@ -344,8 +356,12 @@ def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_e return liveaction_db - def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True): - assert isinstance(statuses, (list, tuple)), '%s is not of list type' % (statuses) + def _wait_on_statuses( + self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True + ): + assert isinstance(statuses, (list, tuple)), "%s is not of list type" % ( + statuses + ) for _ in range(0, retries): eventlet.sleep(delay) @@ -358,7 +374,9 @@ def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, rai return liveaction_db - def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, raise_exc=True): + def _wait_on_ac_ex_status( + self, execution_db, status, retries=300, delay=0.1, raise_exc=True + ): for _ in range(0, retries): eventlet.sleep(delay) execution_db = ex_db_access.ActionExecution.get_by_id(str(execution_db.id)) @@ -370,7 +388,9 @@ def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, ra return execution_db - def _wait_on_call_count(self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True): + def _wait_on_call_count( + self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True + ): for _ in range(0, retries): eventlet.sleep(delay) if mocked.call_count == expected_count: @@ -395,12 +415,14 @@ def setUpClass(cls): def _assert_fields_equal(self, a, b, exclude=None): exclude = exclude or [] - fields = {k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude} + fields = { + k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude + } assert_funcs = { - 'mongoengine.fields.DictField': self.assertDictEqual, - 'mongoengine.fields.ListField': self.assertListEqual, - 'mongoengine.fields.SortedListField': self.assertListEqual + "mongoengine.fields.DictField": self.assertDictEqual, + "mongoengine.fields.ListField": self.assertListEqual, + "mongoengine.fields.SortedListField": self.assertListEqual, } for k, v in six.iteritems(fields): @@ -410,10 +432,7 @@ def _assert_fields_equal(self, a, b, exclude=None): def _assert_values_equal(self, a, values=None): values = values or {} - assert_funcs = { - 'dict': self.assertDictEqual, - 'list': self.assertListEqual - } + assert_funcs = {"dict": self.assertDictEqual, "list": self.assertListEqual} for k, v in six.iteritems(values): assert_func = assert_funcs.get(type(v).__name__, self.assertEqual) @@ -421,7 +440,7 @@ def _assert_values_equal(self, a, values=None): def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is not already in the database. - self.assertIsNone(getattr(instance, 'id', None)) + self.assertIsNone(getattr(instance, "id", None)) # Assert default values are assigned. self._assert_values_equal(instance, values=defaults) @@ -429,7 +448,7 @@ def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is created in the datbaase. saved = self.access_type.add_or_update(instance) self.assertIsNotNone(saved.id) - self._assert_fields_equal(instance, saved, exclude=['id']) + self._assert_fields_equal(instance, saved, exclude=["id"]) retrieved = self.access_type.get_by_id(saved.id) self._assert_fields_equal(saved, retrieved) @@ -443,22 +462,23 @@ def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is deleted from the database. retrieved = self.access_type.get_by_id(instance.id) retrieved.delete() - self.assertRaises(StackStormDBObjectNotFoundError, - self.access_type.get_by_id, instance.id) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access_type.get_by_id, instance.id + ) def _assert_unique_key_constraint(self, instance): # Assert instance is not already in the database. - self.assertIsNone(getattr(instance, 'id', None)) + self.assertIsNone(getattr(instance, "id", None)) # Assert instance is created in the datbaase. saved = self.access_type.add_or_update(instance) self.assertIsNotNone(saved.id) # Assert exception is thrown if try to create same instance again. - delattr(instance, 'id') - self.assertRaises(StackStormDBObjectConflictError, - self.access_type.add_or_update, - instance) + delattr(instance, "id") + self.assertRaises( + StackStormDBObjectConflictError, self.access_type.add_or_update, instance + ) class CleanDbTestCase(BaseDbTestCase): @@ -486,6 +506,7 @@ class CleanFilesTestCase(TestCase): """ Base test class which deletes specified files and directories on setUp and `tearDown. """ + to_delete_files = [] to_delete_directories = [] @@ -555,8 +576,8 @@ def tearDown(self): stderr = None print('Process "%s"' % (process.pid)) - print('Stdout: %s' % (stdout)) - print('Stderr: %s' % (stderr)) + print("Stdout: %s" % (stdout)) + print("Stderr: %s" % (stderr)) def add_process(self, process): """ @@ -578,7 +599,7 @@ def assertProcessIsRunning(self, process): has succesfuly started and is running. """ if not process: - raise ValueError('process is None') + raise ValueError("process is None") return_code = process.poll() @@ -586,24 +607,27 @@ def assertProcessIsRunning(self, process): if process.stdout: stdout = process.stdout.read() else: - stdout = '' + stdout = "" if process.stderr: stderr = process.stderr.read() else: - stderr = '' + stderr = "" - msg = ('Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s' % - (return_code, stdout, stderr)) + msg = "Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s" % ( + return_code, + stdout, + stderr, + ) self.fail(msg) def assertProcessExited(self, proc): try: status = proc.status() except psutil.NoSuchProcess: - status = 'exited' + status = "exited" - if status not in ['exited', 'zombie']: + if status not in ["exited", "zombie"]: self.fail('Process with pid "%s" is still running' % (proc.pid)) @@ -613,49 +637,49 @@ class WorkflowTestCase(ExecutionDbTestCase): """ def get_wf_fixture_meta_data(self, fixture_pack_path, wf_meta_file_name): - wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name + wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name wf_meta_content = loader.load_meta_file(wf_meta_file_path) - wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name'] + wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"] return { - 'file_name': wf_meta_file_name, - 'file_path': wf_meta_file_path, - 'content': wf_meta_content, - 'name': wf_name + "file_name": wf_meta_file_name, + "file_path": wf_meta_file_path, + "content": wf_meta_content, + "name": wf_name, } def get_wf_def(self, test_pack_path, wf_meta): - rel_wf_def_path = wf_meta['content']['entry_point'] - abs_wf_def_path = os.path.join(test_pack_path, 'actions', rel_wf_def_path) + rel_wf_def_path = wf_meta["content"]["entry_point"] + abs_wf_def_path = os.path.join(test_pack_path, "actions", rel_wf_def_path) - with open(abs_wf_def_path, 'r') as def_file: + with open(abs_wf_def_path, "r") as def_file: return def_file.read() def mock_st2_context(self, ac_ex_db, context=None): st2_ctx = { - 'st2': { - 'api_url': api_util.get_full_public_api_url(), - 'action_execution_id': str(ac_ex_db.id), - 'user': 'stanley', - 'action': ac_ex_db.action['ref'], - 'runner': ac_ex_db.runner['name'] + "st2": { + "api_url": api_util.get_full_public_api_url(), + "action_execution_id": str(ac_ex_db.id), + "user": "stanley", + "action": ac_ex_db.action["ref"], + "runner": ac_ex_db.runner["name"], } } if context: - st2_ctx['parent'] = context + st2_ctx["parent"] = context return st2_ctx def prep_wf_ex(self, wf_ex_db): data = { - 'spec': wf_ex_db.spec, - 'graph': wf_ex_db.graph, - 'input': wf_ex_db.input, - 'context': wf_ex_db.context, - 'state': wf_ex_db.state, - 'output': wf_ex_db.output, - 'errors': wf_ex_db.errors + "spec": wf_ex_db.spec, + "graph": wf_ex_db.graph, + "input": wf_ex_db.input, + "context": wf_ex_db.context, + "state": wf_ex_db.state, + "output": wf_ex_db.output, + "errors": wf_ex_db.errors, } conductor = conducting.WorkflowConductor.deserialize(data) @@ -663,7 +687,7 @@ def prep_wf_ex(self, wf_ex_db): for task in conductor.get_next_tasks(): ac_ex_event = events.ActionExecutionEvent(wf_statuses.RUNNING) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) wf_ex_db.status = conductor.get_workflow_status() wf_ex_db.state = conductor.workflow_state.serialize() @@ -672,7 +696,9 @@ def prep_wf_ex(self, wf_ex_db): return wf_ex_db def get_task_ex(self, task_id, route): - task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route) + task_ex_dbs = wf_db_access.TaskExecution.query( + task_id=task_id, task_route=route + ) self.assertGreater(len(task_ex_dbs), 0) return task_ex_dbs[0] @@ -686,21 +712,29 @@ def get_action_ex(self, task_ex_id): self.assertEqual(len(ac_ex_dbs), 1) return ac_ex_dbs[0] - def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None, - expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED, - expected_tk_ex_db_status=wf_statuses.SUCCEEDED): - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + def run_workflow_step( + self, + wf_ex_db, + task_id, + route, + ctx=None, + expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED, + expected_tk_ex_db_status=wf_statuses.SUCCEEDED, + ): + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) - st2_ctx = {'execution_id': wf_ex_db.action_execution} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_spec = wf_spec.tasks.get_task(task_id) - task_actions = [{'action': task_spec.action, 'input': getattr(task_spec, 'input', {})}] + task_actions = [ + {"action": task_spec.action, "input": getattr(task_spec, "input", {})} + ] task_req = { - 'id': task_id, - 'route': route, - 'spec': task_spec, - 'ctx': ctx or {}, - 'actions': task_actions + "id": task_id, + "route": route, + "spec": task_spec, + "ctx": ctx or {}, + "actions": task_actions, } task_ex_db = wf_svc.request_task_execution(wf_ex_db, st2_ctx, task_req) @@ -712,10 +746,12 @@ def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None, self.assertEqual(task_ex_db.status, expected_tk_ex_db_status) def sort_workflow_errors(self, errors): - return sorted(errors, key=lambda x: x.get('task_id', None)) + return sorted(errors, key=lambda x: x.get("task_id", None)) def assert_task_not_started(self, task_id, route): - task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route) + task_ex_dbs = wf_db_access.TaskExecution.query( + task_id=task_id, task_route=route + ) self.assertEqual(len(task_ex_dbs), 0) def assert_task_running(self, task_id, route): @@ -734,7 +770,6 @@ def assert_workflow_completed(self, wf_ex_id, status=None): class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code @@ -748,11 +783,11 @@ def raise_for_status(self): def get_fixtures_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures') + return os.path.join(os.path.dirname(__file__), "fixtures") def get_resources_path(): - return os.path.join(os.path.dirname(__file__), 'resources') + return os.path.join(os.path.dirname(__file__), "resources") def blocking_eventlet_spawn(func, *args, **kwargs): diff --git a/st2tests/st2tests/config.py b/st2tests/st2tests/config.py index 7fa4ad7b6e..b140357839 100644 --- a/st2tests/st2tests/config.py +++ b/st2tests/st2tests/config.py @@ -77,57 +77,66 @@ def _register_config_opts(): def _override_db_opts(): - CONF.set_override(name='db_name', override='st2-test', group='database') - CONF.set_override(name='host', override='127.0.0.1', group='database') + CONF.set_override(name="db_name", override="st2-test", group="database") + CONF.set_override(name="host", override="127.0.0.1", group="database") def _override_common_opts(): packs_base_path = get_fixtures_packs_base_path() - CONF.set_override(name='base_path', override=packs_base_path, group='system') - CONF.set_override(name='validate_output_schema', override=True, group='system') - CONF.set_override(name='system_packs_base_path', override=packs_base_path, group='content') - CONF.set_override(name='packs_base_paths', override=packs_base_path, group='content') - CONF.set_override(name='api_url', override='http://127.0.0.1', group='auth') - CONF.set_override(name='mask_secrets', override=True, group='log') - CONF.set_override(name='stream_output', override=False, group='actionrunner') + CONF.set_override(name="base_path", override=packs_base_path, group="system") + CONF.set_override(name="validate_output_schema", override=True, group="system") + CONF.set_override( + name="system_packs_base_path", override=packs_base_path, group="content" + ) + CONF.set_override( + name="packs_base_paths", override=packs_base_path, group="content" + ) + CONF.set_override(name="api_url", override="http://127.0.0.1", group="auth") + CONF.set_override(name="mask_secrets", override=True, group="log") + CONF.set_override(name="stream_output", override=False, group="actionrunner") def _override_api_opts(): - CONF.set_override(name='allow_origin', override=['http://127.0.0.1:3000', 'http://dev'], - group='api') + CONF.set_override( + name="allow_origin", + override=["http://127.0.0.1:3000", "http://dev"], + group="api", + ) def _override_keyvalue_opts(): current_file_path = os.path.dirname(__file__) - rel_st2_base_path = os.path.join(current_file_path, '../..') + rel_st2_base_path = os.path.join(current_file_path, "../..") abs_st2_base_path = os.path.abspath(rel_st2_base_path) - rel_enc_key_path = 'st2tests/conf/st2_kvstore_tests.crypto.key.json' + rel_enc_key_path = "st2tests/conf/st2_kvstore_tests.crypto.key.json" ovr_enc_key_path = os.path.join(abs_st2_base_path, rel_enc_key_path) - CONF.set_override(name='encryption_key_path', override=ovr_enc_key_path, group='keyvalue') + CONF.set_override( + name="encryption_key_path", override=ovr_enc_key_path, group="keyvalue" + ) def _override_scheduler_opts(): - CONF.set_override(name='sleep_interval', group='scheduler', override=0.01) + CONF.set_override(name="sleep_interval", group="scheduler", override=0.01) def _override_coordinator_opts(noop=False): - driver = None if noop else 'zake://' - CONF.set_override(name='url', override=driver, group='coordination') - CONF.set_override(name='lock_timeout', override=1, group='coordination') + driver = None if noop else "zake://" + CONF.set_override(name="url", override=driver, group="coordination") + CONF.set_override(name="lock_timeout", override=1, group="coordination") def _override_workflow_engine_opts(): - cfg.CONF.set_override('retry_stop_max_msec', 500, group='workflow_engine') - cfg.CONF.set_override('retry_wait_fixed_msec', 100, group='workflow_engine') - cfg.CONF.set_override('retry_max_jitter_msec', 100, group='workflow_engine') - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("retry_stop_max_msec", 500, group="workflow_engine") + cfg.CONF.set_override("retry_wait_fixed_msec", 100, group="workflow_engine") + cfg.CONF.set_override("retry_max_jitter_msec", 100, group="workflow_engine") + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") def _register_common_opts(): try: common_config.register_opts(ignore_errors=True) except: - LOG.exception('Common config registration failed.') + LOG.exception("Common config registration failed.") def _register_api_opts(): @@ -135,225 +144,292 @@ def _register_api_opts(): # Brittle! pecan_opts = [ cfg.StrOpt( - 'root', default='st2api.controllers.root.RootController', - help='Pecan root controller'), - cfg.StrOpt('template_path', default='%(confdir)s/st2api/st2api/templates'), - cfg.ListOpt('modules', default=['st2api']), - cfg.BoolOpt('debug', default=True), - cfg.BoolOpt('auth_enable', default=True), - cfg.DictOpt('errors', default={404: '/error/404', '__force_dict__': True}) + "root", + default="st2api.controllers.root.RootController", + help="Pecan root controller", + ), + cfg.StrOpt("template_path", default="%(confdir)s/st2api/st2api/templates"), + cfg.ListOpt("modules", default=["st2api"]), + cfg.BoolOpt("debug", default=True), + cfg.BoolOpt("auth_enable", default=True), + cfg.DictOpt("errors", default={404: "/error/404", "__force_dict__": True}), ] - _register_opts(pecan_opts, group='api_pecan') + _register_opts(pecan_opts, group="api_pecan") api_opts = [ - cfg.BoolOpt('debug', default=True), + cfg.BoolOpt("debug", default=True), cfg.IntOpt( - 'max_page_size', default=100, - help='Maximum limit (page size) argument which can be specified by the user in a query ' - 'string. If a larger value is provided, it will default to this value.') + "max_page_size", + default=100, + help="Maximum limit (page size) argument which can be specified by the user in a query " + "string. If a larger value is provided, it will default to this value.", + ), ] - _register_opts(api_opts, group='api') + _register_opts(api_opts, group="api") messaging_opts = [ cfg.StrOpt( - 'url', default='amqp://guest:guest@127.0.0.1:5672//', - help='URL of the messaging server.'), + "url", + default="amqp://guest:guest@127.0.0.1:5672//", + help="URL of the messaging server.", + ), cfg.ListOpt( - 'cluster_urls', default=[], - help='URL of all the nodes in a messaging service cluster.'), + "cluster_urls", + default=[], + help="URL of all the nodes in a messaging service cluster.", + ), cfg.IntOpt( - 'connection_retries', default=10, - help='How many times should we retry connection before failing.'), + "connection_retries", + default=10, + help="How many times should we retry connection before failing.", + ), cfg.IntOpt( - 'connection_retry_wait', default=10000, - help='How long should we wait between connection retries.'), + "connection_retry_wait", + default=10000, + help="How long should we wait between connection retries.", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Use SSL / TLS to connect to the messaging server. Same as ' - 'appending "?ssl=true" at the end of the connection URL string.'), + "ssl", + default=False, + help="Use SSL / TLS to connect to the messaging server. Same as " + 'appending "?ssl=true" at the end of the connection URL string.', + ), cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against RabbitMQ.'), + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against RabbitMQ.", + ), cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the local connection (client).'), + "ssl_certfile", + default=None, + help="Certificate file used to identify the local connection (client).", + ), cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided.'), + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided.", + ), cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from RabbitMQ.'), + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from RabbitMQ.", + ), cfg.StrOpt( - 'login_method', default=None, - help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).') + "login_method", + default=None, + help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).", + ), ] - _register_opts(messaging_opts, group='messaging') + _register_opts(messaging_opts, group="messaging") ssh_runner_opts = [ cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.') + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), ] - _register_opts(ssh_runner_opts, group='ssh_runner') + _register_opts(ssh_runner_opts, group="ssh_runner") def _register_stream_opts(): stream_opts = [ cfg.IntOpt( - 'heartbeat', default=25, - help='Send empty message every N seconds to keep connection open'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "heartbeat", + default=25, + help="Send empty message every N seconds to keep connection open", + ), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), ] - _register_opts(stream_opts, group='stream') + _register_opts(stream_opts, group="stream") def _register_auth_opts(): auth_opts = [ - cfg.StrOpt('host', default='127.0.0.1'), - cfg.IntOpt('port', default=9100), - cfg.BoolOpt('use_ssl', default=False), - cfg.StrOpt('mode', default='proxy'), - cfg.StrOpt('backend', default='flat_file'), - cfg.StrOpt('backend_kwargs', default=None), - cfg.StrOpt('logging', default='conf/logging.conf'), - cfg.IntOpt('token_ttl', default=86400, help='Access token ttl in seconds.'), - cfg.BoolOpt('sso', default=True), - cfg.StrOpt('sso_backend', default='noop'), - cfg.StrOpt('sso_backend_kwargs', default=None), - cfg.BoolOpt('debug', default=True) + cfg.StrOpt("host", default="127.0.0.1"), + cfg.IntOpt("port", default=9100), + cfg.BoolOpt("use_ssl", default=False), + cfg.StrOpt("mode", default="proxy"), + cfg.StrOpt("backend", default="flat_file"), + cfg.StrOpt("backend_kwargs", default=None), + cfg.StrOpt("logging", default="conf/logging.conf"), + cfg.IntOpt("token_ttl", default=86400, help="Access token ttl in seconds."), + cfg.BoolOpt("sso", default=True), + cfg.StrOpt("sso_backend", default="noop"), + cfg.StrOpt("sso_backend_kwargs", default=None), + cfg.BoolOpt("debug", default=True), ] - _register_opts(auth_opts, group='auth') + _register_opts(auth_opts, group="auth") def _register_action_sensor_opts(): action_sensor_opts = [ cfg.BoolOpt( - 'enable', default=True, - help='Whether to enable or disable the ability to post a trigger on action.'), + "enable", + default=True, + help="Whether to enable or disable the ability to post a trigger on action.", + ), cfg.StrOpt( - 'triggers_base_url', default='http://127.0.0.1:9101/v1/triggertypes/', - help='URL for action sensor to post TriggerType.'), + "triggers_base_url", + default="http://127.0.0.1:9101/v1/triggertypes/", + help="URL for action sensor to post TriggerType.", + ), cfg.IntOpt( - 'request_timeout', default=1, - help='Timeout value of all httprequests made by action sensor.'), + "request_timeout", + default=1, + help="Timeout value of all httprequests made by action sensor.", + ), cfg.IntOpt( - 'max_attempts', default=10, - help='No. of times to retry registration.'), + "max_attempts", default=10, help="No. of times to retry registration." + ), cfg.IntOpt( - 'retry_wait', default=1, - help='Amount of time to wait prior to retrying a request.') + "retry_wait", + default=1, + help="Amount of time to wait prior to retrying a request.", + ), ] - _register_opts(action_sensor_opts, group='action_sensor') + _register_opts(action_sensor_opts, group="action_sensor") def _register_ssh_runner_opts(): ssh_runner_opts = [ cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.'), + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.IntOpt( - 'max_parallel_actions', default=50, - help='Max number of parallel remote SSH actions that should be run. ' - 'Works only with Paramiko SSH runner.'), + "max_parallel_actions", + default=50, + help="Max number of parallel remote SSH actions that should be run. " + "Works only with Paramiko SSH runner.", + ), ] - _register_opts(ssh_runner_opts, group='ssh_runner') + _register_opts(ssh_runner_opts, group="ssh_runner") def _register_scheduler_opts(): scheduler_opts = [ cfg.FloatOpt( - 'execution_scheduling_timeout_threshold_min', default=1, - help='How long GC to search back in minutes for orphaned scheduled actions'), + "execution_scheduling_timeout_threshold_min", + default=1, + help="How long GC to search back in minutes for orphaned scheduled actions", + ), cfg.IntOpt( - 'pool_size', default=10, - help='The size of the pool used by the scheduler for scheduling executions.'), + "pool_size", + default=10, + help="The size of the pool used by the scheduler for scheduling executions.", + ), cfg.FloatOpt( - 'sleep_interval', default=0.01, - help='How long to sleep between each action scheduler main loop run interval (in ms).'), + "sleep_interval", + default=0.01, + help="How long to sleep between each action scheduler main loop run interval (in ms).", + ), cfg.FloatOpt( - 'gc_interval', default=5, - help='How often to look for zombie executions before rescheduling them (in ms).'), + "gc_interval", + default=5, + help="How often to look for zombie executions before rescheduling them (in ms).", + ), cfg.IntOpt( - 'retry_max_attempt', default=3, - help='The maximum number of attempts that the scheduler retries on error.'), + "retry_max_attempt", + default=3, + help="The maximum number of attempts that the scheduler retries on error.", + ), cfg.IntOpt( - 'retry_wait_msec', default=100, - help='The number of milliseconds to wait in between retries.') + "retry_wait_msec", + default=100, + help="The number of milliseconds to wait in between retries.", + ), ] - _register_opts(scheduler_opts, group='scheduler') + _register_opts(scheduler_opts, group="scheduler") def _register_exporter_opts(): exporter_opts = [ cfg.StrOpt( - 'dump_dir', default='/opt/stackstorm/exports/', - help='Directory to dump data to.') + "dump_dir", + default="/opt/stackstorm/exports/", + help="Directory to dump data to.", + ) ] - _register_opts(exporter_opts, group='exporter') + _register_opts(exporter_opts, group="exporter") def _register_sensor_container_opts(): partition_opts = [ cfg.StrOpt( - 'sensor_node_name', default='sensornode1', - help='name of the sensor node.'), + "sensor_node_name", default="sensornode1", help="name of the sensor node." + ), cfg.Opt( - 'partition_provider', + "partition_provider", type=types.Dict(value_type=types.String()), - default={'name': DEFAULT_PARTITION_LOADER}, - help='Provider of sensor node partition config.') + default={"name": DEFAULT_PARTITION_LOADER}, + help="Provider of sensor node partition config.", + ), ] - _register_opts(partition_opts, group='sensorcontainer') + _register_opts(partition_opts, group="sensorcontainer") # Other options other_opts = [ cfg.BoolOpt( - 'single_sensor_mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single_sensor_mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ) ] - _register_opts(other_opts, group='sensorcontainer') + _register_opts(other_opts, group="sensorcontainer") # CLI options cli_opts = [ cfg.StrOpt( - 'sensor-ref', - help='Only run sensor with the provided reference. Value is of the form ' - '. (e.g. linux.FileWatchSensor).'), + "sensor-ref", + help="Only run sensor with the provided reference. Value is of the form " + ". (e.g. linux.FileWatchSensor).", + ), cfg.BoolOpt( - 'single-sensor-mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single-sensor-mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ), ] _register_cli_opts(cli_opts) @@ -362,40 +438,52 @@ def _register_sensor_container_opts(): def _register_garbage_collector_opts(): common_opts = [ cfg.IntOpt( - 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL, - help='How often to check database for old data and perform garbage collection.'), + "collection_interval", + default=DEFAULT_COLLECTION_INTERVAL, + help="How often to check database for old data and perform garbage collection.", + ), cfg.FloatOpt( - 'sleep_delay', default=DEFAULT_SLEEP_DELAY, - help='How long to wait / sleep (in seconds) between ' - 'collection of different object types.') + "sleep_delay", + default=DEFAULT_SLEEP_DELAY, + help="How long to wait / sleep (in seconds) between " + "collection of different object types.", + ), ] - _register_opts(common_opts, group='garbagecollector') + _register_opts(common_opts, group="garbagecollector") ttl_opts = [ cfg.IntOpt( - 'action_executions_ttl', default=None, - help='Action executions and related objects (live actions, action output ' - 'objects) older than this value (days) will be automatically deleted.'), + "action_executions_ttl", + default=None, + help="Action executions and related objects (live actions, action output " + "objects) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'action_executions_output_ttl', default=7, - help='Action execution output objects (ones generated by action output ' - 'streaming) older than this value (days) will be automatically deleted.'), + "action_executions_output_ttl", + default=7, + help="Action execution output objects (ones generated by action output " + "streaming) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'trigger_instances_ttl', default=None, - help='Trigger instances older than this value (days) will be automatically deleted.') + "trigger_instances_ttl", + default=None, + help="Trigger instances older than this value (days) will be automatically deleted.", + ), ] - _register_opts(ttl_opts, group='garbagecollector') + _register_opts(ttl_opts, group="garbagecollector") inquiry_opts = [ cfg.BoolOpt( - 'purge_inquiries', default=False, - help='Set to True to perform garbage collection on Inquiries (based on ' - 'the TTL value per Inquiry)') + "purge_inquiries", + default=False, + help="Set to True to perform garbage collection on Inquiries (based on " + "the TTL value per Inquiry)", + ) ] - _register_opts(inquiry_opts, group='garbagecollector') + _register_opts(inquiry_opts, group="garbagecollector") def _register_opts(opts, group=None): diff --git a/st2tests/st2tests/fixtures/history_views/__init__.py b/st2tests/st2tests/fixtures/history_views/__init__.py index dd42395788..24567ead6e 100644 --- a/st2tests/st2tests/fixtures/history_views/__init__.py +++ b/st2tests/st2tests/fixtures/history_views/__init__.py @@ -21,12 +21,12 @@ PATH = os.path.join(os.path.dirname(os.path.realpath(__file__))) -FILES = glob.glob('%s/*.yaml' % PATH) +FILES = glob.glob("%s/*.yaml" % PATH) ARTIFACTS = {} for f in FILES: f_name = os.path.split(f)[1] name = six.text_type(os.path.splitext(f_name)[0]) - with open(f, 'r') as fd: + with open(f, "r") as fd: ARTIFACTS[name] = yaml.safe_load(fd) diff --git a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py index b5184b586c..5b2cc19cc0 100755 --- a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py +++ b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py @@ -32,16 +32,16 @@ def print_random_chars(chars=1000, selection=ascii_letters + string.digits): s = [] for _ in range(chars - 1): s.append(random.choice(selection)) - s.append('@') - print(''.join(s)) + s.append("@") + print("".join(s)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--chars', type=int, metavar='N', default=10) + parser.add_argument("--chars", type=int, metavar="N", default=10) args = parser.parse_args() print_random_chars(args.chars) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py index 57f3f6eea3..40875e1182 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py @@ -17,6 +17,5 @@ class PrintPythonVersionAction(Action): - def run(self, value1): return {"context_value": value1} diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py index ef42e25d15..acd2832627 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py @@ -14,8 +14,8 @@ # limitations under the License. from __future__ import absolute_import -from invalid import Invalid # noqa +from invalid import Invalid # noqa -class Foo(): +class Foo: pass diff --git a/st2tests/st2tests/fixtures/packs/executions/__init__.py b/st2tests/st2tests/fixtures/packs/executions/__init__.py index 3faa0a81ad..ef9bf26a3f 100644 --- a/st2tests/st2tests/fixtures/packs/executions/__init__.py +++ b/st2tests/st2tests/fixtures/packs/executions/__init__.py @@ -22,17 +22,17 @@ PATH = os.path.dirname(os.path.realpath(__file__)) -FILES = glob.glob('%s/*.yaml' % PATH) +FILES = glob.glob("%s/*.yaml" % PATH) ARTIFACTS = {} for f in FILES: f_name = os.path.split(f)[1] name = six.text_type(os.path.splitext(f_name)[0]) - with open(f, 'r') as fd: + with open(f, "r") as fd: ARTIFACTS[name] = yaml.safe_load(fd) if isinstance(ARTIFACTS[name], dict): - ARTIFACTS[name][u'id'] = six.text_type(bson.ObjectId()) + ARTIFACTS[name]["id"] = six.text_type(bson.ObjectId()) elif isinstance(ARTIFACTS[name], list): for item in ARTIFACTS[name]: - item[u'id'] = six.text_type(bson.ObjectId()) + item["id"] = six.text_type(bson.ObjectId()) diff --git a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py index 0409202903..31258fae4e 100644 --- a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py +++ b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py index 435f7eb9b6..c48bb9aa67 100644 --- a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py +++ b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py index fe50e37ae5..5d18a77ccb 100644 --- a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py +++ b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py @@ -15,9 +15,7 @@ from st2actions.runners.pythonrunner import Action -__all__ = [ - 'GetLibraryPathAction' -] +__all__ = ["GetLibraryPathAction"] class GetLibraryPathAction(Action): diff --git a/st2tests/st2tests/fixturesloader.py b/st2tests/st2tests/fixturesloader.py index df9f2cef7f..dd1446153e 100644 --- a/st2tests/st2tests/fixturesloader.py +++ b/st2tests/st2tests/fixturesloader.py @@ -21,16 +21,21 @@ from st2common.content.loader import MetaLoader -from st2common.models.api.action import (ActionAPI, LiveActionAPI, ActionExecutionStateAPI, - RunnerTypeAPI, ActionAliasAPI) +from st2common.models.api.action import ( + ActionAPI, + LiveActionAPI, + ActionExecutionStateAPI, + RunnerTypeAPI, + ActionAliasAPI, +) from st2common.models.api.auth import ApiKeyAPI, UserAPI -from st2common.models.api.execution import (ActionExecutionAPI) -from st2common.models.api.policy import (PolicyTypeAPI, PolicyAPI) -from st2common.models.api.rule import (RuleAPI) +from st2common.models.api.execution import ActionExecutionAPI +from st2common.models.api.policy import PolicyTypeAPI, PolicyAPI +from st2common.models.api.rule import RuleAPI from st2common.models.api.rule_enforcement import RuleEnforcementAPI from st2common.models.api.sensor import SensorTypeAPI from st2common.models.api.trace import TraceAPI -from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI) +from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI from st2common.models.db.action import ActionDB from st2common.models.db.actionalias import ActionAliasDB @@ -38,13 +43,13 @@ from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.executionstate import ActionExecutionStateDB from st2common.models.db.runner import RunnerTypeDB -from st2common.models.db.execution import (ActionExecutionDB) -from st2common.models.db.policy import (PolicyTypeDB, PolicyDB) +from st2common.models.db.execution import ActionExecutionDB +from st2common.models.db.policy import PolicyTypeDB, PolicyDB from st2common.models.db.rule import RuleDB from st2common.models.db.rule_enforcement import RuleEnforcementDB from st2common.models.db.sensor import SensorTypeDB from st2common.models.db.trace import TraceDB -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB from st2common.persistence.action import Action from st2common.persistence.actionalias import ActionAlias from st2common.persistence.execution import ActionExecution @@ -52,107 +57,125 @@ from st2common.persistence.auth import ApiKey, User from st2common.persistence.liveaction import LiveAction from st2common.persistence.runner import RunnerType -from st2common.persistence.policy import (PolicyType, Policy) +from st2common.persistence.policy import PolicyType, Policy from st2common.persistence.rule import Rule from st2common.persistence.rule_enforcement import RuleEnforcement from st2common.persistence.sensor import SensorType from st2common.persistence.trace import Trace -from st2common.persistence.trigger import (Trigger, TriggerType, TriggerInstance) - - -ALLOWED_DB_FIXTURES = ['actions', 'actionstates', 'aliases', 'executions', 'liveactions', - 'policies', 'policytypes', 'rules', 'runners', 'sensors', - 'triggertypes', 'triggers', 'triggerinstances', 'traces', 'apikeys', - 'users', 'enforcements'] +from st2common.persistence.trigger import Trigger, TriggerType, TriggerInstance + + +ALLOWED_DB_FIXTURES = [ + "actions", + "actionstates", + "aliases", + "executions", + "liveactions", + "policies", + "policytypes", + "rules", + "runners", + "sensors", + "triggertypes", + "triggers", + "triggerinstances", + "traces", + "apikeys", + "users", + "enforcements", +] ALLOWED_FIXTURES = copy.copy(ALLOWED_DB_FIXTURES) -ALLOWED_FIXTURES.extend(['actionchains', 'workflows']) +ALLOWED_FIXTURES.extend(["actionchains", "workflows"]) FIXTURE_DB_MODEL = { - 'actions': ActionDB, - 'aliases': ActionAliasDB, - 'actionstates': ActionExecutionStateDB, - 'apikeys': ApiKeyDB, - 'enforcements': RuleEnforcementDB, - 'executions': ActionExecutionDB, - 'liveactions': LiveActionDB, - 'policies': PolicyDB, - 'policytypes': PolicyTypeDB, - 'rules': RuleDB, - 'runners': RunnerTypeDB, - 'sensors': SensorTypeDB, - 'traces': TraceDB, - 'triggertypes': TriggerTypeDB, - 'triggers': TriggerDB, - 'triggerinstances': TriggerInstanceDB, - 'users': UserDB + "actions": ActionDB, + "aliases": ActionAliasDB, + "actionstates": ActionExecutionStateDB, + "apikeys": ApiKeyDB, + "enforcements": RuleEnforcementDB, + "executions": ActionExecutionDB, + "liveactions": LiveActionDB, + "policies": PolicyDB, + "policytypes": PolicyTypeDB, + "rules": RuleDB, + "runners": RunnerTypeDB, + "sensors": SensorTypeDB, + "traces": TraceDB, + "triggertypes": TriggerTypeDB, + "triggers": TriggerDB, + "triggerinstances": TriggerInstanceDB, + "users": UserDB, } FIXTURE_API_MODEL = { - 'actions': ActionAPI, - 'aliases': ActionAliasAPI, - 'actionstates': ActionExecutionStateAPI, - 'apikeys': ApiKeyAPI, - 'enforcements': RuleEnforcementAPI, - 'executions': ActionExecutionAPI, - 'liveactions': LiveActionAPI, - 'policies': PolicyAPI, - 'policytypes': PolicyTypeAPI, - 'rules': RuleAPI, - 'runners': RunnerTypeAPI, - 'sensors': SensorTypeAPI, - 'traces': TraceAPI, - 'triggertypes': TriggerTypeAPI, - 'triggers': TriggerAPI, - 'triggerinstances': TriggerInstanceAPI, - 'users': UserAPI + "actions": ActionAPI, + "aliases": ActionAliasAPI, + "actionstates": ActionExecutionStateAPI, + "apikeys": ApiKeyAPI, + "enforcements": RuleEnforcementAPI, + "executions": ActionExecutionAPI, + "liveactions": LiveActionAPI, + "policies": PolicyAPI, + "policytypes": PolicyTypeAPI, + "rules": RuleAPI, + "runners": RunnerTypeAPI, + "sensors": SensorTypeAPI, + "traces": TraceAPI, + "triggertypes": TriggerTypeAPI, + "triggers": TriggerAPI, + "triggerinstances": TriggerInstanceAPI, + "users": UserAPI, } FIXTURE_PERSISTENCE_MODEL = { - 'actions': Action, - 'aliases': ActionAlias, - 'actionstates': ActionExecutionState, - 'apikeys': ApiKey, - 'enforcements': RuleEnforcement, - 'executions': ActionExecution, - 'liveactions': LiveAction, - 'policies': Policy, - 'policytypes': PolicyType, - 'rules': Rule, - 'runners': RunnerType, - 'sensors': SensorType, - 'traces': Trace, - 'triggertypes': TriggerType, - 'triggers': Trigger, - 'triggerinstances': TriggerInstance, - 'users': User + "actions": Action, + "aliases": ActionAlias, + "actionstates": ActionExecutionState, + "apikeys": ApiKey, + "enforcements": RuleEnforcement, + "executions": ActionExecution, + "liveactions": LiveAction, + "policies": Policy, + "policytypes": PolicyType, + "rules": Rule, + "runners": RunnerType, + "sensors": SensorType, + "traces": Trace, + "triggertypes": TriggerType, + "triggers": Trigger, + "triggerinstances": TriggerInstance, + "users": User, } GIT_SUBMODULES_NOT_CHECKED_OUT_ERROR = """ Git submodule "%s" is not checked out. Make sure to run "git submodule update --init --recursive" in the repository root directory to check out all the submodules. -""".replace('\n', '').strip() +""".replace( + "\n", "" +).strip() def get_fixtures_base_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures') + return os.path.join(os.path.dirname(__file__), "fixtures") def get_fixtures_packs_base_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures/packs') + return os.path.join(os.path.dirname(__file__), "fixtures/packs") def get_resources_base_path(): - return os.path.join(os.path.dirname(__file__), 'resources') + return os.path.join(os.path.dirname(__file__), "resources") class FixturesLoader(object): def __init__(self): self.meta_loader = MetaLoader() - def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, - use_object_ids=False): + def save_fixtures_to_db( + self, fixtures_pack="generic", fixtures_dict=None, use_object_ids=False + ): """ Loads fixtures specified in fixtures_dict into the database and returns DB models for the fixtures. @@ -193,17 +216,22 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, for fixture in fixtures: # Guard against copy and type and similar typos if fixture in loaded_fixtures: - msg = 'Fixture "%s" is specified twice, probably a typo.' % (fixture) + msg = 'Fixture "%s" is specified twice, probably a typo.' % ( + fixture + ) raise ValueError(msg) fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) api_model = API_MODEL(**fixture_dict) db_model = API_MODEL.to_model(api_model) # Make sure we also set and use object id if that functionality is used - if use_object_ids and 'id' in fixture_dict: - db_model.id = fixture_dict['id'] + if use_object_ids and "id" in fixture_dict: + db_model.id = fixture_dict["id"] db_model = PERSISTENCE_MODEL.add_or_update(db_model) loaded_fixtures[fixture] = db_model @@ -212,7 +240,7 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, return db_models - def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None): + def load_fixtures(self, fixtures_pack="generic", fixtures_dict=None): """ Loads fixtures specified in fixtures_dict. We simply want to load the meta into dict objects. @@ -241,13 +269,16 @@ def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None): loaded_fixtures = {} for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) loaded_fixtures[fixture] = fixture_dict all_fixtures[fixture_type] = loaded_fixtures return all_fixtures - def load_models(self, fixtures_pack='generic', fixtures_dict=None): + def load_models(self, fixtures_pack="generic", fixtures_dict=None): """ Loads fixtures specified in fixtures_dict as db models. This method must be used for fixtures that have associated DB models. We simply want to load the @@ -281,7 +312,10 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None): loaded_models = {} for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) api_model = API_MODEL(**fixture_dict) db_model = API_MODEL.to_model(api_model) loaded_models[fixture] = db_model @@ -289,8 +323,9 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None): return all_fixtures - def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None, - raise_on_fail=False): + def delete_fixtures_from_db( + self, fixtures_pack="generic", fixtures_dict=None, raise_on_fail=False + ): """ Deletes fixtures specified in fixtures_dict from the database. @@ -320,7 +355,10 @@ def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None, PERSISTENCE_MODEL = FIXTURE_PERSISTENCE_MODEL.get(fixture_type, None) for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) # Note that when we have a reference mechanism consistent for # every model, we can just do a get and delete the object. Until # then, this model conversions are necessary. @@ -362,28 +400,36 @@ def _validate_fixtures_pack(self, fixtures_pack): fixtures_pack_path = self._get_fixtures_pack_path(fixtures_pack) if not self._is_fixture_pack_exists(fixtures_pack_path): - raise Exception('Fixtures pack not found ' + - 'in fixtures path %s.' % get_fixtures_base_path()) + raise Exception( + "Fixtures pack not found " + + "in fixtures path %s." % get_fixtures_base_path() + ) return fixtures_pack_path def _validate_fixture_dict(self, fixtures_dict, allowed=ALLOWED_FIXTURES): fixture_types = list(fixtures_dict.keys()) for fixture_type in fixture_types: if fixture_type not in allowed: - raise Exception('Disallowed fixture type: %s. Valid fixture types are: %s' % ( - fixture_type, ", ".join(allowed))) + raise Exception( + "Disallowed fixture type: %s. Valid fixture types are: %s" + % (fixture_type, ", ".join(allowed)) + ) def _is_fixture_pack_exists(self, fixtures_pack_path): return os.path.exists(fixtures_pack_path) - def _get_fixture_file_path_abs(self, fixtures_pack_path, fixtures_type, fixture_name): + def _get_fixture_file_path_abs( + self, fixtures_pack_path, fixtures_type, fixture_name + ): return os.path.join(fixtures_pack_path, fixtures_type, fixture_name) def _get_fixtures_pack_path(self, fixtures_pack_name): return os.path.join(get_fixtures_base_path(), fixtures_pack_name) def get_fixture_file_path_abs(self, fixtures_pack, fixtures_type, fixture_name): - return os.path.join(get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name) + return os.path.join( + get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name + ) def assert_submodules_are_checked_out(): @@ -392,9 +438,9 @@ def assert_submodules_are_checked_out(): root of the directory and that the "st2tests/st2tests/fixtures/packs/test" git repo submodule used by the tests is checked out. """ - pack_path = os.path.join(get_fixtures_packs_base_path(), 'test_content_version/') + pack_path = os.path.join(get_fixtures_packs_base_path(), "test_content_version/") pack_path = os.path.abspath(pack_path) - submodule_git_dir_or_file_path = os.path.join(pack_path, '.git') + submodule_git_dir_or_file_path = os.path.join(pack_path, ".git") # NOTE: In newer versions of git, that .git is a file and not a directory if not os.path.exists(submodule_git_dir_or_file_path): diff --git a/st2tests/st2tests/http.py b/st2tests/st2tests/http.py index 4dce56f45a..e14672d001 100644 --- a/st2tests/st2tests/http.py +++ b/st2tests/st2tests/http.py @@ -18,7 +18,6 @@ class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code diff --git a/st2tests/st2tests/mocks/action.py b/st2tests/st2tests/mocks/action.py index f09d5a7f8d..ec8f7842b8 100644 --- a/st2tests/st2tests/mocks/action.py +++ b/st2tests/st2tests/mocks/action.py @@ -25,10 +25,7 @@ from python_runner.python_action_wrapper import ActionService from st2tests.mocks.datastore import MockDatastoreService -__all__ = [ - 'MockActionWrapper', - 'MockActionService' -] +__all__ = ["MockActionWrapper", "MockActionService"] class MockActionWrapper(object): @@ -49,9 +46,11 @@ def __init__(self, action_wrapper): # We use a Mock class so use can assert logger was called with particular arguments self._logger = Mock(spec=RootLogger) - self._datastore_service = MockDatastoreService(logger=self._logger, - pack_name=self._action_wrapper._pack, - class_name=self._action_wrapper._class_name) + self._datastore_service = MockDatastoreService( + logger=self._logger, + pack_name=self._action_wrapper._pack, + class_name=self._action_wrapper._class_name, + ) @property def datastore_service(self): diff --git a/st2tests/st2tests/mocks/auth.py b/st2tests/st2tests/mocks/auth.py index e0624aca42..6f322959cc 100644 --- a/st2tests/st2tests/mocks/auth.py +++ b/st2tests/st2tests/mocks/auth.py @@ -18,24 +18,18 @@ from st2auth.backends.base import BaseAuthenticationBackend # auser:apassword in b64 -DUMMY_CREDS = 'YXVzZXI6YXBhc3N3b3Jk' +DUMMY_CREDS = "YXVzZXI6YXBhc3N3b3Jk" -__all__ = [ - 'DUMMY_CREDS', - - 'MockAuthBackend', - 'MockRequest', - - 'get_mock_backend' -] +__all__ = ["DUMMY_CREDS", "MockAuthBackend", "MockRequest", "get_mock_backend"] class MockAuthBackend(BaseAuthenticationBackend): groups = [] def authenticate(self, username, password): - return ((username == 'auser' and password == 'apassword') or - (username == 'username' and password == 'password:password')) + return (username == "auser" and password == "apassword") or ( + username == "username" and password == "password:password" + ) def get_user(self, username): return username @@ -44,7 +38,7 @@ def get_user_groups(self, username): return self.groups -class MockRequest(): +class MockRequest: def __init__(self, ttl): self.ttl = ttl diff --git a/st2tests/st2tests/mocks/datastore.py b/st2tests/st2tests/mocks/datastore.py index fe8156bf9e..0282a18ffd 100644 --- a/st2tests/st2tests/mocks/datastore.py +++ b/st2tests/st2tests/mocks/datastore.py @@ -22,9 +22,7 @@ from st2common.services.datastore import BaseDatastoreService from st2client.models.keyvalue import KeyValuePair -__all__ = [ - 'MockDatastoreService' -] +__all__ = ["MockDatastoreService"] class MockDatastoreService(BaseDatastoreService): @@ -35,7 +33,7 @@ class MockDatastoreService(BaseDatastoreService): def __init__(self, logger, pack_name, class_name, api_username=None): self._pack_name = pack_name self._class_name = class_name - self._username = api_username or 'admin' + self._username = api_username or "admin" # Holds mock KeyValuePair objects # Key is a KeyValuePair name and value is the KeyValuePair object @@ -53,18 +51,9 @@ def get_user_info(self): :rtype: ``dict`` """ result = { - 'username': self._username, - 'rbac': { - 'is_admin': True, - 'enabled': True, - 'roles': [ - 'admin' - ] - }, - 'authentication': { - 'method': 'authentication token', - 'location': 'header' - } + "username": self._username, + "rbac": {"is_admin": True, "enabled": True, "roles": ["admin"]}, + "authentication": {"method": "authentication token", "location": "header"}, } return result @@ -101,12 +90,16 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): kvp = self._datastore_items[name] return kvp.value - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): """ Store a value in a dictionary which is local to this class. """ if ttl: - raise ValueError('MockDatastoreService.set_value doesn\'t support "ttl" argument') + raise ValueError( + 'MockDatastoreService.set_value doesn\'t support "ttl" argument' + ) name = self._get_full_key_name(name=name, local=local) diff --git a/st2tests/st2tests/mocks/execution.py b/st2tests/st2tests/mocks/execution.py index 1fdf8a4262..00e3c8ef11 100644 --- a/st2tests/st2tests/mocks/execution.py +++ b/st2tests/st2tests/mocks/execution.py @@ -21,13 +21,10 @@ from st2common.models.db.execution import ActionExecutionDB -__all__ = [ - 'MockExecutionPublisher' -] +__all__ = ["MockExecutionPublisher"] class MockExecutionPublisher(object): - @classmethod def publish_update(cls, payload): try: @@ -39,7 +36,6 @@ def publish_update(cls, payload): class MockExecutionPublisherNonBlocking(object): - @classmethod def publish_update(cls, payload): try: diff --git a/st2tests/st2tests/mocks/liveaction.py b/st2tests/st2tests/mocks/liveaction.py index 753224d9ea..2b329e6b25 100644 --- a/st2tests/st2tests/mocks/liveaction.py +++ b/st2tests/st2tests/mocks/liveaction.py @@ -26,14 +26,10 @@ from st2common.constants import action as action_constants from st2common.models.db.liveaction import LiveActionDB -__all__ = [ - 'MockLiveActionPublisher', - 'MockLiveActionPublisherNonBlocking' -] +__all__ = ["MockLiveActionPublisher", "MockLiveActionPublisherNonBlocking"] class MockLiveActionPublisher(object): - @classmethod def process(cls, payload): ex_req = scheduling.get_scheduler_entrypoint().process(payload) @@ -106,7 +102,6 @@ def wait_all(cls): class MockLiveActionPublisherSchedulingQueueOnly(object): - @classmethod def process(cls, payload): scheduling.get_scheduler_entrypoint().process(payload) diff --git a/st2tests/st2tests/mocks/runners/async_runner.py b/st2tests/st2tests/mocks/runners/async_runner.py index 0409202903..31258fae4e 100644 --- a/st2tests/st2tests/mocks/runners/async_runner.py +++ b/st2tests/st2tests/mocks/runners/async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/mocks/runners/polling_async_runner.py b/st2tests/st2tests/mocks/runners/polling_async_runner.py index 435f7eb9b6..c48bb9aa67 100644 --- a/st2tests/st2tests/mocks/runners/polling_async_runner.py +++ b/st2tests/st2tests/mocks/runners/polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/mocks/runners/runner.py b/st2tests/st2tests/mocks/runners/runner.py index 40d07516c6..b89b75b712 100644 --- a/st2tests/st2tests/mocks/runners/runner.py +++ b/st2tests/st2tests/mocks/runners/runner.py @@ -17,12 +17,9 @@ import json from st2common.runners.base import ActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED) +from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED -__all__ = [ - 'get_runner', - 'MockActionRunner' -] +__all__ = ["get_runner", "MockActionRunner"] def get_runner(config=None): @@ -31,7 +28,7 @@ def get_runner(config=None): class MockActionRunner(ActionRunner): def __init__(self): - super(MockActionRunner, self).__init__(runner_id='1') + super(MockActionRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False @@ -45,22 +42,15 @@ def run(self, action_params): self.run_called = True result = {} - if self.runner_parameters.get('raise', False): - raise Exception('Raise required.') + if self.runner_parameters.get("raise", False): + raise Exception("Raise required.") - default_result = { - 'ran': True, - 'action_params': action_params - } - default_context = { - 'third_party_system': { - 'ref_id': '1234' - } - } + default_result = {"ran": True, "action_params": action_params} + default_context = {"third_party_system": {"ref_id": "1234"}} - status = self.runner_parameters.get('mock_status', LIVEACTION_STATUS_SUCCEEDED) - result = self.runner_parameters.get('mock_result', default_result) - context = self.runner_parameters.get('mock_context', default_context) + status = self.runner_parameters.get("mock_status", LIVEACTION_STATUS_SUCCEEDED) + result = self.runner_parameters.get("mock_result", default_result) + context = self.runner_parameters.get("mock_context", default_context) return (status, json.dumps(result), context) diff --git a/st2tests/st2tests/mocks/sensor.py b/st2tests/st2tests/mocks/sensor.py index 1f06787b14..c65825786c 100644 --- a/st2tests/st2tests/mocks/sensor.py +++ b/st2tests/st2tests/mocks/sensor.py @@ -27,10 +27,7 @@ from st2reactor.container.sensor_wrapper import SensorService from st2tests.mocks.datastore import MockDatastoreService -__all__ = [ - 'MockSensorWrapper', - 'MockSensorService' -] +__all__ = ["MockSensorWrapper", "MockSensorService"] class MockSensorWrapper(object): @@ -54,9 +51,11 @@ def __init__(self, sensor_wrapper): # Holds a list of triggers which were dispatched self.dispatched_triggers = [] - self._datastore_service = MockDatastoreService(logger=self._logger, - pack_name=self._sensor_wrapper._pack, - class_name=self._sensor_wrapper._class_name) + self._datastore_service = MockDatastoreService( + logger=self._logger, + pack_name=self._sensor_wrapper._pack, + class_name=self._sensor_wrapper._class_name, + ) @property def datastore_service(self): @@ -74,14 +73,11 @@ def get_logger(self, name): def dispatch(self, trigger, payload=None, trace_tag=None): trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None - return self.dispatch_with_context(trigger=trigger, payload=payload, - trace_context=trace_context) + return self.dispatch_with_context( + trigger=trigger, payload=payload, trace_context=trace_context + ) def dispatch_with_context(self, trigger, payload=None, trace_context=None): - item = { - 'trigger': trigger, - 'payload': payload, - 'trace_context': trace_context - } + item = {"trigger": trigger, "payload": payload, "trace_context": trace_context} self.dispatched_triggers.append(item) return item diff --git a/st2tests/st2tests/mocks/workflow.py b/st2tests/st2tests/mocks/workflow.py index ef50b66389..051bf5cb83 100644 --- a/st2tests/st2tests/mocks/workflow.py +++ b/st2tests/st2tests/mocks/workflow.py @@ -23,13 +23,10 @@ from st2common.models.db import workflow as wf_ex_db -__all__ = [ - 'MockWorkflowExecutionPublisher' -] +__all__ = ["MockWorkflowExecutionPublisher"] class MockWorkflowExecutionPublisher(object): - @classmethod def publish_create(cls, payload): try: diff --git a/st2tests/st2tests/pack_resource.py b/st2tests/st2tests/pack_resource.py index 7d51d74219..51f5899218 100644 --- a/st2tests/st2tests/pack_resource.py +++ b/st2tests/st2tests/pack_resource.py @@ -19,9 +19,7 @@ from unittest2 import TestCase -__all__ = [ - 'BasePackResourceTestCase' -] +__all__ = ["BasePackResourceTestCase"] class BasePackResourceTestCase(TestCase): @@ -39,16 +37,16 @@ def get_fixture_content(self, fixture_path): :type fixture_path: ``str`` """ base_pack_path = self._get_base_pack_path() - fixtures_path = os.path.join(base_pack_path, 'tests/fixtures/') + fixtures_path = os.path.join(base_pack_path, "tests/fixtures/") fixture_path = os.path.join(fixtures_path, fixture_path) - with open(fixture_path, 'r') as fp: + with open(fixture_path, "r") as fp: content = fp.read() return content def _get_base_pack_path(self): test_file_path = inspect.getfile(self.__class__) - base_pack_path = os.path.join(os.path.dirname(test_file_path), '..') + base_pack_path = os.path.join(os.path.dirname(test_file_path), "..") base_pack_path = os.path.abspath(base_pack_path) return base_pack_path diff --git a/st2tests/st2tests/policies/concurrency.py b/st2tests/st2tests/policies/concurrency.py index e7494a134b..ecd6ffb51a 100644 --- a/st2tests/st2tests/policies/concurrency.py +++ b/st2tests/st2tests/policies/concurrency.py @@ -20,11 +20,12 @@ class FakeConcurrencyApplicator(BaseConcurrencyApplicator): - def __init__(self, policy_ref, policy_type, *args, **kwargs): - super(FakeConcurrencyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type, - threshold=kwargs.get('threshold', 0)) + super(FakeConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=kwargs.get("threshold", 0), + ) def get_threshold(self): return self.threshold @@ -35,7 +36,8 @@ def apply_before(self, target): target = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_CANCELED, liveaction_id=target.id, - publish=False) + publish=False, + ) return target diff --git a/st2tests/st2tests/policies/mock_exception.py b/st2tests/st2tests/policies/mock_exception.py index 298a8cb7bb..673eccbb54 100644 --- a/st2tests/st2tests/policies/mock_exception.py +++ b/st2tests/st2tests/policies/mock_exception.py @@ -18,9 +18,8 @@ class RaiseExceptionApplicator(base.ResourcePolicyApplicator): - def apply_before(self, target): - raise Exception('For honor!!!!') + raise Exception("For honor!!!!") def apply_after(self, target): return target diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py index 4994db82b3..6a124573e9 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py @@ -18,4 +18,4 @@ class Echoer(Action): def run(self, action_input): - return {'action_input': action_input} + return {"action_input": action_input} diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py index 926a56c73f..2597f811ad 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py @@ -18,14 +18,10 @@ class Test(object): - foo = 'bar' + foo = "bar" class NonSimpleTypeAction(Action): def run(self): - result = [ - {'a': '1'}, - {'c': 2, 'h': 3}, - {'e': Test()} - ] + result = [{"a": "1"}, {"c": 2, "h": 3}, {"e": Test()}] return result diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py index 3034e0352a..cacb89d005 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py @@ -30,35 +30,39 @@ def run(self, **kwargs): except Exception: pass - self.logger.info('test info log message') - self.logger.debug('test debug log message') - self.logger.error('test error log message') + self.logger.info("test info log message") + self.logger.debug("test debug log message") + self.logger.error("test error log message") return PascalRowAction._compute_pascal_row(**kwargs) @staticmethod def _compute_pascal_row(row_index=0): - print('Pascal row action') + print("Pascal row action") - if row_index == 'a': - return False, 'This is suppose to fail don\'t worry!!' - elif row_index == 'b': + if row_index == "a": + return False, "This is suppose to fail don't worry!!" + elif row_index == "b": return None - elif row_index == 'complex_type': + elif row_index == "complex_type": result = PascalRowAction() return (False, result) - elif row_index == 'c': + elif row_index == "c": return False, None - elif row_index == 'd': - return 'succeeded', [1, 2, 3, 4] - elif row_index == 'e': + elif row_index == "d": + return "succeeded", [1, 2, 3, 4] + elif row_index == "e": return [1, 2] elif row_index == 5: - return [math.factorial(row_index) / - (math.factorial(i) * math.factorial(row_index - i)) - for i in range(row_index + 1)] - elif row_index == 'f': - raise ValueError('Duplicate traceback test') + return [ + math.factorial(row_index) + / (math.factorial(i) * math.factorial(row_index - i)) + for i in range(row_index + 1) + ] + elif row_index == "f": + raise ValueError("Duplicate traceback test") else: - return True, [math.factorial(row_index) / - (math.factorial(i) * math.factorial(row_index - i)) - for i in range(row_index + 1)] + return True, [ + math.factorial(row_index) + / (math.factorial(i) * math.factorial(row_index - i)) + for i in range(row_index + 1) + ] diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py index f1f888f069..0bf6856145 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py @@ -22,5 +22,5 @@ class PrintConfigItemAction(Action): def run(self): print(self.config) # Verify .get() still works - print(self.config.get('item1', 'default_value')) - print(self.config['key']) + print(self.config.get("item1", "default_value")) + print(self.config["key"]) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py index 06c0d2f30a..9838e5bfb6 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py @@ -24,7 +24,7 @@ class PrintToStdoutAndStderrAction(Action): def run(self, stdout_count=3, stderr_count=3): for index in range(0, stdout_count): - sys.stdout.write('stdout line %s\n' % (index)) + sys.stdout.write("stdout line %s\n" % (index)) for index in range(0, stderr_count): - sys.stderr.write('stderr line %s\n' % (index)) + sys.stderr.write("stderr line %s\n" % (index)) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py index 717549347b..ffe7b69b3b 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py @@ -22,5 +22,5 @@ class PythonPathsAction(Action): def run(self): - print('sys.path: %s' % (sys.path)) - print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH'))) + print("sys.path: %s" % (sys.path)) + print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH"))) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py index d95939990f..eeed54fbb0 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py @@ -22,4 +22,4 @@ class TestAction(Action): def run(self): - return 'test action' + return "test action" diff --git a/st2tests/st2tests/sensors.py b/st2tests/st2tests/sensors.py index 0b6f31e6b6..52c0451f45 100644 --- a/st2tests/st2tests/sensors.py +++ b/st2tests/st2tests/sensors.py @@ -18,9 +18,7 @@ from st2tests.mocks.sensor import MockSensorService from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseSensorTestCase' -] +__all__ = ["BaseSensorTestCase"] class BaseSensorTestCase(BasePackResourceTestCase): @@ -37,22 +35,20 @@ def setUp(self): super(BaseSensorTestCase, self).setUp() class_name = self.sensor_cls.__name__ - sensor_wrapper = MockSensorWrapper(pack='tests', class_name=class_name) + sensor_wrapper = MockSensorWrapper(pack="tests", class_name=class_name) self.sensor_service = MockSensorService(sensor_wrapper=sensor_wrapper) def get_sensor_instance(self, config=None, poll_interval=None): """ Retrieve instance of the sensor class. """ - kwargs = { - 'sensor_service': self.sensor_service - } + kwargs = {"sensor_service": self.sensor_service} if config: - kwargs['config'] = config + kwargs["config"] = config if poll_interval is not None: - kwargs['poll_interval'] = poll_interval + kwargs["poll_interval"] = poll_interval instance = self.sensor_cls(**kwargs) # pylint: disable=not-callable return instance @@ -79,15 +75,15 @@ def assertTriggerDispatched(self, trigger, payload=None, trace_context=None): """ dispatched_triggers = self.get_dispatched_triggers() for item in dispatched_triggers: - trigger_matches = (item['trigger'] == trigger) + trigger_matches = item["trigger"] == trigger if payload: - payload_matches = (item['payload'] == payload) + payload_matches = item["payload"] == payload else: payload_matches = True if trace_context: - trace_context_matches = (item['trace_context'] == trace_context) + trace_context_matches = item["trace_context"] == trace_context else: trace_context_matches = True diff --git a/st2tests/testpacks/checks/actions/checks/check_loadavg.py b/st2tests/testpacks/checks/actions/checks/check_loadavg.py index 4a56834832..9439679df3 100755 --- a/st2tests/testpacks/checks/actions/checks/check_loadavg.py +++ b/st2tests/testpacks/checks/actions/checks/check_loadavg.py @@ -23,40 +23,40 @@ def print_load_avg(args): period = args[1] - loadavg_file = '/proc/loadavg' - cpuinfo_file = '/proc/cpuinfo' + loadavg_file = "/proc/loadavg" + cpuinfo_file = "/proc/cpuinfo" cpus = 0 try: - fh = open(loadavg_file, 'r') + fh = open(loadavg_file, "r") load = fh.readline().split()[0:3] fh.close() except: - sys.stderr.write('Error opening %s\n' % loadavg_file) + sys.stderr.write("Error opening %s\n" % loadavg_file) sys.exit(2) try: - fh = open(cpuinfo_file, 'r') + fh = open(cpuinfo_file, "r") for line in fh: - if 'processor' in line: + if "processor" in line: cpus += 1 fh.close() except: - sys.stderr.write('Error opeing %s\n' % cpuinfo_file) + sys.stderr.write("Error opeing %s\n" % cpuinfo_file) - one_min = '1 min load/core: %s' % str(float(load[0]) / cpus) - five_min = '5 min load/core: %s' % str(float(load[1]) / cpus) - fifteen_min = '15 min load/core: %s' % str(float(load[2]) / cpus) + one_min = "1 min load/core: %s" % str(float(load[0]) / cpus) + five_min = "5 min load/core: %s" % str(float(load[1]) / cpus) + fifteen_min = "15 min load/core: %s" % str(float(load[2]) / cpus) - if period == '1' or period == 'one': + if period == "1" or period == "one": print(one_min) - elif period == '5' or period == 'five': + elif period == "5" or period == "five": print(five_min) - elif period == '15' or period == 'fifteen': + elif period == "15" or period == "fifteen": print(fifteen_min) else: print(one_min + " " + five_min + " " + fifteen_min) -if __name__ == '__main__': +if __name__ == "__main__": print_load_avg(sys.argv) diff --git a/tools/config_gen.py b/tools/config_gen.py index e705161ea3..e0004d04e1 100755 --- a/tools/config_gen.py +++ b/tools/config_gen.py @@ -24,57 +24,57 @@ from oslo_config import cfg -CONFIGS = ['st2actions.config', - 'st2actions.scheduler.config', - 'st2actions.notifier.config', - 'st2actions.workflows.config', - 'st2api.config', - 'st2stream.config', - 'st2auth.config', - 'st2common.config', - 'st2exporter.config', - 'st2reactor.rules.config', - 'st2reactor.sensor.config', - 'st2reactor.timer.config', - 'st2reactor.garbage_collector.config'] - -SKIP_GROUPS = ['api_pecan', 'rbac', 'results_tracker'] +CONFIGS = [ + "st2actions.config", + "st2actions.scheduler.config", + "st2actions.notifier.config", + "st2actions.workflows.config", + "st2api.config", + "st2stream.config", + "st2auth.config", + "st2common.config", + "st2exporter.config", + "st2reactor.rules.config", + "st2reactor.sensor.config", + "st2reactor.timer.config", + "st2reactor.garbage_collector.config", +] + +SKIP_GROUPS = ["api_pecan", "rbac", "results_tracker"] # We group auth options together to make it a bit more clear what applies where AUTH_OPTIONS = { - 'common': [ - 'enable', - 'mode', - 'logging', - 'api_url', - 'token_ttl', - 'service_token_ttl', - 'sso', - 'sso_backend', - 'sso_backend_kwargs', - 'debug' + "common": [ + "enable", + "mode", + "logging", + "api_url", + "token_ttl", + "service_token_ttl", + "sso", + "sso_backend", + "sso_backend_kwargs", + "debug", + ], + "standalone": [ + "host", + "port", + "use_ssl", + "cert", + "key", + "backend", + "backend_kwargs", ], - 'standalone': [ - 'host', - 'port', - 'use_ssl', - 'cert', - 'key', - 'backend', - 'backend_kwargs' - ] } # Some of the config values change depending on the environment where this script is ran so we # set them to static values to ensure consistent and stable output STATIC_OPTION_VALUES = { - 'actionrunner': { - 'virtualenv_binary': '/usr/bin/virtualenv', - 'python_binary': '/usr/bin/python', + "actionrunner": { + "virtualenv_binary": "/usr/bin/virtualenv", + "python_binary": "/usr/bin/python", }, - 'webui': { - 'webui_base_url': 'https://localhost' - } + "webui": {"webui_base_url": "https://localhost"}, } COMMON_AUTH_OPTIONS_COMMENT = """ @@ -112,22 +112,28 @@ def _clear_config(): def _read_group(opt_group): all_options = list(opt_group._opts.values()) - if opt_group.name == 'auth': + if opt_group.name == "auth": print(COMMON_AUTH_OPTIONS_COMMENT) - print('') - common_options = [option for option in all_options if option['opt'].name in - AUTH_OPTIONS['common']] + print("") + common_options = [ + option + for option in all_options + if option["opt"].name in AUTH_OPTIONS["common"] + ] _print_options(opt_group=opt_group, options=common_options) - print('') + print("") print(STANDALONE_AUTH_OPTIONS_COMMENT) - print('') - standalone_options = [option for option in all_options if option['opt'].name in - AUTH_OPTIONS['standalone']] + print("") + standalone_options = [ + option + for option in all_options + if option["opt"].name in AUTH_OPTIONS["standalone"] + ] _print_options(opt_group=opt_group, options=standalone_options) if len(common_options) + len(standalone_options) != len(all_options): - msg = ('Not all options are declared in AUTH_OPTIONS dict, please update it') + msg = "Not all options are declared in AUTH_OPTIONS dict, please update it" raise Exception(msg) else: options = all_options @@ -137,33 +143,35 @@ def _read_group(opt_group): def _read_groups(opt_groups): opt_groups = collections.OrderedDict(sorted(opt_groups.items())) for name, opt_group in six.iteritems(opt_groups): - print('[%s]' % name) + print("[%s]" % name) _read_group(opt_group) - print('') + print("") def _print_options(opt_group, options): - for opt in sorted(options, key=lambda x: x['opt'].name): - opt = opt['opt'] + for opt in sorted(options, key=lambda x: x["opt"].name): + opt = opt["opt"] # Special case for options which could change during this script run - static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get(opt.name, None) + static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get( + opt.name, None + ) if static_option_value: opt.default = static_option_value # Special handling for list options if isinstance(opt, cfg.ListOpt): if opt.default: - value = ','.join(opt.default) + value = ",".join(opt.default) else: - value = '' + value = "" - value += ' # comma separated list allowed here.' + value += " # comma separated list allowed here." else: value = opt.default - print('# %s' % opt.help) - print('%s = %s' % (opt.name, value)) + print("# %s" % opt.help) + print("%s = %s" % (opt.name, value)) def main(args): @@ -176,5 +184,5 @@ def main(args): _read_groups(opt_groups) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv) diff --git a/tools/diff-db-disk.py b/tools/diff-db-disk.py index ec09a76709..a9e65d72ea 100755 --- a/tools/diff-db-disk.py +++ b/tools/diff-db-disk.py @@ -47,20 +47,20 @@ from st2common.persistence.action import Action registrar = ResourceRegistrar() -registrar.ALLOWED_EXTENSIONS = ['.yaml', '.yml', '.json'] +registrar.ALLOWED_EXTENSIONS = [".yaml", ".yml", ".json"] meta_loader = MetaLoader() API_MODELS_ARTIFACT_TYPES = { - 'actions': ActionAPI, - 'sensors': SensorTypeAPI, - 'rules': RuleAPI + "actions": ActionAPI, + "sensors": SensorTypeAPI, + "rules": RuleAPI, } API_MODELS_PERSISTENT_MODELS = { Action: ActionAPI, SensorType: SensorTypeAPI, - Rule: RuleAPI + Rule: RuleAPI, } @@ -77,13 +77,15 @@ def _get_api_models_from_db(persistence_model, pack_dir=None): filters = {} if pack_dir: pack_name = os.path.basename(os.path.normpath(pack_dir)) - filters = {'pack': pack_name} + filters = {"pack": pack_name} models = persistence_model.query(**filters) models_dict = {} for model in models: - model_pack = getattr(model, 'pack', None) or DEFAULT_PACK_NAME - model_ref = ResourceReference.to_string_reference(name=model.name, pack=model_pack) - if getattr(model, 'id', None): + model_pack = getattr(model, "pack", None) or DEFAULT_PACK_NAME + model_ref = ResourceReference.to_string_reference( + name=model.name, pack=model_pack + ) + if getattr(model, "id", None): del model.id API_MODEL = API_MODELS_PERSISTENT_MODELS[persistence_model] models_dict[model_ref] = API_MODEL.from_model(model) @@ -107,15 +109,14 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None): artifacts_paths = registrar.get_resources_from_pack(pack_path) for artifact_path in artifacts_paths: artifact = meta_loader.load(artifact_path) - if artifact_type == 'sensors': + if artifact_type == "sensors": sensors_dir = os.path.dirname(artifact_path) - sensor_file_path = os.path.join(sensors_dir, artifact['entry_point']) - artifact['artifact_uri'] = 'file://' + sensor_file_path - name = artifact.get('name', None) or artifact.get('class_name', None) - if not artifact.get('pack', None): - artifact['pack'] = pack_name - ref = ResourceReference.to_string_reference(name=name, - pack=pack_name) + sensor_file_path = os.path.join(sensors_dir, artifact["entry_point"]) + artifact["artifact_uri"] = "file://" + sensor_file_path + name = artifact.get("name", None) or artifact.get("class_name", None) + if not artifact.get("pack", None): + artifact["pack"] = pack_name + ref = ResourceReference.to_string_reference(name=name, pack=pack_name) API_MODEL = API_MODELS_ARTIFACT_TYPES[artifact_type] # Following conversions are required because we add some fields with # default values in db model. If we don't do these conversions, @@ -128,42 +129,49 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None): return artifacts_dict -def _content_diff(artifact_type=None, artifact_in_disk=None, artifact_in_db=None, - verbose=False): +def _content_diff( + artifact_type=None, artifact_in_disk=None, artifact_in_db=None, verbose=False +): artifact_in_disk_str = json.dumps( - artifact_in_disk.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_disk.__json__(), sort_keys=True, indent=4, separators=(",", ": ") ) artifact_in_db_str = json.dumps( - artifact_in_db.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_db.__json__(), sort_keys=True, indent=4, separators=(",", ": ") + ) + diffs = difflib.context_diff( + artifact_in_db_str.splitlines(), + artifact_in_disk_str.splitlines(), + fromfile="DB contents", + tofile="Disk contents", ) - diffs = difflib.context_diff(artifact_in_db_str.splitlines(), - artifact_in_disk_str.splitlines(), - fromfile='DB contents', tofile='Disk contents') printed = False for diff in diffs: if not printed: - identifier = getattr(artifact_in_db, 'ref', getattr(artifact_in_db, 'name')) - print('%s %s in db differs from what is in disk.' % (artifact_type.upper(), - identifier)) + identifier = getattr(artifact_in_db, "ref", getattr(artifact_in_db, "name")) + print( + "%s %s in db differs from what is in disk." + % (artifact_type.upper(), identifier) + ) printed = True print(diff) if verbose: - print('\n\nOriginal contents:') - print('===================\n') - print('Artifact in db:\n\n%s\n\n' % artifact_in_db_str) - print('Artifact in disk:\n\n%s\n\n' % artifact_in_disk_str) + print("\n\nOriginal contents:") + print("===================\n") + print("Artifact in db:\n\n%s\n\n" % artifact_in_db_str) + print("Artifact in disk:\n\n%s\n\n" % artifact_in_disk_str) -def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True, - content_diff=True): +def _diff( + persistence_model, artifact_type, pack_dir=None, verbose=True, content_diff=True +): artifacts_in_db_dict = _get_api_models_from_db(persistence_model, pack_dir=pack_dir) artifacts_in_disk_dict = _get_api_models_from_disk(artifact_type, pack_dir=pack_dir) # print(artifacts_in_disk_dict) - all_artifacts = set(list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys())) + all_artifacts = set( + list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys()) + ) for artifact in all_artifacts: artifact_in_db = artifacts_in_db_dict.get(artifact, None) @@ -172,76 +180,96 @@ def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True, artifact_in_db_pretty_json = None if verbose: - print('******************************************************************************') - print('Checking if artifact %s is present in both disk and db.' % artifact) + print( + "******************************************************************************" + ) + print("Checking if artifact %s is present in both disk and db." % artifact) if not artifact_in_db: - print('##############################################################################') - print('%s %s in disk not available in db.' % (artifact_type.upper(), artifact)) + print( + "##############################################################################" + ) + print( + "%s %s in disk not available in db." % (artifact_type.upper(), artifact) + ) artifact_in_disk_pretty_json = json.dumps( - artifact_in_disk.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_disk.__json__(), + sort_keys=True, + indent=4, + separators=(",", ": "), ) if verbose: - print('File contents: \n') + print("File contents: \n") print(artifact_in_disk_pretty_json) continue if not artifact_in_disk: - print('##############################################################################') - print('%s %s in db not available in disk.' % (artifact_type.upper(), artifact)) + print( + "##############################################################################" + ) + print( + "%s %s in db not available in disk." % (artifact_type.upper(), artifact) + ) artifact_in_db_pretty_json = json.dumps( - artifact_in_db.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_db.__json__(), + sort_keys=True, + indent=4, + separators=(",", ": "), ) if verbose: - print('DB contents: \n') + print("DB contents: \n") print(artifact_in_db_pretty_json) continue if verbose: - print('Artifact %s exists in both disk and db.' % artifact) + print("Artifact %s exists in both disk and db." % artifact) if content_diff: if verbose: - print('Performing content diff for artifact %s.' % artifact) + print("Performing content diff for artifact %s." % artifact) - _content_diff(artifact_type=artifact_type, - artifact_in_disk=artifact_in_disk, - artifact_in_db=artifact_in_db, - verbose=verbose) + _content_diff( + artifact_type=artifact_type, + artifact_in_disk=artifact_in_disk, + artifact_in_db=artifact_in_db, + verbose=verbose, + ) def _diff_actions(pack_dir=None, verbose=False, content_diff=True): - _diff(Action, 'actions', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff( + Action, "actions", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff + ) def _diff_sensors(pack_dir=None, verbose=False, content_diff=True): - _diff(SensorType, 'sensors', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff( + SensorType, + "sensors", + pack_dir=pack_dir, + verbose=verbose, + content_diff=content_diff, + ) def _diff_rules(pack_dir=None, verbose=True, content_diff=True): - _diff(Rule, 'rules', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff(Rule, "rules", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff) def main(): monkey_patch() cli_opts = [ - cfg.BoolOpt('sensors', default=False, - help='diff sensor alone.'), - cfg.BoolOpt('actions', default=False, - help='diff actions alone.'), - cfg.BoolOpt('rules', default=False, - help='diff rules alone.'), - cfg.BoolOpt('all', default=False, - help='diff sensors, actions and rules.'), - cfg.BoolOpt('verbose', default=False), - cfg.BoolOpt('simple', default=False, - help='In simple mode, tool only tells you if content is missing.' + - 'It doesn\'t show you content diff between disk and db.'), - cfg.StrOpt('pack-dir', default=None, help='Path to specific pack to diff.') + cfg.BoolOpt("sensors", default=False, help="diff sensor alone."), + cfg.BoolOpt("actions", default=False, help="diff actions alone."), + cfg.BoolOpt("rules", default=False, help="diff rules alone."), + cfg.BoolOpt("all", default=False, help="diff sensors, actions and rules."), + cfg.BoolOpt("verbose", default=False), + cfg.BoolOpt( + "simple", + default=False, + help="In simple mode, tool only tells you if content is missing." + + "It doesn't show you content diff between disk and db.", + ), + cfg.StrOpt("pack-dir", default=None, help="Path to specific pack to diff."), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -254,23 +282,35 @@ def main(): content_diff = not cfg.CONF.simple if cfg.CONF.all: - _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) - _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) - _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_sensors( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) + _diff_actions( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) + _diff_rules( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) return if cfg.CONF.sensors: - _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_sensors( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) if cfg.CONF.actions: - _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_actions( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) if cfg.CONF.rules: - _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_rules( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) # Disconnect from db. db_teardown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/direct_queue_publisher.py b/tools/direct_queue_publisher.py index bc01242085..0da7dd0b08 100755 --- a/tools/direct_queue_publisher.py +++ b/tools/direct_queue_publisher.py @@ -22,26 +22,27 @@ def main(queue, payload): - connection = pika.BlockingConnection(pika.ConnectionParameters( - host='localhost', - credentials=pika.credentials.PlainCredentials(username='guest', password='guest'))) + connection = pika.BlockingConnection( + pika.ConnectionParameters( + host="localhost", + credentials=pika.credentials.PlainCredentials( + username="guest", password="guest" + ), + ) + ) channel = connection.channel() channel.queue_declare(queue=queue, durable=True) - channel.basic_publish(exchange='', - routing_key=queue, - body=payload) + channel.basic_publish(exchange="", routing_key=queue, body=payload) print("Sent %s" % payload) connection.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Direct queue publisher') - parser.add_argument('--queue', required=True, - help='Routing key to use') - parser.add_argument('--payload', required=True, - help='Message payload') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Direct queue publisher") + parser.add_argument("--queue", required=True, help="Routing key to use") + parser.add_argument("--payload", required=True, help="Message payload") args = parser.parse_args() main(queue=args.queue, payload=args.payload) diff --git a/tools/enumerate-runners.py b/tools/enumerate-runners.py index 9610407411..9cae10cd18 100755 --- a/tools/enumerate-runners.py +++ b/tools/enumerate-runners.py @@ -20,15 +20,18 @@ from st2common.runners import get_backend_driver from st2common import config + config.parse_args() runner_names = get_available_backends() -print('Available / installed action runners:') +print("Available / installed action runners:") for name in runner_names: runner_driver = get_backend_driver(name) runner_instance = runner_driver.get_runner() runner_metadata = runner_driver.get_metadata() - print('- %s (runner_module=%s,cls=%s)' % (name, runner_metadata['runner_module'], - runner_instance.__class__)) + print( + "- %s (runner_module=%s,cls=%s)" + % (name, runner_metadata["runner_module"], runner_instance.__class__) + ) diff --git a/tools/json2yaml.py b/tools/json2yaml.py index 29959949e8..5aecb3711e 100755 --- a/tools/json2yaml.py +++ b/tools/json2yaml.py @@ -21,6 +21,7 @@ from __future__ import absolute_import import argparse import fnmatch + try: import simplejson as json except ImportError: @@ -33,7 +34,7 @@ PRINT = pprint.pprint -YAML_HEADER = '---' +YAML_HEADER = "---" def get_files_matching_pattern(dir_, pattern): @@ -47,47 +48,47 @@ def get_files_matching_pattern(dir_, pattern): def json_2_yaml_convert(filename): data = None try: - with open(filename, 'r') as json_file: + with open(filename, "r") as json_file: data = json.load(json_file) except: - PRINT('Failed on {}'.format(filename)) + PRINT("Failed on {}".format(filename)) traceback.print_exc() - return (filename, '') - new_filename = os.path.splitext(filename)[0] + '.yaml' - with open(new_filename, 'w') as yaml_file: - yaml_file.write(YAML_HEADER + '\n') + return (filename, "") + new_filename = os.path.splitext(filename)[0] + ".yaml" + with open(new_filename, "w") as yaml_file: + yaml_file.write(YAML_HEADER + "\n") yaml_file.write(yaml.safe_dump(data, default_flow_style=False)) return (filename, new_filename) def git_rm(filename): try: - subprocess.check_call(['git', 'rm', filename]) + subprocess.check_call(["git", "rm", filename]) except subprocess.CalledProcessError: - PRINT('Failed to git rm {}'.format(filename)) + PRINT("Failed to git rm {}".format(filename)) traceback.print_exc() return (False, filename) return (True, filename) def main(dir_, skip_convert): - files = get_files_matching_pattern(dir_, '*.json') + files = get_files_matching_pattern(dir_, "*.json") if skip_convert: PRINT(files) return results = [json_2_yaml_convert(filename) for filename in files] - PRINT('*** conversion done ***') - PRINT(['converted {} to {}'.format(result[0], result[1]) for result in results]) + PRINT("*** conversion done ***") + PRINT(["converted {} to {}".format(result[0], result[1]) for result in results]) results = [git_rm(filename) for filename, new_filename in results if new_filename] - PRINT('*** git rm done ***') + PRINT("*** git rm done ***") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='json2yaml converter.') - parser.add_argument('--dir', '-d', required=True, - help='The dir to look for json.') - parser.add_argument('--skipconvert', '-s', action='store_true', - help='Skip conversion') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="json2yaml converter.") + parser.add_argument("--dir", "-d", required=True, help="The dir to look for json.") + parser.add_argument( + "--skipconvert", "-s", action="store_true", help="Skip conversion" + ) args = parser.parse_args() main(dir_=args.dir, skip_convert=args.skipconvert) diff --git a/tools/list_group_members.py b/tools/list_group_members.py index e811eabd00..9cf575b62e 100755 --- a/tools/list_group_members.py +++ b/tools/list_group_members.py @@ -31,24 +31,26 @@ def main(group_id=None): if not group_id: group_ids = list(coordinator.get_groups().get()) - group_ids = [item.decode('utf-8') for item in group_ids] + group_ids = [item.decode("utf-8") for item in group_ids] - print('Available groups (%s):' % (len(group_ids))) + print("Available groups (%s):" % (len(group_ids))) for group_id in group_ids: - print(' - %s' % (group_id)) - print('') + print(" - %s" % (group_id)) + print("") else: group_ids = [group_id] for group_id in group_ids: member_ids = list(coordinator.get_members(group_id).get()) - member_ids = [member_id.decode('utf-8') for member_id in member_ids] + member_ids = [member_id.decode("utf-8") for member_id in member_ids] print('Members in group "%s" (%s):' % (group_id, len(member_ids))) for member_id in member_ids: - capabilities = coordinator.get_member_capabilities(group_id, member_id).get() - print(' - %s (capabilities=%s)' % (member_id, str(capabilities))) + capabilities = coordinator.get_member_capabilities( + group_id, member_id + ).get() + print(" - %s (capabilities=%s)" % (member_id, str(capabilities))) def do_register_cli_opts(opts, ignore_errors=False): @@ -60,11 +62,13 @@ def do_register_cli_opts(opts, ignore_errors=False): raise -if __name__ == '__main__': +if __name__ == "__main__": cli_opts = [ - cfg.StrOpt('group-id', default=None, - help='If provided, only list members for that group.'), - + cfg.StrOpt( + "group-id", + default=None, + help="If provided, only list members for that group.", + ), ] do_register_cli_opts(cli_opts) config.parse_args() diff --git a/tools/log_watcher.py b/tools/log_watcher.py index cafcb4efec..b16af95cc1 100755 --- a/tools/log_watcher.py +++ b/tools/log_watcher.py @@ -27,25 +27,9 @@ LOG_ALERT_PERCENT = 5 # default. -EVILS = [ - 'info', - 'debug', - 'warning', - 'exception', - 'error', - 'audit' -] - -LOG_VARS = [ - 'LOG', - 'Log', - 'log', - 'LOGGER', - 'Logger', - 'logger', - 'logging', - 'LOGGING' -] +EVILS = ["info", "debug", "warning", "exception", "error", "audit"] + +LOG_VARS = ["LOG", "Log", "log", "LOGGER", "Logger", "logger", "logging", "LOGGING"] FILE_LOG_COUNT = collections.defaultdict() FILE_LINE_COUNT = collections.defaultdict() @@ -55,25 +39,25 @@ def _parse_args(args): global LOG_ALERT_PERCENT params = {} if len(args) > 1: - params['alert_percent'] = args[1] + params["alert_percent"] = args[1] LOG_ALERT_PERCENT = int(args[1]) return params def _skip_file(filename): - if filename.startswith('.') or filename.startswith('_'): + if filename.startswith(".") or filename.startswith("_"): return True def _get_files(dir_path): if not os.path.exists(dir_path): - print('Directory %s doesn\'t exist.' % dir_path) + print("Directory %s doesn't exist." % dir_path) files = [] - exclude = set(['virtualenv', 'build', '.tox']) + exclude = set(["virtualenv", "build", ".tox"]) for root, dirnames, filenames in os.walk(dir_path): dirnames[:] = [d for d in dirnames if d not in exclude] - for filename in fnmatch.filter(filenames, '*.py'): + for filename in fnmatch.filter(filenames, "*.py"): if not _skip_file(filename): files.append(os.path.join(root, filename)) return files @@ -84,7 +68,7 @@ def _build_regex(): regex_strings = {} regexes = {} for level in EVILS: - regex_string = '|'.join([r'\.'.join([log, level]) for log in LOG_VARS]) + regex_string = "|".join([r"\.".join([log, level]) for log in LOG_VARS]) regex_strings[level] = regex_string # print('Level: %s, regex_string: %s' % (level, regex_strings[level])) regexes[level] = re.compile(regex_strings[level]) @@ -98,7 +82,7 @@ def _regex_match(line, regexes): def _build_str_matchers(): match_strings = {} for level in EVILS: - match_strings[level] = ['.'.join([log, level]) for log in LOG_VARS] + match_strings[level] = [".".join([log, level]) for log in LOG_VARS] return match_strings @@ -107,8 +91,10 @@ def _get_log_count_dict(): def _alert(fil, lines, logs, logs_level): - print('WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, ' - 'logs: %s' % (fil, lines, logs, float(logs) / lines * 100, logs_level)) + print( + "WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, " + "logs: %s" % (fil, lines, logs, float(logs) / lines * 100, logs_level) + ) def _match(line, match_strings): @@ -117,7 +103,7 @@ def _match(line, match_strings): if line.startswith(match_string): # print('Line: %s, match: %s' % (line, match_string)) return True, level, line - return False, 'UNKNOWN', line + return False, "UNKNOWN", line def _detect_log_lines(fil, matchers): @@ -148,23 +134,45 @@ def _post_process(file_dir): if total_log_count > 0: if float(total_log_count) / lines * 100 > LOG_ALERT_PERCENT: if file_dir in fil: - fil = fil[len(file_dir) + 1:] - alerts.append([fil, lines, total_log_count, float(total_log_count) / lines * 100, - log_lines_count_level['audit'], - log_lines_count_level['exception'], - log_lines_count_level['error'], - log_lines_count_level['warning'], - log_lines_count_level['info'], - log_lines_count_level['debug']]) + fil = fil[len(file_dir) + 1 :] + alerts.append( + [ + fil, + lines, + total_log_count, + float(total_log_count) / lines * 100, + log_lines_count_level["audit"], + log_lines_count_level["exception"], + log_lines_count_level["error"], + log_lines_count_level["warning"], + log_lines_count_level["info"], + log_lines_count_level["debug"], + ] + ) # sort by percent alerts.sort(key=lambda alert: alert[3], reverse=True) - print(tabulate(alerts, headers=['File', 'Lines', 'Logs', 'Percent', 'adt', 'exc', 'err', 'wrn', - 'inf', 'dbg'])) + print( + tabulate( + alerts, + headers=[ + "File", + "Lines", + "Logs", + "Percent", + "adt", + "exc", + "err", + "wrn", + "inf", + "dbg", + ], + ) + ) def main(args): params = _parse_args(args) - file_dir = params.get('dir', os.getcwd()) + file_dir = params.get("dir", os.getcwd()) files = _get_files(file_dir) matchers = _build_str_matchers() for f in files: @@ -172,5 +180,5 @@ def main(args): _post_process(file_dir) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv) diff --git a/tools/migrate_messaging_setup.py b/tools/migrate_messaging_setup.py index 095af26e0d..3fea8cab83 100755 --- a/tools/migrate_messaging_setup.py +++ b/tools/migrate_messaging_setup.py @@ -36,11 +36,13 @@ class Migrate_0_13_x_to_1_1_0(object): # changes or changes in durability proeprties. OLD_QS = [ # Name changed in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.timers', routing_key='#'), + reactor.get_trigger_cud_queue("st2.trigger.watch.timers", routing_key="#"), # Split to multiple queues in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.sensorwrapper', routing_key='#'), + reactor.get_trigger_cud_queue( + "st2.trigger.watch.sensorwrapper", routing_key="#" + ), # Name changed in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.webhooks', routing_key='#') + reactor.get_trigger_cud_queue("st2.trigger.watch.webhooks", routing_key="#"), ] def migrate(self): @@ -53,7 +55,7 @@ def _cleanup_old_queues(self): try: bound_q.delete() except: - print('Failed to delete %s.' % q.name) + print("Failed to delete %s." % q.name) traceback.print_exc() @@ -62,10 +64,10 @@ def main(): migrator = Migrate_0_13_x_to_1_1_0() migrator.migrate() except: - print('Messaging setup migration failed.') + print("Messaging setup migration failed.") traceback.print_exc() -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) main() diff --git a/tools/migrate_rules_to_include_pack.py b/tools/migrate_rules_to_include_pack.py index 8afd3faa15..1acdd26383 100755 --- a/tools/migrate_rules_to_include_pack.py +++ b/tools/migrate_rules_to_include_pack.py @@ -31,8 +31,11 @@ class Migration(object): - class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin): + class RuleDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + ): """Specifies the action to invoke on the occurrence of a Trigger. It also includes the transformation to perform to match the impedance between the payload of a TriggerInstance and input of a action. @@ -43,22 +46,23 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", + ) - meta = { - 'indexes': stormbase.TagsMixin.get_indexes() - } + meta = {"indexes": stormbase.TagsMixin.get_indexes()} # specialized access objects @@ -76,15 +80,17 @@ class RuleDB(stormbase.StormBaseDB, stormbase.TagsMixin): status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", + ) - meta = { - 'indexes': stormbase.TagsMixin.get_indexes() - } + meta = {"indexes": stormbase.TagsMixin.get_indexes()} rule_access_without_pack = MongoDBAccess(RuleDB) @@ -100,7 +106,7 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For Rule name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) @@ -126,13 +132,14 @@ def migrate_rules(): action=rule.action, enabled=rule.enabled, pack=DEFAULT_PACK_NAME, - ref=ResourceReference.to_string_reference(pack=DEFAULT_PACK_NAME, - name=rule.name) + ref=ResourceReference.to_string_reference( + pack=DEFAULT_PACK_NAME, name=rule.name + ), ) - print('Migrating rule: %s to rule: %s' % (rule.name, rule_with_pack.ref)) + print("Migrating rule: %s to rule: %s" % (rule.name, rule_with_pack.ref)) RuleWithPack.add_or_update(rule_with_pack) except Exception as e: - print('Migration failed. %s' % six.text_type(e)) + print("Migration failed. %s" % six.text_type(e)) def main(): @@ -148,5 +155,5 @@ def main(): db_teardown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/migrate_triggers_to_include_ref_count.py b/tools/migrate_triggers_to_include_ref_count.py index 3e8f1b79f0..af98a00a07 100755 --- a/tools/migrate_triggers_to_include_ref_count.py +++ b/tools/migrate_triggers_to_include_ref_count.py @@ -27,7 +27,6 @@ class TriggerMigrator(object): - def _get_trigger_with_parameters(self): """ All TriggerDB that has a parameter. @@ -38,7 +37,7 @@ def _get_rules_for_trigger(self, trigger_ref): """ All rules that reference the supplied trigger_ref. """ - return Rule.get_all(**{'trigger': trigger_ref}) + return Rule.get_all(**{"trigger": trigger_ref}) def _update_trigger_ref_count(self, trigger_db, ref_count): """ @@ -56,7 +55,7 @@ def migrate(self): trigger_ref = trigger_db.get_reference().ref rules = self._get_rules_for_trigger(trigger_ref=trigger_ref) ref_count = len(rules) - print('Updating Trigger %s to ref_count %s' % (trigger_ref, ref_count)) + print("Updating Trigger %s to ref_count %s" % (trigger_ref, ref_count)) self._update_trigger_ref_count(trigger_db=trigger_db, ref_count=ref_count) @@ -76,5 +75,5 @@ def main(): teartown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/queue_consumer.py b/tools/queue_consumer.py index bf19cbf1d5..69de164fbd 100755 --- a/tools/queue_consumer.py +++ b/tools/queue_consumer.py @@ -37,28 +37,31 @@ def __init__(self, connection, queue): self.queue = queue def get_consumers(self, Consumer, channel): - return [Consumer(queues=[self.queue], - accept=['pickle'], - callbacks=[self.process_task])] + return [ + Consumer( + queues=[self.queue], accept=["pickle"], callbacks=[self.process_task] + ) + ] def process_task(self, body, message): - print('===================================================') - print('Received message') - print('message.properties:') + print("===================================================") + print("Received message") + print("message.properties:") pprint(message.properties) - print('message.delivery_info:') + print("message.delivery_info:") pprint(message.delivery_info) - print('body:') + print("body:") pprint(body) - print('===================================================') + print("===================================================") message.ack() -def main(queue, exchange, routing_key='#'): - exchange = Exchange(exchange, type='topic') - queue = Queue(name=queue, exchange=exchange, routing_key=routing_key, - auto_delete=True) +def main(queue, exchange, routing_key="#"): + exchange = Exchange(exchange, type="topic") + queue = Queue( + name=queue, exchange=exchange, routing_key=routing_key, auto_delete=True + ) with transport_utils.get_connection() as connection: connection.connect() @@ -66,13 +69,11 @@ def main(queue, exchange, routing_key='#'): watcher.run() -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) - parser = argparse.ArgumentParser(description='Queue consumer') - parser.add_argument('--exchange', required=True, - help='Exchange to listen on') - parser.add_argument('--routing-key', default='#', - help='Routing key') + parser = argparse.ArgumentParser(description="Queue consumer") + parser.add_argument("--exchange", required=True, help="Exchange to listen on") + parser.add_argument("--routing-key", default="#", help="Routing key") args = parser.parse_args() queue_name = args.exchange + str(random.randint(1, 10000)) diff --git a/tools/queue_producer.py b/tools/queue_producer.py index c936088676..01476a26b8 100755 --- a/tools/queue_producer.py +++ b/tools/queue_producer.py @@ -30,22 +30,20 @@ def main(exchange, routing_key, payload): - exchange = Exchange(exchange, type='topic') + exchange = Exchange(exchange, type="topic") publisher = PoolPublisher() publisher.publish(payload=payload, exchange=exchange, routing_key=routing_key) eventlet.sleep(0.5) -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) - parser = argparse.ArgumentParser(description='Queue producer') - parser.add_argument('--exchange', required=True, - help='Exchange to publish the message to') - parser.add_argument('--routing-key', required=True, - help='Routing key to use') - parser.add_argument('--payload', required=True, - help='Message payload') + parser = argparse.ArgumentParser(description="Queue producer") + parser.add_argument( + "--exchange", required=True, help="Exchange to publish the message to" + ) + parser.add_argument("--routing-key", required=True, help="Routing key to use") + parser.add_argument("--payload", required=True, help="Message payload") args = parser.parse_args() - main(exchange=args.exchange, routing_key=args.routing_key, - payload=args.payload) + main(exchange=args.exchange, routing_key=args.routing_key, payload=args.payload) diff --git a/tools/st2-analyze-links.py b/tools/st2-analyze-links.py index 4daeeafa44..f66c158dea 100644 --- a/tools/st2-analyze-links.py +++ b/tools/st2-analyze-links.py @@ -44,8 +44,10 @@ try: from graphviz import Digraph except ImportError: - msg = ('Missing "graphviz" dependency. You can install it using pip: \n' - 'pip install graphviz') + msg = ( + 'Missing "graphviz" dependency. You can install it using pip: \n' + "pip install graphviz" + ) raise ImportError(msg) @@ -59,18 +61,20 @@ def do_register_cli_opts(opts, ignore_errors=False): class RuleLink(object): - def __init__(self, source_action_ref, rule_ref, dest_action_ref): self._source_action_ref = source_action_ref self._rule_ref = rule_ref self._dest_action_ref = dest_action_ref def __str__(self): - return '(%s -> %s -> %s)' % (self._source_action_ref, self._rule_ref, self._dest_action_ref) + return "(%s -> %s -> %s)" % ( + self._source_action_ref, + self._rule_ref, + self._dest_action_ref, + ) class LinksAnalyzer(object): - def __init__(self): self._rule_link_by_action_ref = {} self._rules = {} @@ -81,25 +85,30 @@ def analyze(self, root_action_ref, link_tigger_ref): for rule in rules: source_action_ref = self._get_source_action_ref(rule) if not source_action_ref: - print('No source_action_ref for rule %s' % rule.ref) + print("No source_action_ref for rule %s" % rule.ref) continue rule_links = self._rules.get(source_action_ref, None) if rule_links is None: rule_links = [] self._rules[source_action_ref] = rule_links - rule_links.append(RuleLink(source_action_ref=source_action_ref, rule_ref=rule.ref, - dest_action_ref=rule.action.ref)) + rule_links.append( + RuleLink( + source_action_ref=source_action_ref, + rule_ref=rule.ref, + dest_action_ref=rule.action.ref, + ) + ) analyzed = self._do_analyze(action_ref=root_action_ref) for (depth, rule_link) in analyzed: - print('%s%s' % (' ' * depth, rule_link)) + print("%s%s" % (" " * depth, rule_link)) return analyzed def _get_source_action_ref(self, rule): criteria = rule.criteria - source_action_ref = criteria.get('trigger.action_name', None) + source_action_ref = criteria.get("trigger.action_name", None) if not source_action_ref: - source_action_ref = criteria.get('trigger.action_ref', None) - return source_action_ref['pattern'] if source_action_ref else None + source_action_ref = criteria.get("trigger.action_ref", None) + return source_action_ref["pattern"] if source_action_ref else None def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0): if processed is None: @@ -111,24 +120,32 @@ def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0): rule_links.append((depth, rule_link)) if rule_link._dest_action_ref in processed: continue - self._do_analyze(rule_link._dest_action_ref, rule_links=rule_links, - processed=processed, depth=depth + 1) + self._do_analyze( + rule_link._dest_action_ref, + rule_links=rule_links, + processed=processed, + depth=depth + 1, + ) return rule_links class Grapher(object): def generate_graph(self, rule_links, out_file): - graph_label = 'Rule based visualizer' + graph_label = "Rule based visualizer" graph_attr = { - 'rankdir': 'TD', - 'labelloc': 't', - 'fontsize': '15', - 'label': graph_label + "rankdir": "TD", + "labelloc": "t", + "fontsize": "15", + "label": graph_label, } node_attr = {} - dot = Digraph(comment='Rule based links visualization', - node_attr=node_attr, graph_attr=graph_attr, format='png') + dot = Digraph( + comment="Rule based links visualization", + node_attr=node_attr, + graph_attr=graph_attr, + format="png", + ) nodes = set() for _, rule_link in rule_links: @@ -139,10 +156,14 @@ def generate_graph(self, rule_links, out_file): if rule_link._dest_action_ref not in nodes: nodes.add(rule_link._dest_action_ref) dot.node(rule_link._dest_action_ref, rule_link._dest_action_ref) - dot.edge(rule_link._source_action_ref, rule_link._dest_action_ref, constraint='true', - label=rule_link._rule_ref) + dot.edge( + rule_link._source_action_ref, + rule_link._dest_action_ref, + constraint="true", + label=rule_link._rule_ref, + ) output_path = os.path.join(os.getcwd(), out_file) - dot.format = 'png' + dot.format = "png" dot.render(output_path) @@ -150,11 +171,13 @@ def main(): monkey_patch() cli_opts = [ - cfg.StrOpt('action_ref', default=None, - help='Root action to begin analysis.'), - cfg.StrOpt('link_trigger_ref', default='core.st2.generic.actiontrigger', - help='Root action to begin analysis.'), - cfg.StrOpt('out_file', default='pipeline') + cfg.StrOpt("action_ref", default=None, help="Root action to begin analysis."), + cfg.StrOpt( + "link_trigger_ref", + default="core.st2.generic.actiontrigger", + help="Root action to begin analysis.", + ), + cfg.StrOpt("out_file", default="pipeline"), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -163,5 +186,5 @@ def main(): Grapher().generate_graph(rule_links, cfg.CONF.out_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/st2-inject-trigger-instances.py b/tools/st2-inject-trigger-instances.py index a20ba8bcbb..79b0f18e25 100755 --- a/tools/st2-inject-trigger-instances.py +++ b/tools/st2-inject-trigger-instances.py @@ -49,7 +49,9 @@ def do_register_cli_opts(opts, ignore_errors=False): raise -def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_throughput=False): +def _inject_instances( + trigger, rate_per_trigger, duration, payload=None, max_throughput=False +): payload = payload or {} start = date_utils.get_datetime_utc_now() @@ -72,37 +74,54 @@ def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_thr actual_rate = int(count / elapsed) - print('%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)' % - (trigger, count, elapsed, actual_rate)) + print( + "%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)" + % (trigger, count, elapsed, actual_rate) + ) # NOTE: Due to the overhead of dispatcher.dispatch call, we allow for 10% of deviation from # requested rate before warning if rate_per_trigger and (actual_rate < (rate_per_trigger * 0.9)): - print('') - print('Warning, requested rate was %s triggers / second, but only achieved %s ' - 'triggers / second' % (rate_per_trigger, actual_rate)) - print('Too increase the throuput you will likely need to run multiple instances of ' - 'this script in parallel.') + print("") + print( + "Warning, requested rate was %s triggers / second, but only achieved %s " + "triggers / second" % (rate_per_trigger, actual_rate) + ) + print( + "Too increase the throuput you will likely need to run multiple instances of " + "this script in parallel." + ) def main(): monkey_patch() cli_opts = [ - cfg.IntOpt('rate', default=100, - help='Rate of trigger injection measured in instances in per sec.' + - ' Assumes a default exponential distribution in time so arrival is poisson.'), - cfg.ListOpt('triggers', required=False, - help='List of triggers for which instances should be fired.' + - ' Uniform distribution will be followed if there is more than one' + - 'trigger.'), - cfg.StrOpt('schema_file', default=None, - help='Path to schema file defining trigger and payload.'), - cfg.IntOpt('duration', default=60, - help='Duration of stress test in seconds.'), - cfg.BoolOpt('max-throughput', default=False, - help='If True, "rate" argument will be ignored and this script will try to ' - 'saturize the CPU and achieve max utilization.') + cfg.IntOpt( + "rate", + default=100, + help="Rate of trigger injection measured in instances in per sec." + + " Assumes a default exponential distribution in time so arrival is poisson.", + ), + cfg.ListOpt( + "triggers", + required=False, + help="List of triggers for which instances should be fired." + + " Uniform distribution will be followed if there is more than one" + + "trigger.", + ), + cfg.StrOpt( + "schema_file", + default=None, + help="Path to schema file defining trigger and payload.", + ), + cfg.IntOpt("duration", default=60, help="Duration of stress test in seconds."), + cfg.BoolOpt( + "max-throughput", + default=False, + help='If True, "rate" argument will be ignored and this script will try to ' + "saturize the CPU and achieve max utilization.", + ), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -112,15 +131,20 @@ def main(): trigger_payload_schema = {} if not triggers: - if (cfg.CONF.schema_file is None or cfg.CONF.schema_file == '' or - not os.path.exists(cfg.CONF.schema_file)): - print('Either "triggers" need to be provided or a schema file containing' + - ' triggers should be provided.') + if ( + cfg.CONF.schema_file is None + or cfg.CONF.schema_file == "" + or not os.path.exists(cfg.CONF.schema_file) + ): + print( + 'Either "triggers" need to be provided or a schema file containing' + + " triggers should be provided." + ) return with open(cfg.CONF.schema_file) as fd: trigger_payload_schema = yaml.safe_load(fd) triggers = list(trigger_payload_schema.keys()) - print('Triggers=%s' % triggers) + print("Triggers=%s" % triggers) rate = cfg.CONF.rate rate_per_trigger = int(rate / len(triggers)) @@ -135,11 +159,17 @@ def main(): for trigger in triggers: payload = trigger_payload_schema.get(trigger, {}) - dispatcher_pool.spawn(_inject_instances, trigger, rate_per_trigger, duration, - payload=payload, max_throughput=max_throughput) + dispatcher_pool.spawn( + _inject_instances, + trigger, + rate_per_trigger, + duration, + payload=payload, + max_throughput=max_throughput, + ) eventlet.sleep(random.uniform(0, 1)) dispatcher_pool.waitall() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/visualize_action_chain.py b/tools/visualize_action_chain.py index 9981bd956c..c6742c460d 100755 --- a/tools/visualize_action_chain.py +++ b/tools/visualize_action_chain.py @@ -26,8 +26,10 @@ try: from graphviz import Digraph except ImportError: - msg = ('Missing "graphviz" dependency. You can install it using pip: \n' - 'pip install graphviz') + msg = ( + 'Missing "graphviz" dependency. You can install it using pip: \n' + "pip install graphviz" + ) raise ImportError(msg) from st2common.content.loader import MetaLoader @@ -41,25 +43,29 @@ def main(metadata_path, output_path, print_source=False): meta_loader = MetaLoader() data = meta_loader.load(metadata_path) - action_name = data['name'] - entry_point = data['entry_point'] + action_name = data["name"] + entry_point = data["entry_point"] workflow_metadata_path = os.path.join(metadata_dir, entry_point) chainspec = meta_loader.load(workflow_metadata_path) - chain_holder = ChainHolder(chainspec, 'workflow') + chain_holder = ChainHolder(chainspec, "workflow") - graph_label = '%s action-chain workflow visualization' % (action_name) + graph_label = "%s action-chain workflow visualization" % (action_name) graph_attr = { - 'rankdir': 'TD', - 'labelloc': 't', - 'fontsize': '15', - 'label': graph_label + "rankdir": "TD", + "labelloc": "t", + "fontsize": "15", + "label": graph_label, } node_attr = {} - dot = Digraph(comment='Action chain work-flow visualization', - node_attr=node_attr, graph_attr=graph_attr, format='png') + dot = Digraph( + comment="Action chain work-flow visualization", + node_attr=node_attr, + graph_attr=graph_attr, + format="png", + ) # dot.body.extend(['rankdir=TD', 'size="10,5"']) # Add all nodes @@ -74,23 +80,35 @@ def main(metadata_path, output_path, print_source=False): nodes = [node] while nodes: previous_node = nodes.pop() - success_node = chain_holder.get_next_node(curr_node_name=previous_node.name, - condition='on-success') - failure_node = chain_holder.get_next_node(curr_node_name=previous_node.name, - condition='on-failure') + success_node = chain_holder.get_next_node( + curr_node_name=previous_node.name, condition="on-success" + ) + failure_node = chain_holder.get_next_node( + curr_node_name=previous_node.name, condition="on-failure" + ) # Add success node (if any) if success_node: - dot.edge(previous_node.name, success_node.name, constraint='true', - color='green', label='on success') + dot.edge( + previous_node.name, + success_node.name, + constraint="true", + color="green", + label="on success", + ) if success_node.name not in processed_nodes: nodes.append(success_node) processed_nodes.add(success_node.name) # Add failure node (if any) if failure_node: - dot.edge(previous_node.name, failure_node.name, constraint='true', - color='red', label='on failure') + dot.edge( + previous_node.name, + failure_node.name, + constraint="true", + color="red", + label="on failure", + ) if failure_node.name not in processed_nodes: nodes.append(failure_node) processed_nodes.add(failure_node.name) @@ -103,21 +121,36 @@ def main(metadata_path, output_path, print_source=False): else: output_path = output_path or os.path.join(os.getcwd(), action_name) - dot.format = 'png' + dot.format = "png" dot.render(output_path) - print('Graph saved at %s' % (output_path + '.png')) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Action chain visualization') - parser.add_argument('--metadata-path', action='store', required=True, - help='Path to the workflow action metadata file') - parser.add_argument('--output-path', action='store', required=False, - help='Output directory for the generated image') - parser.add_argument('--print-source', action='store_true', default=False, - help='Print graphviz source code to the stdout') + print("Graph saved at %s" % (output_path + ".png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Action chain visualization") + parser.add_argument( + "--metadata-path", + action="store", + required=True, + help="Path to the workflow action metadata file", + ) + parser.add_argument( + "--output-path", + action="store", + required=False, + help="Output directory for the generated image", + ) + parser.add_argument( + "--print-source", + action="store_true", + default=False, + help="Print graphviz source code to the stdout", + ) args = parser.parse_args() - main(metadata_path=args.metadata_path, output_path=args.output_path, - print_source=args.print_source) + main( + metadata_path=args.metadata_path, + output_path=args.output_path, + print_source=args.print_source, + ) From 3277415eabfc7c481c9ff350a38a9d8be4857b3f Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 22:44:24 +0100 Subject: [PATCH 03/22] Update black config, update .flake8 config so we ignore rules which conflict with black. --- lint-configs/python/.flake8 | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lint-configs/python/.flake8 b/lint-configs/python/.flake8 index 4edeebe162..271a9a21e6 100644 --- a/lint-configs/python/.flake8 +++ b/lint-configs/python/.flake8 @@ -5,7 +5,11 @@ enable-extensions = L101,L102 # We ignore some rules which conflict with black # E203 - whitespace before ':' - in direct conflict with black rule # W503 line break before binary operator - in direct conflict with black rule -ignore = E128,E402,E722,W504,E203,W503 +# E501 is line length limit +# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length +# We don't really need line length rule since black formatting takes care of +# that. +ignore = E128,E402,E722,W504,E501,E203,W503 exclude=*.egg/*,build,dist # Configuration for flake8-copyright extension From 4c526389b9beea8f34da64e0827dfc475ffc7156 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 23:42:55 +0100 Subject: [PATCH 04/22] Fix typo. --- Makefile | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index abf89f450b..67a4cc6b57 100644 --- a/Makefile +++ b/Makefile @@ -331,7 +331,7 @@ schemasgen: requirements .schemasgen black: requirements .black-check .PHONY: .black-check -.black: +.black-check: @echo @echo "================== black-check ====================" @echo @@ -349,8 +349,7 @@ black: requirements .black-check echo "==========================================================="; \ . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ done - # Python pack management actions - . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/* || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/ || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml scripts/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml pylint_plugins/*.py || exit 1; From 705d585a46c8568933b0b10ccdaf1ea2b409fdac Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:03:17 +0100 Subject: [PATCH 05/22] Add pre-commit config which runs various lint tools on the modified / added files. --- .pre-commit-config.yaml | 42 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..76a3c8363f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,42 @@ +# pre-commit hook which runs all the various lint checks + black auto formatting on the added +# files. +# This hook relies on development virtual environment being present in virtualenv/. +default_language_version: + python: python3.6 + +exclude: '(build|dist)' +repos: + - repo: local + hooks: + - id: black + name: black + entry: ./virtualenv/bin/python -m black --config pyproject.toml + language: script + types: [file, python] + - repo: local + hooks: + - id: flake8 + name: flake8 + entry: ./virtualenv/bin/python -m flake8 --config lint-configs/python/.flake8 + language: script + types: [file, python] + - repo: local + hooks: + - id: pylint + name: pylint + entry: ./virtualenv/bin/python -m pylint -E --rcfile=./lint-configs/python/.pylintrc + language: script + types: [file, python] + - repo: local + hooks: + - id: bandit + name: bandit + entry: ./virtualenv/bin/python -m bandit -lll -x build,dist + language: script + types: [file, python] + # - repo: https://github.com/pre-commit/pre-commit-hooks + # rev: v2.5.0 + # hooks: + # - id: trailing-whitespace + # - id: check-yaml + # exclude: (^openapi|fixtures) From 0720c0a50843d930ef498e3512f625dc2888e758 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:19:32 +0100 Subject: [PATCH 06/22] Update pre-commit config to also run trialing whitespace and check yaml syntax check. --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76a3c8363f..f4496a2278 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,9 +34,9 @@ repos: entry: ./virtualenv/bin/python -m bandit -lll -x build,dist language: script types: [file, python] - # - repo: https://github.com/pre-commit/pre-commit-hooks - # rev: v2.5.0 - # hooks: - # - id: trailing-whitespace - # - id: check-yaml - # exclude: (^openapi|fixtures) + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 + hooks: + - id: trailing-whitespace + - id: check-yaml + exclude: (openapi\.yaml) From e7945275e117898f346ea0cf0f078fe3ee28ca5f Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:22:52 +0100 Subject: [PATCH 07/22] Also run trailing whitespace + yaml checks as part of Make targets and CI. --- Makefile | 15 +++++++++++++-- test-requirements.txt | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 67a4cc6b57..8923cb7a5c 100644 --- a/Makefile +++ b/Makefile @@ -382,6 +382,17 @@ black: requirements .black-format . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml pylint_plugins/*.py || exit 1; +.PHONY: pre-commit-checks +black: requirements .pre-commit-checks + +# Ensure all files contain no trailing whitespace + that all YAML files are valid. +.PHONY: .pre-commit-checks +.pre-commit-checks: + @echo + @echo "================== pre-commit-checks ====================" + @echo + pre-commit run trailing-whitespace --all --show-diff-on-failure + pre-commit run check-yaml --all --show-diff-on-failure .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec @@ -474,7 +485,7 @@ bandit: requirements .bandit lint: requirements .lint .PHONY: .lint -.lint: .generate-api-spec .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check +.lint: .generate-api-spec .black-check .pre-commit-checks .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check .PHONY: clean clean: .cleanpycs @@ -1035,7 +1046,7 @@ debs: ci: ci-checks ci-unit ci-integration ci-packs-tests .PHONY: ci-checks -ci-checks: .generated-files-check .black-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages +ci-checks: .generated-files-check .black-check .pre-commit-checks .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages .PHONY: .rst-check .rst-check: diff --git a/test-requirements.txt b/test-requirements.txt index b1909e4535..c004342bc8 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,6 +6,7 @@ astroid==2.4.2 pylint==2.6.0 pylint-plugin-utils>=0.4 black==20.8b1 +pre-commit==2.1.0 bandit==1.5.1 ipython<6.0.0 isort>=4.2.5 From 969793f1fdbdd2c228e59ab112189166530d2680 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 00:23:20 +0100 Subject: [PATCH 08/22] Remove trailing whitespace from all the files. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- .../fixtures/execution_double_backslash.txt | 4 +-- .../tests/fixtures/execution_get_default.txt | 4 +-- .../fixtures/execution_get_has_schema.txt | 4 +-- .../fixtures/execution_unescape_newline.txt | 2 +- .../tests/fixtures/execution_unicode.txt | 4 +-- .../tests/fixtures/execution_unicode_py3.txt | 4 +-- st2common/bin/st2-run-pack-tests | 2 +- st2common/st2common/openapi.yaml | 6 ++-- st2reactor/Makefile | 2 +- .../fixtures/generic/runners/inquirer.yaml | 2 +- .../test_pause_resume_with_init_vars.yaml | 2 +- .../fixtures/packs/dummy_pack_20/pack.yaml | 2 +- .../workflows/jinja-version-functions.yaml | 2 +- ...low-default-value-from-action-context.yaml | 2 +- ...ow-source-channel-from-action-context.yaml | 2 +- .../workflows/yaql-version-functions.yaml | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 57 files changed, 85 insertions(+), 85 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd..bde4a90784 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2..a8b52fb674 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc..501e7b8f14 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd..b22e908d5c 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75..488939eb55 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 2357b08263..29078016d0 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88..c0b1692b03 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae420..86ead5303a 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc3..878772636a 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a702..35c5ab26d8 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6..82a131712c 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0..80047d2e5e 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5ed..4e3dfa38c2 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11..e949dc3742 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66..a247423948 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb..a1f203fb09 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36..404681a369 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b907898..6bcbb82c58 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af49..5833e27051 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7..907a18e8bf 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e..a8be531180 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e516441..003ab8b69d 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f..0c23ee6a82 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7be..149fb93b97 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a..4d4d9e5f39 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100a..4ddd986755 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc3..285bf972d7 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721..3a4b20cee0 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b59..6e24c0ec41 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf86..e2b9f09d44 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17..084fcad6a6 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a..191accd1c3 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df15..47091705f3 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df850..60d79a5b74 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d2..f61cedcba4 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a66..e17db7e4f6 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2client/tests/fixtures/execution_double_backslash.txt b/st2client/tests/fixtures/execution_double_backslash.txt index 5437c7add4..efb21dc7e3 100644 --- a/st2client/tests/fixtures/execution_double_backslash.txt +++ b/st2client/tests/fixtures/execution_double_backslash.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde333 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: echo 'C:\Users\ADMINI~1\AppData\Local\Temp\jking_vmware20_test' status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_get_default.txt b/st2client/tests/fixtures/execution_get_default.txt index c29c2c221d..4dea32224a 100644 --- a/st2client/tests/fixtures/execution_get_default.txt +++ b/st2client/tests/fixtures/execution_get_default.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_get_has_schema.txt b/st2client/tests/fixtures/execution_get_has_schema.txt index c29c2c221d..4dea32224a 100644 --- a/st2client/tests/fixtures/execution_get_has_schema.txt +++ b/st2client/tests/fixtures/execution_get_has_schema.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unescape_newline.txt b/st2client/tests/fixtures/execution_unescape_newline.txt index 4abac251a5..a0b0624e55 100644 --- a/st2client/tests/fixtures/execution_unescape_newline.txt +++ b/st2client/tests/fixtures/execution_unescape_newline.txt @@ -5,7 +5,7 @@ parameters: None status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unicode.txt b/st2client/tests/fixtures/execution_unicode.txt index 7b7491d3b7..54a9ccc254 100644 --- a/st2client/tests/fixtures/execution_unicode.txt +++ b/st2client/tests/fixtures/execution_unicode.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '‡'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_unicode_py3.txt b/st2client/tests/fixtures/execution_unicode_py3.txt index 0db50aa746..0e69f4eff4 100644 --- a/st2client/tests/fixtures/execution_unicode_py3.txt +++ b/st2client/tests/fixtures/execution_unicode_py3.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '\u2021'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed2826760..9f7c2306ab 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2common/st2common/openapi.yaml b/st2common/st2common/openapi.yaml index f1a116c3d6..ecab78f5ad 100644 --- a/st2common/st2common/openapi.yaml +++ b/st2common/st2common/openapi.yaml @@ -8,7 +8,7 @@ info: version: "1.0.0" title: StackStorm API description: | - + ## Welcome Welcome to the StackStorm API Reference documentation! You can use the StackStorm API to integrate StackStorm with 3rd-party systems and custom applications. Example integrations include writing your own self-service user portal, or integrating with other orquestation systems. @@ -197,7 +197,7 @@ info: Join our [Slack Community](https://stackstorm.com/community-signup) to get help from the engineering team and fellow users. You can also create issues against the main [StackStorm GitHub repo](https://github.com/StackStorm/st2/issues), or the [st2apidocs repo](https://github.com/StackStorm/st2apidocs) for documentation-specific issues. We also recommend reviewing the main [StackStorm documentation](https://docs.stackstorm.com/). - + paths: /api/v1/: @@ -1477,7 +1477,7 @@ paths: /api/v1/keys: get: operationId: st2api.controllers.v1.keyvalue:key_value_pair_controller.get_all - x-permissions: + x-permissions: description: Returns a list of all key value pairs. parameters: - name: prefix diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3e..232abed4dd 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml index 421262c52b..f49903a7e9 100644 --- a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml +++ b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml @@ -24,7 +24,7 @@ runner_parameters: roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml index dda36fb302..72114dd49e 100644 --- a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml +++ b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml @@ -8,7 +8,7 @@ chain: cmd: "while [ -e '{{tempfile}}' ]; do sleep 0.1; done" timeout: 180 publish: - var1: "{{var1|upper}}" + var1: "{{var1|upper}}" on-success: task2 - name: task2 diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml index c99661eb7c..a843fb64d4 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml @@ -9,6 +9,6 @@ attribute1: value1 attribute2: value2 attribute3: value3 attribute: 4 -some: +some: - "feature" - value diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml index 0d77ef8aef..579bc515bb 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '{{ version_more_than("0.9.0", "0.10.0") }}' - less_than: '{{ version_less_than("0.10.0", "0.9.0") }}' - match: '{{ version_match("0.10.1", ">0.10.0") }}' - - bump_major: '{{ version_bump_major("0.10.0") }}' + - bump_major: '{{ version_bump_major("0.10.0") }}' - bump_minor: '{{ version_bump_minor("0.10.0") }}' - bump_patch: '{{ version_bump_patch("0.10.0") }}' - strip_patch: '{{ version_strip_patch("0.10.0") }}' diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml index 90e3ac78ad..3301ab423a 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml index eedc5b8c3e..7a6bd62fa5 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml index cce350c46c..7bda9cc83b 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '<% version_more_than("0.9.0", "0.10.0") %>' - less_than: '<% version_less_than("0.10.0", "0.9.0") %>' - match: '<% version_match("0.10.1", ">0.10.0") %>' - - bump_major: '<% version_bump_major("0.10.0") %>' + - bump_major: '<% version_bump_major("0.10.0") %>' - bump_minor: '<% version_bump_minor("0.10.0") %>' - bump_patch: '<% version_bump_patch("0.10.0") %>' - strip_patch: '<% version_strip_patch("0.10.0") %>' diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c..06abc65227 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f36..2e6eadf6a2 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1..de40b85878 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From e4cdc0584deb6034d65163b586cfd59ecd4c1f47 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:25:58 +0100 Subject: [PATCH 09/22] Add .git-blame-ignore-rev file. --- .git-blame-ignore-rev | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .git-blame-ignore-rev diff --git a/.git-blame-ignore-rev b/.git-blame-ignore-rev new file mode 100644 index 0000000000..2e9f4011b2 --- /dev/null +++ b/.git-blame-ignore-rev @@ -0,0 +1,5 @@ +# Code was auto formatted using black +8496bb2407b969f0937431992172b98b545f6756 + +# Files were auto formatted to remove trailing whitespace +969793f1fdbdd2c228e59ab112189166530d2680 From 6051539231e4a8d1c547695f072882ed8cf384d3 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 10:52:46 +0100 Subject: [PATCH 10/22] Revert "Remove trailing whitespace from all the files." This reverts commit 969793f1fdbdd2c228e59ab112189166530d2680. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- .../fixtures/execution_double_backslash.txt | 4 +-- .../tests/fixtures/execution_get_default.txt | 4 +-- .../fixtures/execution_get_has_schema.txt | 4 +-- .../fixtures/execution_unescape_newline.txt | 2 +- .../tests/fixtures/execution_unicode.txt | 4 +-- .../tests/fixtures/execution_unicode_py3.txt | 4 +-- st2common/bin/st2-run-pack-tests | 2 +- st2common/st2common/openapi.yaml | 6 ++-- st2reactor/Makefile | 2 +- .../fixtures/generic/runners/inquirer.yaml | 2 +- .../test_pause_resume_with_init_vars.yaml | 2 +- .../fixtures/packs/dummy_pack_20/pack.yaml | 2 +- .../workflows/jinja-version-functions.yaml | 2 +- ...low-default-value-from-action-context.yaml | 2 +- ...ow-source-channel-from-action-context.yaml | 2 +- .../workflows/yaql-version-functions.yaml | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 57 files changed, 85 insertions(+), 85 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bde4a90784..bcbacbe3bd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a8b52fb674..d0460f4ea2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index 501e7b8f14..dfb8fb87bc 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index b22e908d5c..4d84895bbd 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 488939eb55..758b743e75 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 29078016d0..2357b08263 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index c0b1692b03..b9c04efa88 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index 86ead5303a..f226eae420 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 878772636a..3ff06eabc3 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 35c5ab26d8..85e774a702 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index 82a131712c..a0793f8bf6 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 80047d2e5e..5d9c6f22a0 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index 4e3dfa38c2..da9179b5ed 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index e949dc3742..61b14a3c11 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index b86d8ef25b..936db68ff3 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index a247423948..eaf09fed66 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index b86d8ef25b..936db68ff3 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index a1f203fb09..0d80b0dbcb 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 404681a369..3a03409d36 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index 6bcbb82c58..e20b907898 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 5833e27051..6a2cc4af49 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index 907a18e8bf..ce935f62f7 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index a8be531180..c0322d025e 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index 003ab8b69d..dd1e516441 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index 0c23ee6a82..a0deab1d8f 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 149fb93b97..0887d4a7be 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 4d4d9e5f39..8fd2a94d8a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 4ddd986755..403728100a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 285bf972d7..7123727cc3 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 3a4b20cee0..11eb22a721 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 6e24c0ec41..8af6899b59 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index e2b9f09d44..33d872cf86 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 084fcad6a6..7924e91e17 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 191accd1c3..1b8d0d572a 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 47091705f3..18d1b3df15 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 60d79a5b74..26711df850 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index f61cedcba4..4e1c1f22d2 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index e17db7e4f6..9d6cf70a66 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2client/tests/fixtures/execution_double_backslash.txt b/st2client/tests/fixtures/execution_double_backslash.txt index efb21dc7e3..5437c7add4 100644 --- a/st2client/tests/fixtures/execution_double_backslash.txt +++ b/st2client/tests/fixtures/execution_double_backslash.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde333 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: echo 'C:\Users\ADMINI~1\AppData\Local\Temp\jking_vmware20_test' status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_get_default.txt b/st2client/tests/fixtures/execution_get_default.txt index 4dea32224a..c29c2c221d 100644 --- a/st2client/tests/fixtures/execution_get_default.txt +++ b/st2client/tests/fixtures/execution_get_default.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_get_has_schema.txt b/st2client/tests/fixtures/execution_get_has_schema.txt index 4dea32224a..c29c2c221d 100644 --- a/st2client/tests/fixtures/execution_get_has_schema.txt +++ b/st2client/tests/fixtures/execution_get_has_schema.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unescape_newline.txt b/st2client/tests/fixtures/execution_unescape_newline.txt index a0b0624e55..4abac251a5 100644 --- a/st2client/tests/fixtures/execution_unescape_newline.txt +++ b/st2client/tests/fixtures/execution_unescape_newline.txt @@ -5,7 +5,7 @@ parameters: None status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unicode.txt b/st2client/tests/fixtures/execution_unicode.txt index 54a9ccc254..7b7491d3b7 100644 --- a/st2client/tests/fixtures/execution_unicode.txt +++ b/st2client/tests/fixtures/execution_unicode.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '‡'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_unicode_py3.txt b/st2client/tests/fixtures/execution_unicode_py3.txt index 0e69f4eff4..0db50aa746 100644 --- a/st2client/tests/fixtures/execution_unicode_py3.txt +++ b/st2client/tests/fixtures/execution_unicode_py3.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '\u2021'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index 9f7c2306ab..bed2826760 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2common/st2common/openapi.yaml b/st2common/st2common/openapi.yaml index ecab78f5ad..f1a116c3d6 100644 --- a/st2common/st2common/openapi.yaml +++ b/st2common/st2common/openapi.yaml @@ -8,7 +8,7 @@ info: version: "1.0.0" title: StackStorm API description: | - + ## Welcome Welcome to the StackStorm API Reference documentation! You can use the StackStorm API to integrate StackStorm with 3rd-party systems and custom applications. Example integrations include writing your own self-service user portal, or integrating with other orquestation systems. @@ -197,7 +197,7 @@ info: Join our [Slack Community](https://stackstorm.com/community-signup) to get help from the engineering team and fellow users. You can also create issues against the main [StackStorm GitHub repo](https://github.com/StackStorm/st2/issues), or the [st2apidocs repo](https://github.com/StackStorm/st2apidocs) for documentation-specific issues. We also recommend reviewing the main [StackStorm documentation](https://docs.stackstorm.com/). - + paths: /api/v1/: @@ -1477,7 +1477,7 @@ paths: /api/v1/keys: get: operationId: st2api.controllers.v1.keyvalue:key_value_pair_controller.get_all - x-permissions: + x-permissions: description: Returns a list of all key value pairs. parameters: - name: prefix diff --git a/st2reactor/Makefile b/st2reactor/Makefile index 232abed4dd..cd3eb75a3e 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml index f49903a7e9..421262c52b 100644 --- a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml +++ b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml @@ -24,7 +24,7 @@ runner_parameters: roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml index 72114dd49e..dda36fb302 100644 --- a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml +++ b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml @@ -8,7 +8,7 @@ chain: cmd: "while [ -e '{{tempfile}}' ]; do sleep 0.1; done" timeout: 180 publish: - var1: "{{var1|upper}}" + var1: "{{var1|upper}}" on-success: task2 - name: task2 diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml index a843fb64d4..c99661eb7c 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml @@ -9,6 +9,6 @@ attribute1: value1 attribute2: value2 attribute3: value3 attribute: 4 -some: +some: - "feature" - value diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml index 579bc515bb..0d77ef8aef 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '{{ version_more_than("0.9.0", "0.10.0") }}' - less_than: '{{ version_less_than("0.10.0", "0.9.0") }}' - match: '{{ version_match("0.10.1", ">0.10.0") }}' - - bump_major: '{{ version_bump_major("0.10.0") }}' + - bump_major: '{{ version_bump_major("0.10.0") }}' - bump_minor: '{{ version_bump_minor("0.10.0") }}' - bump_patch: '{{ version_bump_patch("0.10.0") }}' - strip_patch: '{{ version_strip_patch("0.10.0") }}' diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml index 3301ab423a..90e3ac78ad 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml index 7a6bd62fa5..eedc5b8c3e 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml index 7bda9cc83b..cce350c46c 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '<% version_more_than("0.9.0", "0.10.0") %>' - less_than: '<% version_less_than("0.10.0", "0.9.0") %>' - match: '<% version_match("0.10.1", ">0.10.0") %>' - - bump_major: '<% version_bump_major("0.10.0") %>' + - bump_major: '<% version_bump_major("0.10.0") %>' - bump_minor: '<% version_bump_minor("0.10.0") %>' - bump_patch: '<% version_bump_patch("0.10.0") %>' - strip_patch: '<% version_strip_patch("0.10.0") %>' diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index 06abc65227..ac38037d6c 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 2e6eadf6a2..5320dc2f36 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index de40b85878..451ceee8e1 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 33d9efd711ddd3626ef611e867c81a0a5084739c Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 11:00:37 +0100 Subject: [PATCH 11/22] Exclude test fixture files from trailing whitespace hook since it breaks some tests. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4496a2278..9ccc28e322 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,5 +38,5 @@ repos: rev: v2.5.0 hooks: - id: trailing-whitespace + exclude: (^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) - id: check-yaml - exclude: (openapi\.yaml) From 223a7ade496cbe0bb3f26b529d6d9d1c0f69a96c Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 11:04:22 +0100 Subject: [PATCH 12/22] Remove trailing whitespace. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 43 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd..bde4a90784 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2..a8b52fb674 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc..501e7b8f14 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd..b22e908d5c 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75..488939eb55 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 2357b08263..29078016d0 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88..c0b1692b03 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae420..86ead5303a 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc3..878772636a 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a702..35c5ab26d8 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6..82a131712c 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0..80047d2e5e 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5ed..4e3dfa38c2 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11..e949dc3742 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66..a247423948 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb..a1f203fb09 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36..404681a369 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b907898..6bcbb82c58 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af49..5833e27051 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7..907a18e8bf 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e..a8be531180 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e516441..003ab8b69d 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f..0c23ee6a82 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7be..149fb93b97 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a..4d4d9e5f39 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100a..4ddd986755 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc3..285bf972d7 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721..3a4b20cee0 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b59..6e24c0ec41 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf86..e2b9f09d44 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17..084fcad6a6 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a..191accd1c3 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df15..47091705f3 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df850..60d79a5b74 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d2..f61cedcba4 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a66..e17db7e4f6 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed2826760..9f7c2306ab 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3e..232abed4dd 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c..06abc65227 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f36..2e6eadf6a2 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1..de40b85878 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From eba4abe51ccb624f96a2fde8b9804d69fecda995 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 12:46:53 +0100 Subject: [PATCH 13/22] Also don't re-format config files. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ccc28e322..c539e0b854 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,5 +38,5 @@ repos: rev: v2.5.0 hooks: - id: trailing-whitespace - exclude: (^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) + exclude: (^conf/|^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) - id: check-yaml From 514bd279cc68d7dbaee33df57d7d94dc2183ee5b Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 12:47:10 +0100 Subject: [PATCH 14/22] Revert "Remove trailing whitespace." This reverts commit 223a7ade496cbe0bb3f26b529d6d9d1c0f69a96c. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 43 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bde4a90784..bcbacbe3bd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a8b52fb674..d0460f4ea2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index 501e7b8f14..dfb8fb87bc 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index b22e908d5c..4d84895bbd 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 488939eb55..758b743e75 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 29078016d0..2357b08263 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index c0b1692b03..b9c04efa88 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index 86ead5303a..f226eae420 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 878772636a..3ff06eabc3 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 35c5ab26d8..85e774a702 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index 82a131712c..a0793f8bf6 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 80047d2e5e..5d9c6f22a0 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index 4e3dfa38c2..da9179b5ed 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index e949dc3742..61b14a3c11 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index b86d8ef25b..936db68ff3 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index a247423948..eaf09fed66 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index b86d8ef25b..936db68ff3 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index a1f203fb09..0d80b0dbcb 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 404681a369..3a03409d36 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index 6bcbb82c58..e20b907898 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 5833e27051..6a2cc4af49 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index 907a18e8bf..ce935f62f7 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index a8be531180..c0322d025e 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index 003ab8b69d..dd1e516441 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index 0c23ee6a82..a0deab1d8f 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 149fb93b97..0887d4a7be 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 4d4d9e5f39..8fd2a94d8a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 4ddd986755..403728100a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 285bf972d7..7123727cc3 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 3a4b20cee0..11eb22a721 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 6e24c0ec41..8af6899b59 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index e2b9f09d44..33d872cf86 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 084fcad6a6..7924e91e17 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 191accd1c3..1b8d0d572a 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 47091705f3..18d1b3df15 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 60d79a5b74..26711df850 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index f61cedcba4..4e1c1f22d2 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index e17db7e4f6..9d6cf70a66 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index 9f7c2306ab..bed2826760 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index 232abed4dd..cd3eb75a3e 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index 06abc65227..ac38037d6c 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 2e6eadf6a2..5320dc2f36 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index de40b85878..451ceee8e1 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 56101b8481a330a05e7fe668d762ca9ba1c386ac Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 13:00:06 +0100 Subject: [PATCH 15/22] Make sure sample config doesn't contain trailing whitespace. --- conf/st2.conf.sample | 4 ++-- tools/config_gen.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75..488939eb55 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/tools/config_gen.py b/tools/config_gen.py index e0004d04e1..309bdf608f 100755 --- a/tools/config_gen.py +++ b/tools/config_gen.py @@ -170,8 +170,8 @@ def _print_options(opt_group, options): else: value = opt.default - print("# %s" % opt.help) - print("%s = %s" % (opt.name, value)) + print(("# %s" % opt.help).strip()) + print(("%s = %s" % (opt.name, value)).strip()) def main(args): From 100fbdb45d24d5829906f1e5e1a9fc1b398a7bf2 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 13:00:18 +0100 Subject: [PATCH 16/22] Remove trailing whitespace. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 41 files changed, 61 insertions(+), 61 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd..bde4a90784 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2..a8b52fb674 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc..501e7b8f14 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd..b22e908d5c 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88..c0b1692b03 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae420..86ead5303a 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc3..878772636a 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a702..35c5ab26d8 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6..82a131712c 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0..80047d2e5e 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5ed..4e3dfa38c2 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11..e949dc3742 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66..a247423948 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3..b86d8ef25b 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb..a1f203fb09 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36..404681a369 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b907898..6bcbb82c58 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af49..5833e27051 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7..907a18e8bf 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e..a8be531180 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e516441..003ab8b69d 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f..0c23ee6a82 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7be..149fb93b97 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a..4d4d9e5f39 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100a..4ddd986755 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc3..285bf972d7 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721..3a4b20cee0 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b59..6e24c0ec41 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf86..e2b9f09d44 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17..084fcad6a6 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a..191accd1c3 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df15..47091705f3 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df850..60d79a5b74 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d2..f61cedcba4 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a66..e17db7e4f6 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed2826760..9f7c2306ab 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3e..232abed4dd 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c..06abc65227 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f36..2e6eadf6a2 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1..de40b85878 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 549bcd00750a2ac31181279c3fd30b3947b7f30b Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 16:26:07 +0100 Subject: [PATCH 17/22] Update black config so we don't try to reformat submodule. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4d03482994..1889c6a5da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ exclude = ''' | \.git | \.virtualenv | __pycache__ + | test_content_version )/ ) ''' From 00157676b47373142fec620d124718ff44671534 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 17:48:57 +0100 Subject: [PATCH 18/22] Update Makefile. --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 8923cb7a5c..91439ae15b 100644 --- a/Makefile +++ b/Makefile @@ -391,8 +391,8 @@ black: requirements .pre-commit-checks @echo @echo "================== pre-commit-checks ====================" @echo - pre-commit run trailing-whitespace --all --show-diff-on-failure - pre-commit run check-yaml --all --show-diff-on-failure + . $(VIRTUALENV_DIR)/bin/activate; pre-commit run trailing-whitespace --all --show-diff-on-failure + . $(VIRTUALENV_DIR)/bin/activate; pre-commit run check-yaml --all --show-diff-on-failure .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec From efa46112f5091eb33eb18e26e91c1aaadb116533 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 18:39:52 +0100 Subject: [PATCH 19/22] Fix lint. --- st2common/st2common/models/api/rbac.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/st2common/st2common/models/api/rbac.py b/st2common/st2common/models/api/rbac.py index bd269ce3d6..ffaff75409 100644 --- a/st2common/st2common/models/api/rbac.py +++ b/st2common/st2common/models/api/rbac.py @@ -228,10 +228,9 @@ def validate(self, validate_role_exists=False): if validate_role_exists: # Validate that the referenced roles exist in the db rbac_service = get_rbac_backend().get_service_class() - rbac_service.validate_roles_exists( - role_names=self.roles - ) # pylint: disable=no-member - + # pylint: disable=no-member + rbac_service.validate_roles_exists(role_names=self.roles) + # pylint: enable=no-member return cleaned From 083b103649814425a782d0e9ee090e9087d5d4ed Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:40:45 +0100 Subject: [PATCH 20/22] Fix typo. --- st2common/st2common/util/virtualenvs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/st2common/st2common/util/virtualenvs.py b/st2common/st2common/util/virtualenvs.py index b3635d15a5..62cfc99b52 100644 --- a/st2common/st2common/util/virtualenvs.py +++ b/st2common/st2common/util/virtualenvs.py @@ -236,7 +236,7 @@ def remove_virtualenv(virtualenv_path, logger=None): logger.debug('Removing virtualenv in "%s"' % virtualenv_path) try: shutil.rmtree(virtualenv_path) - logger.debug("Virtualenv successfull removed.") + logger.debug("Virtualenv successfully removed.") except Exception as e: logger.error( 'Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e) From 370aa874c825f616a92123b75808a10a05f1f1bc Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:42:27 +0100 Subject: [PATCH 21/22] Add changelog entry. --- CHANGELOG.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 891719ccbb..0fa0a562b8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,13 @@ Changelog in development -------------- +Changed +~~~~~~~ + +* All the code has been refactored using black and black style is automatically enforced and + required for all the new code. (#5156) + + Contributed by @Kami. 3.4.0 - March 02, 2021 ---------------------- @@ -22,7 +29,8 @@ Added * Added st2-auth-ldap pip requirements for LDAP auth integartion. (new feature) #5082 Contributed by @hnanchahal -* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch. (part of upgrade instructions) [#5167] +* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch. + (part of upgrade instructions) [#5167] Contributed by @winem and @blag Changed From 5d07a5c6b456737c5032e0dca459afa5b45f30af Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:54:56 +0100 Subject: [PATCH 22/22] Fix typo. --- st2common/tests/integration/test_register_content_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/st2common/tests/integration/test_register_content_script.py b/st2common/tests/integration/test_register_content_script.py index cbc7983112..0a0dc8e7f1 100644 --- a/st2common/tests/integration/test_register_content_script.py +++ b/st2common/tests/integration/test_register_content_script.py @@ -183,5 +183,5 @@ def test_register_recreate_virtualenvs(self): self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr) self.assertIn("Setup virtualenv for 1 pack(s)", stderr) - self.assertIn("Virtualenv successfull removed.", stderr) + self.assertIn("Virtualenv successfully removed.", stderr) self.assertEqual(exit_code, 0)