Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,50 +60,93 @@ end

extraChain(::Tuple{}, x) = ()


using Base: depwarn

"""
Dense(in::Integer, out::Integer, σ = identity)
Dense(in => out, σ) = Dense(in, out, σ)
Dense(W::AbstractMatrix, b, σ)

Creates a traditional `Dense` layer with parameters `W` and `b`.
Creates a traditional `Dense` layer with parameters `W` and `b`,
and by default `σ = identity`. This maps `x` to

y = σ.(W * x .+ b)

The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.

If `in` or `out` is a tuple of dimensions, then reshaping is inserted to allow input
with `size(x) == (in..., batch...)`, and produce output `size(y) == (out..., batch...)`.

Keyword `init = glorot_uniform` is the default function which generates `W` and `b`,
and giving `bias = false` will omit the parameter `b`.

```julia
julia> d = Dense(5, 2)
Dense(5, 2)
Dense(5 => 2)

julia> d(rand(5))
Tracked 2-element Array{Float64,1}:
julia> d(rand(Float32, 5))
2-element Array{Float32,1}:
0.00257447
-0.00449443

julia> d2 = Dense(5 => (2,2), tanh)
Dense(5 => (2, 2), tanh)

julia> size(d2(ones(5, 3, 7)))
(2, 2, 3, 7)
```
"""
struct Dense{F,S,T}
struct Dense{F,S,T,D}
W::S
b::T
σ::F
shapes::D
end

Dense(W, b) = Dense(W, b, identity)
Dense(W::AbstractMatrix, b, σ = identity) = Dense(W, b, σ, reverse(size(W)))

Dense(p::Pair, σ = identity; kw...) = Dense(p.first, p.second, σ; kw...)

function Dense(in::Union{Integer,Tuple}, out::Union{Integer,Tuple}, σ = identity;
init = glorot_uniform, bias = true,
initW = nothing, initb = nothing)

# depwarn as in https://github.com/FluxML/Flux.jl/pull/722
if initb === nothing
initb = init
else
depwarn("keyword argument `initb` is deprecated; use `init` or explicit `Dense(W,b)` to initialise", :Dense)
end

# optional bias as in https://github.com/FluxML/Flux.jl/issues/868
b = (bias === true) ? initb(prod(out)) : bias

if initW === nothing
W = init(prod(out), prod(in))
else
depwarn("keyword argument `initW` is deprecated; use `init` or explicit `Dense(W,b)` to initialise", :Dense)
W = initW(prod(out), prod(in))
end

function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Dense(initW(out, in), initb(out), σ)
return Dense(W, b, σ, (in, out))
end

@functor Dense
@functor Dense (W,b,σ)

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
if a.shapes isa Tuple{Integer, Integer} && x isa AbstractVecOrMat
return σ.(W*x .+ b)
else
in, out = a.shapes
xin = reshape(x, prod(ntuple(d -> size(x,d), length(in))), :)
y = σ.(W*xin .+ b)
return reshape(y, out..., ntuple(d -> size(x,length(in)+d), ndims(x)-length(in))...)
end
end

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
print(io, "Dense(", l.shapes[1], " => ", l.shapes[2])
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
Expand Down
8 changes: 8 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ import Flux: activations
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]

@test size(Dense(10 => 5)(randn(10))) == (5,)
@test size(Dense(10 => (5,))(randn(10))) == (5,)
@test size(Dense(10 => (5,))(randn(10,7))) == (5,7)
@test size(Dense(10 => (5,3))(randn(10,7))) == (5,3,7)
@test size(Dense((10,7) => 5)(randn(10,7))) == (5,)

@test Dense(10, (2,2), identity, initW=ones, initb=ones)(ones(10,3)) == 11 .* ones(2,2,3)
@test Dense((3,3) => (2,2), sqrt, initW=ones, initb=zeros)(ones(3,3,5)) == 3 .* ones(2,2,5)
end

@testset "Diagonal" begin
Expand Down