Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405)
* Excise datasets in favour of other providers in the julia ecosystem.
* other new features and bug fixes (see GitHub's releases page)
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
Expand Down
19 changes: 14 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ extraChain(::Tuple{}, x) = ()
"""
Dense(in::Integer, out::Integer, σ = identity; bias=true)

Create a traditional `Dense` layer with parameters `W` and `b`.
Create a traditional `Dense` layer with in×out weight matrix `W` and
bias vector `b` of length `out`. The forward pass is given by:

y = σ.(W * x .+ b)

The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
The input `x` must be a vector of length `in`, a batch of vectors represented
as an `in × N` matrix, or a higher order tensor where all dimensions
after the first one will be treated as batch dimensions.

Setting `bias` to `false` will switch bias off for the layer.
The out `y` will be a vector of length `out` or a batch whose first
dimension is `out` and the remaining dimensions are the same as in the input.

Setting `bias` to `false` will switch the bias off for the layer.

# Example
```
Expand Down Expand Up @@ -125,7 +130,11 @@ end

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
# reshape to handle dims > 1 as batch dimensions
sz = size(x)
x = reshape(x, sz[1], :)
x = σ.(W*x .+ b)
return reshape(x, :, sz[2:end]...)
end

function Base.show(io::IO, l::Dense)
Expand Down
28 changes: 17 additions & 11 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,23 @@ import Flux: activations
@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
end

@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting

@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@testset "dimensions" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test size(Dense(10, 5)(randn(10))) == (5,)
@test size(Dense(10, 5)(randn(10,2))) == (5,2)
@test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3)
@test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4)
end
@testset "zeros" begin
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
@test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
end
end

@testset "Diagonal" begin
Expand Down