Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
)
from typing_extensions import Self

from crawlee._types import EnqueueStrategy, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._types import EnqueueStrategy, HttpHeaders, HttpMethod, HttpPayload, HttpQueryParams
from crawlee._utils.http import normalize_headers
from crawlee._utils.requests import compute_unique_key, unique_key_to_request_id
from crawlee._utils.urls import extract_query_params, validate_http_url

Expand Down Expand Up @@ -119,7 +120,12 @@ class BaseRequestData(BaseModel):
method: HttpMethod = 'GET'
"""HTTP request method."""

headers: Annotated[dict[str, str], Field(default_factory=dict)] = {}
headers: Annotated[
HttpHeaders,
# Normalize headers to lowercase keys and sort them.
PlainValidator(lambda value: normalize_headers(value)),
Field(default_factory={}),
] = {}
"""HTTP request headers."""

query_params: Annotated[HttpQueryParams, Field(alias='queryParams', default_factory=dict)] = {}
Expand Down
54 changes: 3 additions & 51 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal, Protocol, Union
Expand All @@ -10,7 +9,7 @@
if TYPE_CHECKING:
import logging
import re
from collections.abc import Coroutine, Iterator, Sequence
from collections.abc import Coroutine, Sequence

from crawlee import Glob
from crawlee._request import BaseRequestData, Request
Expand All @@ -27,6 +26,8 @@

