Skip to content

Commit 3f013df

Browse files
feat: Restricting a secret for a certain VO
1 parent c2181aa commit 3f013df

File tree

8 files changed

+193
-39
lines changed

8 files changed

+193
-39
lines changed

diracx-core/src/diracx/core/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ class BadPilotCredentialsError(GenericError):
137137
tail = ""
138138

139139

140+
class BadPilotVOError(GenericError):
141+
head = "Bad VO"
142+
tail = ""
143+
144+
140145
class SecretNotFoundError(GenericError):
141146
head = "Secret"
142147
tail = "not found"

diracx-db/src/diracx/db/sql/pilot_agents/db.py

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from diracx.core.exceptions import (
1010
BadPilotCredentialsError,
11+
BadPilotVOError,
1112
CredentialsNotFoundError,
1213
OverusedSecretError,
1314
PilotAlreadyExistsError,
@@ -112,32 +113,40 @@ async def verify_pilot_secret(
112113
)
113114

114115
# 5. Verify the secret counter
115-
if secret["SecretGlobalUseCount"] + 1 > secret["SecretGlobalUseCountMax"]:
116-
raise OverusedSecretError(
117-
data={
118-
"pilot_stamp" "pilot_hashed_secret": pilot_hashed_secret,
119-
"secret_global_use_count": secret["SecretGlobalUseCount"],
120-
"secret_global_use_count_max": secret["SecretGlobalUseCountMax"],
121-
}
122-
)
123-
124-
# Now the pilot is authorized. Increment the counters (globally and locally).
116+
# 5.1 Only check if the SecretGlobalUseCountMax is defined
117+
# If not defined, there is an infinite use.
118+
if secret["SecretGlobalUseCountMax"]:
119+
# 5.2 Finite use, we check if we can still login
120+
if secret["SecretGlobalUseCount"] + 1 > secret["SecretGlobalUseCountMax"]:
121+
raise OverusedSecretError(
122+
data={
123+
"pilot_stamp" "pilot_hashed_secret": pilot_hashed_secret,
124+
"secret_global_use_count": secret["SecretGlobalUseCount"],
125+
"secret_global_use_count_max": secret[
126+
"SecretGlobalUseCountMax"
127+
],
128+
}
129+
)
130+
131+
# 6. Now the pilot is authorized, increment the counters (globally and locally).
125132
try:
126-
# Increment the local count
133+
# 6.1 Increment the local count
127134
await self.increment_pilot_local_secret_and_last_time_use(
128135
pilot_secret_id=pilot_credentials["PilotSecretID"],
129136
pilot_stamp=pilot_credentials["PilotStamp"],
130137
)
131138

132-
# Increment the global count
139+
# 6.2 Increment the global count
133140
await self.increment_global_secret_use(
134141
secret_id=pilot_credentials["PilotSecretID"]
135142
)
136143
except Exception as e: # Generic, to catch it.
137144
# Should NOT happen
138-
# Still caught in case of an error in the counters
145+
# Wrapped in a try/catch to still catch in case of an error in the counters
139146
# Caught and raised here to avoid raising a 4XX error
140-
raise DBInBadStateError(detail="This should not happen.") from e
147+
raise DBInBadStateError(
148+
detail="This should not happen. Pilot has credentials, but has a corrupted secret."
149+
) from e
141150

142151
async def add_pilots_bulk(
143152
self,
@@ -175,6 +184,7 @@ async def add_pilots_credentials_bulk(
175184
self,
176185
pilot_stamps: list[str],
177186
pilot_hashed_secrets: list[str],
187+
vo: str | None,
178188
pilot_secret_use_count_max: int = 1,
179189
) -> list[dict]:
180190

@@ -186,6 +196,7 @@ async def add_pilots_credentials_bulk(
186196
await self.insert_unique_secrets_bulk(
187197
hashed_secrets=pilot_hashed_secrets,
188198
secret_global_use_count_max=pilot_secret_use_count_max,
199+
vo=vo,
189200
)
190201

191202
# Get the secret ids to later associate them with pilots
@@ -207,13 +218,17 @@ async def add_pilots_credentials_bulk(
207218
return secrets
208219

209220
async def insert_unique_secrets_bulk(
210-
self, hashed_secrets: list[str], secret_global_use_count_max: int = 1
221+
self,
222+
hashed_secrets: list[str],
223+
vo: str | None,
224+
secret_global_use_count_max: int = 1,
211225
):
212226
"""Bulk insert secrets. Raises an error in case of a Integrity violation."""
213227
values = [
214228
{
215229
"SecretGlobalUseCountMax": secret_global_use_count_max,
216230
"HashedSecret": hashed_secret,
231+
"SecretVO": vo,
217232
}
218233
for hashed_secret in hashed_secrets
219234
]
@@ -234,13 +249,19 @@ async def insert_unique_secrets_bulk(
234249
) from e
235250

236251
# Other errors to catch
237-
raise DBInBadStateError("Engine Specific error not caught") from e
252+
raise DBInBadStateError("Engine Specific error not caught" + str(e)) from e
238253

239254
async def associate_pilots_with_secrets_bulk(
240255
self, pilot_to_secret_id_mapping_values: list[dict[str, Any]]
241256
):
242257
"""Bulk associate pilots with secrets. Raises an error in case of a Integrity violation."""
243258
# Better to give as a parameter pilot to secret associations, rather than associating here.
259+
260+
# First verify that pilots can access a certain secret
261+
await self.verify_that_pilot_can_access_secret_bulk(
262+
pilot_to_secret_id_mapping_values
263+
)
264+
244265
stmt = insert(PilotToSecretMapping).values(pilot_to_secret_id_mapping_values)
245266

246267
try:
@@ -261,6 +282,52 @@ async def associate_pilots_with_secrets_bulk(
261282
detail="at least one of these pilots already have a secret",
262283
) from e
263284

285+
async def verify_that_pilot_can_access_secret_bulk(
286+
self, pilot_to_secret_id_mapping_values: list[dict[str, Any]]
287+
):
288+
# 1. Extract unique pilot_stamps and secret_ids
289+
pilot_stamps = [
290+
entry["PilotStamp"] for entry in pilot_to_secret_id_mapping_values
291+
]
292+
secret_ids = [
293+
entry["PilotSecretID"] for entry in pilot_to_secret_id_mapping_values
294+
]
295+
296+
# 2. Bulk fetch pilot and secret info
297+
pilots = await self.get_pilots_by_stamp_bulk(pilot_stamps)
298+
secrets = await self.get_secrets_by_secret_ids_bulk(secret_ids)
299+
300+
# 3. Build lookup maps
301+
pilot_vo_map = {pilot["PilotStamp"]: pilot["VO"] for pilot in pilots}
302+
secret_vo_map = {secret["SecretID"]: secret["SecretVO"] for secret in secrets}
303+
304+
# 4. Validate access
305+
bad_mapping = []
306+
307+
for mapping in pilot_to_secret_id_mapping_values:
308+
pilot_stamp = mapping["PilotStamp"]
309+
secret_id = mapping["PilotSecretID"]
310+
311+
pilot_vo = pilot_vo_map[pilot_stamp]
312+
secret_vo = secret_vo_map[secret_id]
313+
314+
# If secret_vo is set to NULL, everybody can access it
315+
if not secret_vo:
316+
continue
317+
318+
# Access allowed only if VOs match or secret_vo is open (None)
319+
if secret_vo is not None and pilot_vo != secret_vo:
320+
bad_mapping.append(
321+
{
322+
"pilot_stamp": pilot_stamp,
323+
"given_vo": pilot_vo,
324+
"expected_vo": secret_vo,
325+
}
326+
)
327+
328+
if bad_mapping:
329+
raise BadPilotVOError(data={"bad_mapping": str(bad_mapping)})
330+
264331
async def set_secret_expirations_bulk(
265332
self, secret_ids: list[int], pilot_secret_expiration_dates: list[DateTime]
266333
):
@@ -288,6 +355,7 @@ async def get_pilots_by_stamp_bulk(self, pilot_stamps: list[str]) -> list[dict]:
288355
PilotAgents,
289356
PilotNotFoundError,
290357
"pilot_stamp",
358+
"PilotStamp",
291359
pilot_stamps,
292360
)
293361

@@ -300,6 +368,7 @@ async def get_pilots_credentials_by_stamps_bulk(
300368
PilotToSecretMapping,
301369
CredentialsNotFoundError,
302370
"pilot_stamp",
371+
"PilotStamp",
303372
pilot_stamps,
304373
)
305374

@@ -310,6 +379,18 @@ async def get_secrets_by_hashed_secrets_bulk(self, hashed_secrets: list[str]):
310379
PilotSecrets,
311380
SecretNotFoundError,
312381
"hashed_secret",
382+
"HashedSecret",
313383
hashed_secrets,
314384
order_by=("secret_id", "asc"),
315385
)
386+
387+
async def get_secrets_by_secret_ids_bulk(self, secret_ids: list[int]):
388+
"""Bulk fetch secrets. Ensure all secrets are found, else raise an error."""
389+
return await fetch_records_bulk_or_raises(
390+
self.conn,
391+
PilotSecrets,
392+
SecretNotFoundError,
393+
"secret_id",
394+
"SecretID",
395+
secret_ids,
396+
)

diracx-db/src/diracx/db/sql/pilot_agents/schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class PilotSecrets(PilotAgentsDBBase):
7676
)
7777
secret_creation_time = DateNowColumn("SecretCreationDate")
7878
secret_expiration_date = NullColumn("SecretExpirationDate", DateTime(timezone=True))
79+
# To authorize only pilots from a specific VO to access a secret
80+
# Null VO => Can be used by everyone
81+
secret_vo = NullColumn("SecretVO", String(128))
7982

8083
__table_args__ = (UniqueConstraint("HashedSecret", name="uq_hashed_secret"),)
8184

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import hashlib
4-
import re
54
from datetime import datetime, timedelta, timezone
65
from typing import TYPE_CHECKING, Any, Type
76

@@ -151,21 +150,12 @@ def rows_to_dicts(rows):
151150
return [dict(row._mapping) for row in rows]
152151

153152

154-
# For efficiency
155-
_SNAKE_SPLIT_RE = re.compile(r"_+")
156-
157-
158-
def snake_to_pascal(snake_str: str) -> str:
159-
# Split the string using the precompiled regex
160-
# Capitalize each word and join to form PascalCase
161-
return "".join(word.capitalize() for word in _SNAKE_SPLIT_RE.split(snake_str))
162-
163-
164153
async def fetch_records_bulk_or_raises(
165154
conn: AsyncConnection,
166155
model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any
167156
missing_elements_error_cls: Type[GenericError],
168-
column_to_use: str,
157+
column_attribute_name: str,
158+
column_name: str,
169159
elements_to_fetch: list,
170160
order_by: tuple[str, str] | None = None,
171161
) -> list[dict]:
@@ -178,6 +168,7 @@ async def fetch_records_bulk_or_raises(
178168
self.conn,
179169
PilotAgents,
180170
PilotNotFound,
171+
"pilot_id",
181172
"PilotID",
182173
[1,2,3]
183174
)
@@ -186,7 +177,7 @@ async def fetch_records_bulk_or_raises(
186177
assert elements_to_fetch
187178

188179
# Get the column that needs to be in elements_to_fetch
189-
column = getattr(model, column_to_use)
180+
column = getattr(model, column_attribute_name)
190181

191182
# Create the request
192183
stmt = select(model).with_for_update().where(column.in_(elements_to_fetch))
@@ -209,13 +200,12 @@ async def fetch_records_bulk_or_raises(
209200
raise DBInBadStateError(detail="Seems to have duplicates in the database.")
210201

211202
# Checks if we have every elements we wanted
212-
camel_case_column_to_use = snake_to_pascal(column_to_use)
213-
found_keys = {row[camel_case_column_to_use] for row in results}
203+
found_keys = {row[column_name] for row in results}
214204
missing = set(elements_to_fetch) - found_keys
215205

216206
if missing:
217207
raise missing_elements_error_cls(
218-
data={camel_case_column_to_use: str(missing)}, detail=str(missing)
208+
data={column_name: str(missing)}, detail=str(missing)
219209
)
220210

221211
return results

0 commit comments

Comments
 (0)