Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ function (r::Rule)(term)
end
end

"""
rewrite_rhs(expr::Expr)

Rewrite the `expr` by dealing with `:where` if necessary.
The `:where` is rewritten from, for example, `~x where f(~x)` to `f(~x) ? ~x : nothing`.
"""
function rewrite_rhs(expr::Expr)
if expr.head == :where
rhs = expr.args[1]
predicate = expr.args[2]
expr = Meta.parse("$predicate ? $rhs : nothing")
end
return expr
end
rewrite_rhs(expr) = expr

"""
@rule LHS => RHS

Expand Down Expand Up @@ -221,7 +237,8 @@ julia> r(2 * (a+b+c))

**Predicates**:

Predicates can be used on both `~x` and `~~x` by using the `~x::f` or `~~x::f`.
There are two kinds of predicates, namely over slot variables and over the whole rule.
For the former, predicates can be used on both `~x` and `~~x` by using the `~x::f` or `~~x::f`.
Here `f` can be any julia function. In the case of a slot the function gets a single
matched subexpression, in the case of segment, it gets an array of matched expressions.

Expand Down Expand Up @@ -249,6 +266,25 @@ sin((a + c))

Predicate function gets an array of values if attached to a segment variable (`~~x`).

For the predicate over the whole rule, use `@rule <LHS> => <RHS> where <predicate>`:

```
julia> @syms a b;

julia> predicate(x) = x === a;

julia> r = @rule ~x => ~x where f(~x);

julia> r(a)
a

julia> r(b) === nothing
true
```

Note that this is syntactic sugar and that it is the same as something like
`@rule ~x => f(~x) ? ~x : nothing`.

**Context**:

_In predicates_: Contextual predicates are functions wrapped in the `Contextual` type.
Expand All @@ -264,7 +300,8 @@ of an expression.
"""
macro rule(expr)
@assert expr.head == :call && expr.args[1] == :(=>)
lhs,rhs = expr.args[2], expr.args[3]
lhs = expr.args[2]
rhs = rewrite_rhs(expr.args[3])
keys = Symbol[]
lhs_term = makepattern(lhs, keys)
unique!(keys)
Expand Down
15 changes: 15 additions & 0 deletions test/rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ end
@test isnothing(@rule(f(1)(a) => 2)(sin(a)))
@test @rule($(f(1))(a) => 2)(sin(a)) == 2
end

@testset "where" begin
expected = Meta.parse("f(~x) ? ~x + ~y : nothing")
@test SymbolicUtils.rewrite_rhs(:((~x + ~y) where f(~x))) == expected

@syms a b
f(x) = x === a
r = @rule ~x => ~x where f(~x)
@eqtest r(a) == a
@test isnothing(r(b))

r = @acrule ~x => ~x where f(~x)
@eqtest r(a) == a
@test r(b) === nothing
end