diff --git a/.azure-pipelines/test_plan.py b/.azure-pipelines/test_plan.py index c238ddd5504..20ea07f6079 100644 --- a/.azure-pipelines/test_plan.py +++ b/.azure-pipelines/test_plan.py @@ -5,6 +5,7 @@ import json import os import sys +import subprocess import copy import time from datetime import datetime, timedelta @@ -22,7 +23,7 @@ PR_TEST_SCRIPTS_FILE = "pr_test_scripts.yaml" SPECIFIC_PARAM_KEYWORD = "specific_param" TOLERATE_HTTP_EXCEPTION_TIMES = 20 -TOKEN_EXPIRE_HOURS = 6 +TOKEN_EXPIRE_HOURS = 1 MAX_GET_TOKEN_RETRY_TIMES = 3 @@ -130,16 +131,16 @@ def __init__(self): super(FinishStatus, self).__init__(TestPlanStatus.FINISHED) -def get_scope(elastictest_url): - scope = "api://sonic-testbed-tools-dev/.default" - if elastictest_url in [ - "http://sonic-testbed2-scheduler-backend.azurewebsites.net", - "https://sonic-testbed2-scheduler-backend.azurewebsites.net", - "http://sonic-elastictest-prod-scheduler-backend-webapp.azurewebsites.net", - "https://sonic-elastictest-prod-scheduler-backend-webapp.azurewebsites.net" - ]: - scope = "api://sonic-testbed-tools-prod/.default" - return scope +# def get_scope(elastictest_url): +# scope = "api://sonic-testbed-tools-dev/.default" +# if elastictest_url in [ +# "http://sonic-testbed2-scheduler-backend.azurewebsites.net", +# "https://sonic-testbed2-scheduler-backend.azurewebsites.net", +# "http://sonic-elastictest-prod-scheduler-backend-webapp.azurewebsites.net", +# "https://sonic-elastictest-prod-scheduler-backend-webapp.azurewebsites.net" +# ]: +# scope = "api://sonic-testbed-tools-prod/.default" +# return scope def parse_list_from_str(s): @@ -157,51 +158,86 @@ def parse_list_from_str(s): class TestPlanManager(object): - def __init__(self, url, frontend_url, tenant_id=None, client_id=None, client_secret=None, ): + def __init__(self, url, frontend_url, client_id=None): self.url = url self.frontend_url = frontend_url - self.tenant_id = tenant_id self.client_id = client_id - self.client_secret = client_secret self.with_auth = False self._token = None - self._token_generate_time = None - if self.tenant_id and self.client_id and self.client_secret: + self._token_expires_on = None + if self.client_id: self.with_auth = True self.get_token() + def cmd(self, cmds): + process = subprocess.Popen( + cmds, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate() + return_code = process.returncode + + return stdout, stderr, return_code + def get_token(self): - token_generate_time_valid = \ - self._token_generate_time is not None and \ - (datetime.utcnow() - self._token_generate_time) < timedelta(hours=TOKEN_EXPIRE_HOURS) + token_is_valid = \ + self._token_expires_on is not None and \ + (self._token_expires_on - datetime.now()) > timedelta(hours=TOKEN_EXPIRE_HOURS) - if self._token is not None and token_generate_time_valid: + if self._token is not None and token_is_valid: return self._token - token_url = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(self.tenant_id) - headers = { - "Content-Type": "application/x-www-form-urlencoded" - } - - payload = { - "grant_type": "client_credentials", - "client_id": self.client_id, - "client_secret": self.client_secret, - "scope": get_scope(self.url) - } + cmd = 'az account get-access-token --resource {}'.format(self.client_id) attempt = 0 while (attempt < MAX_GET_TOKEN_RETRY_TIMES): + stdout, stderr, return_code = self.cmd(cmd.split()) try: - resp = requests.post(token_url, headers=headers, data=payload, timeout=10).json() - self._token = resp["access_token"] - self._token_generate_time = datetime.utcnow() + if return_code != 0: + raise Exception("Failed to get token: rc: {}, error: {}".format(return_code, stderr)) + + token = json.loads(stdout.decode("utf-8")) + self._token = token.get("accessToken", None) + if not self._token: + raise Exception("Parse token from stdout failed") + + # Parse token expires time from string + token_expires_on = token.get("expiresOn", "") + self._token_expires_on = datetime.strptime(token_expires_on, "%Y-%m-%d %H:%M:%S.%f") + return self._token + except Exception as exception: attempt += 1 - print("Get token failed with exception: {}. Retry {} times to get token." - .format(repr(exception), MAX_GET_TOKEN_RETRY_TIMES - attempt)) + print("Failed to get token with exception: {}".format(repr(exception))) + raise Exception("Failed to get token after {} attempts".format(MAX_GET_TOKEN_RETRY_TIMES)) + # token_url = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(self.tenant_id) + # headers = { + # "Content-Type": "application/x-www-form-urlencoded" + # } + + # payload = { + # "grant_type": "client_credentials", + # "client_id": self.client_id, + # "client_secret": self.client_secret, + # "scope": get_scope(self.url) + # } + # attempt = 0 + # while (attempt < MAX_GET_TOKEN_RETRY_TIMES): + # try: + # resp = requests.post(token_url, headers=headers, data=payload, timeout=10).json() + # self._token = resp["access_token"] + # self._token_generate_time = datetime.utcnow() + # return self._token + # except Exception as exception: + # attempt += 1 + # print("Get token failed with exception: {}. Retry {} times to get token." + # .format(repr(exception), MAX_GET_TOKEN_RETRY_TIMES - attempt)) + # raise Exception("Failed to get token after {} attempts".format(MAX_GET_TOKEN_RETRY_TIMES)) + def create(self, topology, test_plan_name="my_test_plan", deploy_mg_extra_params="", kvm_build_id="", min_worker=None, max_worker=None, pr_id="unknown", output=None, common_extra_params="", **kwargs): @@ -852,7 +888,7 @@ def poll(self, test_plan_id, interval=60, timeout=-1, expected_state="", expecte args.test_plan_id = args.test_plan_id.replace("'", "") print("Test plan utils parameters: {}".format(args)) - auth_env = ["TENANT_ID", "CLIENT_ID", "CLIENT_SECRET"] + auth_env = ["CLIENT_ID"] required_env = ["ELASTICTEST_SCHEDULER_BACKEND_URL"] if args.action in ["create", "cancel"]: @@ -860,9 +896,7 @@ def poll(self, test_plan_id, interval=60, timeout=-1, expected_state="", expecte env = { "elastictest_scheduler_backend_url": os.environ.get("ELASTICTEST_SCHEDULER_BACKEND_URL"), - "tenant_id": os.environ.get("ELASTICTEST_MSAL_TENANT_ID"), "client_id": os.environ.get("ELASTICTEST_MSAL_CLIENT_ID"), - "client_secret": os.environ.get("ELASTICTEST_MSAL_CLIENT_SECRET"), "frontend_url": os.environ.get("ELASTICTEST_FRONTEND_URL", "https://elastictest.org"), } env_missing = [k.upper() for k, v in env.items() if k.upper() in required_env and not v] @@ -874,9 +908,7 @@ def poll(self, test_plan_id, interval=60, timeout=-1, expected_state="", expecte tp = TestPlanManager( env["elastictest_scheduler_backend_url"], env["frontend_url"], - env["tenant_id"], - env["client_id"], - env["client_secret"]) + env["client_id"]) if args.action == "create": pr_id = os.environ.get("SYSTEM_PULLREQUEST_PULLREQUESTNUMBER") or os.environ.get(