Skip to content

Commit 63b83b8

Browse files
committed
added tests for expressions
1 parent c4cd995 commit 63b83b8

File tree

7 files changed

+170
-48
lines changed

7 files changed

+170
-48
lines changed

google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,23 @@ class Selectable(Expr):
340340
"""Base class for expressions that can be selected or aliased in projection stages."""
341341

342342
@abstractmethod
343-
def _to_map(self):
343+
def _to_map(self) -> tuple[str, Value]:
344+
"""
345+
Returns a str: Value representation of the Selectable
346+
"""
344347
raise NotImplementedError
345348

349+
@classmethod
350+
def _value_from_selectables(cls, *selectables: Selectable) -> Value:
351+
"""
352+
Returns a Value representing a map of Selectables
353+
"""
354+
return Value(
355+
map_value={
356+
"fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]}
357+
}
358+
)
359+
346360

347361
class Field(Selectable):
348362
"""Represents a reference to a field within a document."""
@@ -384,8 +398,8 @@ class FilterCondition(Function):
384398
def __init__(
385399
self,
386400
*args,
387-
use_infix_repr:bool = True,
388-
infix_name_override:str | None= None,
401+
use_infix_repr: bool = True,
402+
infix_name_override: str | None = None,
389403
**kwargs,
390404
):
391405
self._use_infix_repr = use_infix_repr
@@ -521,7 +535,9 @@ class In(FilterCondition):
521535
"""Represents checking if an expression's value is within a list of values."""
522536

523537
def __init__(self, left: Expr, others: List[Expr]):
524-
super().__init__("in", [left, ListOfExprs(others)], infix_name_override="in_any")
538+
super().__init__(
539+
"in", [left, ListOfExprs(others)], infix_name_override="in_any"
540+
)
525541

526542

527543
class IsNaN(FilterCondition):

google/cloud/firestore_v1/pipeline_stages.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,7 @@ def __init__(self, *selections: str | Selectable):
128128
self.projections = [Field(s) if isinstance(s, str) else s for s in selections]
129129

130130
def _pb_args(self) -> list[Value]:
131-
return [
132-
Value(
133-
map_value={
134-
"fields": {
135-
m[0]: m[1] for m in [f._to_map() for f in self.projections]
136-
}
137-
}
138-
)
139-
]
131+
return [Selectable._value_from_selectables(*self.projections)]
140132

141133

142134
class Sort(Stage):

tests/unit/v1/test_async_pipeline.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,19 +309,22 @@ async def test_async_pipeline_execute_with_transaction():
309309
assert request.database == "projects/A/databases/B"
310310
assert request.transaction == b"123"
311311

312-
@pytest.mark.parametrize("method,args,result_cls", [
313-
("select", (), stages.Select),
314-
("where", (mock.Mock(),), stages.Where),
315-
("sort", (), stages.Sort),
316-
("offset", (1,), stages.Offset),
317-
("limit", (1,), stages.Limit),
318-
319-
])
312+
313+
@pytest.mark.parametrize(
314+
"method,args,result_cls",
315+
[
316+
("select", (), stages.Select),
317+
("where", (mock.Mock(),), stages.Where),
318+
("sort", (), stages.Sort),
319+
("offset", (1,), stages.Offset),
320+
("limit", (1,), stages.Limit),
321+
],
322+
)
320323
def test_async_pipeline_methods(method, args, result_cls):
321324
start_ppl = _make_async_pipeline()
322325
method_ptr = getattr(start_ppl, method)
323326
result_ppl = method_ptr(*args)
324327
assert result_ppl != start_ppl
325328
assert len(start_ppl.stages) == 0
326329
assert len(result_ppl.stages) == 1
327-
assert isinstance(result_ppl.stages[0], result_cls)
330+
assert isinstance(result_ppl.stages[0], result_cls)

tests/unit/v1/test_pipeline.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,19 +290,22 @@ def test_pipeline_execute_with_transaction():
290290
assert request.database == "projects/A/databases/B"
291291
assert request.transaction == b"123"
292292

