Skip to content

Commit 306941f

Browse files
fix: Fix, bug not caught by CI (sqlite not raising anything)
1 parent cc8de96 commit 306941f

File tree

4 files changed

+22
-22
lines changed

4 files changed

+22
-22
lines changed

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def hash(code: str):
157157
return hashlib.sha256(code.encode()).hexdigest()
158158

159159

160+
def raw_hash(code: str):
161+
return hashlib.sha256(code.encode()).digest()
162+
163+
160164
async def fetch_records_bulk_or_raises(
161165
conn: AsyncConnection,
162166
model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any

diracx-db/tests/pilot_agents/test_pilot_auth.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from diracx.db.exceptions import DBInBadStateError
1818
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
19-
from diracx.db.sql.utils.functions import hash
19+
from diracx.db.sql.utils.functions import raw_hash
2020
from diracx.testing.time import mock_sqlite_time
2121

2222
MAIN_VO = "lhcb"
@@ -70,7 +70,7 @@ async def add_secrets_and_time(
7070
stamps = [pilot["PilotStamp"] for pilot in add_stamps]
7171

7272
secrets = [f"AW0nd3rfulS3cr3t_{str(i)}" for i in range(len(stamps))]
73-
hashed_secrets = [hash(secret).encode() for secret in secrets]
73+
hashed_secrets = [raw_hash(secret) for secret in secrets]
7474

7575
# Add creds
7676
await pilot_agents_db.insert_unique_secrets_bulk(
@@ -115,7 +115,7 @@ async def add_secrets_and_time(
115115
async def verify_pilot_secret(
116116
pilot_stamp: str,
117117
pilot_db: PilotAgentsDB,
118-
hashed_secret: str,
118+
hashed_secret: bytes,
119119
frozen_time: freezegun.FreezeGun,
120120
) -> None:
121121

@@ -125,9 +125,7 @@ async def verify_pilot_secret(
125125
real_secret_uuid = pilot["PilotSecretUUID"]
126126

127127
# 2. Get the secret itself
128-
given_secrets = await pilot_db.get_secrets_by_hashed_secrets_bulk(
129-
[hashed_secret.encode()]
130-
)
128+
given_secrets = await pilot_db.get_secrets_by_hashed_secrets_bulk([hashed_secret])
131129
given_secret = given_secrets[0]
132130
given_secret_uuid = given_secret[
133131
"SecretUUID"
@@ -217,23 +215,23 @@ async def test_create_pilot_and_verify_secret(
217215
await verify_pilot_secret(
218216
pilot_db=pilot_agents_db,
219217
pilot_stamp=stamp,
220-
hashed_secret=hash(secret),
218+
hashed_secret=raw_hash(secret),
221219
frozen_time=frozen_time,
222220
)
223221

224222
with pytest.raises(SecretNotFoundError):
225223
await verify_pilot_secret(
226224
pilot_db=pilot_agents_db,
227225
pilot_stamp=stamps[0],
228-
hashed_secret=hash("I love stawberries :)"),
226+
hashed_secret=raw_hash("I love stawberries :)"),
229227
frozen_time=frozen_time,
230228
)
231229

232230
with pytest.raises(PilotNotFoundError):
233231
await verify_pilot_secret(
234232
pilot_db=pilot_agents_db,
235233
pilot_stamp="I am a spider",
236-
hashed_secret=hash(secrets[0]),
234+
hashed_secret=raw_hash(secrets[0]),
237235
frozen_time=frozen_time,
238236
)
239237

@@ -258,7 +256,7 @@ async def test_create_pilot_and_verify_secret_with_delay(
258256
await verify_pilot_secret(
259257
pilot_db=pilot_agents_db,
260258
pilot_stamp=stamps[0],
261-
hashed_secret=hash(secrets[0]),
259+
hashed_secret=raw_hash(secrets[0]),
262260
frozen_time=frozen_time,
263261
)
264262

@@ -281,7 +279,7 @@ async def test_create_pilot_and_verify_secret_too_much_secret_use(
281279
await verify_pilot_secret(
282280
pilot_db=pilot_agents_db,
283281
pilot_stamp=stamps[0],
284-
hashed_secret=hash(secrets[0]),
282+
hashed_secret=raw_hash(secrets[0]),
285283
frozen_time=frozen_time,
286284
)
287285

@@ -291,7 +289,7 @@ async def test_create_pilot_and_verify_secret_too_much_secret_use(
291289
await verify_pilot_secret(
292290
pilot_db=pilot_agents_db,
293291
pilot_stamp=stamps[0],
294-
hashed_secret=hash(secrets[0]),
292+
hashed_secret=raw_hash(secrets[0]),
295293
frozen_time=frozen_time,
296294
)
297295

@@ -316,6 +314,6 @@ async def test_create_pilot_and_login_with_bad_secret(
316314
await verify_pilot_secret(
317315
pilot_db=pilot_agents_db,
318316
pilot_stamp=stamps[0],
319-
hashed_secret=hash(secret),
317+
hashed_secret=raw_hash(secret),
320318
frozen_time=frozen_time,
321319
)

diracx-logic/src/diracx/logic/pilots/auth.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from diracx.db.sql import AuthDB, PilotAgentsDB
2929

3030
# TODO: Move this hash function in diracx-logic, and rename it
31-
from diracx.db.sql.utils.functions import hash
31+
from diracx.db.sql.utils.functions import raw_hash
3232
from diracx.logic.auth.token import exchange_token, get_token_info_from_refresh_flow
3333

3434

@@ -50,7 +50,7 @@ async def create_raw_secrets(
5050
# Can be customized
5151
random_secrets = [generate_pilot_secret() for _ in range(n)]
5252

53-
hashed_secrets = [hash(random_secret).encode() for random_secret in random_secrets]
53+
hashed_secrets = [raw_hash(random_secret) for random_secret in random_secrets]
5454

5555
# Insert secrets
5656
await pilot_db.insert_unique_secrets_bulk(
@@ -120,7 +120,7 @@ async def associate_pilots_with_secrets(
120120
):
121121

122122
# 1. Hash the secrets
123-
hashed_secrets = [hash(secret).encode() for secret in pilot_secrets]
123+
hashed_secrets = [raw_hash(secret) for secret in pilot_secrets]
124124

125125
# 2. Get the secret ids to later associate them with pilots
126126
secrets_obj = await pilot_db.get_secrets_by_hashed_secrets_bulk(hashed_secrets)
@@ -254,17 +254,15 @@ async def verify_pilot_credentials(
254254
available_properties: set[SecurityProperty],
255255
) -> tuple[AccessTokenPayload, RefreshTokenPayload | None]:
256256

257-
hashed_secret = hash(pilot_secret)
257+
hashed_secret = raw_hash(pilot_secret)
258258

259259
# 1. Get the pilot
260260
pilots = await pilot_db.get_pilots_by_stamp_bulk([pilot_stamp])
261261
pilot = dict(pilots[0]) # Semantic, assured by fetch_records_bulk_or_raises
262262
real_secret_uuid = pilot["PilotSecretUUID"]
263263

264264
# 2. Get the secret itself
265-
given_secrets = await pilot_db.get_secrets_by_hashed_secrets_bulk(
266-
[hashed_secret.encode()]
267-
)
265+
given_secrets = await pilot_db.get_secrets_by_hashed_secrets_bulk([hashed_secret])
268266
given_secret = given_secrets[0]
269267
given_secret_uuid = given_secret[
270268
"SecretUUID"

diracx-routers/tests/pilots/test_pilot_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from diracx.core.utils import extract_timestamp_from_uuid7
99
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
10-
from diracx.db.sql.utils import hash
10+
from diracx.db.sql.utils.functions import raw_hash
1111

1212
pytestmark = pytest.mark.enabled_dependencies(
1313
[
@@ -75,7 +75,7 @@ async def add_secrets_and_time(test_client, add_stamps, secret_duration_sec):
7575
stamps = [pilot["PilotStamp"] for pilot in add_stamps]
7676

7777
secrets = [f"AW0nd3rfulS3cr3t_{str(i)}" for i in range(len(stamps))]
78-
hashed_secrets = [hash(secret).encode() for secret in secrets]
78+
hashed_secrets = [raw_hash(secret) for secret in secrets]
7979

8080
# Add creds
8181
await pilot_agents_db.insert_unique_secrets_bulk(

0 commit comments

Comments
 (0)