Skip to content

Commit de29c0e

Browse files
authored
fix: Improve Request.user_data serialization (#540)
- resolves #524 This adds validation to `Request.user_data` so that the user cannot pass in data that is not JSON-serializable.
1 parent e8fc644 commit de29c0e

File tree

3 files changed

+145
-42
lines changed

3 files changed

+145
-42
lines changed

src/crawlee/_request.py

Lines changed: 97 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Iterator, MutableMapping
56
from datetime import datetime
67
from decimal import Decimal
78
from enum import Enum
8-
from typing import Annotated, Any
9-
10-
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
9+
from typing import Annotated, Any, cast
10+
11+
from pydantic import (
12+
BaseModel,
13+
BeforeValidator,
14+
ConfigDict,
15+
Field,
16+
JsonValue,
17+
PlainSerializer,
18+
PlainValidator,
19+
TypeAdapter,
20+
)
1121
from typing_extensions import Self
1222

1323
from crawlee._types import EnqueueStrategy, HttpMethod
@@ -28,6 +38,64 @@ class RequestState(Enum):
2838
SKIPPED = 7
2939

3040

41+
class CrawleeRequestData(BaseModel):
42+
"""Crawlee-specific configuration stored in the `user_data`."""
43+
44+
max_retries: Annotated[int | None, Field(alias='maxRetries')] = None
45+
"""Maximum number of retries for this request. Allows to override the global `max_request_retries` option of
46+
`BasicCrawler`."""
47+
48+
enqueue_strategy: Annotated[str | None, Field(alias='enqueueStrategy')] = None
49+
50+
state: RequestState | None = None
51+
"""Describes the request's current lifecycle state."""
52+
53+
session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None
54+
55+
skip_navigation: Annotated[bool, Field(alias='skipNavigation')] = False
56+
57+
last_proxy_tier: Annotated[int | None, Field(alias='lastProxyTier')] = None
58+
59+
forefront: Annotated[bool, Field()] = False
60+
61+
62+
class UserData(BaseModel, MutableMapping[str, JsonValue]):
63+
"""Represents the `user_data` part of a Request.
64+
65+
Apart from the well-known attributes (`label` and `__crawlee`), it can also contain arbitrary JSON-compatible
66+
values.
67+
"""
68+
69+
model_config = ConfigDict(extra='allow')
70+
__pydantic_extra__: dict[str, JsonValue] = Field(init=False) # pyright: ignore
71+
72+
crawlee_data: Annotated[CrawleeRequestData | None, Field(alias='__crawlee')] = None
73+
label: Annotated[str | None, Field()] = None
74+
75+
def __getitem__(self, key: str) -> JsonValue:
76+
return self.__pydantic_extra__[key]
77+
78+
def __setitem__(self, key: str, value: JsonValue) -> None:
79+
if key == 'label':
80+
if value is not None and not isinstance(value, str):
81+
raise ValueError('`label` must be str or None')
82+
83+
self.label = value
84+
self.__pydantic_extra__[key] = value
85+
86+
def __delitem__(self, key: str) -> None:
87+
del self.__pydantic_extra__[key]
88+
89+
def __iter__(self) -> Iterator[str]: # type: ignore
90+
yield from self.__pydantic_extra__
91+
92+
def __len__(self) -> int:
93+
return len(self.__pydantic_extra__)
94+
95+
96+
user_data_adapter = TypeAdapter(UserData)
97+
98+
3199
class BaseRequestData(BaseModel):
32100
"""Data needed to create a new crawling request."""
33101

@@ -58,7 +126,20 @@ class BaseRequestData(BaseModel):
58126

59127
data: Annotated[dict[str, Any] | None, Field(default_factory=dict)] = None
60128

61-
user_data: Annotated[dict[str, Any], Field(alias='userData', default_factory=dict)]
129+
user_data: Annotated[
130+
dict[str, JsonValue], # Internally, the model contains `UserData`, this is just for convenience
131+
Field(alias='userData', default_factory=lambda: UserData()),
132+
PlainValidator(user_data_adapter.validate_python),
133+
PlainSerializer(
134+
lambda instance: user_data_adapter.dump_python(
135+
instance,
136+
by_alias=True,
137+
exclude_none=True,
138+
exclude_unset=True,
139+
exclude_defaults=True,
140+
)
141+
),
142+
]
62143
"""Custom user data assigned to the request. Use this to save any request related data to the
63144
request's scope, keeping them accessible on retries, failures etc.
64145
"""
@@ -216,14 +297,16 @@ def from_base_request_data(cls, base_request_data: BaseRequestData, *, id: str |
216297
@property
217298
def label(self) -> str | None:
218299
"""A string used to differentiate between arbitrary request types."""
219-
if 'label' in self.user_data:
220-
return str(self.user_data['label'])
221-
return None
300+
return cast(UserData, self.user_data).label
222301

223302
@property
224303
def crawlee_data(self) -> CrawleeRequestData:
225304
"""Crawlee-specific configuration stored in the user_data."""
226-
return CrawleeRequestData.model_validate(self.user_data.get('__crawlee', {}))
305+
user_data = cast(UserData, self.user_data)
306+
if user_data.crawlee_data is None:
307+
user_data.crawlee_data = CrawleeRequestData()
308+
309+
return user_data.crawlee_data
227310

228311
@property
229312
def state(self) -> RequestState | None:
@@ -232,8 +315,7 @@ def state(self) -> RequestState | None:
232315

233316
@state.setter
234317
def state(self, new_state: RequestState) -> None:
235-
self.user_data.setdefault('__crawlee', {})
236-
self.user_data['__crawlee']['state'] = new_state
318+
self.crawlee_data.state = new_state
237319

238320
@property
239321
def max_retries(self) -> int | None:
@@ -242,8 +324,7 @@ def max_retries(self) -> int | None:
242324

243325
@max_retries.setter
244326
def max_retries(self, new_max_retries: int) -> None:
245-
self.user_data.setdefault('__crawlee', {})
246-
self.user_data['__crawlee']['maxRetries'] = new_max_retries
327+
self.crawlee_data.max_retries = new_max_retries
247328

248329
@property
249330
def session_rotation_count(self) -> int | None:
@@ -252,8 +333,7 @@ def session_rotation_count(self) -> int | None:
252333

253334
@session_rotation_count.setter
254335
def session_rotation_count(self, new_session_rotation_count: int) -> None:
255-
self.user_data.setdefault('__crawlee', {})
256-
self.user_data['__crawlee']['sessionRotationCount'] = new_session_rotation_count
336+
self.crawlee_data.session_rotation_count = new_session_rotation_count
257337

258338
@property
259339
def enqueue_strategy(self) -> EnqueueStrategy:
@@ -266,8 +346,7 @@ def enqueue_strategy(self) -> EnqueueStrategy:
266346

267347
@enqueue_strategy.setter
268348
def enqueue_strategy(self, new_enqueue_strategy: EnqueueStrategy) -> None:
269-
self.user_data.setdefault('__crawlee', {})
270-
self.user_data['__crawlee']['enqueueStrategy'] = str(new_enqueue_strategy)
349+
self.crawlee_data.enqueue_strategy = new_enqueue_strategy
271350

272351
@property
273352
def last_proxy_tier(self) -> int | None:
@@ -276,8 +355,7 @@ def last_proxy_tier(self) -> int | None:
276355

277356
@last_proxy_tier.setter
278357
def last_proxy_tier(self, new_value: int) -> None:
279-
self.user_data.setdefault('__crawlee', {})
280-
self.user_data['__crawlee']['lastProxyTier'] = new_value
358+
self.crawlee_data.last_proxy_tier = new_value
281359

282360
@property
283361
def forefront(self) -> bool:
@@ -286,32 +364,10 @@ def forefront(self) -> bool:
286364

287365
@forefront.setter
288366
def forefront(self, new_value: bool) -> None:
289-
self.user_data.setdefault('__crawlee', {})
290-
self.user_data['__crawlee']['forefront'] = new_value
367+
self.crawlee_data.forefront = new_value
291368

292369

293370
class RequestWithLock(Request):
294371
"""A crawling request with information about locks."""
295372

296373
lock_expires_at: Annotated[datetime, Field(alias='lockExpiresAt')]
297-
298-
299-
class CrawleeRequestData(BaseModel):
300-
"""Crawlee-specific configuration stored in the user_data."""
301-
302-
max_retries: Annotated[int | None, Field(alias='maxRetries')] = None
303-
"""Maximum number of retries for this request. Allows to override the global `max_request_retries` option of
304-
`BasicCrawler`."""
305-
306-
enqueue_strategy: Annotated[str | None, Field(alias='enqueueStrategy')] = None
307-
308-
state: RequestState | None = None
309-
"""Describes the request's current lifecycle state."""
310-
311-
session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None
312-
313-
skip_navigation: Annotated[bool, Field(alias='skipNavigation')] = False
314-
315-
last_proxy_tier: Annotated[int | None, Field(alias='lastProxyTier')] = None
316-
317-
forefront: Annotated[bool, Field()] = False

tests/unit/basic_crawler/test_basic_crawler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ async def test_enqueue_strategy(test_input: AddRequestsTestInput) -> None:
390390
crawler = BasicCrawler(request_provider=RequestList([Request.from_url('https://someplace.com/', label='start')]))
391391

392392
@crawler.router.handler('start')
393-
async def default_handler(context: BasicCrawlingContext) -> None:
393+
async def start_handler(context: BasicCrawlingContext) -> None:
394394
await context.add_requests(
395395
test_input.requests,
396396
**test_input.kwargs,

tests/unit/storages/test_request_queue.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING
66

77
import pytest
8+
from pydantic import ValidationError
89

910
from crawlee import Request
1011
from crawlee.storages import RequestQueue
@@ -162,3 +163,49 @@ async def test_add_batched_requests(
162163

163164
# Confirm the queue is empty after processing all requests
164165
assert await request_queue.is_empty() is True
166+
167+
168+
async def test_invalid_user_data_serialization() -> None:
169+
with pytest.raises(ValidationError):
170+
Request.from_url(
171+
'https://crawlee.dev',
172+
user_data={
173+
'foo': datetime(year=2020, month=7, day=4, tzinfo=timezone.utc),
174+
'bar': {datetime(year=2020, month=4, day=7, tzinfo=timezone.utc)},
175+
},
176+
)
177+
178+
179+
async def test_user_data_serialization(request_queue: RequestQueue) -> None:
180+
request = Request.from_url(
181+
'https://crawlee.dev',
182+
user_data={
183+
'hello': 'world',
184+
'foo': 42,
185+
},
186+
)
187+
188+
await request_queue.add_request(request)
189+
190+
dequeued_request = await request_queue.fetch_next_request()
191+
assert dequeued_request is not None
192+
193+
assert dequeued_request.user_data['hello'] == 'world'
194+
assert dequeued_request.user_data['foo'] == 42
195+
196+
197+
async def test_complex_user_data_serialization(request_queue: RequestQueue) -> None:
198+
request = Request.from_url('https://crawlee.dev')
199+
request.user_data['hello'] = 'world'
200+
request.user_data['foo'] = 42
201+
request.crawlee_data.max_retries = 1
202+
203+
await request_queue.add_request(request)
204+
205+
dequeued_request = await request_queue.fetch_next_request()
206+
assert dequeued_request is not None
207+
208+
data = dequeued_request.model_dump(by_alias=True)
209+
assert data['userData']['hello'] == 'world'
210+
assert data['userData']['foo'] == 42
211+
assert data['userData']['__crawlee'] == {'maxRetries': 1}

0 commit comments

Comments
 (0)