diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a46520818..28f7592d61 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -60,50 +60,93 @@ end extraChain(::Tuple{}, x) = () - +using Base: depwarn """ - Dense(in::Integer, out::Integer, σ = identity) + Dense(in => out, σ) = Dense(in, out, σ) + Dense(W::AbstractMatrix, b, σ) -Creates a traditional `Dense` layer with parameters `W` and `b`. +Creates a traditional `Dense` layer with parameters `W` and `b`, +and by default `σ = identity`. This maps `x` to y = σ.(W * x .+ b) The input `x` must be a vector of length `in`, or a batch of vectors represented as an `in × N` matrix. The out `y` will be a vector or batch of length `out`. +If `in` or `out` is a tuple of dimensions, then reshaping is inserted to allow input +with `size(x) == (in..., batch...)`, and produce output `size(y) == (out..., batch...)`. + +Keyword `init = glorot_uniform` is the default function which generates `W` and `b`, +and giving `bias = false` will omit the parameter `b`. + ```julia julia> d = Dense(5, 2) -Dense(5, 2) +Dense(5 => 2) -julia> d(rand(5)) -Tracked 2-element Array{Float64,1}: +julia> d(rand(Float32, 5)) +2-element Array{Float32,1}: 0.00257447 -0.00449443 + +julia> d2 = Dense(5 => (2,2), tanh) +Dense(5 => (2, 2), tanh) + +julia> size(d2(ones(5, 3, 7))) +(2, 2, 3, 7) ``` """ -struct Dense{F,S,T} +struct Dense{F,S,T,D} W::S b::T σ::F + shapes::D end -Dense(W, b) = Dense(W, b, identity) +Dense(W::AbstractMatrix, b, σ = identity) = Dense(W, b, σ, reverse(size(W))) + +Dense(p::Pair, σ = identity; kw...) = Dense(p.first, p.second, σ; kw...) + +function Dense(in::Union{Integer,Tuple}, out::Union{Integer,Tuple}, σ = identity; + init = glorot_uniform, bias = true, + initW = nothing, initb = nothing) + + # depwarn as in https://github.com/FluxML/Flux.jl/pull/722 + if initb === nothing + initb = init + else + depwarn("keyword argument `initb` is deprecated; use `init` or explicit `Dense(W,b)` to initialise", :Dense) + end + + # optional bias as in https://github.com/FluxML/Flux.jl/issues/868 + b = (bias === true) ? initb(prod(out)) : bias + + if initW === nothing + W = init(prod(out), prod(in)) + else + depwarn("keyword argument `initW` is deprecated; use `init` or explicit `Dense(W,b)` to initialise", :Dense) + W = initW(prod(out), prod(in)) + end -function Dense(in::Integer, out::Integer, σ = identity; - initW = glorot_uniform, initb = zeros) - return Dense(initW(out, in), initb(out), σ) + return Dense(W, b, σ, (in, out)) end -@functor Dense +@functor Dense (W,b,σ) function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ - σ.(W*x .+ b) + if a.shapes isa Tuple{Integer, Integer} && x isa AbstractVecOrMat + return σ.(W*x .+ b) + else + in, out = a.shapes + xin = reshape(x, prod(ntuple(d -> size(x,d), length(in))), :) + y = σ.(W*xin .+ b) + return reshape(y, out..., ntuple(d -> size(x,length(in)+d), ndims(x)-length(in))...) + end end function Base.show(io::IO, l::Dense) - print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) + print(io, "Dense(", l.shapes[1], " => ", l.shapes[2]) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0ff1776db8..da03304a17 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -38,6 +38,14 @@ import Flux: activations @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @test size(Dense(10 => 5)(randn(10))) == (5,) + @test size(Dense(10 => (5,))(randn(10))) == (5,) + @test size(Dense(10 => (5,))(randn(10,7))) == (5,7) + @test size(Dense(10 => (5,3))(randn(10,7))) == (5,3,7) + @test size(Dense((10,7) => 5)(randn(10,7))) == (5,) + + @test Dense(10, (2,2), identity, initW=ones, initb=ones)(ones(10,3)) == 11 .* ones(2,2,3) + @test Dense((3,3) => (2,2), sqrt, initW=ones, initb=zeros)(ones(3,3,5)) == 3 .* ones(2,2,5) end @testset "Diagonal" begin