From f48e8d491003f3b138c88a32e9c1216dcf1c5672 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Mon, 28 Mar 2022 11:06:40 +0100 Subject: [PATCH] fix: Pass in token to fetch GH default branch You can't fetch the default branch for a private repo without a token. --- src/nitpick/style/fetchers/github.py | 31 ++++++++++++++++--------- tests/test_style.py | 34 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/nitpick/style/fetchers/github.py b/src/nitpick/style/fetchers/github.py index 0a0f7b1f..42afd396 100644 --- a/src/nitpick/style/fetchers/github.py +++ b/src/nitpick/style/fetchers/github.py @@ -34,21 +34,30 @@ class GitHubURL: def default_branch(self) -> str: """Default GitHub branch.""" # get_default_branch() is memoized - return get_default_branch(self.api_url.url) + return get_default_branch(self.api_url.url, token=self.token) @property - def credentials(self) -> tuple[str, str] | tuple[()]: - """Credentials encoded in this URL. + def token(self) -> str | None: + """Token encoded in this URL. - A tuple of ``(api_token, '')`` if present, or empty tuple otherwise. If - the value of ``api_token`` begins with ``$``, it will be replaced with - the value of the environment corresponding to the remaining part of the + If present and it starts with a ``$``, it will be replaced with the + value of the environment corresponding to the remaining part of the string. + """ token = self.auth_token if token is not None and token.startswith("$"): token = os.getenv(token[1:]) + return token + @property + def credentials(self) -> tuple[str, str] | tuple[()]: + """Credentials encoded in this URL. + + A tuple of ``(api_token, '')`` if present, or empty tuple otherwise. + + """ + token = self.token return (token, "") if token else () @property @@ -128,18 +137,18 @@ def _build_url(self, scheme: str) -> furl: @lru_cache() -def get_default_branch(api_url: str) -> str: +def get_default_branch(api_url: str, *, token: str | None = None) -> str: """Get the default branch from the GitHub repo using the API. - For now, the request is not authenticated on GitHub, so it might hit a rate limit with: + For now, for URLs without an authorization token embedded, the request is + not authenticated on GitHub, so it might hit a rate limit with: ``requests.exceptions.HTTPError: 403 Client Error: rate limit exceeded for url`` This function is using ``lru_cache()`` as a simple memoizer, trying to avoid this rate limit error. - Another option for the future: perform an authenticated request to GitHub. - That would require some user credentials. """ - response = API_SESSION.get(api_url) + headers = {"Authorization": f"token {token}"} if token else None + response = API_SESSION.get(api_url, headers=headers) response.raise_for_status() return response.json()["default_branch"] diff --git a/tests/test_style.py b/tests/test_style.py index ab4b6aa5..3f614ab6 100644 --- a/tests/test_style.py +++ b/tests/test_style.py @@ -431,6 +431,40 @@ def test_fetch_private_github_urls(tmp_path): project.flake8(offline=True).assert_no_errors() +@responses.activate +def test_fetch_private_github_urls_no_branch(tmp_path): + """Fetch private GitHub URLs with a token on the query string.""" + file_token = "query-string-token-generated-by-github-for-private-files" + gh_url = f"gh://{file_token}@user/private_repo/path/to/nitpick-style" + api_url = "https://api.github.com/repos/user/private_repo" + api_response = '{"default_branch": "branch"}' + full_raw_url = f"https://raw.githubusercontent.com/user/private_repo/branch/path/to/nitpick-style{TOML_EXTENSION}" + body = """ + ["pyproject.toml".tool.black] + missing = "thing" + """ + responses.add(responses.GET, api_url, api_response, status=200) + responses.add(responses.GET, full_raw_url, dedent(body), status=200) + + project = ProjectMock(tmp_path).pyproject_toml( + f""" + [tool.nitpick] + style = "{gh_url}" + """ + ) + project.flake8(offline=False).assert_single_error( + f""" + NIP318 File pyproject.toml has missing values:{SUGGESTION_BEGIN} + [tool.black] + missing = "thing"{SUGGESTION_END} + """ + ) + assert responses.calls[0].request.headers["Authorization"] == f"token {file_token}" + token_on_basic_auth = b64encode(f"{file_token}:".encode()).decode().strip() + assert responses.calls[1].request.headers["Authorization"] == f"Basic {token_on_basic_auth}" + project.flake8(offline=True).assert_no_errors() + + @pytest.mark.parametrize( "style_url", [