Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405)
* Excise datasets in favour of other providers in the julia ecosystem.
* other new features and bug fixes (see GitHub's releases page)
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
Expand Down
19 changes: 14 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ extraChain(::Tuple{}, x) = ()
"""
Dense(in::Integer, out::Integer, σ = identity; bias=true)

Create a traditional `Dense` layer with parameters `W` and `b`.
Create a traditional `Dense` layer with in×out weight matrix `W` and
bias vector `b` of length `out`. The forward pass is given by:

y = σ.(W * x .+ b)

The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
The input `x` must be a vector of length `in`, a batch of vectors represented
as an `in × N` matrix, or a higher order tensor where all dimensions
after the first one will be treated as batch dimensions.

Setting `bias` to `false` will switch bias off for the layer.
The out `y` will be a vector of length `out` or a batch whose first
dimension is `out` and the remaining dimensions are the same as in the input.

Setting `bias` to `false` will switch the bias off for the layer.

# Example
```
Expand Down Expand Up @@ -125,7 +130,11 @@ end

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
# reshape to handle dims > 1 as batch dimensions
sz = size(x)
x = reshape(x, sz[1], :)
x = σ.(W*x .+ b)
return reshape(x, :, sz[2:end]...)
end

function Base.show(io::IO, l::Dense)
Expand Down
28 changes: 17 additions & 11 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,23 @@ import Flux: activations
@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
end

@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting

@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@testset "dimensions" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test size(Dense(10, 5)(randn(10))) == (5,)
@test size(Dense(10, 5)(randn(10,2))) == (5,2)
@test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3)
@test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4)
end
@testset "zeros" begin
@test Dense(10, 1, identity, initW=ones, initb=zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW=ones, initb=zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW=ones, initb=zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW=ones, initb=zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, initW=ones, bias=false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
end

@testset "Diagonal" begin
Expand Down