1313def gen_tree (t_dtype , dim ):
1414 torch .manual_seed (0 )
1515 newick = read_newick ('test_tree.newick' )
16- L = 300
16+ L = 100
1717 leaf_log_p = torch .randn ([L , newick ['num_leaf' ], dim ], dtype = t_dtype )
1818 leaf_log_p = torch .nn .functional .log_softmax (leaf_log_p , dim = 2 )
1919
@@ -23,7 +23,7 @@ def gen_tree(t_dtype, dim):
2323
2424def gen_data (t_dtype , dim , seed ):
2525 torch .manual_seed (seed )
26- L = 300
26+ L = 100
2727 S = torch .exp (torch .randn (L ,dim , dim , dtype = t_dtype ))
2828 sqrt_pi = torch .sqrt (torch .nn .functional .softmax (torch .randn (L ,dim , dtype = t_dtype ), dim = 1 ))
2929 return S , sqrt_pi
@@ -108,7 +108,7 @@ def test_grads():
108108 helper_test ("f64" , 20 , True )
109109
110110def test_grads_single_model ():
111+ helper_test ("f64" , 4 , True , single_model = True )
112+ helper_test ("f64" , 20 , True , single_model = True )
111113 helper_test ("f32" , 4 , True , single_model = True )
112114 helper_test ("f32" , 20 , True , single_model = True )
113- helper_test ("f64" , 4 , True , single_model = True )
114- helper_test ("f64" , 20 , True , single_model = True )
0 commit comments