Skip to content

Commit 56521dc

Browse files
authored
Implement execution for @transform applied to filters' left-hand operand. (#626)
1 parent d99f842 commit 56521dc

File tree

6 files changed

+393
-93
lines changed

6 files changed

+393
-93
lines changed

trustfall_core/src/interpreter/execution.rs

Lines changed: 127 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ use crate::{
88
ir::{
99
Argument, ContextField, EdgeParameters, Eid, FieldRef, FieldValue, FoldSpecificFieldKind,
1010
IREdge, IRFold, IRQueryComponent, IRVertex, IndexedQuery, LocalField, Operation,
11-
OperationSubject, Recursive, Vid,
11+
OperationSubject, Recursive, TransformBase, Vid,
1212
},
1313
util::BTreeMapTryInsertExt,
1414
};
1515

1616
use super::{
17-
error::QueryArgumentsError, filtering::apply_filter, Adapter, AsVertex, ContextIterator,
18-
ContextOutcomeIterator, DataContext, InterpretedQuery, ResolveEdgeInfo, ResolveInfo,
19-
TaggedValue, ValueOrVec, VertexIterator,
17+
error::QueryArgumentsError,
18+
filtering::apply_filter,
19+
transformation::{apply_transforms, push_transform_argument_tag_values_onto_stack},
20+
Adapter, AsVertex, ContextIterator, ContextOutcomeIterator, DataContext, InterpretedQuery,
21+
ResolveEdgeInfo, ResolveInfo, TaggedValue, ValueOrVec, VertexIterator,
2022
};
2123

2224
#[derive(Debug, Clone)]
@@ -621,25 +623,15 @@ fn compute_fold<'query, AdapterT: Adapter<'query> + 'query>(
621623
let mut post_filtered_iterator: ContextIterator<'query, AdapterT::Vertex> =
622624
Box::new(folded_iterator);
623625
for post_fold_filter in fold.post_filters.iter() {
624-
let left = post_fold_filter.left();
625-
match left {
626-
OperationSubject::FoldSpecificField(fold_specific_field) => {
627-
let remapped_operation = post_fold_filter.map(|_| fold_specific_field.kind, |x| x);
628-
post_filtered_iterator = apply_fold_specific_filter(
629-
adapter.as_ref(),
630-
carrier,
631-
parent_component,
632-
fold.as_ref(),
633-
expanding_from.vid,
634-
&remapped_operation,
635-
post_filtered_iterator,
636-
);
637-
}
638-
OperationSubject::TransformedField(_) => todo!(),
639-
OperationSubject::LocalField(_) => {
640-
unreachable!("unexpectedly found a fold post-filtering step that references a LocalField: {fold:#?}");
641-
}
642-
}
626+
post_filtered_iterator = apply_fold_specific_filter(
627+
adapter.as_ref(),
628+
carrier,
629+
parent_component,
630+
fold.as_ref(),
631+
expanding_from.vid,
632+
post_fold_filter,
633+
post_filtered_iterator,
634+
);
643635
}
644636

