Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/nitpick/style/fetchers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,30 @@ class GitHubURL:
def default_branch(self) -> str:
"""Default GitHub branch."""
# get_default_branch() is memoized
return get_default_branch(self.api_url.url)
return get_default_branch(self.api_url.url, token=self.token)

@property
def credentials(self) -> tuple[str, str] | tuple[()]:
"""Credentials encoded in this URL.
def token(self) -> str | None:
"""Token encoded in this URL.

A tuple of ``(api_token, '')`` if present, or empty tuple otherwise. If
the value of ``api_token`` begins with ``$``, it will be replaced with
the value of the environment corresponding to the remaining part of the
If present and it starts with a ``$``, it will be replaced with the
value of the environment corresponding to the remaining part of the
string.

"""
token = self.auth_token
if token is not None and token.startswith("$"):
token = os.getenv(token[1:])
return token

@property
def credentials(self) -> tuple[str, str] | tuple[()]:
"""Credentials encoded in this URL.

A tuple of ``(api_token, '')`` if present, or empty tuple otherwise.

"""
token = self.token
return (token, "") if token else ()

@property
Expand Down Expand Up @@ -128,18 +137,18 @@ def _build_url(self, scheme: str) -> furl:


@lru_cache()
def get_default_branch(api_url: str) -> str:
def get_default_branch(api_url: str, *, token: str | None = None) -> str:
"""Get the default branch from the GitHub repo using the API.

For now, the request is not authenticated on GitHub, so it might hit a rate limit with:
For now, for URLs without an authorization token embedded, the request is
not authenticated on GitHub, so it might hit a rate limit with:
``requests.exceptions.HTTPError: 403 Client Error: rate limit exceeded for url``

This function is using ``lru_cache()`` as a simple memoizer, trying to avoid this rate limit error.

Another option for the future: perform an authenticated request to GitHub.
That would require some user credentials.
"""
response = API_SESSION.get(api_url)
headers = {"Authorization": f"token {token}"} if token else None
response = API_SESSION.get(api_url, headers=headers)
response.raise_for_status()

return response.json()["default_branch"]
Expand Down
34 changes: 34 additions & 0 deletions tests/test_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,40 @@ def test_fetch_private_github_urls(tmp_path):
project.flake8(offline=True).assert_no_errors()


@responses.activate
def test_fetch_private_github_urls_no_branch(tmp_path):
"""Fetch private GitHub URLs with a token on the query string."""
file_token = "query-string-token-generated-by-github-for-private-files"
gh_url = f"gh://{file_token}@user/private_repo/path/to/nitpick-style"
api_url = "https://api.github.com/repos/user/private_repo"
api_response = '{"default_branch": "branch"}'
full_raw_url = f"https://raw.githubusercontent.com/user/private_repo/branch/path/to/nitpick-style{TOML_EXTENSION}"
body = """
["pyproject.toml".tool.black]
missing = "thing"
"""
responses.add(responses.GET, api_url, api_response, status=200)
responses.add(responses.GET, full_raw_url, dedent(body), status=200)

project = ProjectMock(tmp_path).pyproject_toml(
f"""
[tool.nitpick]
style = "{gh_url}"
"""
)
project.flake8(offline=False).assert_single_error(
f"""
NIP318 File pyproject.toml has missing values:{SUGGESTION_BEGIN}
[tool.black]
missing = "thing"{SUGGESTION_END}
"""
)
assert responses.calls[0].request.headers["Authorization"] == f"token {file_token}"
token_on_basic_auth = b64encode(f"{file_token}:".encode()).decode().strip()
assert responses.calls[1].request.headers["Authorization"] == f"Basic {token_on_basic_auth}"
project.flake8(offline=True).assert_no_errors()


@pytest.mark.parametrize(
"style_url",
[
Expand Down