Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit aff3ddd

Browse files
committed
Remove use of aesara.tensor.nnet in other tests
1 parent e757807 commit aff3ddd

6 files changed

Lines changed: 5 additions & 455 deletions

File tree

tests/link/jax/test_elemwise.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
from aesara.graph.fg import FunctionGraph
66
from aesara.graph.op import get_test_value
77
from aesara.tensor import elemwise as at_elemwise
8-
from aesara.tensor import nnet as at_nnet
98
from aesara.tensor.math import SoftmaxGrad
109
from aesara.tensor.math import all as at_all
11-
from aesara.tensor.math import prod
10+
from aesara.tensor.math import log_softmax, prod, softmax
1211
from aesara.tensor.math import sum as at_sum
1312
from aesara.tensor.type import matrix, tensor, vector
1413
from tests.link.jax.test_basic import compare_jax_and_py
@@ -76,7 +75,7 @@ def test_jax_CAReduce():
7675
def test_softmax(axis):
7776
x = matrix("x")
7877
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
79-
out = at_nnet.softmax(x, axis=axis)
78+
out = softmax(x, axis=axis)
8079
fgraph = FunctionGraph([x], [out])
8180
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
8281

@@ -85,7 +84,7 @@ def test_softmax(axis):
8584
def test_logsoftmax(axis):
8685
x = matrix("x")
8786
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
88-
out = at_nnet.logsoftmax(x, axis=axis)
87+
out = log_softmax(x, axis=axis)
8988
fgraph = FunctionGraph([x], [out])
9089
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
9190

tests/link/jax/test_scalar.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from aesara.graph.fg import FunctionGraph
88
from aesara.graph.op import get_test_value
99
from aesara.scalar.basic import Composite
10-
from aesara.tensor import nnet as at_nnet
1110
from aesara.tensor.elemwise import Elemwise
1211
from aesara.tensor.math import all as at_all
1312
from aesara.tensor.math import (
@@ -128,10 +127,6 @@ def test_nnet():
128127
fgraph = FunctionGraph([x], [out])
129128
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
130129

131-
out = at_nnet.ultra_fast_sigmoid(x)
132-
fgraph = FunctionGraph([x], [out])
133-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
134-
135130
out = softplus(x)
136131
fgraph = FunctionGraph([x], [out])
137132
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

tests/scalar/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_grad_inrange():
444444

445445
def test_grad_abs():
446446
a = fscalar("a")
447-
b = aesara.tensor.nnet.relu(a)
447+
b = 0.5 * (a + aesara.tensor.abs(a))
448448
c = aesara.grad(b, a)
449449
f = aesara.function([a], c, mode=Mode(optimizer=None))
450450
# Currently Aesara return 0.5, but it isn't sure it won't change

tests/scan/test_basic.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from aesara.tensor.math import dot, exp, mean, sigmoid
4444
from aesara.tensor.math import sum as at_sum
4545
from aesara.tensor.math import tanh
46-
from aesara.tensor.nnet import categorical_crossentropy
4746
from aesara.tensor.random import normal
4847
from aesara.tensor.random.utils import RandomStream
4948
from aesara.tensor.shape import Shape_i, reshape, specify_shape
@@ -58,7 +57,6 @@
5857
fscalar,
5958
ftensor3,
6059
fvector,
61-
imatrix,
6260
iscalar,
6361
ivector,
6462
lscalar,
@@ -3810,36 +3808,6 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
38103808

38113809
# TODO FIXME: What is this testing? At least assert something.
38123810

3813-
def test_grad_two_scans(self):
3814-
3815-
# data input & output
3816-
x = tensor3("x")
3817-
t = imatrix("t")
3818-
3819-
# forward pass
3820-
W = shared(
3821-
np.random.default_rng(utt.fetch_seed()).random((2, 2)).astype("float32"),
3822-
name="W",
3823-
borrow=True,
3824-
)
3825-
3826-
def forward_scanner(x_t):
3827-
a2_t = dot(x_t, W)
3828-
y_t = softmax_graph(a2_t)
3829-
return y_t
3830-
3831-
y, _ = scan(fn=forward_scanner, sequences=x, outputs_info=[None])
3832-
3833-
# loss function
3834-
def error_scanner(y_t, t_t):
3835-
return mean(categorical_crossentropy(y_t, t_t))
3836-
3837-
L, _ = scan(fn=error_scanner, sequences=[y, t], outputs_info=[None])
3838-
L = mean(L)
3839-
3840-
# backward pass
3841-
grad(L, [W])
3842-
38433811
def _grad_mout_helper(self, n_iters, mode):
38443812
rng = np.random.default_rng(utt.fetch_seed())
38453813
n_hid = 3

0 commit comments

Comments
 (0)