Skip to content

Commit b40ef68

Browse files
committed
Resolved comments
1 parent 73c46c0 commit b40ef68

File tree

2 files changed

+52
-33
lines changed

2 files changed

+52
-33
lines changed

semhash/datamodels.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import warnings
22
from collections import defaultdict
33
from dataclasses import dataclass, field
4-
from typing import Any, Generic, Hashable, Sequence, TypeVar
4+
from typing import Any, Generic, Hashable, Sequence, TypeAlias, TypeVar
5+
6+
from frozendict import frozendict
57

68
from semhash.utils import to_frozendict
79

810
Record = TypeVar("Record", str, dict[str, Any])
11+
DuplicateList: TypeAlias = list[tuple[Record, float]]
912

1013

1114
@dataclass
@@ -23,13 +26,29 @@ class DuplicateRecord(Generic[Record]):
2326

2427
record: Record
2528
exact: bool
26-
duplicates: list[tuple[Record, float]] = field(default_factory=list)
29+
duplicates: DuplicateList = field(default_factory=list)
2730

2831
def _rethreshold(self, threshold: float) -> None:
2932
"""Rethreshold the duplicates."""
3033
self.duplicates = [(d, score) for d, score in self.duplicates if score >= threshold]
3134

3235

36+
@dataclass
37+
class SelectedWithDuplicates(Generic[Record]):
38+
"""
39+
A record that has been selected along with its duplicates.
40+
41+
Attributes
42+
----------
43+
record: The original record being selected.
44+
duplicates: List of tuples consisting of duplicate records and their associated scores.
45+
46+
"""
47+
48+
record: Record
49+
duplicates: DuplicateList = field(default_factory=list)
50+
51+
3352
@dataclass
3453
class DeduplicationResult(Generic[Record]):
3554
"""
@@ -49,7 +68,7 @@ class DeduplicationResult(Generic[Record]):
4968
selected: list[Record] = field(default_factory=list)
5069
filtered: list[DuplicateRecord] = field(default_factory=list)
5170
threshold: float = field(default=0.9)
52-
columns: Sequence[str] = field(default_factory=list)
71+
columns: Sequence[str] | None = field(default=None)
5372
deduplicated: list[Record] = field(default_factory=list) # Deprecated
5473
duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated
5574

@@ -108,33 +127,34 @@ def rethreshold(self, threshold: float) -> None:
108127
self.threshold = threshold
109128

110129
@property
111-
def selected_with_duplicates(self) -> list[tuple[Record, list[tuple[Record, float]]]]:
130+
def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]:
112131
"""
113132
For every kept record, return the duplicates that were removed along with their similarity scores.
114133
115134
:return: A list of tuples where each tuple contains a kept record
116135
and a list of its duplicates with their similarity scores.
117136
"""
118137

119-
def _to_hashable(record: Record) -> Hashable:
120-
if isinstance(record, dict):
138+
def _to_hashable(record: Record) -> frozendict[str, str] | str:
139+
"""Convert a record to a hashable representation."""
140+
if isinstance(record, dict) and self.columns is not None:
121141
# Convert dict to frozendict for immutability and hashability
122142
return to_frozendict(record, set(self.columns))
123-
return record
143+
return str(record)
124144

125145
# Build a mapping from original-record to [(duplicate, score), …]
126-
buckets: defaultdict[Hashable, list[tuple[Record, float]]] = defaultdict(list)
146+
buckets: defaultdict[Hashable, DuplicateList] = defaultdict(list)
127147
for duplicate_record in self.filtered:
128148
for original_record, score in duplicate_record.duplicates:
129149
buckets[_to_hashable(original_record)].append((duplicate_record.record, float(score)))
130150

131-
result: list[tuple[Record, list[tuple[Record, float]]]] = []
151+
result: list[SelectedWithDuplicates[Record]] = []
132152
for selected in self.selected:
133153
# Get the list of duplicates for the selected record
134154
raw_list = buckets.get(_to_hashable(selected), [])
135155
# Ensure we don't have duplicates in the list
136156
deduped = {_to_hashable(rec): (rec, score) for rec, score in raw_list}
137-
result.append((selected, list(deduped.values())))
157+
result.append(SelectedWithDuplicates(record=selected, duplicates=list(deduped.values())))
138158

139159
return result
140160

tests/test_datamodels.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import semhash
44
import semhash.version
5-
from semhash.datamodels import DeduplicationResult, DuplicateRecord
5+
from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates
66

77

88
def test_deduplication_scoring() -> None:
@@ -11,7 +11,6 @@ def test_deduplication_scoring() -> None:
1111
["a", "b", "c"],
1212
[DuplicateRecord("a", False, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])],
1313
0.8,
14-
columns=["text"],
1514
)
1615
assert d.duplicate_ratio == 0.4
1716

@@ -22,7 +21,6 @@ def test_deduplication_scoring_exact() -> None:
2221
["a", "b", "c"],
2322
[DuplicateRecord("a", True, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])],
2423
0.8,
25-
columns=["text"],
2624
)
2725
assert d.exact_duplicate_ratio == 0.2
2826

@@ -59,7 +57,6 @@ def test_get_least_similar_from_duplicates() -> None:
5957
["a", "b", "c"],
6058
[DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])],
6159
0.8,
62-
columns=["text"],
6360
)
6461
result = d.get_least_similar_from_duplicates(1)
6562
assert result == [("a", "c", 0.7)]
@@ -80,7 +77,6 @@ def test_rethreshold_deduplication_result() -> None:
8077
DuplicateRecord("e", False, [("z", 0.8)]),
8178
],
8279
0.8,
83-
columns=["text"],
8480
)
8581
d.rethreshold(0.85)
8682
assert d.filtered == [DuplicateRecord("d", False, [("x", 0.9)])]
@@ -96,7 +92,6 @@ def test_rethreshold_exception() -> None:
9692
DuplicateRecord("e", False, [("z", 0.8)]),
9793
],
9894
0.7,
99-
columns=["text"],
10095
)
10196
with pytest.raises(ValueError):
10297
d.rethreshold(0.6)
@@ -113,7 +108,6 @@ def test_deprecation_deduplicated_duplicates() -> None:
113108
DuplicateRecord("e", False, [("z", 0.8)]),
114109
],
115110
threshold=0.8,
116-
columns=["text"],
117111
)
118112
else:
119113
raise ValueError("deprecate `deduplicated` and `duplicates` fields in `DeduplicationResult`")
@@ -133,10 +127,14 @@ def test_selected_with_duplicates_strings() -> None:
133127
DuplicateRecord("duplicate_2", False, [("original", 0.8)]),
134128
],
135129
threshold=0.8,
136-
columns=["text"],
137130
)
138131

139-
expected = [("original", [("duplicate_1", 0.9), ("duplicate_2", 0.8)])]
132+
expected = [
133+
SelectedWithDuplicates(
134+
record="original",
135+
duplicates=[("duplicate_1", 0.9), ("duplicate_2", 0.8)],
136+
)
137+
]
140138
assert d.selected_with_duplicates == expected
141139

142140

@@ -153,9 +151,10 @@ def test_selected_with_duplicates_dicts() -> None:
153151
columns=["text"],
154152
)
155153

156-
pairs = d.selected_with_duplicates
157-
assert len(pairs) == 1
158-
kept, dups = pairs[0]
154+
items = d.selected_with_duplicates
155+
assert len(items) == 1
156+
kept = items[0].record
157+
dups = items[0].duplicates
159158
assert kept == selected
160159
assert {r["id"] for r, _ in dups} == {1, 2}
161160

@@ -173,16 +172,16 @@ def test_selected_with_duplicates_multi_column() -> None:
173172
columns=["text", "text2"],
174173
)
175174

176-
pairs = d.selected_with_duplicates
177-
assert len(pairs) == 1
178-
kept, _ = pairs[0]
175+
items = d.selected_with_duplicates
176+
assert len(items) == 1
177+
kept = items[0].record
179178
assert kept == selected
180179

181180

182181
def test_selected_with_duplicates_unhashable_values() -> None:
183182
"""Test selected_with_duplicates with unhashable values in records."""
184-
selected = {"a": [1, 2, 3]} # list -> unhashable value
185-
filtered = {"a": [1, 2, 3], "flag": True}
183+
selected = {"text": "hello", "a": [1, 2, 3]} # list -> unhashable value
184+
filtered = {"text": "hello", "a": [1, 2, 3], "flag": True}
186185

187186
d = DeduplicationResult(
188187
selected=[selected],
@@ -191,8 +190,8 @@ def test_selected_with_duplicates_unhashable_values() -> None:
191190
columns=["text"],
192191
)
193192

194-
pairs = d.selected_with_duplicates
195-
assert pairs == [(selected, [(filtered, 1.0)])]
193+
items = d.selected_with_duplicates
194+
assert items == [SelectedWithDuplicates(record=selected, duplicates=[(filtered, 1.0)])]
196195

197196

198197
def test_selected_with_duplicates_removes_internal_duplicates() -> None:
@@ -210,11 +209,11 @@ def test_selected_with_duplicates_removes_internal_duplicates() -> None:
210209
columns=["text"],
211210
)
212211

213-
selected_with_duplicates = d.selected_with_duplicates
214-
215-
assert len(selected_with_duplicates) == 1
212+
items = d.selected_with_duplicates
213+
assert len(items) == 1
216214

217-
selected_record, duplicate_list = selected_with_duplicates[0]
215+
selected_record = items[0].record
216+
duplicate_list = items[0].duplicates
218217
# Should keep the kept record unchanged
219218
assert selected_record == selected
220219
# The duplicate row must appear only once

0 commit comments

Comments
 (0)