Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 209 additions & 72 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,52 @@ defmodule Axon do
@doc """
Trainable Axon parameter used to create custom layers.

Parameters are specified in usages of `Axon.layer` and will be
automatically initialized and used in subsequent applications of
Axon models.

You must specify a parameter "template" which can be a static template
tensor or a function which takes model input templates and returns a
template. It's most common to use functions because most parameters'
shapes rely on input shape information.
"""
@doc type: :special
def parameter(name, template, opts \\ [])

def parameter(name, %Nx.Tensor{} = template, opts) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
kind = opts[:kind] || :parameter

template = Nx.to_template(template)

%Axon.Parameter{
name: name,
template: template,
initializer: initializer,
kind: kind,
# Legacy
type: Nx.type(template),
shape: Nx.shape(template)
}
end

def parameter(name, function, opts) when is_function(function) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
kind = opts[:kind] || :parameter

%Axon.Parameter{
name: name,
template: function,
initializer: initializer,
kind: kind
}
end

@doc """
Trainable Axon parameter used to create custom layers.

Parameters are specified in usages of `Axon.layer` and will
be automatically initialized and used in subsequent applications
of Axon models.
Expand All @@ -421,36 +467,35 @@ defmodule Axon do
@doc type: :special
def param(name, shape, opts \\ [])

def param(name, {:map, [_ | _] = inner_params}, opts) do
maybe_warn_on_param_opts(opts)
def param(name, shape, opts) when is_binary(name) and is_tuple(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
{type, opts} = Keyword.pop(opts, :type, {:f, 32})

%Axon.Parameter{
name: name,
type: :map,
children: inner_params
}
template = Nx.template(shape, type)
parameter(name, template, opts)
end

def param(name, shape, opts) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
def param(name, shape, opts) when is_binary(name) and is_function(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}
kind = opts[:kind] || :parameter
{type, opts} = Keyword.pop(opts, :type, {:f, 32})

%Axon.Parameter{
name: name,
shape: shape,
type: type,
initializer: initializer,
kind: kind
}
{:arity, arity} = Function.info(shape, :arity)

template =
shape_fun(arity, fn templates ->
shapes = Enum.map(List.wrap(templates), &Nx.shape/1)
out_shape = apply(shape, shapes)
Nx.template(out_shape, type)
end)

parameter(name, template, opts)
end

defp maybe_warn_on_param_opts(opts) do
if :initializer in opts or :type in opts do
Logger.warning(
"Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
)
for i <- 0..128 do
args = Macro.generate_arguments(i, __MODULE__)

defp shape_fun(unquote(i), callback) do
fn unquote_splicing(args) -> callback.(unquote(args)) end
end
end

Expand Down Expand Up @@ -2583,25 +2628,63 @@ defmodule Axon do
activation = opts[:activation]
gate = opts[:gate]
unroll = opts[:unroll]

kernel_initializer = opts[:kernel_initializer]

input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end
input_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

wii = param("wii", input_kernel_shape, initializer: kernel_initializer)
wif = param("wif", input_kernel_shape, initializer: kernel_initializer)
wig = param("wig", input_kernel_shape, initializer: kernel_initializer)
wio = param("wio", input_kernel_shape, initializer: kernel_initializer)
hidden_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

bias_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

initializer = fn prefix, init ->
fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 4)

init =
if is_atom(init) do
apply(Axon.Initializers, init, [])
else
init
end

whi = param("whi", hidden_kernel_shape, initializer: kernel_initializer)
whf = param("whf", hidden_kernel_shape, initializer: kernel_initializer)
whg = param("whg", hidden_kernel_shape, initializer: kernel_initializer)
who = param("who", hidden_kernel_shape, initializer: kernel_initializer)
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"#{prefix}i" => fun.(split_key[0]),
"#{prefix}f" => fun.(split_key[1]),
"#{prefix}g" => fun.(split_key[2]),
"#{prefix}o" => fun.(split_key[3])
}
end
end

# Parameters
input_kernel = param("input_kernel", {:map, [wii, wif, wig, wio]})
hidden_kernel = param("hidden_kernel", {:map, [whi, whf, whg, who]})
input_kernel =
parameter("input_kernel", input_kernel_template,
initializer: initializer.("wi", kernel_initializer)
)

hidden_kernel =
parameter("hidden_kernel", hidden_kernel_template,
initializer: initializer.("wh", kernel_initializer)
)

hidden_state_name =
case opts[:name] do
Expand All @@ -2620,12 +2703,7 @@ defmodule Axon do
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]

bi = param("bi", bias_shape, initializer: bias_initializer)
bf = param("bf", bias_shape, initializer: bias_initializer)
bg = param("bg", bias_shape, initializer: bias_initializer)
bo = param("bo", bias_shape, initializer: bias_initializer)

bias = param("bias", {:map, [bi, bf, bg, bo]})
bias = parameter("bias", bias_template, initializer: initializer.("b", bias_initializer))

{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm}
else
Expand Down Expand Up @@ -2790,22 +2868,58 @@ defmodule Axon do
gate = opts[:gate]
unroll = opts[:unroll]

input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :gru) end
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :gru) end
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :gru) end
input_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

hidden_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

kernel_initializer = opts[:kernel_initializer]
bias_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

wir = param("wir", input_kernel_shape, initializer: kernel_initializer)
wiz = param("wiz", input_kernel_shape, initializer: kernel_initializer)
win = param("win", input_kernel_shape, initializer: kernel_initializer)
initializer = fn prefix, init ->
fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 3)

whr = param("whr", hidden_kernel_shape, initializer: kernel_initializer)
whz = param("whz", hidden_kernel_shape, initializer: kernel_initializer)
whn = param("whn", hidden_kernel_shape, initializer: kernel_initializer)
init =
if is_atom(init) do
apply(Axon.Initializers, init, [])
else
init
end

input_kernel = param("input_kernel", {:map, [wir, wiz, win]})
hidden_kernel = param("hidden_kernel", {:map, [whr, whz, whn]})
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"#{prefix}r" => fun.(split_key[0]),
"#{prefix}z" => fun.(split_key[1]),
"#{prefix}n" => fun.(split_key[2])
}
end
end

input_kernel =
parameter("input_kernel", input_kernel_template,
initializer: initializer.("wi", opts[:kernel_initializer])
)

hidden_kernel =
parameter("hidden_kernel", hidden_kernel_template,
initializer: initializer.("wh", opts[:kernel_initializer])
)

hidden_state_name =
case opts[:name] do
Expand All @@ -2822,14 +2936,34 @@ defmodule Axon do

inputs =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
bias_initializer = fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 4)

init =
if is_atom(opts[:bias_initializer]) do
apply(Axon.Initializers, opts[:bias_initializer], [])
else
opts[:bias_initializer]
end

br = param("br", bias_shape, initializer: bias_initializer)
bz = param("bz", bias_shape, initializer: bias_initializer)
bin = param("bin", bias_shape, initializer: bias_initializer)
bhn = param("bhn", bias_shape, initializer: bias_initializer)
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"br" => fun.(split_key[0]),
"bz" => fun.(split_key[1]),
"bin" => fun.(split_key[2]),
"bhn" => fun.(split_key[3])
}
end

bias = param("bias", {:map, [br, bz, bin, bhn]})
bias = parameter("bias", bias_template, initializer: bias_initializer)

[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias]
else
Expand Down Expand Up @@ -2983,23 +3117,26 @@ defmodule Axon do
unroll = opts[:unroll]
kernel_initializer = opts[:kernel_initializer]

hidden_kernel_shape = fn _, {inp, _}, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
hidden_kernel_template = fn _, {inp, _}, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

input_kernel_shape = fn inp, _, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
input_kernel_template = fn inp, _, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

bias_shape = fn inp, _, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
bias_template = fn inp, _, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

wi = param("input_kernel", input_kernel_shape, initializer: kernel_initializer)
wh = param("hidden_kernel", hidden_kernel_shape, initializer: kernel_initializer)
wi = parameter("input_kernel", input_kernel_template, initializer: kernel_initializer)
wh = parameter("hidden_kernel", hidden_kernel_template, initializer: kernel_initializer)

hidden_state_name =
case opts[:name] do
Expand All @@ -3017,7 +3154,7 @@ defmodule Axon do
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b = param("bias", bias_shape, initializer: bias_initializer)
b = parameter("bias", bias_template, initializer: bias_initializer)
{[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm}
else
{[x, hidden_state, opts[:mask], wi, wh], :conv_lstm}
Expand Down
Loading