From 62395b76c3f7012a4c63118849254d923026bc3b Mon Sep 17 00:00:00 2001 From: Fangyuanziti Date: Tue, 12 Jun 2018 18:02:18 +0800 Subject: [PATCH] Add additional dimensions support in Dense layer. By supporting additional dimensions, the interface of Dense layer would be like this. * Input Shape: (in_features, d1, d2, d3, ..., N) * Ouput Shape: (out_features, d2, d2, d3,..., N) --- src/layers/basic.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ad37464351..4a8effa19b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -75,7 +75,27 @@ treelike(Dense) function (a::Dense)(x) W, b, σ = a.W, a.b, a.σ - @fix σ.(W*x .+ b) + + in_size = size(x) + + if length(in_size) <= 2 + return σ.(W*x .+ b) + else + extra_dims = in_size[2:end] + out_feature_size = size(b)[1] + + out_size = (out_feature_size, extra_dims...) + + A = zeros(out_size) + + dim_range = Iterators.map(x -> 1:x, extra_dims) + + for d in Iterators.product(dim_range...) + A[:, d...] = W * x[:, d...] .+ b + end + + return σ.(A) + end end function Base.show(io::IO, l::Dense)