88
99from diracx .core .exceptions import (
1010 BadPilotCredentialsError ,
11+ BadPilotVOError ,
1112 CredentialsNotFoundError ,
1213 OverusedSecretError ,
1314 PilotAlreadyExistsError ,
@@ -112,32 +113,40 @@ async def verify_pilot_secret(
112113 )
113114
114115 # 5. Verify the secret counter
115- if secret ["SecretGlobalUseCount" ] + 1 > secret ["SecretGlobalUseCountMax" ]:
116- raise OverusedSecretError (
117- data = {
118- "pilot_stamp" "pilot_hashed_secret" : pilot_hashed_secret ,
119- "secret_global_use_count" : secret ["SecretGlobalUseCount" ],
120- "secret_global_use_count_max" : secret ["SecretGlobalUseCountMax" ],
121- }
122- )
123-
124- # Now the pilot is authorized. Increment the counters (globally and locally).
116+ # 5.1 Only check if the SecretGlobalUseCountMax is defined
117+ # If not defined, there is an infinite use.
118+ if secret ["SecretGlobalUseCountMax" ]:
119+ # 5.2 Finite use, we check if we can still login
120+ if secret ["SecretGlobalUseCount" ] + 1 > secret ["SecretGlobalUseCountMax" ]:
121+ raise OverusedSecretError (
122+ data = {
123+ "pilot_stamp" "pilot_hashed_secret" : pilot_hashed_secret ,
124+ "secret_global_use_count" : secret ["SecretGlobalUseCount" ],
125+ "secret_global_use_count_max" : secret [
126+ "SecretGlobalUseCountMax"
127+ ],
128+ }
129+ )
130+
131+ # 6. Now the pilot is authorized, increment the counters (globally and locally).
125132 try :
126- # Increment the local count
133+ # 6.1 Increment the local count
127134 await self .increment_pilot_local_secret_and_last_time_use (
128135 pilot_secret_id = pilot_credentials ["PilotSecretID" ],
129136 pilot_stamp = pilot_credentials ["PilotStamp" ],
130137 )
131138
132- # Increment the global count
139+ # 6.2 Increment the global count
133140 await self .increment_global_secret_use (
134141 secret_id = pilot_credentials ["PilotSecretID" ]
135142 )
136143 except Exception as e : # Generic, to catch it.
137144 # Should NOT happen
138- # Still caught in case of an error in the counters
145+ # Wrapped in a try/catch to still catch in case of an error in the counters
139146 # Caught and raised here to avoid raising a 4XX error
140- raise DBInBadStateError (detail = "This should not happen." ) from e
147+ raise DBInBadStateError (
148+ detail = "This should not happen. Pilot has credentials, but has a corrupted secret."
149+ ) from e
141150
142151 async def add_pilots_bulk (
143152 self ,
@@ -175,6 +184,7 @@ async def add_pilots_credentials_bulk(
175184 self ,
176185 pilot_stamps : list [str ],
177186 pilot_hashed_secrets : list [str ],
187+ vo : str | None ,
178188 pilot_secret_use_count_max : int = 1 ,
179189 ) -> list [dict ]:
180190
@@ -186,6 +196,7 @@ async def add_pilots_credentials_bulk(
186196 await self .insert_unique_secrets_bulk (
187197 hashed_secrets = pilot_hashed_secrets ,
188198 secret_global_use_count_max = pilot_secret_use_count_max ,
199+ vo = vo ,
189200 )
190201
191202 # Get the secret ids to later associate them with pilots
@@ -207,13 +218,17 @@ async def add_pilots_credentials_bulk(
207218 return secrets
208219
209220 async def insert_unique_secrets_bulk (
210- self , hashed_secrets : list [str ], secret_global_use_count_max : int = 1
221+ self ,
222+ hashed_secrets : list [str ],
223+ vo : str | None ,
224+ secret_global_use_count_max : int = 1 ,
211225 ):
212226 """Bulk insert secrets. Raises an error in case of a Integrity violation."""
213227 values = [
214228 {
215229 "SecretGlobalUseCountMax" : secret_global_use_count_max ,
216230 "HashedSecret" : hashed_secret ,
231+ "SecretVO" : vo ,
217232 }
218233 for hashed_secret in hashed_secrets
219234 ]
@@ -234,13 +249,19 @@ async def insert_unique_secrets_bulk(
234249 ) from e
235250
236251 # Other errors to catch
237- raise DBInBadStateError ("Engine Specific error not caught" ) from e
252+ raise DBInBadStateError ("Engine Specific error not caught" + str ( e ) ) from e
238253
239254 async def associate_pilots_with_secrets_bulk (
240255 self , pilot_to_secret_id_mapping_values : list [dict [str , Any ]]
241256 ):
242257 """Bulk associate pilots with secrets. Raises an error in case of a Integrity violation."""
243258 # Better to give as a parameter pilot to secret associations, rather than associating here.
259+
260+ # First verify that pilots can access a certain secret
261+ await self .verify_that_pilot_can_access_secret_bulk (
262+ pilot_to_secret_id_mapping_values
263+ )
264+
244265 stmt = insert (PilotToSecretMapping ).values (pilot_to_secret_id_mapping_values )
245266
246267 try :
@@ -261,6 +282,52 @@ async def associate_pilots_with_secrets_bulk(
261282 detail = "at least one of these pilots already have a secret" ,
262283 ) from e
263284
285+ async def verify_that_pilot_can_access_secret_bulk (
286+ self , pilot_to_secret_id_mapping_values : list [dict [str , Any ]]
287+ ):
288+ # 1. Extract unique pilot_stamps and secret_ids
289+ pilot_stamps = [
290+ entry ["PilotStamp" ] for entry in pilot_to_secret_id_mapping_values
291+ ]
292+ secret_ids = [
293+ entry ["PilotSecretID" ] for entry in pilot_to_secret_id_mapping_values
294+ ]
295+
296+ # 2. Bulk fetch pilot and secret info
297+ pilots = await self .get_pilots_by_stamp_bulk (pilot_stamps )
298+ secrets = await self .get_secrets_by_secret_ids_bulk (secret_ids )
299+
300+ # 3. Build lookup maps
301+ pilot_vo_map = {pilot ["PilotStamp" ]: pilot ["VO" ] for pilot in pilots }
302+ secret_vo_map = {secret ["SecretID" ]: secret ["SecretVO" ] for secret in secrets }
303+
304+ # 4. Validate access
305+ bad_mapping = []
306+
307+ for mapping in pilot_to_secret_id_mapping_values :
308+ pilot_stamp = mapping ["PilotStamp" ]
309+ secret_id = mapping ["PilotSecretID" ]
310+
311+ pilot_vo = pilot_vo_map [pilot_stamp ]
312+ secret_vo = secret_vo_map [secret_id ]
313+
314+ # If secret_vo is set to NULL, everybody can access it
315+ if not secret_vo :
316+ continue
317+
318+ # Access allowed only if VOs match or secret_vo is open (None)
319+ if secret_vo is not None and pilot_vo != secret_vo :
320+ bad_mapping .append (
321+ {
322+ "pilot_stamp" : pilot_stamp ,
323+ "given_vo" : pilot_vo ,
324+ "expected_vo" : secret_vo ,
325+ }
326+ )
327+
328+ if bad_mapping :
329+ raise BadPilotVOError (data = {"bad_mapping" : str (bad_mapping )})
330+
264331 async def set_secret_expirations_bulk (
265332 self , secret_ids : list [int ], pilot_secret_expiration_dates : list [DateTime ]
266333 ):
@@ -288,6 +355,7 @@ async def get_pilots_by_stamp_bulk(self, pilot_stamps: list[str]) -> list[dict]:
288355 PilotAgents ,
289356 PilotNotFoundError ,
290357 "pilot_stamp" ,
358+ "PilotStamp" ,
291359 pilot_stamps ,
292360 )
293361
@@ -300,6 +368,7 @@ async def get_pilots_credentials_by_stamps_bulk(
300368 PilotToSecretMapping ,
301369 CredentialsNotFoundError ,
302370 "pilot_stamp" ,
371+ "PilotStamp" ,
303372 pilot_stamps ,
304373 )
305374
@@ -310,6 +379,18 @@ async def get_secrets_by_hashed_secrets_bulk(self, hashed_secrets: list[str]):
310379 PilotSecrets ,
311380 SecretNotFoundError ,
312381 "hashed_secret" ,
382+ "HashedSecret" ,
313383 hashed_secrets ,
314384 order_by = ("secret_id" , "asc" ),
315385 )
386+
387+ async def get_secrets_by_secret_ids_bulk (self , secret_ids : list [int ]):
388+ """Bulk fetch secrets. Ensure all secrets are found, else raise an error."""
389+ return await fetch_records_bulk_or_raises (
390+ self .conn ,
391+ PilotSecrets ,
392+ SecretNotFoundError ,
393+ "secret_id" ,
394+ "SecretID" ,
395+ secret_ids ,
396+ )
0 commit comments