diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ad37464351..4a8effa19b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -75,7 +75,27 @@ treelike(Dense) function (a::Dense)(x) W, b, σ = a.W, a.b, a.σ - @fix σ.(W*x .+ b) + + in_size = size(x) + + if length(in_size) <= 2 + return σ.(W*x .+ b) + else + extra_dims = in_size[2:end] + out_feature_size = size(b)[1] + + out_size = (out_feature_size, extra_dims...) + + A = zeros(out_size) + + dim_range = Iterators.map(x -> 1:x, extra_dims) + + for d in Iterators.product(dim_range...) + A[:, d...] = W * x[:, d...] .+ b + end + + return σ.(A) + end end function Base.show(io::IO, l::Dense)