Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -837,15 +837,12 @@ defmodule Axon.Compiler do
# parameter map, so we just need to extract them and then apply
# freezing and dtype policy
parameter_inputs =
Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} ->
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
param = params[name][v]

cond do
param != nil and should_cast?(type, compute) ->
safe_as_type(maybe_freeze(param, frz), compute)

param != nil ->
maybe_freeze(param, frz)
safe_as_type(maybe_freeze(param, frz), compute)

true ->
raise ArgumentError,
Expand Down Expand Up @@ -936,8 +933,11 @@ defmodule Axon.Compiler do
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
%{stateful | output: out}

out ->
%Nx.Tensor{} = out ->
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})

out ->
out
end
rescue
exception ->
Expand Down Expand Up @@ -1082,17 +1082,23 @@ defmodule Axon.Compiler do
none

%Nx.Tensor{} = tensor ->
Nx.as_type(tensor, type)
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end

container ->
deep_new(container, &Nx.as_type(&1, type))
deep_new(container, fn tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end
end)
end
end

defp should_cast?(type1, type2) do
not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ defmodule Axon.Layers do
@doc type: :linear
defn embedding(input, kernel, _opts \\ []) do
assert_rank!("Axon.Layers.embedding", "kernel", kernel, 2)
Nx.take(kernel, Nx.as_type(input, {:s, 64}), axis: 0)
Nx.take(kernel, input, axis: 0)
end

## Shape
Expand Down
4 changes: 2 additions & 2 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ defmodule CompilerTest do
test "initializes in default case" do
model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding")

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand All @@ -615,7 +615,7 @@ defmodule CompilerTest do
Axon.input("input_0", shape: {nil, 1})
|> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros)

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model1)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand Down