Skip to content

Commit 5023c9d

Browse files
Merge pull request #19 from soedinglab/dev
change tests
2 parents eefc4ec + 82a3c53 commit 5023c9d

2 files changed

Lines changed: 10 additions & 9 deletions

File tree

benchmark_test/felsenstein.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def __init__(self, tree):
4646

4747
def log_likelihood(self, S, sqrt_pi):
4848
if S.shape[0] == 1:
49-
S = S.expand(300, -1, -1)
50-
sqrt_pi = sqrt_pi.expand(300, -1)
49+
# Quick hack to get single side to work for the testing, ideally it should know what L is here
50+
S = S.expand(100, -1, -1)
51+
sqrt_pi = sqrt_pi.expand(100, -1)
5152
matrices = rate_matrix_from_S(S, sqrt_pi)
5253
self.root.compute(matrices)
5354

@@ -69,13 +70,13 @@ def gradients(tree, S, sqrt_pi):
6970
S.requires_grad = True
7071
sqrt_pi.requires_grad = True
7172
if S.shape[0] == 1:
72-
S_expand = S.expand(300, -1, -1)
73-
sqrt_pi_expand = sqrt_pi.expand(300, -1)
73+
S_expand = S.expand(100, -1, -1)
74+
sqrt_pi_expand = sqrt_pi.expand(100, -1)
7475
else:
7576
S_expand = S
7677
sqrt_pi_expand = sqrt_pi
7778
logP = tree.log_likelihood(S_expand, sqrt_pi_expand)
7879

7980
logP.sum().backward()
8081

81-
return S.grad, sqrt_pi.grad
82+
return S.grad, sqrt_pi.grad

benchmark_test/test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def gen_tree(t_dtype, dim):
1414
torch.manual_seed(0)
1515
newick = read_newick('test_tree.newick')
16-
L = 300
16+
L = 100
1717
leaf_log_p = torch.randn([L, newick['num_leaf'], dim], dtype=t_dtype)
1818
leaf_log_p = torch.nn.functional.log_softmax(leaf_log_p, dim=2)
1919

@@ -23,7 +23,7 @@ def gen_tree(t_dtype, dim):
2323

2424
def gen_data(t_dtype, dim, seed):
2525
torch.manual_seed(seed)
26-
L = 300
26+
L = 100
2727
S = torch.exp(torch.randn(L,dim, dim, dtype=t_dtype))
2828
sqrt_pi = torch.sqrt(torch.nn.functional.softmax(torch.randn(L,dim, dtype=t_dtype), dim = 1))
2929
return S, sqrt_pi
@@ -108,7 +108,7 @@ def test_grads():
108108
helper_test("f64", 20, True)
109109

110110
def test_grads_single_model():
111+
helper_test("f64", 4, True, single_model=True)
112+
helper_test("f64", 20, True, single_model=True)
111113
helper_test("f32", 4, True, single_model=True)
112114
helper_test("f32", 20, True, single_model=True)
113-
helper_test("f64", 4, True, single_model=True)
114-
helper_test("f64", 20, True, single_model=True)

0 commit comments

Comments
 (0)