Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 123 additions & 3 deletions test/test_sqlite_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

"""Tests for SQLite index implementations with real embeddings."""

import os
import sqlite3
import tempfile
from typing import Generator

import pytest
Expand Down Expand Up @@ -368,6 +370,39 @@ async def test_fuzzy_lookup_edge_cases(
# All results should have weights (even if surprisingly high)
assert all(result.weight is not None for result in results)

@pytest.mark.asyncio
async def test_fuzzy_serialize(
self,
sqlite_db: sqlite3.Connection,
embedding_settings: TextEmbeddingIndexSettings,
needs_auth: None,
):
"""Test serialization of fuzzy index data."""
index = SqliteRelatedTermsFuzzy(sqlite_db, embedding_settings)

# Add some terms
test_terms = ["chess", "grandmaster", "artificial intelligence"]
await index.add_terms(test_terms)

# Serialize the fuzzy index
data = index.serialize()

# Verify serialized data structure
assert data is not None
assert "textItems" in data
assert "embeddings" in data

# Verify text items
text_items = data["textItems"]
assert len(text_items) == 3
assert "chess" in text_items
assert "grandmaster" in text_items
assert "artificial intelligence" in text_items

# Verify embeddings exist
embeddings = data["embeddings"]
assert embeddings is not None


class TestSqliteRelatedTermsIndex:
"""Test SqliteRelatedTermsIndex combined functionality."""
Expand Down Expand Up @@ -401,6 +436,91 @@ async def test_combined_index_basic(
assert len(alias_results) == 1
assert alias_results[0].text == "artificial intelligence"

@pytest.mark.asyncio
async def test_serialize_and_deserialize(
self,
sqlite_db: sqlite3.Connection,
embedding_settings: TextEmbeddingIndexSettings,
needs_auth: None,
):
"""Test serialization and deserialization of the combined related terms index."""
index = SqliteRelatedTermsIndex(sqlite_db, embedding_settings)

# Ensure fuzzy_index is available
assert index.fuzzy_index is not None

# Add data to both sub-indexes
await index.fuzzy_index.add_terms(["chess", "grandmaster", "tournament"])
await index.aliases.add_related_term("AI", Term("artificial intelligence"))
await index.aliases.add_related_term("ML", Term("machine learning"))

# Serialize the combined index
data = await index.serialize()

# Verify serialized data structure
assert data is not None
assert "aliasData" in data
assert "textEmbeddingData" in data

# Verify alias data
alias_data = data["aliasData"]
assert alias_data is not None
assert "relatedTerms" in alias_data
related_terms = alias_data["relatedTerms"]
assert related_terms is not None
assert len(related_terms) == 2

# Verify text embedding data
text_embedding_data = data["textEmbeddingData"]
assert text_embedding_data is not None
assert "textItems" in text_embedding_data
assert len(text_embedding_data["textItems"]) == 3
assert "chess" in text_embedding_data["textItems"]
assert "grandmaster" in text_embedding_data["textItems"]
assert "embeddings" in text_embedding_data

# Create a fresh database and index to test deserialization
with tempfile.TemporaryDirectory() as tmpdir:
fresh_db_path = os.path.join(tmpdir, "fresh_test.db")
fresh_db = sqlite3.connect(fresh_db_path)
init_db_schema(fresh_db)

fresh_index = SqliteRelatedTermsIndex(fresh_db, embedding_settings)

# Deserialize into fresh index
await fresh_index.deserialize(data)

# Verify alias data was restored
ai_results = await fresh_index.aliases.lookup_term("AI")
assert ai_results is not None
assert len(ai_results) == 1
assert ai_results[0].text == "artificial intelligence"

ml_results = await fresh_index.aliases.lookup_term("ML")
assert ml_results is not None
assert len(ml_results) == 1
assert ml_results[0].text == "machine learning"

# Verify fuzzy index data was restored
fresh_fuzzy = fresh_index.fuzzy_index
assert fresh_fuzzy is not None
assert await fresh_fuzzy.size() == 3

# Use concrete type to access get_terms (not in interface)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if get_terms() should be added to the interface. (Probably not in this PR unless it's a one-liner -- which it isn't since you'd have to implement it for the memory provider too.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gvanrossum I agree get_terms() could be useful in the interface. I'll keep it out of this PR to keep the scope small, but happy to open a follow-up issue/PR to add it to the interface and implement it for both providers if that would be helpful.

assert isinstance(fresh_fuzzy, SqliteRelatedTermsFuzzy)
fuzzy_terms = await fresh_fuzzy.get_terms()
assert "chess" in fuzzy_terms
assert "grandmaster" in fuzzy_terms
assert "tournament" in fuzzy_terms

# Verify fuzzy lookup works after deserialization
fuzzy_results = await fresh_fuzzy.lookup_term(
"chess", max_hits=5, min_score=0.1
)
assert len(fuzzy_results) > 0

fresh_db.close()


# Integration test to verify the fix for the regression we encountered
class TestRegressionPrevention:
Expand Down Expand Up @@ -656,9 +776,9 @@ async def test_serialization_edge_cases(
fuzzy_index = SqliteRelatedTermsFuzzy(sqlite_db, embedding_settings)

# Test serialization of empty index
# Note: fuzzy index doesn't implement serialize (returns empty for SQLite)
# But test that calling it doesn't crash
# This would be implemented if needed
empty_data = fuzzy_index.serialize()
assert empty_data is not None
assert empty_data["textItems"] == []

# Test fuzzy index with some data then clear
await fuzzy_index.add_terms(["test1", "test2"])
Expand Down
13 changes: 12 additions & 1 deletion typeagent/storage/sqlite/reltermsindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ async def lookup_terms(
results.append(term_results)
return results

def serialize(self) -> interfaces.TextEmbeddingIndexData:
"""Serialize the fuzzy index data."""
return interfaces.TextEmbeddingIndexData(
textItems=self._terms_list.copy(),
embeddings=self._vector_base.serialize(),
)

async def deserialize(self, data: interfaces.TextEmbeddingIndexData) -> None:
"""Deserialize fuzzy index data from JSON into SQLite database."""
# Clear existing data
Expand Down Expand Up @@ -313,7 +320,11 @@ def fuzzy_index(self) -> interfaces.ITermToRelatedTermsFuzzy | None:
return self._fuzzy_index

async def serialize(self) -> interfaces.TermsToRelatedTermsIndexData:
raise NotImplementedError("TODO")
"""Serialize the related terms index (both aliases and fuzzy index)."""
return interfaces.TermsToRelatedTermsIndexData(
aliasData=await self._aliases.serialize(),
textEmbeddingData=self._fuzzy_index.serialize(),
)

async def deserialize(self, data: interfaces.TermsToRelatedTermsIndexData) -> None:
"""Deserialize related terms index data."""
Expand Down