293-
@pytest.mark.parametrize("method,args,result_cls", [
294-
("select", (), stages.Select),
295-
("where", (mock.Mock(),), stages.Where),
296-
("sort", (), stages.Sort),
297-
("offset", (1,), stages.Offset),
298-
("limit", (1,), stages.Limit),
299-
300-
])
293+
294+
@pytest.mark.parametrize(
295+
"method,args,result_cls",
296+
[
297+
("select", (), stages.Select),
298+
("where", (mock.Mock(),), stages.Where),
299+
("sort", (), stages.Sort),
300+
("offset", (1,), stages.Offset),
301+
("limit", (1,), stages.Limit),
302+
],
303+
)
301304
def test_pipeline_methods(method, args, result_cls):
302305
start_ppl = _make_pipeline()
303306
method_ptr = getattr(start_ppl, method)
304307
result_ppl = method_ptr(*args)
305308
assert result_ppl != start_ppl
306309
assert len(start_ppl.stages) == 0
307310
assert len(result_ppl.stages) == 1
308-
assert isinstance(result_ppl.stages[0], result_cls)
311+
assert isinstance(result_ppl.stages[0], result_cls)

tests/unit/v1/test_pipeline_expressions.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,47 @@ def mock_client():
3333
return client
3434

3535

36+
class TestOrdering:
37+
@pytest.mark.parametrize(
38+
"direction_arg,expected_direction",
39+
[
40+
("ASCENDING", expr.Ordering.Direction.ASCENDING),
41+
("DESCENDING", expr.Ordering.Direction.DESCENDING),
42+
("ascending", expr.Ordering.Direction.ASCENDING),
43+
("descending", expr.Ordering.Direction.DESCENDING),
44+
(expr.Ordering.Direction.ASCENDING, expr.Ordering.Direction.ASCENDING),
45+
(expr.Ordering.Direction.DESCENDING, expr.Ordering.Direction.DESCENDING),
46+
],
47+
)
48+
def test_ctor(self, direction_arg, expected_direction):
49+
instance = expr.Ordering("field1", direction_arg)
50+
assert isinstance(instance.expr, expr.Field)
51+
assert instance.expr.path == "field1"
52+
assert instance.order_dir == expected_direction
53+
54+
def test_repr(self):
55+
field_expr = expr.Field.of("field1")
56+
instance = expr.Ordering(field_expr, "ASCENDING")
57+
repr_str = repr(instance)
58+
assert repr_str == "Field.of('field1').ascending()"
59+
60+
instance = expr.Ordering(field_expr, "DESCENDING")
61+
repr_str = repr(instance)
62+
assert repr_str == "Field.of('field1').descending()"
63+
64+
def test_to_pb(self):
65+
field_expr = expr.Field.of("field1")
66+
instance = expr.Ordering(field_expr, "ASCENDING")
67+
result = instance._to_pb()
68+
assert result.map_value.fields["expression"].field_reference_value == "field1"
69+
assert result.map_value.fields["direction"].string_value == "ascending"
70+
71+
instance = expr.Ordering(field_expr, "DESCENDING")
72+
result = instance._to_pb()
73+
assert result.map_value.fields["expression"].field_reference_value == "field1"
74+
assert result.map_value.fields["direction"].string_value == "descending"
75+
76+
3677
class TestExpr:
3778
def test_ctor(self):
3879
"""
@@ -116,6 +157,65 @@ def test_repr(self, input_val, expected):
116157
assert repr_string == expected
117158

118159

160+
class TestListOfExprs:
161+
def test_to_pb(self):
162+
instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)])
163+
result = instance._to_pb()
164+
assert len(result.array_value.values) == 2
165+
assert result.array_value.values[0].integer_value == 1
166+
assert result.array_value.values[1].integer_value == 2
167+
168+
def test_empty_to_pb(self):
169+
instance = expr.ListOfExprs([])
170+
result = instance._to_pb()
171+
assert len(result.array_value.values) == 0
172+
173+
def test_repr(self):
174+
instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)])
175+
repr_string = repr(instance)
176+
assert repr_string == "ListOfExprs([Constant.of(1), Constant.of(2)])"
177+
empty_instance = expr.ListOfExprs([])
178+
empty_repr_string = repr(empty_instance)
179+
assert empty_repr_string == "ListOfExprs([])"
180+
181+
182+
class TestSelectable:
183+
def test_ctor(self):
184+
"""
185+
Base class should be abstract
186+
"""
187+
with pytest.raises(TypeError):
188+
expr.Selectable()
189+
190+
def test_value_from_selectables(self):
191+
selectable_list = [expr.Field.of("field1"), expr.Field.of("field2")]
192+
result = expr.Selectable._value_from_selectables(*selectable_list)
193+
assert len(result.map_value.fields) == 2
194+
assert result.map_value.fields["field1"].field_reference_value == "field1"
195+
assert result.map_value.fields["field2"].field_reference_value == "field2"
196+
197+
class TestField:
198+
def test_repr(self):
199+
instance = expr.Field.of("field1")
200+
repr_string = repr(instance)
201+
assert repr_string == "Field.of('field1')"
202+
203+
def test_of(self):
204+
instance = expr.Field.of("field1")
205+
assert instance.path == "field1"
206+
207+
def test_to_pb(self):
208+
instance = expr.Field.of("field1")
209+
result = instance._to_pb()
210+
assert result.field_reference_value == "field1"
211+
212+
def test_to_map(self):
213+
instance = expr.Field.of("field1")
214+
result = instance._to_map()
215+
assert result[0] == "field1"
216+
assert result[1] == Value(field_reference_value="field1")
217+
218+
119219
class TestFilterCondition:
120220
def test__from_query_filter_pb_composite_filter_or(self, mock_client):
121221
"""
@@ -420,21 +520,23 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client):
420520
with pytest.raises(TypeError, match="Unexpected filter type"):
421521
FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client)
422522