645637
// Compute the outputs from this fold.
@@ -809,7 +801,63 @@ fn apply_filter_with_non_folded_field_subject<'query, AdapterT: Adapter<'query>>
809801
filter.map_left(|_| field),
810802
iterator,
811803
),
812-
OperationSubject::TransformedField(_) => todo!(),
804+
OperationSubject::TransformedField(transformed) => {
805+
let prepped_iterator = push_transform_argument_tag_values_onto_stack(
806+
adapter,
807+
carrier,
808+
component,
809+
current_vid,
810+
&transformed.value.transforms,
811+
iterator,
812+
);
813+
814+
let query_variables =
815+
Arc::clone(&carrier.query.as_ref().expect("query was not returned").arguments);
816+
let transform_data = Arc::clone(&transformed.value);
817+
818+
match &transformed.value.base {
819+
TransformBase::ContextField(field) => {
820+
assert_eq!(current_vid, field.vertex_id, "filter left-hand side was a transformed field from a different vertex: {current_vid:?} {filter:?}");
821+
let local_field = LocalField {
822+
field_name: field.field_name.clone(),
823+
field_type: field.field_type.clone(),
824+
};
825+
826+
let filter_input_iterator = Box::new(
827+
compute_local_field_with_separate_value(
828+
adapter,
829+
carrier,
830+
component,
831+
current_vid,
832+
&local_field,
833+
prepped_iterator,
834+
)
835+
.map(move |(mut ctx, mut value)| {
836+
value = apply_transforms(
837+
&transform_data,
838+
&query_variables,
839+
&mut ctx.values,
840+
value,
841+
);
842+
ctx.values.push(value);
843+
ctx
844+
}),
845+
);
846+
847+
apply_filter(
848+
adapter,
849+
carrier,
850+
component,
851+
current_vid,
852+
&filter.map(|_| (), |r| r),
853+
filter_input_iterator,
854+
)
855+
}
856+
TransformBase::FoldSpecificField(..) => unreachable!(
857+
"illegal filter over fold-specific field passed to this function: {filter:?}"
858+
),
859+
}
860+
}
813861
OperationSubject::FoldSpecificField(..) => unreachable!(
814862
"illegal filter over fold-specific field passed to this function: {filter:?}"
815863
),
@@ -844,27 +892,70 @@ fn apply_fold_specific_filter<'query, AdapterT: Adapter<'query>>(
844892
component: &IRQueryComponent,
845893
fold: &IRFold,
846894
current_vid: Vid,
847-
filter: &Operation<FoldSpecificFieldKind, &Argument>,
895+
filter: &Operation<OperationSubject, Argument>,
848896
iterator: ContextIterator<'query, AdapterT::Vertex>,
849897
) -> ContextIterator<'query, AdapterT::Vertex> {
850-
let fold_specific_field = filter.left();
851-
let field_iterator = Box::new(compute_fold_specific_field_with_separate_value(fold.eid, fold_specific_field, iterator).map(|(mut ctx, tagged_value)| {
852-
let value = match tagged_value {
853-
TaggedValue::Some(value) => value,
854-
TaggedValue::NonexistentOptional => {
855-
unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}")
898+
let left = filter.left();
899+
let (fold_specific_field, transform_data) = match left {
900+
OperationSubject::FoldSpecificField(field) => (field, None),
901+
OperationSubject::TransformedField(transformed) => match &transformed.value.base {
902+
TransformBase::FoldSpecificField(field) => (field, Some(&transformed.value)),
903+
TransformBase::ContextField(_) => {
904+
unreachable!("post-fold filter does not refer to a fold-specific field: {left:?}")
856905
}
857-
};
858-
ctx.values.push(value);
859-
ctx
860-
}));
906+
},
907+
OperationSubject::LocalField(_) => {
908+
unreachable!("post-fold filter does not refer to a fold-specific field: {left:?}")
909+
}
910+
};
911+
912+
let field_iterator: ContextIterator<'query, AdapterT::Vertex> = if let Some(transform_data) =
913+
transform_data
914+
{
915+
let prepped_iterator = push_transform_argument_tag_values_onto_stack(
916+
adapter,
917+
carrier,
918+
component,
919+
current_vid,
920+
&transform_data.transforms,
921+
iterator,
922+
);
923+
924+
let query_variables =
925+
Arc::clone(&carrier.query.as_ref().expect("query was not returned").arguments);
926+
let transform_data = Arc::clone(transform_data);
927+
Box::new(compute_fold_specific_field_with_separate_value(fold.eid, &fold_specific_field.kind, prepped_iterator).map(move |(mut ctx, tagged_value)| {
928+
let mut value = match tagged_value {
929+
TaggedValue::Some(value) => value,
930+
TaggedValue::NonexistentOptional => {
931+
unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}")
932+
}
933+
};
934+
935+
value = apply_transforms(&transform_data, &query_variables, &mut ctx.values, value);
936+
937+
ctx.values.push(value);
938+
ctx
939+
}))
940+
} else {
941+
Box::new(compute_fold_specific_field_with_separate_value(fold.eid, &fold_specific_field.kind, iterator).map(|(mut ctx, tagged_value)| {
942+
let value = match tagged_value {
943+
TaggedValue::Some(value) => value,
944+
TaggedValue::NonexistentOptional => {
945+
unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}")
946+
}
947+
};
948+
ctx.values.push(value);
949+
ctx
950+
}))
951+
};
861952