HttpMethod: TypeAlias = Literal['GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'CONNECT', 'OPTIONS', 'TRACE', 'PATCH']

HttpHeaders: TypeAlias = dict[str, str]

HttpQueryParams: TypeAlias = dict[str, str]

HttpPayload: TypeAlias = Union[str, bytes]
Expand Down Expand Up @@ -222,52 +223,3 @@ async def add_requests(
) -> None:
"""Track a call to the `add_requests` context helper."""
self.add_requests_calls.append(AddRequestsFunctionCall(requests=requests, **kwargs))


class HttpHeaders(Mapping[str, str]):
"""An immutable mapping for HTTP headers that ensures case-insensitivity for header names."""

def __init__(self, headers: Mapping[str, str] | None = None) -> None:
"""Create a new instance.

Args:
headers: A mapping of header names to values.
"""
# Ensure immutability by sorting and fixing the order.
headers = headers or {}
headers = {k.capitalize(): v for k, v in headers.items()}
self._headers = dict(sorted(headers.items()))

def __getitem__(self, key: str) -> str:
"""Get the value of a header by its name, case-insensitive."""
return self._headers[key.capitalize()]

def __iter__(self) -> Iterator[str]:
"""Return an iterator over the header names."""
return iter(self._headers)

def __len__(self) -> int:
"""Return the number of headers."""
return len(self._headers)

def __repr__(self) -> str:
"""Return a string representation of the object."""
return f'{self.__class__.__name__}({self._headers})'

def __setitem__(self, key: str, value: str) -> None:
"""Prevent setting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __delitem__(self, key: str) -> None:
"""Prevent deleting a header, as the object is immutable."""
raise TypeError(f'{self.__class__.__name__} is immutable')

def __or__(self, other: Mapping[str, str]) -> HttpHeaders:
"""Return a new instance of `HttpHeaders` combining this one with another one."""
combined_headers = {**self._headers, **other}
return HttpHeaders(combined_headers)

def __ror__(self, other: Mapping[str, str]) -> HttpHeaders:
"""Support reversed | operation (other | self)."""
combined_headers = {**other, **self._headers}
return HttpHeaders(combined_headers)
12 changes: 12 additions & 0 deletions src/crawlee/_utils/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from crawlee._types import HttpHeaders


def normalize_headers(headers: HttpHeaders) -> HttpHeaders:
"""Converts all header keys to lowercase and returns them sorted by key."""
normalized_headers = {k.lower(): v for k, v in headers.items()}
sorted_headers = sorted(normalized_headers.items())
return dict(sorted_headers)


def is_status_code_error(value: int) -> bool:
"""Returns `True` for 4xx or 5xx status codes, `False` otherwise."""
Expand Down
12 changes: 5 additions & 7 deletions src/crawlee/fingerprint_suite/_header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
from typing import TYPE_CHECKING

from crawlee._types import HttpHeaders
from crawlee.fingerprint_suite._consts import (
COMMON_ACCEPT,
COMMON_ACCEPT_LANGUAGE,
Expand All @@ -17,6 +16,7 @@
)

if TYPE_CHECKING:
from crawlee._types import HttpHeaders
from crawlee.browsers._types import BrowserType


Expand All @@ -29,16 +29,14 @@ def get_common_headers(self) -> HttpHeaders:
We do not modify the "Accept-Encoding", "Connection" and other headers. They should be included and handled
by the HTTP client or browser.
"""
headers = {
return {
'Accept': COMMON_ACCEPT,
'Accept-Language': COMMON_ACCEPT_LANGUAGE,
}
return HttpHeaders(headers)

def get_random_user_agent_header(self) -> HttpHeaders:
"""Get a random User-Agent header."""
headers = {'User-Agent': random.choice(USER_AGENT_POOL)}
return HttpHeaders(headers)
return {'User-Agent': random.choice(USER_AGENT_POOL)}

def get_user_agent_header(
self,
Expand All @@ -60,7 +58,7 @@ def get_user_agent_header(
else:
raise ValueError(f'Unsupported browser type: {browser_type}')

return HttpHeaders(headers)
return headers

def get_sec_ch_ua_headers(
self,
Expand All @@ -85,4 +83,4 @@ def get_sec_ch_ua_headers(
else:
raise ValueError(f'Unsupported browser type: {browser_type}')

return HttpHeaders(headers)
return headers
8 changes: 3 additions & 5 deletions src/crawlee/http_clients/_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,9 @@ def _get_client(self, proxy_url: str | None) -> httpx.AsyncClient:

def _combine_headers(self, explicit_headers: HttpHeaders | None) -> HttpHeaders | None:
"""Helper to get the headers for a HTTP request."""
common_headers = self._header_generator.get_common_headers() if self._header_generator else HttpHeaders()
user_agent_header = (
self._header_generator.get_random_user_agent_header() if self._header_generator else HttpHeaders()
)
explicit_headers = explicit_headers or HttpHeaders()
common_headers = self._header_generator.get_common_headers() if self._header_generator else {}
user_agent_header = self._header_generator.get_random_user_agent_header() if self._header_generator else {}
explicit_headers = explicit_headers or {}
headers = common_headers | user_agent_header | explicit_headers
return headers if headers else None

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/basic_crawler/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def handler(context: BasicCrawlingContext) -> None:

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext, error: Exception) -> Request:
headers = context.request.headers or HttpHeaders()
headers = context.request.headers or {}
custom_retry_count = int(headers.get('custom_retry_count', '0'))
calls.append((context, error, custom_retry_count))

Expand Down Expand Up @@ -289,7 +289,7 @@ async def handler(context: BasicCrawlingContext) -> None:

response = await context.send_request('http://b.com/')
response_body = response.read()
response_headers = HttpHeaders(response.headers)
response_headers = response.headers

await crawler.run()
assert respx_mock['test_endpoint'].called
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/fingerprint_suite/test_header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from crawlee._types import HttpHeaders
from crawlee.fingerprint_suite import HeaderGenerator
from crawlee.fingerprint_suite._consts import (
PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA,
Expand All @@ -21,7 +20,6 @@ def test_get_common_headers() -> None:

assert 'Accept' in headers
assert 'Accept-Language' in headers
assert isinstance(headers, HttpHeaders)


def test_get_random_user_agent_header() -> None:
Expand All @@ -30,7 +28,6 @@ def test_get_random_user_agent_header() -> None:
headers = header_generator.get_random_user_agent_header()

assert 'User-Agent' in headers
assert isinstance(headers, HttpHeaders)
assert headers['User-Agent'] in USER_AGENT_POOL


Expand All @@ -41,7 +38,6 @@ def test_get_user_agent_header_chromium() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_CHROMIUM_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_firefox() -> None:
Expand All @@ -51,7 +47,6 @@ def test_get_user_agent_header_firefox() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_FIREFOX_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_webkit() -> None:
Expand All @@ -61,7 +56,6 @@ def test_get_user_agent_header_webkit() -> None:

assert 'User-Agent' in headers
assert headers['User-Agent'] == PW_WEBKIT_HEADLESS_DEFAULT_USER_AGENT
assert isinstance(headers, HttpHeaders)


def test_get_user_agent_header_invalid_browser_type() -> None:
Expand All @@ -83,7 +77,6 @@ def test_get_sec_ch_ua_headers_chromium() -> None:
assert headers['Sec-Ch-Ua-Mobile'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_MOBILE
assert 'Sec-Ch-Ua-Platform' in headers
assert headers['Sec-Ch-Ua-Platform'] == PW_CHROMIUM_HEADLESS_DEFAULT_SEC_CH_UA_PLATFORM
assert isinstance(headers, HttpHeaders)


def test_get_sec_ch_ua_headers_firefox() -> None:
Expand Down
Loading