423-
424-
@pytest.mark.parametrize("method,args,result_cls", [
425-
("eq", (2,), expr.Eq),
426-
("neq", (2,), expr.Neq),
427-
("lt", (2,), expr.Lt),
428-
("lte", (2,), expr.Lte),
429-
("gt", (2,), expr.Gt),
430-
("gte", (2,), expr.Gte),
431-
("in_any", ([None],), expr.In),
432-
("not_in_any", ([None],), expr.Not),
433-
("array_contains", (None,), expr.ArrayContains),
434-
("array_contains_any", ([None],), expr.ArrayContainsAny),
435-
("is_nan", (), expr.IsNaN),
436-
("exists", (), expr.Exists),
437-
])
523+
@pytest.mark.parametrize(
524+
"method,args,result_cls",
525+
[
526+
("eq", (2,), expr.Eq),
527+
("neq", (2,), expr.Neq),
528+
("lt", (2,), expr.Lt),
529+
("lte", (2,), expr.Lte),
530+
("gt", (2,), expr.Gt),
531+
("gte", (2,), expr.Gte),
532+
("in_any", ([None],), expr.In),
533+
("not_in_any", ([None],), expr.Not),
534+
("array_contains", (None,), expr.ArrayContains),
535+
("array_contains_any", ([None],), expr.ArrayContainsAny),
536+
("is_nan", (), expr.IsNaN),
537+
("exists", (), expr.Exists),
538+
],
539+
)
438540
def test_infix_call(self, method, args, result_cls):
439541
"""
440542
most FilterExpressions should support infix execution
@@ -483,7 +585,10 @@ def test_array_contains_any(self):
483585
assert isinstance(instance.params[1], ListOfExprs)
484586
assert instance.params[0] == arg1
485587
assert instance.params[1].exprs == [arg2, arg3]
486-
assert repr(instance) == "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))"
588+
assert (
589+
repr(instance)
590+
== "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))"
591+
)
487592

488593
def test_exists(self):
489594
arg1 = self._make_arg("Field")

tests/unit/v1/test_pipeline_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_collection_group(self):
5353
assert isinstance(first_stage, stages.CollectionGroup)
5454
assert first_stage.collection_id == "id"
5555

56+
5657
class TestPipelineSourceWithAsyncClient(TestPipelineSource):
5758
"""
5859
When an async client is used, it should produce async pipelines

tests/unit/v1/test_pipeline_stages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def _make_one(self, *args, **kwargs):
168168
def test_repr(self):
169169
instance = self._make_one("field1", Field.of("field2"))
170170
repr_str = repr(instance)
171-
assert repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])"
171+
assert (
172+
repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])"
173+
)
172174

173175
def test_to_pb(self):
174176
instance = self._make_one("field1", "field2.subfield", Field.of("field3"))

0 commit comments

Comments
 (0)