Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class CrawleeRequestData(BaseModel):
forefront: Annotated[bool, Field()] = False
"""Indicate whether the request should be enqueued at the front of the queue."""

crawl_depth: Annotated[int, Field(alias='crawlDepth')] = 0
"""The depth of the request in the crawl tree."""


class UserData(BaseModel, MutableMapping[str, JsonSerializable]):
"""Represents the `user_data` part of a Request.
Expand Down Expand Up @@ -360,6 +363,15 @@ def crawlee_data(self) -> CrawleeRequestData:

return user_data.crawlee_data

@property
def crawl_depth(self) -> int:
"""The depth of the request in the crawl tree."""
return self.crawlee_data.crawl_depth

@crawl_depth.setter
def crawl_depth(self, new_value: int) -> None:
self.crawlee_data.crawl_depth = new_value

@property
def state(self) -> RequestState | None:
"""Crawlee-specific request handling state."""
Expand Down
30 changes: 22 additions & 8 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]):
configure_logging: NotRequired[bool]
"""If True, the crawler will set up logging infrastructure automatically."""

max_crawl_depth: NotRequired[int | None]
"""Limits crawl depth from 0 (initial requests) up to the specified `max_crawl_depth`.
Requests at the maximum depth are processed, but no further links are enqueued."""

_context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]]
"""Enables extending the request lifecycle and modifying the crawling context. Intended for use by
subclasses rather than direct instantiation of `BasicCrawler`."""
Expand Down Expand Up @@ -174,6 +178,7 @@ def __init__(
statistics: Statistics | None = None,
event_manager: EventManager | None = None,
configure_logging: bool = True,
max_crawl_depth: int | None = None,
_context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
_additional_context_managers: Sequence[AsyncContextManager] | None = None,
_logger: logging.Logger | None = None,
Expand Down Expand Up @@ -201,6 +206,7 @@ def __init__(
statistics: A custom `Statistics` instance, allowing the use of non-default configuration.
event_manager: A custom `EventManager` instance, allowing the use of non-default configuration.
configure_logging: If True, the crawler will set up logging infrastructure automatically.
max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth.
_context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
_additional_context_managers: Additional context managers used throughout the crawler lifecycle.
Expand Down Expand Up @@ -283,6 +289,7 @@ def __init__(

self._running = False
self._has_finished_before = False
self._max_crawl_depth = max_crawl_depth

@property
def log(self) -> logging.Logger:
Expand Down Expand Up @@ -787,14 +794,21 @@ async def _commit_request_handler_result(
else:
dst_request = Request.from_base_request_data(request)

if self._check_enqueue_strategy(
add_requests_call.get('strategy', EnqueueStrategy.ALL),
target_url=urlparse(dst_request.url),
origin_url=urlparse(origin),
) and self._check_url_patterns(
dst_request.url,
add_requests_call.get('include', None),
add_requests_call.get('exclude', None),
# Update the crawl depth of the request.
dst_request.crawl_depth = context.request.crawl_depth + 1

if (
(self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth)
and self._check_enqueue_strategy(
add_requests_call.get('strategy', EnqueueStrategy.ALL),
target_url=urlparse(dst_request.url),
origin_url=urlparse(origin),
)
and self._check_url_patterns(
dst_request.url,
add_requests_call.get('include', None),
add_requests_call.get('exclude', None),
)
):
requests.append(dst_request)

Expand Down
29 changes: 29 additions & 0 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,35 @@ async def handler(context: BasicCrawlingContext) -> None:
assert stats.requests_finished == 3


async def test_max_crawl_depth(httpbin: str) -> None:
processed_urls = []

start_request = Request.from_url('https://someplace.com/', label='start')
start_request.crawl_depth = 2

# Set max_concurrency to 1 to ensure testing max_requests_per_crawl accurately
crawler = BasicCrawler(
concurrency_settings=ConcurrencySettings(max_concurrency=1),
max_crawl_depth=2,
request_provider=RequestList([start_request]),
)

@crawler.router.handler('start')
async def start_handler(context: BasicCrawlingContext) -> None:
processed_urls.append(context.request.url)
await context.add_requests(['https://someplace.com/too-deep'])

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
processed_urls.append(context.request.url)

stats = await crawler.run()

assert len(processed_urls) == 1
assert stats.requests_total == 1
assert stats.requests_finished == 1


def test_crawler_log() -> None:
crawler = BasicCrawler()
assert isinstance(crawler.log, logging.Logger)
Expand Down
Loading