Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ class UseStateFunction(Protocol):

def __call__(
self,
key: str,
default_value: dict[str, JsonSerializable] | None = None,
) -> Coroutine[None, None, dict[str, JsonSerializable]]: ...

Expand Down
8 changes: 4 additions & 4 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class BasicCrawler(Generic[TCrawlingContext]):
- and more.
"""

_CRAWLEE_STATE_KEY = 'CRAWLEE_STATE'

def __init__(
self,
*,
Expand Down Expand Up @@ -564,11 +566,9 @@ async def add_requests(
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
)

async def _use_state(
self, key: str, default_value: dict[str, JsonSerializable] | None = None
) -> dict[str, JsonSerializable]:
async def _use_state(self, default_value: dict[str, JsonSerializable] | None = None) -> dict[str, JsonSerializable]:
store = await self.get_key_value_store()
return await store.get_auto_saved_value(key, default_value)
return await store.get_auto_saved_value(self._CRAWLEE_STATE_KEY, default_value)

async def _save_crawler_state(self) -> None:
store = await self.get_key_value_store()
Expand Down
19 changes: 11 additions & 8 deletions src/crawlee/storages/_key_value_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload

Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient)

# Get resource clients from storage client
self._resource_client = storage_client.key_value_store(self._id)
self._autosave_lock = asyncio.Lock()

@property
@override
Expand Down Expand Up @@ -198,17 +200,18 @@ async def get_auto_saved_value(
"""
default_value = {} if default_value is None else default_value

if key in self._cache:
return self._cache[key]
async with self._autosave_lock:
if key in self._cache:
return self._cache[key]

value = await self.get_value(key, default_value)
value = await self.get_value(key, default_value)

if not isinstance(value, dict):
raise TypeError(
f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}'
)
if not isinstance(value, dict):
raise TypeError(
f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}'
)

self._cache[key] = value
self._cache[key] = value

self._ensure_persist_event()

Expand Down
41 changes: 35 additions & 6 deletions tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import os
import sys
from collections import Counter
from dataclasses import dataclass
from datetime import timedelta
Expand Down Expand Up @@ -710,13 +711,13 @@ async def test_context_use_state(key_value_store: KeyValueStore) -> None:

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
await context.use_state('state', {'hello': 'world'})
await context.use_state({'hello': 'world'})

await crawler.run(['https://hello.world'])

store = await crawler.get_key_value_store()

assert (await store.get_value('state')) == {'hello': 'world'}
assert (await store.get_value(BasicCrawler._CRAWLEE_STATE_KEY)) == {'hello': 'world'}


async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> None:
Expand All @@ -728,20 +729,20 @@ async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> Non

@crawler.router.handler('one')
async def handler_one(context: BasicCrawlingContext) -> None:
state = await context.use_state('state', {'hello': 'world'})
state = await context.use_state({'hello': 'world'})
state_in_handler_one.update(state)
state['hello'] = 'new_world'
await context.add_requests([Request.from_url('https://crawlee.dev/docs/quick-start', label='two')])

@crawler.router.handler('two')
async def handler_two(context: BasicCrawlingContext) -> None:
state = await context.use_state('state', {'hello': 'world'})
state = await context.use_state({'hello': 'world'})
state_in_handler_two.update(state)
state['hello'] = 'last_world'

@crawler.router.handler('three')
async def handler_three(context: BasicCrawlingContext) -> None:
state = await context.use_state('state', {'hello': 'world'})
state = await context.use_state({'hello': 'world'})
state_in_handler_three.update(state)

await crawler.run([Request.from_url('https://crawlee.dev/', label='one')])
Expand All @@ -759,7 +760,7 @@ async def handler_three(context: BasicCrawlingContext) -> None:
store = await crawler.get_key_value_store()

# The state in the KVS must match with the last set state
assert (await store.get_value('state')) == {'hello': 'last_world'}
assert (await store.get_value(BasicCrawler._CRAWLEE_STATE_KEY)) == {'hello': 'last_world'}


