Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,67 @@ defmodule Axon do

See `Axon.Updates` and `Axon.Loop` for a more in-depth treatment of
model optimization and model training.

## Using with `Nx.Serving`

When deploying an `Axon` model to production, you usually want to batch
multiple prediction requests and run the inference for all of them at
once. Conveniently, `Nx` already has an abstraction for this task in the
form of `Nx.Serving`. Here's how you could define a serving for an `Axon`
model:

def build_serving() do
# Configuration
batch_size = 4
defn_options = [compiler: EXLA]

Nx.Serving.new(
# This function runs on the serving startup
fn ->
# Build the Axon model and load params (usually from file)
model = build_model()
params = load_params()

# Build the prediction defn function
{_init_fun, predict_fun} = Axon.build(model)

inputs_template = %{"pixel_values" => Nx.template({batch_size, 224, 224, 3}, :f32)}
template_args = [Nx.to_template(params), inputs_template]

# Compile the prediction function upfront for the configured batch_size
predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)

# The returned function is called for every accumulated batch
fn inputs ->
inputs = Nx.Batch.pad(inputs, batch_size - inputs.size)
predict_fun.(params, inputs)
end
end,
batch_size: batch_size
)
end

Then you would start the serving server as part of your application's
supervision tree:

children = [
...,
{Nx.Serving, serving: build_serving(), name: MyApp.Serving, batch_timeout: 100}
]

With that in place, you can now ask serving for predictions all across
your application (controllers, live views, async jobs, etc.). Having a
tensor input you would do:

inputs = %{"pixel_values" => ...}
batch = Nx.Batch.concatenate([inputs])
result = Nx.Serving.batched_run(MyApp.Serving, batch)

Usually you also want to do pre/post-processing of the model input/output.
You could make those preparations directly before/after `Nx.Serving.batched_run/2`,
however you can also make use of `Nx.Serving.client_preprocessing/2` and
`Nx.Serving.client_postprocessing/2` to encapsulate that logic as part of
the serving.
"""
alias __MODULE__, as: Axon
alias Axon.Parameter
Expand Down