diff --git a/NEWS.md b/NEWS.md index 69211a9ea4..4768abea2b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.12.0 +* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405) * Excise datasets in favour of other providers in the julia ecosystem. * other new features and bug fixes (see GitHub's releases page) * Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1aa81895e3..300cece909 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -85,14 +85,19 @@ extraChain(::Tuple{}, x) = () """ Dense(in::Integer, out::Integer, σ = identity; bias=true) -Create a traditional `Dense` layer with parameters `W` and `b`. +Create a traditional `Dense` layer with in×out weight matrix `W` and +bias vector `b` of length `out`. The forward pass is given by: y = σ.(W * x .+ b) -The input `x` must be a vector of length `in`, or a batch of vectors represented -as an `in × N` matrix. The out `y` will be a vector or batch of length `out`. +The input `x` must be a vector of length `in`, a batch of vectors represented +as an `in × N` matrix, or a higher order tensor where all dimensions +after the first one will be treated as batch dimensions. -Setting `bias` to `false` will switch bias off for the layer. +The out `y` will be a vector of length `out` or a batch whose first +dimension is `out` and the remaining dimensions are the same as in the input. + +Setting `bias` to `false` will switch the bias off for the layer. # Example ``` @@ -125,7 +130,11 @@ end function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ - σ.(W*x .+ b) + # reshape to handle dims > 1 as batch dimensions + sz = size(x) + x = reshape(x, sz[1], :) + x = σ.(W*x .+ b) + return reshape(x, :, sz[2:end]...) end function Base.show(io::IO, l::Dense) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 6b082f7454..40afee5668 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -35,17 +35,23 @@ import Flux: activations @test_throws MethodError Dense(10, 10.5) @test_throws MethodError Dense(10, 10.5, tanh) end - - @test length(Dense(10, 5)(randn(10))) == 5 - @test_throws DimensionMismatch Dense(10, 5)(randn(1)) - @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting - @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting - - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) - @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) - @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] - @test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @testset "dimensions" begin + @test length(Dense(10, 5)(randn(10))) == 5 + @test_throws DimensionMismatch Dense(10, 5)(randn(1)) + @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting + @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting + @test size(Dense(10, 5)(randn(10))) == (5,) + @test size(Dense(10, 5)(randn(10,2))) == (5,2) + @test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3) + @test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4) + end + @testset "zeros" begin + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) + @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) + @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + end end @testset "Diagonal" begin