1+ from __future__ import annotations
2+
13from abc import ABC , abstractmethod
24from collections import defaultdict
35from dataclasses import dataclass
46from itertools import zip_longest
7+ from logging import getLogger
58from statistics import mean
6- from typing import Literal
9+ from typing import TYPE_CHECKING , Annotated , Literal
710from urllib .parse import urlparse
811
912from jaro import jaro_winkler_metric
13+ from pydantic import BaseModel , ConfigDict , Field , PlainSerializer , PlainValidator
1014from sklearn .linear_model import LogisticRegression
1115from typing_extensions import override
1216
13- from crawlee import Request
1417from crawlee ._utils .docs import docs_group
18+ from crawlee ._utils .recoverable_state import RecoverableState
19+
20+ from ._utils import sklearn_model_serializer , sklearn_model_validator
21+
22+ if TYPE_CHECKING :
23+ from types import TracebackType
24+
25+ from crawlee import Request
26+
27+ logger = getLogger (__name__ )
1528
1629UrlComponents = list [str ]
1730RenderingType = Literal ['static' , 'client only' ]
1831FeatureVector = tuple [float , float ]
1932
2033
34+ class RenderingTypePredictorState (BaseModel ):
35+ model_config = ConfigDict (populate_by_name = True )
36+
37+ model : Annotated [
38+ LogisticRegression ,
39+ Field (LogisticRegression ),
40+ PlainValidator (sklearn_model_validator ),
41+ PlainSerializer (sklearn_model_serializer ),
42+ ]
43+
44+ labels_coefficients : Annotated [defaultdict [str , float ], Field (alias = 'labelsCoefficients' )]
45+
46+
2147@docs_group ('Other' )
2248@dataclass (frozen = True )
2349class RenderingTypePrediction :
@@ -36,6 +62,11 @@ class RenderingTypePrediction:
3662class RenderingTypePredictor (ABC ):
3763 """Stores rendering type for previously crawled URLs and predicts the rendering type for unvisited urls."""
3864
65+ def __init__ (self ) -> None :
66+ """Initialize a new instance."""
67+ # Flag to indicate the state.
68+ self ._active = False
69+
3970 @abstractmethod
4071 def predict (self , request : Request ) -> RenderingTypePrediction :
4172 """Get `RenderingTypePrediction` based on the input request.
@@ -53,6 +84,32 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
5384 rendering_type: Known suitable `RenderingType`.
5485 """
5586
87+ async def initialize (self ) -> None :
88+ """Initialize additional resources required for the predictor operation."""
89+ if self ._active :
90+ raise RuntimeError (f'The { self .__class__ .__name__ } is already active.' )
91+ self ._active = True
92+
93+ async def clear (self ) -> None :
94+ """Clear and release additional resources used by the predictor."""
95+ if not self ._active :
96+ raise RuntimeError (f'The { self .__class__ .__name__ } is not active.' )
97+ self ._active = False
98+
99+ async def __aenter__ (self ) -> RenderingTypePredictor :
100+ """Initialize the predictor upon entering the context manager."""
101+ await self .initialize ()
102+ return self
103+
104+ async def __aexit__ (
105+ self ,
106+ exc_type : type [BaseException ] | None ,
107+ exc_value : BaseException | None ,
108+ exc_traceback : TracebackType | None ,
109+ ) -> None :
110+ """Clear the predictor upon exiting the context manager."""
111+ await self .clear ()
112+
56113
57114@docs_group ('Other' )
58115class DefaultRenderingTypePredictor (RenderingTypePredictor ):
@@ -62,24 +119,59 @@ class DefaultRenderingTypePredictor(RenderingTypePredictor):
62119 https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
63120 """
64121
65- def __init__ (self , detection_ratio : float = 0.1 ) -> None :
122+ def __init__ (
123+ self ,
124+ detection_ratio : float = 0.1 ,
125+ * ,
126+ persistence_enabled : bool = False ,
127+ persist_state_key : str = 'rendering-type-predictor-state' ,
128+ ) -> None :
66129 """Initialize a new instance.
67130
68131 Args:
69132 detection_ratio: A number between 0 and 1 that determines the desired ratio of rendering type detections.
133+ persist_state_key: Key in the key-value storage where the trained model parameters will be saved.
134+ If None, defaults to 'rendering-type-predictor-state'.
135+ persistence_enabled: Whether to enable persistence of the trained model parameters for reuse.
136+
70137 """
138+ super ().__init__ ()
139+
71140 self ._rendering_type_detection_results : dict [RenderingType , dict [str , list [UrlComponents ]]] = {
72141 'static' : defaultdict (list ),
73142 'client only' : defaultdict (list ),
74143 }
75- self ._model = LogisticRegression (max_iter = 1000 )
76144 self ._detection_ratio = max (0 , min (1 , detection_ratio ))
77145
78146 # Used to increase detection probability recommendation for initial recommendations of each label.
79147 # Reaches 1 (no additional increase) after n samples of specific label is already present in
80148 # `self._rendering_type_detection_results`.
81149 n = 3
82- self ._labels_coefficients : dict [str , float ] = defaultdict (lambda : n + 2 )
150+
151+ self ._state = RecoverableState (
152+ default_state = RenderingTypePredictorState (
153+ model = LogisticRegression (max_iter = 1000 ), labels_coefficients = defaultdict (lambda : n + 2 )
154+ ),
155+ persist_state_key = persist_state_key ,
156+ persistence_enabled = persistence_enabled ,
157+ logger = logger ,
158+ )
159+
160+ @override
161+ async def initialize (self ) -> None :
162+ """Get current state of the predictor."""
163+ await super ().initialize ()
164+
165+ if not self ._state .is_initialized :
166+ await self ._state .initialize ()
167+
168+ @override
169+ async def clear (self ) -> None :
170+ """Clear the predictor state."""
171+ await super ().clear ()
172+
173+ if self ._state .is_initialized :
174+ await self ._state .teardown ()
83175
84176 @override
85177 def predict (self , request : Request ) -> RenderingTypePrediction :
@@ -91,19 +183,20 @@ def predict(self, request: Request) -> RenderingTypePrediction:
91183 similarity_threshold = 0.1 # Prediction probability difference threshold to consider prediction unreliable.
92184 label = request .label or ''
93185
94- if self ._rendering_type_detection_results ['static' ] or self ._rendering_type_detection_results ['client only' ]:
186+ # Check that the model has already been fitted.
187+ if hasattr (self ._state .current_value .model , 'coef_' ):
95188 url_feature = self ._calculate_feature_vector (get_url_components (request .url ), label )
96189 # Are both calls expensive?
97- prediction = self ._model .predict ([url_feature ])[0 ]
98- probability = self ._model .predict_proba ([url_feature ])[0 ]
190+ prediction = self ._state . current_value . model .predict ([url_feature ])[0 ]
191+ probability = self ._state . current_value . model .predict_proba ([url_feature ])[0 ]
99192
100193 if abs (probability [0 ] - probability [1 ]) < similarity_threshold :
101194 # Prediction not reliable.
102195 detection_probability_recommendation = 1.0
103196 else :
104197 detection_probability_recommendation = self ._detection_ratio
105198 # Increase recommendation for uncommon labels.
106- detection_probability_recommendation *= self ._labels_coefficients [label ]
199+ detection_probability_recommendation *= self ._state . current_value . labels_coefficients [label ]
107200
108201 return RenderingTypePrediction (
109202 rendering_type = ('client only' , 'static' )[int (prediction )],
@@ -122,8 +215,8 @@ def store_result(self, request: Request, rendering_type: RenderingType) -> None:
122215 """
123216 label = request .label or ''
124217 self ._rendering_type_detection_results [rendering_type ][label ].append (get_url_components (request .url ))
125- if self ._labels_coefficients [label ] > 1 :
126- self ._labels_coefficients [label ] -= 1
218+ if self ._state . current_value . labels_coefficients [label ] > 1 :
219+ self ._state . current_value . labels_coefficients [label ] -= 1
127220 self ._retrain ()
128221
129222 def _retrain (self ) -> None :
@@ -137,7 +230,7 @@ def _retrain(self) -> None:
137230 x .append (self ._calculate_feature_vector (url_components , label ))
138231 y .append (encoded_rendering_type )
139232
140- self ._model .fit (x , y )
233+ self ._state . current_value . model .fit (x , y )
141234
142235 def _calculate_mean_similarity (self , url : UrlComponents , label : str , rendering_type : RenderingType ) -> float :
143236 if not self ._rendering_type_detection_results [rendering_type ][label ]:
0 commit comments