Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
[workspace]
resolver = "2"
members = [ "integration",
members = [
"integration",
"samples",
"samples-loose-types",
]
default-members = [
"integration",
"samples",
]

Expand Down
1 change: 1 addition & 0 deletions integration/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
samples = { version = "0.1.0", path = "../samples" }
10 changes: 0 additions & 10 deletions integration/src/lib.rs

This file was deleted.

31 changes: 12 additions & 19 deletions integration/tests/examples/array.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
#![feature(autodiff)]

#[autodiff(d_array, Reverse, Active, Duplicated)]
fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 {
arr[0][0][0] * arr[1][1][1]
}

fn main() {
let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]];
let mut d_arr = [[[0.0; 2]; 2]; 2];

d_array(&arr, &mut d_arr, 1.0);

dbg!(&d_arr);
}
samples::test! {
reverse_duplicated_active;
#[autodiff(d_array, Reverse, Duplicated, Active)]
fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 {
arr[0][0][0] * arr[1][1][1]
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
let arr = [[[2.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 3.0]]];
let mut b_arr = [[[0.0; 2]; 2]; 2];

let y = d_array(&arr, &mut b_arr, 1.0);
assert_eq!(6.0, y);
assert_eq!([[[3.0, 0.0], [0.0; 2]], [[0.0; 2], [0.0, 2.0]]], b_arr);
}
}
25 changes: 0 additions & 25 deletions integration/tests/examples/box.rs

This file was deleted.

15 changes: 15 additions & 0 deletions integration/tests/examples/boxed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
samples::test! {
duplicated_active;
#[autodiff(cos_box, Reverse, Duplicated, Active)]
fn sin(x: &Box<f32>) -> f32 {
f32::sin(**x)
}

fn main() {
let x = Box::<f32>::new(3.14);
let mut df_dx = Box::<f32>::new(0.0);
let y = cos_box(&x, &mut df_dx, 1.0);
assert_eq!(f32::sin(*x), y);
assert_eq!(f32::cos(*x), *df_dx);
}
}
38 changes: 0 additions & 38 deletions integration/tests/examples/enum.rs

This file was deleted.

39 changes: 39 additions & 0 deletions integration/tests/examples/enum1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#[cfg(broken)]
samples::test! {
f32_i32;
enum Foo {
A(f32),
B(i32),
}

#[autodiff(d_bar, Reverse, Duplicated, Active)]
fn bar(x: &f32) -> f32 {
let val: Foo =
if *x > 0.0 {
Foo::A(*x)
} else {
Foo::B(12)
};

std::hint::black_box(&val);
match val {
Foo::A(f) => f * f,
Foo::B(_) => 4.0,
}
}

fn main() {
let x = 1.0;
let x2 = -1.0;
let mut dx = 0.0;
let mut dx2 = 0.0;
let out = bar(&x);
let dout = d_bar(&x, &mut dx, 1.0);
let dout2 = d_bar(&x2, &mut dx2, 1.0);
println!("x: {out}");
println!("dx: {dout}");
println!("dx2: {dout2}");
assert_eq!(2.0, dx);
assert_eq!(0.0, dx2);
}
}
20 changes: 11 additions & 9 deletions integration/tests/examples/foo.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#![feature(autodiff)]
samples::test! {
foo;
#[autodiff(df, Forward, Dual, Dual)]
fn f(x: &[f32; 2]) -> f32 { x[0] * x[0] + x[1] * x[0] }

#[autodiff(df, Forward, Dual, Dual)]
fn f(x: &[f32]) -> f32 { x[0] * x[0] + x[1] * x[0] }

fn main() {
let x = [2.0, 2.0];
let dx = [1.0, 0.0];
let (y, dy) = df(&x, &dx);
assert_eq!(dy, 2.0 * x[0] + x[1]);
fn main() {
let x = [2.0, 2.0];
let dx = [1.0, 0.0];
let (y, dy) = df(&x, &dx);
assert_eq!(y, 8.0);
assert_eq!(dy, 2.0 * x[0] + x[1]);
}
}
42 changes: 19 additions & 23 deletions integration/tests/examples/hessian_sin.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
#![feature(rustc_attrs)]
#![feature(autodiff)]

fn sin(x: &Vec<f32>, y: &mut f32) {
*y = x.into_iter().map(|x| f32::sin(*x)).sum()
}

#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)]
fn jac(x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32);
samples::test! {
vec;
#[autodiff(jac, ReverseFirst, Duplicated, Duplicated)]
fn sin(x: &Vec<f32>, y: &mut f32) {
*y = x.into_iter().map(|x| f32::sin(*x)).sum()
}

#[autodiff(jac, Forward, Const, Dual, Const, Const, Const)]
fn hessian(x: &Vec<f32>, y_x: &Vec<f32>, d_x: &mut Vec<f32>, y: &mut f32, y_t: &f32);
#[autodiff(hessian, Forward, Dual, Dual, Const, Const)]
fn jac2(x: &Vec<f32>, b_x: &mut Vec<f32>, y: &mut f32, b_y: &mut f32) {
jac(x, b_x, y, b_y);
}

fn main() {
let inp = vec![3.1415 / 2., 1.0, 0.5];
let mut d_inp = vec![0.0, 0.0, 0.0];
let mut y = 0.0;
let tang = vec![1.0, 0.0, 0.0];
hessian(&inp, &tang, &mut d_inp, &mut y, &1.0);
dbg!(&d_inp);
}

#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main()
let inp = vec![3.1415 / 2., 1.0, 0.5];
let mut b_inp = vec![0.0, 0.0, 0.0];
let mut db_inp = vec![0.0, 0.0, 0.0];
let mut y = 0.0;
let tang = vec![0.0, 1.0, 0.0];
hessian(&inp, &tang, &mut b_inp, &mut db_inp, &mut y, &mut 1.0);
assert_eq!(inp.iter().map(|x| x.sin()).sum::<f32>(), y);
assert_eq!(inp.iter().map(|x| x.cos()).collect::<Vec<_>>(), b_inp);
assert_eq!(vec![0.0, -inp[1].sin(), 0.0], db_inp);
}
}
5 changes: 5 additions & 0 deletions integration/tests/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod array;
mod boxed;
mod enum1;
mod foo;
mod hessian_sin;
3 changes: 2 additions & 1 deletion integration/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#![feature(autodiff)]


mod examples;
8 changes: 8 additions & 0 deletions samples-loose-types/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "samples-loose-types"
version = "0.1.0"
edition = "2021"
description = "Contains versions of samples that require ENZYME_LOOSE_TYPES=1 to compile"

[dependencies]
samples = { version = "0.1.0", path = "../samples" }
3 changes: 3 additions & 0 deletions samples-loose-types/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![feature(autodiff)]

mod neohookean;
Loading