11from __future__ import annotations
22
33from datetime import datetime , timezone
4- from typing import Any , Sequence
4+ from typing import Any
55
6- from sqlalchemy import RowMapping , bindparam , func
6+ from sqlalchemy import bindparam
77from sqlalchemy .exc import IntegrityError
8- from sqlalchemy .sql import delete , insert , select , update
8+ from sqlalchemy .sql import delete , insert , update
99
1010from diracx .core .exceptions import (
11- InvalidQueryError ,
1211 PilotAlreadyAssociatedWithJobError ,
13- PilotJobsNotFoundError ,
1412 PilotNotFoundError ,
1513)
1614from diracx .core .models import (
2119
2220from ..utils import (
2321 BaseSQLDB ,
24- _get_columns ,
25- apply_search_filters ,
26- apply_sort_constraints ,
27- fetch_records_bulk_or_raises ,
2822)
2923from .schema import (
3024 JobToPilotMapping ,
@@ -43,7 +37,7 @@ async def add_pilots_bulk(
4337 pilot_stamps : list [str ],
4438 vo : str ,
4539 grid_type : str = "DIRAC" ,
46- pilot_references : dict | None = None ,
40+ pilot_references : dict [ str , str ] | None = None ,
4741 ):
4842 """Bulk add pilots in the DB.
4943
@@ -85,7 +79,9 @@ async def delete_pilots_by_stamps_bulk(self, pilot_stamps: list[str]):
8579 if res .rowcount != len (pilot_stamps ):
8680 raise PilotNotFoundError (data = {"pilot_stamps" : str (pilot_stamps )})
8781
88- async def associate_pilot_with_jobs (self , job_to_pilot_mapping : list [dict ]):
82+ async def associate_pilot_with_jobs (
83+ self , job_to_pilot_mapping : list [dict [str , Any ]]
84+ ):
8985 """Associate a pilot with jobs.
9086
9187 job_to_pilot_mapping format:
@@ -182,61 +178,28 @@ async def update_pilot_fields_bulk(
182178 data = {"mapping" : str (pilot_stamps_to_fields_mapping )}
183179 )
184180
185- async def get_pilots_by_stamp_bulk (
186- self , pilot_stamps : list [str ]
187- ) -> Sequence [RowMapping ]:
188- """Bulk fetch pilots.
189-
190- Raises PilotNotFoundError if one of the stamp is not associated with a pilot.
191-
192- """
193- results = await fetch_records_bulk_or_raises (
194- self .conn ,
195- PilotAgents ,
196- PilotNotFoundError ,
197- "pilot_stamp" ,
198- "PilotStamp" ,
199- pilot_stamps ,
200- allow_no_result = True ,
201- )
202-
203- # Custom handling, to see which pilot_stamp does not exist (if so, say which one)
204- found_keys = {row ["PilotStamp" ] for row in results }
205- missing = set (pilot_stamps ) - found_keys
206-
207- if missing :
208- raise PilotNotFoundError (
209- data = {"pilot_stamp" : str (missing )},
210- detail = str (missing ),
211- non_existing_pilots = missing ,
212- )
213-
214- return results
215-
216- async def get_pilot_jobs_ids_by_pilot_id (self , pilot_id : int ) -> list [int ]:
217- """Fetch pilot jobs."""
218- job_to_pilot_mapping = await fetch_records_bulk_or_raises (
219- self .conn ,
220- JobToPilotMapping ,
221- PilotJobsNotFoundError ,
222- "pilot_id" ,
223- "PilotID" ,
224- [pilot_id ],
225- allow_more_than_one_result_per_input = True ,
226- allow_no_result = True ,
181+ async def search_pilots (
182+ self ,
183+ parameters : list [str ] | None ,
184+ search : list [SearchSpec ],
185+ sorts : list [SortSpec ],
186+ * ,
187+ distinct : bool = False ,
188+ per_page : int = 100 ,
189+ page : int | None = None ,
190+ ) -> tuple [int , list [dict [Any , Any ]]]:
191+ """Search for pilots in the database."""
192+ return await self .search (
193+ model = PilotAgents ,
194+ parameters = parameters ,
195+ search = search ,
196+ sorts = sorts ,
197+ distinct = distinct ,
198+ per_page = per_page ,
199+ page = page ,
227200 )
228201
229- return [mapping ["JobID" ] for mapping in job_to_pilot_mapping ]
230-
231- async def get_pilot_ids_by_stamps (self , pilot_stamps : list [str ]) -> list [int ]:
232- """Get pilot ids."""
233- # This function is currently needed while we are relying on pilot_ids instead of pilot_stamps
234- # (Ex: JobToPilotMapping)
235- pilots = await self .get_pilots_by_stamp_bulk (pilot_stamps )
236-
237- return [pilot ["PilotID" ] for pilot in pilots ]
238-
239- async def search (
202+ async def search_pilot_to_job_mapping (
240203 self ,
241204 parameters : list [str ] | None ,
242205 search : list [SearchSpec ],
@@ -247,39 +210,15 @@ async def search(
247210 page : int | None = None ,
248211 ) -> tuple [int , list [dict [Any , Any ]]]:
249212 """Search for pilots in the database."""
250- # TODO: Refactorize with the search function for jobs.
251- # Find which columns to select
252- columns = _get_columns ( PilotAgents . __table__ , parameters )
253-
254- stmt = select ( * columns )
255-
256- stmt = apply_search_filters (
257- PilotAgents . __table__ . columns . __getitem__ , stmt , search
213+ return await self . search (
214+ model = JobToPilotMapping ,
215+ parameters = parameters ,
216+ search = search ,
217+ sorts = sorts ,
218+ distinct = distinct ,
219+ per_page = per_page ,
220+ page = page ,
258221 )
259- stmt = apply_sort_constraints (
260- PilotAgents .__table__ .columns .__getitem__ , stmt , sorts
261- )
262-
263- if distinct :
264- stmt = stmt .distinct ()
265-
266- # Calculate total count before applying pagination
267- total_count_subquery = stmt .alias ()
268- total_count_stmt = select (func .count ()).select_from (total_count_subquery )
269- total = (await self .conn .execute (total_count_stmt )).scalar_one ()
270-
271- # Apply pagination
272- if page is not None :
273- if page < 1 :
274- raise InvalidQueryError ("Page must be a positive integer" )
275- if per_page < 1 :
276- raise InvalidQueryError ("Per page must be a positive integer" )
277- stmt = stmt .offset ((page - 1 ) * per_page ).limit (per_page )
278-
279- # Execute the query
280- return total , [
281- dict (row ._mapping ) async for row in (await self .conn .stream (stmt ))
282- ]
283222
284223 async def clear_pilots_bulk (
285224 self , cutoff_date : datetime , delete_only_aborted : bool
0 commit comments