11from __future__ import annotations
22
33import asyncio
4- from collections .abc import AsyncIterable , AsyncIterator , Iterable
5- from typing import TYPE_CHECKING
4+ import contextlib
5+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable
6+ from logging import getLogger
7+ from typing import Annotated
68
9+ from pydantic import BaseModel , ConfigDict , Field
710from typing_extensions import override
811
12+ from crawlee ._request import Request
913from crawlee ._utils .docs import docs_group
1014from crawlee .request_loaders ._request_loader import RequestLoader
1115
12- if TYPE_CHECKING :
13- from crawlee ._request import Request
16+ logger = getLogger (__name__ )
17+
18+
19+ class RequestListState (BaseModel ):
20+ model_config = ConfigDict (populate_by_name = True )
21+
22+ next_index : Annotated [int , Field (alias = 'nextIndex' )] = 0
23+ next_unique_key : Annotated [str | None , Field (alias = 'nextUniqueKey' )] = None
24+ in_progress : Annotated [set [str ], Field (alias = 'inProgress' )] = set ()
25+
26+
27+ class RequestListData (BaseModel ):
28+ requests : Annotated [list [Request ], Field ()]
1429
1530
1631@docs_group ('Request loaders' )
1732class RequestList (RequestLoader ):
18- """Represents a (potentially very large) list of URLs to crawl.
19-
20- Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently
21- intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory
22- storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue:
23- https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`.
24- """
33+ """Represents a (potentially very large) list of URLs to crawl."""
2534
2635 def __init__ (
2736 self ,
2837 requests : Iterable [str | Request ] | AsyncIterable [str | Request ] | None = None ,
2938 name : str | None = None ,
39+ persist_state_key : str | None = None ,
40+ persist_requests_key : str | None = None ,
3041 ) -> None :
3142 """Initialize a new instance.
3243
3344 Args:
3445 requests: The request objects (or their string representations) to be added to the provider.
3546 name: A name of the request list.
47+ persist_state_key: A key for persisting the progress information of the RequestList.
48+ If you do not pass a key but pass a `name`, a key will be derived using the name.
49+ Otherwise, state will not be persisted.
50+ persist_requests_key: A key for persisting the request data loaded from the `requests` iterator.
51+ If specified, the request data will be stored in the KeyValueStore to make sure that they don't change
52+ over time. This is useful if the `requests` iterator pulls the data dynamically.
3653 """
54+ from crawlee ._utils .recoverable_state import RecoverableState # noqa: PLC0415
55+
3756 self ._name = name
3857 self ._handled_count = 0
3958 self ._assumed_total_count = 0
4059
41- self ._in_progress = set [str ]()
42- self ._next : Request | None = None
60+ self ._next : tuple [Request | None , Request | None ] = (None , None )
61+
62+ if persist_state_key is None and name is not None :
63+ persist_state_key = f'SDK_REQUEST_LIST_STATE-{ name } '
64+
65+ self ._state = RecoverableState (
66+ default_state = RequestListState (),
67+ persistence_enabled = bool (persist_state_key ),
68+ persist_state_key = persist_state_key or '' ,
69+ logger = logger ,
70+ )
71+
72+ self ._persist_request_data = bool (persist_requests_key )
73+
74+ self ._requests_data = RecoverableState (
75+ default_state = RequestListData (requests = []),
76+ # With request data persistence enabled, a snapshot of the requests will be done on initialization
77+ persistence_enabled = 'explicit_only' if self ._persist_request_data else False ,
78+ persist_state_key = persist_requests_key or '' ,
79+ logger = logger ,
80+ )
4381
4482 if isinstance (requests , AsyncIterable ):
4583 self ._requests = requests .__aiter__ ()
@@ -50,6 +88,53 @@ def __init__(
5088
5189 self ._requests_lock : asyncio .Lock | None = None
5290
91+ async def _get_state (self ) -> RequestListState :
92+ # If state is already initialized, we are done
93+ if self ._state .is_initialized :
94+ return self ._state .current_value
95+
96+ # Initialize recoverable state
97+ await self ._state .initialize ()
98+ await self ._requests_data .initialize ()
99+
100+ # Initialize lock if necessary
101+ if self ._requests_lock is None :
102+ self ._requests_lock = asyncio .Lock ()
103+
104+ # If the RequestList is configured to persist request data, ensure that a copy of request data is used
105+ if self ._persist_request_data :
106+ async with self ._requests_lock :
107+ if not await self ._requests_data .has_persisted_state ():
108+ self ._requests_data .current_value .requests = [
109+ request if isinstance (request , Request ) else Request .from_url (request )
110+ async for request in self ._requests
111+ ]
112+ await self ._requests_data .persist_state ()
113+
114+ self ._requests = self ._iterate_in_threadpool (
115+ self ._requests_data .current_value .requests [self ._state .current_value .next_index :]
116+ )
117+ # If not using persistent request data, advance the request iterator
118+ else :
119+ async with self ._requests_lock :
120+ for _ in range (self ._state .current_value .next_index ):
121+ with contextlib .suppress (StopAsyncIteration ):
122+ await self ._requests .__anext__ ()
123+
124+ # Check consistency of the stored state and the request iterator
125+ if (unique_key_to_check := self ._state .current_value .next_unique_key ) is not None :
126+ await self ._ensure_next_request ()
127+
128+ next_unique_key = self ._next [0 ].unique_key if self ._next [0 ] is not None else None
129+ if next_unique_key != unique_key_to_check :
130+ raise RuntimeError (
131+ f"""Mismatch at index {
132+ self ._state .current_value .next_index
133+ } in persisted requests - Expected unique key `{ unique_key_to_check } `, got `{ next_unique_key } `"""
134+ )
135+
136+ return self ._state .current_value
137+
53138 @property
54139 def name (self ) -> str | None :
55140 return self ._name
@@ -65,42 +150,62 @@ async def get_total_count(self) -> int:
65150 @override
66151 async def is_empty (self ) -> bool :
67152 await self ._ensure_next_request ()
68- return self ._next is None
153+ return self ._next [ 0 ] is None
69154
70155 @override
71156 async def is_finished (self ) -> bool :
72- return len (self ._in_progress ) == 0 and await self .is_empty ()
157+ state = await self ._get_state ()
158+ return len (state .in_progress ) == 0 and await self .is_empty ()
73159
74160 @override
75161 async def fetch_next_request (self ) -> Request | None :
162+ await self ._get_state ()
76163 await self ._ensure_next_request ()
77164
78- if self ._next is None :
165+ if self ._next [ 0 ] is None :
79166 return None
80167
81- self ._in_progress .add (self ._next .id )
168+ state = await self ._get_state ()
169+ state .in_progress .add (self ._next [0 ].id )
82170 self ._assumed_total_count += 1
83171
84- next_request = self ._next
85- self ._next = None
172+ next_request = self ._next [0 ]
173+ if next_request is not None :
174+ state .next_index += 1
175+ state .next_unique_key = self ._next [1 ].unique_key if self ._next [1 ] is not None else None
176+
177+ self ._next = (self ._next [1 ], None )
178+ await self ._ensure_next_request ()
86179
87180 return next_request
88181
89182 @override
90183 async def mark_request_as_handled (self , request : Request ) -> None :
91184 self ._handled_count += 1
92- self ._in_progress .remove (request .id )
185+ state = await self ._get_state ()
186+ state .in_progress .remove (request .id )
93187
94188 async def _ensure_next_request (self ) -> None :
189+ await self ._get_state ()
190+
95191 if self ._requests_lock is None :
96192 self ._requests_lock = asyncio .Lock ()
97193
98- try :
99- async with self ._requests_lock :
100- if self ._next is None :
101- self ._next = self ._transform_request (await self ._requests .__anext__ ())
102- except StopAsyncIteration :
103- self ._next = None
194+ async with self ._requests_lock :
195+ if None in self ._next :
196+ if self ._next [0 ] is None :
197+ to_enqueue = [item async for item in self ._dequeue_requests (2 )]
198+ self ._next = (to_enqueue [0 ], to_enqueue [1 ])
199+ else :
200+ to_enqueue = [item async for item in self ._dequeue_requests (1 )]
201+ self ._next = (self ._next [0 ], to_enqueue [0 ])
202+
203+ async def _dequeue_requests (self , count : int ) -> AsyncGenerator [Request | None ]:
204+ for _ in range (count ):
205+ try :
206+ yield self ._transform_request (await self ._requests .__anext__ ())
207+ except StopAsyncIteration : # noqa: PERF203
208+ yield None
104209
105210 async def _iterate_in_threadpool (self , iterable : Iterable [str | Request ]) -> AsyncIterator [str | Request ]:
106211 """Inspired by a function of the same name from encode/starlette."""
0 commit comments