Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions ebrains_drive/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def put(self, *args, **kwargs):
def delete(self, *args, **kwargs):
return self.send_request("DELETE", *args, **kwargs)

def _exchange_oidc_for_seafile_token(self):
url = self.server.rstrip("/") + "/api2/account/token/"
headers = {"Authorization": f"Bearer {self._token}"}

resp = self.session.get(url, headers=headers)

if resp.status_code != 200:
raise Exception(f"Failed to exchange OIDC token for Seafile token: {resp.status_code} {resp.text}")

self._seafile_token = resp.text.strip()
return self._seafile_token

def send_request(self, method: str, url: str, *args, **kwargs):
if not url.startswith("http"):
# sanity checks.
Expand All @@ -82,21 +94,31 @@ def send_request(self, method: str, url: str, *args, **kwargs):
# We cannot deepcopy the whole thing, because some values (e.g. BufferedReader objects)
# cannot be pickled
kwargs = copy(kwargs)
headers = deepcopy(kwargs.get("headers", {}))
headers.setdefault("Authorization", "Bearer " + self._token)
kwargs["headers"] = headers
headers = kwargs.pop("headers", {}).copy()

if self._seafile_token:
headers.setdefault("Authorization", "Token " + self._seafile_token)
else:
headers.setdefault("Authorization", "Bearer " + self._token)

expected = kwargs.pop("expected", 200)
if not hasattr(expected, "__iter__"):
expected = (expected,)
resp = self.session.request(method, url, *args, **kwargs)

resp = self.session.request(method, url, headers=headers, *args, **kwargs)

if resp.status_code == 401 and not self._seafile_token:
self._seafile_token = self._exchange_oidc_for_seafile_token()

headers["Authorization"] = "Token " + self._seafile_token
resp = self.session.request(method, url, headers=headers, *args, **kwargs)

if resp.status_code not in expected:
msg = "Expected %s, but get %s" % (" or ".join(map(str, expected)), resp.status_code)
msg = f"Expected {expected}, but got {resp.status_code}"
raise ClientHttpError(resp.status_code, msg)

return resp


class DriveApiClient(ClientBase):
"""Wraps seafile web api"""

Expand All @@ -110,6 +132,7 @@ def __init__(self, username=None, password=None, token=None, env=""):
self.repos = Repos(self)
self.groups = Groups(self)
self.file = File(self)
self._seafile_token = None

def _set_env(self, env=""):
super()._set_env(env)
Expand Down