862953
apply_filter(
863954
adapter,
864955
carrier,
865956
component,
866957
current_vid,
867-
&filter.map(|_| (), |r| *r),
958+
&filter.map(|_| (), |r| r),
868959
field_iterator,
869960
)
870961
}

trustfall_core/src/interpreter/filtering.rs

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@ use std::{fmt::Debug, mem};
22

33
use regex::Regex;
44

5-
use crate::ir::{Argument, FieldRef, FieldValue, IRQueryComponent, LocalField, Operation, Vid};
5+
use crate::ir::{Argument, FieldValue, IRQueryComponent, Operation, Vid};
66

77
use super::{
8-
execution::{
9-
compute_context_field_with_separate_value, compute_fold_specific_field_with_separate_value,
10-
compute_local_field_with_separate_value, QueryCarrier,
11-
},
12-
Adapter, ContextIterator, ContextOutcomeIterator, TaggedValue,
8+
execution::QueryCarrier, tags::compute_tag_with_separate_value, Adapter, ContextIterator,
9+
ContextOutcomeIterator, TaggedValue,
1310
};
1411

1512
#[inline(always)]
@@ -281,56 +278,15 @@ pub(super) fn apply_filter<'query, AdapterT: Adapter<'query>>(
281278
let right_value = query_arguments[var.variable_name.as_ref()].to_owned();
282279
apply_filter_with_static_argument_value(filter, right_value, iterator)
283280
}
284-
Some(Argument::Tag(FieldRef::ContextField(context_field))) => {
285-
// TODO: Benchmark if it would be faster to duplicate the filtering code to special-case
286-
// the situation when the tag is always known to exist, so we don't have to unwrap
287-
// a TaggedValue enum, because we know it would be TaggedValue::Some.
288-
let argument_value_iterator = if context_field.vertex_id == current_vid {
289-
// This tag is from the vertex we're currently filtering. That means the field
290-
// whose value we want to get is actually local, so there's no need to compute it
291-
// using the more expensive approach we use for non-local fields.
292-
let local_equivalent_field = LocalField {
293-
field_name: context_field.field_name.clone(),
294-
field_type: context_field.field_type.clone(),
295-
};
296-
Box::new(
297-
compute_local_field_with_separate_value(
298-
adapter,
299-
carrier,
300-
component,
301-
current_vid,
302-
&local_equivalent_field,
303-
iterator,
304-
)
305-
.map(|(ctx, value)| (ctx, TaggedValue::Some(value))),
306-
)
307-
} else {
308-
compute_context_field_with_separate_value(
309-
adapter,
310-
carrier,
311-
component,
312-
context_field,
313-
iterator,
314-
)
315-
};
316-
apply_filter_with_tagged_argument_value(filter, argument_value_iterator)
317-
}
318-
Some(Argument::Tag(field_ref @ FieldRef::FoldSpecificField(fold_field))) => {
319-
let argument_value_iterator = if component.folds.contains_key(&fold_field.fold_eid) {
320-
compute_fold_specific_field_with_separate_value(
321-
fold_field.fold_eid,
322-
&fold_field.kind,
323-
iterator,
324-
)
325-
} else {
326-
// This value represents an imported tag value from an outer component.
327-
// Grab its value from the context itself.
328-
let cloned_ref = field_ref.clone();
329-
Box::new(iterator.map(move |ctx| {
330-
let right_value = ctx.imported_tags[&cloned_ref].clone();
331-
(ctx, right_value)
332-
}))
333-
};
281+
Some(Argument::Tag(field_ref)) => {
282+
let argument_value_iterator = compute_tag_with_separate_value(
283+
adapter,
284+
carrier,
285+
component,
286+
current_vid,
287+
field_ref,
288+
iterator,
289+
);
334290
apply_filter_with_tagged_argument_value(filter, argument_value_iterator)
335291
}
336292
Some(Argument::Tag(FieldRef::TransformedField(_))) => {

trustfall_core/src/interpreter/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ mod filtering;
1717
pub mod helpers;
1818
mod hints;
1919
pub mod replay;
20+
mod tags;
2021
pub mod trace;
22+
mod transformation;
2123

2224
pub use hints::{
2325
CandidateValue, DynamicallyResolvedValue, EdgeInfo, NeighborInfo, QueryInfo, Range,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use crate::ir::{FieldRef, IRQueryComponent, LocalField, Vid};
2+
3+
use super::{
4+
execution::{
5+
compute_context_field_with_separate_value, compute_fold_specific_field_with_separate_value,
6+
compute_local_field_with_separate_value, QueryCarrier,
7+
},
8+
Adapter, ContextIterator, DataContext, TaggedValue,
9+
};
10+
11+
pub(super) fn compute_tag_with_separate_value<'query, AdapterT: Adapter<'query>>(
12+
adapter: &AdapterT,
13+
carrier: &mut QueryCarrier,
14+
component: &IRQueryComponent,
15+
current_vid: Vid,
16+
field_ref: &FieldRef,
17+
iterator: ContextIterator<'query, AdapterT::Vertex>,
18+
) -> Box<dyn Iterator<Item = (DataContext<AdapterT::Vertex>, TaggedValue)> + 'query> {
19+
match field_ref {
20+
FieldRef::ContextField(context_field) => {
21+
// TODO: Benchmark if it would be faster to duplicate the code to special-case
22+
// the situation when the tag is always known to exist, so we don't have to unwrap
23+
// a TaggedValue enum, because we know it would be TaggedValue::Some.
24+
if context_field.vertex_id == current_vid {
25+
// This tag is from the vertex we're currently evaluating. That means the field
26+
// whose value we want to get is actually local, so there's no need to compute it
27+
// using the more expensive approach we use for non-local fields.
28+
let local_equivalent_field = LocalField {
29+
field_name: context_field.field_name.clone(),
30+
field_type: context_field.field_type.clone(),
31+
};
32+
Box::new(
33+
compute_local_field_with_separate_value(
34+
adapter,
35+
carrier,
36+
component,
37+
current_vid,
38+
&local_equivalent_field,
39+
iterator,
40+
)
41+
.map(|(ctx, value)| (ctx, TaggedValue::Some(value))),
42+
)
43+
} else {
44+
compute_context_field_with_separate_value(
45+
adapter,
46+
carrier,
47+
component,
48+
context_field,
49+
iterator,
50+
)
51+
}
52+
}
53+
FieldRef::FoldSpecificField(fold_field) => {
54+
if component.folds.contains_key(&fold_field.fold_eid) {
55+
compute_fold_specific_field_with_separate_value(
56+
fold_field.fold_eid,
57+
&fold_field.kind,
58+
iterator,
59+
)
60+
} else {
61+
// This value represents an imported tag value from an outer component.
62+
// Grab its value from the context itself.
63+
let cloned_ref = field_ref.clone();
64+
Box::new(iterator.map(move |ctx| {
65+
let right_value = ctx.imported_tags[&cloned_ref].clone();
66+
(ctx, right_value)
67+
}))
68+
}
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)