async def test_max_requests_per_crawl(httpbin: URL) -> None:
Expand Down Expand Up @@ -1046,3 +1047,31 @@ async def handler(context: BasicCrawlingContext) -> None:

data = await dataset.get_data()
assert data.items == [{'foo': 'bar'}]


@pytest.mark.skipif(sys.version_info[:3] < (3, 11), reason='asyncio.Barrier was introduced in Python 3.11.')
async def test_context_use_state_race_condition_in_handlers(key_value_store: KeyValueStore) -> None:
"""Two parallel handlers increment global variable obtained by `use_state` method.

Result should be incremented by 2.
Method `use_state` must be implemented in a way that prevents race conditions in such scenario."""
from asyncio import Barrier # type:ignore[attr-defined] # Test is skipped in older Python versions.

crawler = BasicCrawler()
store = await crawler.get_key_value_store()
await store.set_value(BasicCrawler._CRAWLEE_STATE_KEY, {'counter': 0})
handler_barrier = Barrier(2)

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
state = cast(dict[str, int], await context.use_state())
await handler_barrier.wait() # Block until both handlers get the state.
state['counter'] += 1
await handler_barrier.wait() # Block until both handlers increment the state.

await crawler.run(['https://crawlee.dev/', 'https://crawlee.dev/docs/quick-start'])

store = await crawler.get_key_value_store()
# Ensure that local state is pushed back to kvs.
await store.persist_autosaved_values()
assert (await store.get_value(BasicCrawler._CRAWLEE_STATE_KEY))['counter'] == 2
29 changes: 28 additions & 1 deletion tests/unit/storages/test_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import asyncio
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from itertools import chain, repeat
from typing import TYPE_CHECKING, cast
from unittest.mock import patch
from urllib.parse import urlparse

Expand Down Expand Up @@ -178,3 +179,29 @@ async def autosaved_within_deadline(key: str, expected_value: dict[str, str]) ->

value['hello'] = 'new_world'
assert await autosaved_within_deadline(key=key_name, expected_value={'hello': 'new_world'})


async def test_get_auto_saved_value_auto_save_race_conditions(key_value_store: KeyValueStore) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this reliably reproduce the error if you remove the lock?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried stress testing it 1000x times and it failed with old implementation all the time. But I acknowledge that it is not guaranteed and in theory it can pass. Same as in reality if such delay in kvs access occurs, it will most likely end in race condition, but in theoretical case it does not have to.
I did try coming up with test without such weakness, but I failed to find better approach.

This test should not have any false positives and should not be flaky, but has theoretical chance for false negative. In CI we run the same test for various Python implementations which makes already purely theoretical chance completely negligible.

"""Two parallel functions increment global variable obtained by `get_auto_saved_value`.

Result should be incremented by 2.
Method `get_auto_saved_value` must be implemented in a way that prevents race conditions in such scenario.
Test creates situation where first `get_auto_saved_value` call to kvs gets delayed. Such situation can happen
and unless handled, it can cause race condition in getting the state value."""
await key_value_store.set_value('state', {'counter': 0})

sleep_time_iterator = chain(iter([0.5]), repeat(0))

async def delayed_get_value(key: str, default_value: None) -> None:
await asyncio.sleep(next(sleep_time_iterator))
return await KeyValueStore.get_value(key_value_store, key=key, default_value=default_value)

async def increment_counter() -> None:
state = cast(dict[str, int], await key_value_store.get_auto_saved_value('state'))
state['counter'] += 1

with patch.object(key_value_store, 'get_value', delayed_get_value):
tasks = [asyncio.create_task(increment_counter()), asyncio.create_task(increment_counter())]
await asyncio.gather(*tasks)

assert (await key_value_store.get_auto_saved_value('state'))['counter'] == 2
Loading