77from diracx .core .config import Config
88from diracx .core .exceptions import (
99 ConfigurationError ,
10+ CredentialsAlreadyExistError ,
1011 PilotAlreadyExistsError ,
1112 PilotNotFoundError ,
1213)
13- from diracx .core .models import PilotCredentialsInfo , PilotCredentialsResponse
14+ from diracx .core .models import PilotCredentialsInfo , PilotSecretsInfo , PilotStampInfo
1415from diracx .core .settings import AuthSettings
1516from diracx .db .sql import PilotAgentsDB
1617
@@ -52,7 +53,7 @@ async def create_credentials(
5253 secret ["SecretCreationDate" ]
5354 + timedelta (
5455 seconds = (
55- expiration_minutes
56+ expiration_minutes * 60
5657 if expiration_minutes
5758 else settings .pilot_secret_expire_seconds
5859 )
@@ -74,6 +75,36 @@ async def create_credentials(
7475 return random_secrets , hashed_secrets , expiration_dates_timestamps
7576
7677
78+ async def associate_pilots_with_secrets (
79+ pilot_db : PilotAgentsDB ,
80+ pilot_stamps : list [str ],
81+ secrets : list [str ] | None = None ,
82+ hashed_secrets : list [str ] | None = None ,
83+ ):
84+
85+ if not hashed_secrets :
86+ assert secrets
87+ hashed_secrets = [hash (secret ) for secret in secrets ]
88+
89+ # Get the secret ids to later associate them with pilots
90+ secrets_obj = await pilot_db .get_secrets_by_hashed_secrets_bulk (hashed_secrets )
91+ secret_ids = [secret ["SecretID" ] for secret in secrets_obj ]
92+
93+ if len (secret_ids ) == 1 :
94+ secret_ids = secret_ids * len (pilot_stamps )
95+
96+ # Associates pilots with their secrets
97+ pilot_to_secret_id_mapping_values = [
98+ {
99+ "PilotSecretID" : secret_id ,
100+ "PilotStamp" : pilot_stamp ,
101+ }
102+ for pilot_stamp , secret_id in zip (pilot_stamps , secret_ids )
103+ ]
104+
105+ await pilot_db .associate_pilots_with_secrets_bulk (pilot_to_secret_id_mapping_values )
106+
107+
77108async def add_pilot_credentials (
78109 pilot_stamps : list [str ],
79110 pilot_db : PilotAgentsDB ,
@@ -91,29 +122,25 @@ async def add_pilot_credentials(
91122 )
92123 )
93124
94- # Get the secret ids to later associate them with pilots
95- secrets = await pilot_db .get_secrets_by_hashed_secrets_bulk (hashed_secrets )
96- secret_ids = [secret ["SecretID" ] for secret in secrets ]
97-
98- # Associates pilots with their secrets
99- pilot_to_secret_id_mapping_values = [
100- {
101- "PilotSecretID" : secret_id ,
102- "PilotStamp" : pilot_stamp ,
103- }
104- for pilot_stamp , secret_id in zip (pilot_stamps , secret_ids )
105- ]
106- await pilot_db .associate_pilots_with_secrets_bulk (pilot_to_secret_id_mapping_values )
125+ try :
126+ await associate_pilots_with_secrets (
127+ pilot_db = pilot_db , hashed_secrets = hashed_secrets , pilot_stamps = pilot_stamps
128+ )
129+ except CredentialsAlreadyExistError as e :
130+ # Undo everything in case of an error.
131+ # TODO: Validate in PR
132+ await pilot_db .conn .rollback ()
133+ raise e
107134
108135 return random_secrets , expiration_dates_timestamps
109136
110137
111138def create_pilot_credentials_response (
112- pilot_stamps : list [str | None ],
139+ pilot_stamps : list [str ],
113140 pilot_secrets : list [str ],
114141 pilot_expiration_dates : list [int ],
115- ) -> PilotCredentialsResponse :
116- credentials_list = [
142+ ) -> list [ PilotCredentialsInfo ] :
143+ return [
117144 PilotCredentialsInfo (
118145 pilot_stamp = pilot_stamp ,
119146 pilot_secret = secret ,
@@ -124,7 +151,22 @@ def create_pilot_credentials_response(
124151 )
125152 ]
126153
127- return PilotCredentialsResponse (pilot_credentials = credentials_list )
154+
155+ def create_secrets_response (
156+ pilot_secrets : list [str ],
157+ pilot_expiration_dates : list [int ],
158+ ) -> list [PilotSecretsInfo ]:
159+ return [
160+ PilotSecretsInfo (
161+ pilot_secret = secret ,
162+ pilot_secret_expires_in = expires_in ,
163+ )
164+ for secret , expires_in in zip (pilot_secrets , pilot_expiration_dates )
165+ ]
166+
167+
168+ def create_stamp_response (pilot_stamps : list [str ]) -> list [PilotStampInfo ]:
169+ return [PilotStampInfo (pilot_stamp = stamp ) for stamp in pilot_stamps ]
128170
129171
130172def get_registry_and_group_configuration (config : Config , vo : str ):
@@ -212,6 +254,8 @@ async def register_new_pilots(
212254 pilots_that_already_exist = set (pilot_stamps ) - set (
213255 literal_eval (e .detail )
214256 )
257+ else :
258+ raise ValueError ("Bad internal error." )
215259 except AttributeError as e2 :
216260 raise ValueError ("Must be defined and a set string representation" ) from e2
217261
0 commit comments