diff --git a/docs/guides/code_examples/playwright_crawler_adaptive/init_prediction.py b/docs/guides/code_examples/playwright_crawler_adaptive/init_prediction.py index b07b1592ae..a8409d6150 100644 --- a/docs/guides/code_examples/playwright_crawler_adaptive/init_prediction.py +++ b/docs/guides/code_examples/playwright_crawler_adaptive/init_prediction.py @@ -12,6 +12,8 @@ class CustomRenderingTypePredictor(RenderingTypePredictor): def __init__(self) -> None: + super().__init__() + self._learning_data = list[tuple[Request, RenderingType]]() def predict(self, request: Request) -> RenderingTypePrediction: diff --git a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py index 673e8678c6..84c477dc13 100644 --- a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +++ b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py @@ -205,6 +205,7 @@ async def adaptive_pre_navigation_hook_pw(context: PlaywrightPreNavCrawlingConte self._additional_context_managers = [ *self._additional_context_managers, + self.rendering_type_predictor, static_crawler.statistics, playwright_crawler.statistics, playwright_crawler._browser_pool, # noqa: SLF001 # Intentional access to private member. diff --git a/src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py b/src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py index c9352ac65a..8e7fe5b58f 100644 --- a/src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py +++ b/src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py @@ -1,23 +1,49 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from itertools import zip_longest +from logging import getLogger from statistics import mean -from typing import Literal +from typing import TYPE_CHECKING, Annotated, Literal from urllib.parse import urlparse from jaro import jaro_winkler_metric +from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator from sklearn.linear_model import LogisticRegression from typing_extensions import override -from crawlee import Request from crawlee._utils.docs import docs_group +from crawlee._utils.recoverable_state import RecoverableState + +from ._utils import sklearn_model_serializer, sklearn_model_validator + +if TYPE_CHECKING: + from types import TracebackType + + from crawlee import Request + +logger = getLogger(__name__) UrlComponents = list[str] RenderingType = Literal['static', 'client only'] FeatureVector = tuple[float, float] +class RenderingTypePredictorState(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + model: Annotated[ + LogisticRegression, + Field(LogisticRegression), + PlainValidator(sklearn_model_validator), + PlainSerializer(sklearn_model_serializer), + ] + + labels_coefficients: Annotated[defaultdict[str, float], Field(alias='labelsCoefficients')] + + @docs_group('Other') @dataclass(frozen=True) class RenderingTypePrediction: @@ -36,6 +62,11 @@ class RenderingTypePrediction: class RenderingTypePredictor(ABC): """Stores rendering type for previously crawled URLs and predicts the rendering type for unvisited urls.""" + def __init__(self) -> None: + """Initialize a new instance.""" + # Flag to indicate the state. + self._active = False + @abstractmethod def predict(self, request: Request) -> RenderingTypePrediction: """Get `RenderingTypePrediction` based on the input request. @@ -53,6 +84,32 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None: rendering_type: Known suitable `RenderingType`. """ + async def initialize(self) -> None: + """Initialize additional resources required for the predictor operation.""" + if self._active: + raise RuntimeError(f'The {self.__class__.__name__} is already active.') + self._active = True + + async def clear(self) -> None: + """Clear and release additional resources used by the predictor.""" + if not self._active: + raise RuntimeError(f'The {self.__class__.__name__} is not active.') + self._active = False + + async def __aenter__(self) -> RenderingTypePredictor: + """Initialize the predictor upon entering the context manager.""" + await self.initialize() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + """Clear the predictor upon exiting the context manager.""" + await self.clear() + @docs_group('Other') class DefaultRenderingTypePredictor(RenderingTypePredictor): @@ -62,24 +119,59 @@ class DefaultRenderingTypePredictor(RenderingTypePredictor): https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html """ - def __init__(self, detection_ratio: float = 0.1) -> None: + def __init__( + self, + detection_ratio: float = 0.1, + *, + persistence_enabled: bool = False, + persist_state_key: str = 'rendering-type-predictor-state', + ) -> None: """Initialize a new instance. Args: detection_ratio: A number between 0 and 1 that determines the desired ratio of rendering type detections. + persist_state_key: Key in the key-value storage where the trained model parameters will be saved. + If None, defaults to 'rendering-type-predictor-state'. + persistence_enabled: Whether to enable persistence of the trained model parameters for reuse. + """ + super().__init__() + self._rendering_type_detection_results: dict[RenderingType, dict[str, list[UrlComponents]]] = { 'static': defaultdict(list), 'client only': defaultdict(list), } - self._model = LogisticRegression(max_iter=1000) self._detection_ratio = max(0, min(1, detection_ratio)) # Used to increase detection probability recommendation for initial recommendations of each label. # Reaches 1 (no additional increase) after n samples of specific label is already present in # `self._rendering_type_detection_results`. n = 3 - self._labels_coefficients: dict[str, float] = defaultdict(lambda: n + 2) + + self._state = RecoverableState( + default_state=RenderingTypePredictorState( + model=LogisticRegression(max_iter=1000), labels_coefficients=defaultdict(lambda: n + 2) + ), + persist_state_key=persist_state_key, + persistence_enabled=persistence_enabled, + logger=logger, + ) + + @override + async def initialize(self) -> None: + """Get current state of the predictor.""" + await super().initialize() + + if not self._state.is_initialized: + await self._state.initialize() + + @override + async def clear(self) -> None: + """Clear the predictor state.""" + await super().clear() + + if self._state.is_initialized: + await self._state.teardown() @override def predict(self, request: Request) -> RenderingTypePrediction: @@ -91,11 +183,12 @@ def predict(self, request: Request) -> RenderingTypePrediction: similarity_threshold = 0.1 # Prediction probability difference threshold to consider prediction unreliable. label = request.label or '' - if self._rendering_type_detection_results['static'] or self._rendering_type_detection_results['client only']: + # Check that the model has already been fitted. + if hasattr(self._state.current_value.model, 'coef_'): url_feature = self._calculate_feature_vector(get_url_components(request.url), label) # Are both calls expensive? - prediction = self._model.predict([url_feature])[0] - probability = self._model.predict_proba([url_feature])[0] + prediction = self._state.current_value.model.predict([url_feature])[0] + probability = self._state.current_value.model.predict_proba([url_feature])[0] if abs(probability[0] - probability[1]) < similarity_threshold: # Prediction not reliable. @@ -103,7 +196,7 @@ def predict(self, request: Request) -> RenderingTypePrediction: else: detection_probability_recommendation = self._detection_ratio # Increase recommendation for uncommon labels. - detection_probability_recommendation *= self._labels_coefficients[label] + detection_probability_recommendation *= self._state.current_value.labels_coefficients[label] return RenderingTypePrediction( rendering_type=('client only', 'static')[int(prediction)], @@ -122,8 +215,8 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None: """ label = request.label or '' self._rendering_type_detection_results[rendering_type][label].append(get_url_components(request.url)) - if self._labels_coefficients[label] > 1: - self._labels_coefficients[label] -= 1 + if self._state.current_value.labels_coefficients[label] > 1: + self._state.current_value.labels_coefficients[label] -= 1 self._retrain() def _retrain(self) -> None: @@ -137,7 +230,7 @@ def _retrain(self) -> None: x.append(self._calculate_feature_vector(url_components, label)) y.append(encoded_rendering_type) - self._model.fit(x, y) + self._state.current_value.model.fit(x, y) def _calculate_mean_similarity(self, url: UrlComponents, label: str, rendering_type: RenderingType) -> float: if not self._rendering_type_detection_results[rendering_type][label]: diff --git a/src/crawlee/crawlers/_adaptive_playwright/_utils.py b/src/crawlee/crawlers/_adaptive_playwright/_utils.py new file mode 100644 index 0000000000..5a665b041b --- /dev/null +++ b/src/crawlee/crawlers/_adaptive_playwright/_utils.py @@ -0,0 +1,32 @@ +from typing import Any + +import numpy as np +from sklearn.linear_model import LogisticRegression + + +def sklearn_model_validator(v: LogisticRegression | dict[str, Any]) -> LogisticRegression: + if isinstance(v, LogisticRegression): + return v + + model = LogisticRegression(max_iter=1000) + if v.get('is_fitted', False): + model.coef_ = np.array(v['coef']) + model.intercept_ = np.array(v['intercept']) + model.classes_ = np.array(v['classes']) + model.n_iter_ = np.array(v.get('n_iter', [1000])) + + return model + + +def sklearn_model_serializer(model: LogisticRegression) -> dict[str, Any]: + if hasattr(model, 'coef_'): + return { + 'coef': model.coef_.tolist(), + 'intercept': model.intercept_.tolist(), + 'classes': model.classes_.tolist(), + 'n_iter': model.n_iter_.tolist() if hasattr(model, 'n_iter_') else [1000], + 'is_fitted': True, + 'max_iter': model.max_iter, + 'solver': model.solver, + } + return {'is_fitted': False, 'max_iter': model.max_iter, 'solver': model.solver} diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index d91eb8f8e0..5fdb621718 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -81,6 +81,8 @@ def __init__( rendering_types: Iterator[RenderingType] | None = None, detection_probability_recommendation: None | Iterator[float] = None, ) -> None: + super().__init__() + self._rendering_types = rendering_types or cycle(['static']) self._detection_probability_recommendation = detection_probability_recommendation or cycle([1]) diff --git a/tests/unit/crawlers/_adaptive_playwright/test_predictor.py b/tests/unit/crawlers/_adaptive_playwright/test_predictor.py index 157daa0a11..7a767a2e57 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_predictor.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_predictor.py @@ -9,6 +9,7 @@ calculate_url_similarity, get_url_components, ) +from crawlee.storages import KeyValueStore @pytest.mark.parametrize('label', ['some label', None]) @@ -23,100 +24,177 @@ ('http://www.ddf.com/some', 'client only'), ], ) -def ictor_same_label(url: str, expected_prediction: RenderingType, label: str | None) -> None: - predictor = DefaultRenderingTypePredictor() - - learning_inputs: tuple[tuple[str, RenderingType], ...] = ( - ('http://www.aaa.com/some/stuff', 'static'), - ('http://www.aab.com/some/stuff', 'static'), - ('http://www.aac.com/some/stuff', 'static'), - ('http://www.ddd.com/some/stuff', 'client only'), - ('http://www.dde.com/some/stuff', 'client only'), - ('http://www.ddf.com/some/stuff', 'client only'), - ) +async def ictor_same_label(url: str, expected_prediction: RenderingType, label: str | None) -> None: + async with DefaultRenderingTypePredictor() as predictor: + learning_inputs: tuple[tuple[str, RenderingType], ...] = ( + ('http://www.aaa.com/some/stuff', 'static'), + ('http://www.aab.com/some/stuff', 'static'), + ('http://www.aac.com/some/stuff', 'static'), + ('http://www.ddd.com/some/stuff', 'client only'), + ('http://www.dde.com/some/stuff', 'client only'), + ('http://www.ddf.com/some/stuff', 'client only'), + ) - # Learn from small set - for learned_url, rendering_type in learning_inputs: - predictor.store_result(Request.from_url(url=learned_url, label=label), rendering_type=rendering_type) + # Learn from small set + for learned_url, rendering_type in learning_inputs: + predictor.store_result(Request.from_url(url=learned_url, label=label), rendering_type=rendering_type) - assert predictor.predict(Request.from_url(url=url, label=label)).rendering_type == expected_prediction + assert predictor.predict(Request.from_url(url=url, label=label)).rendering_type == expected_prediction -def test_predictor_new_label_increased_detection_probability_recommendation() -> None: +async def test_predictor_new_label_increased_detection_probability_recommendation() -> None: """Test that urls of uncommon labels have increased detection recommendation. This increase should gradually drop as the predictor learns more data with this label.""" detection_ratio = 0.01 label = 'some label' - predictor = DefaultRenderingTypePredictor(detection_ratio=detection_ratio) - - # Learn first prediction of this label - predictor.store_result(Request.from_url(url='http://www.aaa.com/some/stuff', label=label), rendering_type='static') - # Increased detection_probability_recommendation - prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) - assert prediction.rendering_type == 'static' - assert prediction.detection_probability_recommendation == detection_ratio * 4 - - # Learn second prediction of this label - predictor.store_result(Request.from_url(url='http://www.aaa.com/some/stuffe', label=label), rendering_type='static') - # Increased detection_probability_recommendation - prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) - assert prediction.rendering_type == 'static' - assert prediction.detection_probability_recommendation == detection_ratio * 3 - - # Learn third prediction of this label - predictor.store_result(Request.from_url(url='http://www.aaa.com/some/stuffi', label=label), rendering_type='static') - # Increased detection_probability_recommendation - prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) - assert prediction.rendering_type == 'static' - assert prediction.detection_probability_recommendation == detection_ratio * 2 - - # Learn fourth prediction of this label. - predictor.store_result(Request.from_url(url='http://www.aaa.com/some/stuffo', label=label), rendering_type='static') - # Label considered stable now. There should be no increase of detection_probability_recommendation. - prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) - assert prediction.rendering_type == 'static' - assert prediction.detection_probability_recommendation == detection_ratio - - -def test_unreliable_prediction() -> None: + async with DefaultRenderingTypePredictor(detection_ratio=detection_ratio) as predictor: + # Learn first prediction of this label + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuff', label=label), rendering_type='static' + ) + # Increased detection_probability_recommendation + prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) + assert prediction.rendering_type == 'static' + assert prediction.detection_probability_recommendation == detection_ratio * 4 + + # Learn second prediction of this label + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuffe', label=label), rendering_type='static' + ) + # Increased detection_probability_recommendation + prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) + assert prediction.rendering_type == 'static' + assert prediction.detection_probability_recommendation == detection_ratio * 3 + + # Learn third prediction of this label + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuffi', label=label), rendering_type='static' + ) + # Increased detection_probability_recommendation + prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) + assert prediction.rendering_type == 'static' + assert prediction.detection_probability_recommendation == detection_ratio * 2 + + # Learn fourth prediction of this label. + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuffo', label=label), rendering_type='static' + ) + # Label considered stable now. There should be no increase of detection_probability_recommendation. + prediction = predictor.predict(Request.from_url(url='http://www.aaa.com/some/stuffa', label=label)) + assert prediction.rendering_type == 'static' + assert prediction.detection_probability_recommendation == detection_ratio + + +async def test_unreliable_prediction() -> None: """Test that detection_probability_recommendation for unreliable predictions is 1. Create situation where no learning data of new label is available for the predictor. It's first prediction is not reliable as both options have 50% chance, so it should set maximum detection_probability_recommendation.""" learnt_label = 'some label' - predictor = DefaultRenderingTypePredictor() - - # Learn two predictions of some label. One of each to make predictor very uncertain. - predictor.store_result( - Request.from_url(url='http://www.aaa.com/some/stuff', label=learnt_label), rendering_type='static' - ) - predictor.store_result( - Request.from_url(url='http://www.aaa.com/some/otherstuff', label=learnt_label), rendering_type='client only' - ) - # Predict for new label. Predictor does not have enough information to give any reliable guess and should make it - # clear by setting detection_probability_recommendation=1 - assert ( - predictor.predict( + async with DefaultRenderingTypePredictor() as predictor: + # Learn two predictions of some label. One of each to make predictor very uncertain. + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuff', label=learnt_label), rendering_type='static' + ) + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/otherstuff', label=learnt_label), rendering_type='client only' + ) + + # Predict for new label. Predictor does not have enough information to give any reliable guess and should make + # it clear by setting detection_probability_recommendation=1 + probability = predictor.predict( Request.from_url(url='http://www.unknown.com', label='new label') ).detection_probability_recommendation - == 1 - ) + assert probability == 1 -def test_no_learning_data_prediction() -> None: +async def test_no_learning_data_prediction() -> None: """Test that predictor can predict even if it never learnt anything before. It should give some prediction, but it has to set detection_probability_recommendation=1""" - predictor = DefaultRenderingTypePredictor() - assert ( - predictor.predict( + async with DefaultRenderingTypePredictor() as predictor: + probability = predictor.predict( Request.from_url(url='http://www.unknown.com', label='new label') ).detection_probability_recommendation - == 1 - ) + + assert probability == 1 + + +async def test_persistent_no_learning_data_prediction() -> None: + """Test that the model is saved after initialisation in KeyValueStore.""" + persist_key = 'test-no_learning-state' + async with DefaultRenderingTypePredictor(persistence_enabled=True, persist_state_key=persist_key) as _predictor: + pass + + kvs = await KeyValueStore.open() + + persisted_data = await kvs.get_value(persist_key) + + assert persisted_data is not None + assert persisted_data['model']['is_fitted'] is False + + +async def test_persistent_prediction() -> None: + """Test that the model and resources is saved after train in KeyValueStore.""" + persist_key = 'test-persistent-state' + async with DefaultRenderingTypePredictor(persistence_enabled=True, persist_state_key=persist_key) as predictor: + # Learn some data + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuff', label='some label'), rendering_type='static' + ) + + kvs = await KeyValueStore.open() + + persisted_data = await kvs.get_value(persist_key) + + assert persisted_data is not None + assert persisted_data['model']['is_fitted'] is True + + +@pytest.mark.parametrize( + ('persistence_enabled', 'same_result'), + [ + pytest.param(True, True, id='with persistence'), + pytest.param(False, False, id='without persistence'), + ], +) +async def test_persistent_prediction_recovery(*, persistence_enabled: bool, same_result: bool) -> None: + """Test that the model and resources is recovered from KeyValueStore.""" + persist_key = 'test-persistent-state-recovery' + + async with DefaultRenderingTypePredictor( + detection_ratio=0.01, persistence_enabled=persistence_enabled, persist_state_key=persist_key + ) as predictor: + # Learn some data + predictor.store_result( + Request.from_url(url='http://www.aaa.com/some/stuff', label='some label'), rendering_type='static' + ) + before_recover_prediction = predictor.predict( + Request.from_url(url='http://www.aaa.com/some/stuff', label='some label') + ) + + # Recover predictor + async with DefaultRenderingTypePredictor( + detection_ratio=0.01, persistence_enabled=True, persist_state_key=persist_key + ) as recover_predictor: + after_recover_prediction = recover_predictor.predict( + Request.from_url(url='http://www.aaa.com/some/stuff', label='some label') + ) + + # If persistence is enabled, the predicted results must be the same. + if same_result: + assert ( + before_recover_prediction.detection_probability_recommendation + == after_recover_prediction.detection_probability_recommendation + ) + else: + assert ( + before_recover_prediction.detection_probability_recommendation + != after_recover_prediction.detection_probability_recommendation + ) @pytest.mark.parametrize( diff --git a/uv.lock b/uv.lock index 386ed49db2..17372db8ee 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -574,7 +574,7 @@ toml = [ [[package]] name = "crawlee" -version = "0.6.12" +version = "0.6.13" source = { editable = "." } dependencies = [ { name = "cachetools" },