Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

class CustomRenderingTypePredictor(RenderingTypePredictor):
def __init__(self) -> None:
super().__init__()

self._learning_data = list[tuple[Request, RenderingType]]()

def predict(self, request: Request) -> RenderingTypePrediction:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ async def adaptive_pre_navigation_hook_pw(context: PlaywrightPreNavCrawlingConte

self._additional_context_managers = [
*self._additional_context_managers,
self.rendering_type_predictor,
static_crawler.statistics,
playwright_crawler.statistics,
playwright_crawler._browser_pool, # noqa: SLF001 # Intentional access to private member.
Expand Down
113 changes: 101 additions & 12 deletions src/crawlee/crawlers/_adaptive_playwright/_rendering_type_predictor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from itertools import zip_longest
from logging import getLogger
from statistics import mean
from typing import Literal
from typing import TYPE_CHECKING, Annotated, Literal
from urllib.parse import urlparse

from jaro import jaro_winkler_metric
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
from sklearn.linear_model import LogisticRegression
from typing_extensions import override

from crawlee import Request
from crawlee._utils.docs import docs_group
from crawlee._utils.recoverable_state import RecoverableState

from ._utils import sklearn_model_serializer, sklearn_model_validator

if TYPE_CHECKING:
from types import TracebackType

from crawlee import Request

logger = getLogger(__name__)

UrlComponents = list[str]
RenderingType = Literal['static', 'client only']
FeatureVector = tuple[float, float]


class RenderingTypePredictorState(BaseModel):
model_config = ConfigDict(populate_by_name=True)

model: Annotated[
LogisticRegression,
Field(LogisticRegression),
PlainValidator(sklearn_model_validator),
PlainSerializer(sklearn_model_serializer),
]

labels_coefficients: Annotated[defaultdict[str, float], Field(alias='labelsCoefficients')]


@docs_group('Other')
@dataclass(frozen=True)
class RenderingTypePrediction:
Expand All @@ -36,6 +62,11 @@ class RenderingTypePrediction:
class RenderingTypePredictor(ABC):
"""Stores rendering type for previously crawled URLs and predicts the rendering type for unvisited urls."""

def __init__(self) -> None:
"""Initialize a new instance."""
# Flag to indicate the state.
self._active = False

@abstractmethod
def predict(self, request: Request) -> RenderingTypePrediction:
"""Get `RenderingTypePrediction` based on the input request.
Expand All @@ -53,6 +84,32 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
rendering_type: Known suitable `RenderingType`.
"""

async def initialize(self) -> None:
"""Initialize additional resources required for the predictor operation."""
if self._active:
raise RuntimeError(f'The {self.__class__.__name__} is already active.')
self._active = True

async def clear(self) -> None:
"""Clear and release additional resources used by the predictor."""
if not self._active:
raise RuntimeError(f'The {self.__class__.__name__} is not active.')
self._active = False

async def __aenter__(self) -> RenderingTypePredictor:
"""Initialize the predictor upon entering the context manager."""
await self.initialize()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
"""Clear the predictor upon exiting the context manager."""
await self.clear()


@docs_group('Other')
class DefaultRenderingTypePredictor(RenderingTypePredictor):
Expand All @@ -62,24 +119,55 @@ class DefaultRenderingTypePredictor(RenderingTypePredictor):
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
"""

def __init__(self, detection_ratio: float = 0.1) -> None:
def __init__(
self, detection_ratio: float = 0.1, *, persistence_enabled: bool = False, persist_state_key: str | None = None
) -> None:
"""Initialize a new instance.

Args:
detection_ratio: A number between 0 and 1 that determines the desired ratio of rendering type detections.
persist_state_key: Key in the key-value storage where the trained model parameters will be saved.
If None, defaults to 'rendering-type-predictor-state'.
persistence_enabled: Whether to enable persistence of the trained model parameters for reuse.

"""
super().__init__()

self._rendering_type_detection_results: dict[RenderingType, dict[str, list[UrlComponents]]] = {
'static': defaultdict(list),
'client only': defaultdict(list),
}
self._model = LogisticRegression(max_iter=1000)
self._detection_ratio = max(0, min(1, detection_ratio))

# Used to increase detection probability recommendation for initial recommendations of each label.
# Reaches 1 (no additional increase) after n samples of specific label is already present in
# `self._rendering_type_detection_results`.
n = 3
self._labels_coefficients: dict[str, float] = defaultdict(lambda: n + 2)

self._state = RecoverableState(
default_state=RenderingTypePredictorState(
model=LogisticRegression(max_iter=1000), labels_coefficients=defaultdict(lambda: n + 2)
),
persist_state_key=persist_state_key or 'rendering-type-predictor-state',
persistence_enabled=persistence_enabled,
logger=logger,
)

@override
async def initialize(self) -> None:
"""Get current state of the predictor."""
await super().initialize()

if not self._state.is_initialized:
await self._state.initialize()

@override
async def clear(self) -> None:
"""Clear the predictor state."""
await super().clear()

if self._state.is_initialized:
await self._state.teardown()

@override
def predict(self, request: Request) -> RenderingTypePrediction:
Expand All @@ -91,19 +179,20 @@ def predict(self, request: Request) -> RenderingTypePrediction:
similarity_threshold = 0.1 # Prediction probability difference threshold to consider prediction unreliable.
label = request.label or ''

if self._rendering_type_detection_results['static'] or self._rendering_type_detection_results['client only']:
# Check that the model has already been fitted.
if hasattr(self._state.current_value.model, 'coef_'):
url_feature = self._calculate_feature_vector(get_url_components(request.url), label)
# Are both calls expensive?
prediction = self._model.predict([url_feature])[0]
probability = self._model.predict_proba([url_feature])[0]
prediction = self._state.current_value.model.predict([url_feature])[0]
probability = self._state.current_value.model.predict_proba([url_feature])[0]

if abs(probability[0] - probability[1]) < similarity_threshold:
# Prediction not reliable.
detection_probability_recommendation = 1.0
else:
detection_probability_recommendation = self._detection_ratio
# Increase recommendation for uncommon labels.
detection_probability_recommendation *= self._labels_coefficients[label]
detection_probability_recommendation *= self._state.current_value.labels_coefficients[label]

return RenderingTypePrediction(
rendering_type=('client only', 'static')[int(prediction)],
Expand All @@ -122,8 +211,8 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
"""
label = request.label or ''
self._rendering_type_detection_results[rendering_type][label].append(get_url_components(request.url))
if self._labels_coefficients[label] > 1:
self._labels_coefficients[label] -= 1
if self._state.current_value.labels_coefficients[label] > 1:
self._state.current_value.labels_coefficients[label] -= 1
self._retrain()

def _retrain(self) -> None:
Expand All @@ -137,7 +226,7 @@ def _retrain(self) -> None:
x.append(self._calculate_feature_vector(url_components, label))
y.append(encoded_rendering_type)

self._model.fit(x, y)
self._state.current_value.model.fit(x, y)

def _calculate_mean_similarity(self, url: UrlComponents, label: str, rendering_type: RenderingType) -> float:
if not self._rendering_type_detection_results[rendering_type][label]:
Expand Down
32 changes: 32 additions & 0 deletions src/crawlee/crawlers/_adaptive_playwright/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

import numpy as np
from sklearn.linear_model import LogisticRegression


def sklearn_model_validator(v: LogisticRegression | dict[str, Any]) -> LogisticRegression:
if isinstance(v, LogisticRegression):
return v

model = LogisticRegression(max_iter=1000)
if v.get('is_fitted', False):
model.coef_ = np.array(v['coef'])
model.intercept_ = np.array(v['intercept'])
model.classes_ = np.array(v['classes'])
model.n_iter_ = np.array(v.get('n_iter', [1000]))

return model


def sklearn_model_serializer(model: LogisticRegression) -> dict[str, Any]:
if hasattr(model, 'coef_'):
return {
'coef': model.coef_.tolist(),
'intercept': model.intercept_.tolist(),
'classes': model.classes_.tolist(),
'n_iter': model.n_iter_.tolist() if hasattr(model, 'n_iter_') else [1000],
'is_fitted': True,
'max_iter': model.max_iter,
'solver': model.solver,
}
return {'is_fitted': False, 'max_iter': model.max_iter, 'solver': model.solver}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
rendering_types: Iterator[RenderingType] | None = None,
detection_probability_recommendation: None | Iterator[float] = None,
) -> None:
super().__init__()

self._rendering_types = rendering_types or cycle(['static'])
self._detection_probability_recommendation = detection_probability_recommendation or cycle([1])

Expand Down
Loading
Loading