Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/crawlee/storage_clients/_sql/_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
Expand Down Expand Up @@ -94,6 +94,29 @@ async def __aexit__(
"""Async context manager exit."""
await self.close()

def __deepcopy__(self, memo: dict[int, Any] | None) -> SqlStorageClient:
# AsyncEngine is not deepcopy-able, reuse the same instance
if memo is None:
memo = {}

if id(self) in memo:
return cast('SqlStorageClient', memo[id(self)])

# Suppress warnings about experimental feature during deepcopy
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
if self._engine is not None:
new_client = self.__class__(engine=self._engine)
else:
new_client = self.__class__(connection_string=self._connection_string)

# Copy simple attributes
for attr in ('_initialized', '_dialect_name', '_default_flag', '_accessed_modified_update_interval'):
setattr(new_client, attr, getattr(self, attr))

memo[id(self)] = new_client
return new_client

@property
def engine(self) -> AsyncEngine:
"""Get the SQLAlchemy AsyncEngine instance."""
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/storage_clients/_sql/test_sql_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy.ext.asyncio import create_async_engine

from crawlee.configuration import Configuration
from crawlee.storage_clients import SqlStorageClient

if TYPE_CHECKING:
from pathlib import Path


@pytest.fixture
def configuration(tmp_path: Path) -> Configuration:
"""Temporary configuration for tests."""
return Configuration(
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
)


@pytest.mark.parametrize(
('connection_parameter'),
[pytest.param('connection_string', id='with connection string'), pytest.param('engine', id='with engine')],
)
async def test_deepcopy_with_engine_init(
configuration: Configuration, tmp_path: Path, connection_parameter: str
) -> None:
"""Test that SQL dataset client creates tables with a connection string."""
storage_dir = tmp_path / 'test_table.db'
connection_string = f'sqlite+aiosqlite:///{storage_dir}'
sql_kwargs: dict[str, Any] = {}
if connection_parameter == 'connection_string':
sql_kwargs['connection_string'] = connection_string
else:
engine = create_async_engine(connection_string, future=True, echo=False)
sql_kwargs['engine'] = engine
async with SqlStorageClient(**sql_kwargs) as storage_client:
copy_storage_client = deepcopy(storage_client)

# Ensure that the copy is a new instance
assert copy_storage_client is not storage_client
# Ensure that the copy uses the same engine
assert copy_storage_client._engine is storage_client._engine

# Ensure that the copy can create a new dataset client
copy_dataset_client = await copy_storage_client.create_dataset_client(
configuration=configuration, name='new-dataset'
)
dataset_client = await storage_client.create_dataset_client(configuration=configuration, name='new-dataset')

# Ensure that the metadata from both clients is the same
copy_metadata = await copy_dataset_client.get_metadata()
storage_metadata = await dataset_client.get_metadata()
assert copy_metadata == storage_metadata
Loading