ONNX-GoMLX converts ONNX models (.onnx suffix) to
GoMLX (an accelerated machine learning framework for Go.
One can also fine-tune models with GoMLX and then save back its weights to the ONNX model.
Future plans also include creating an ONNX backend for GoMLX: so one can execute or export GoMLX models using ONNX.
The main use cases so far are:
- Fine-tuning: import an inference-only ONNX model to GoMLX and use its auto-differentiation and training loop to fine-tune models. It allows saving the fine-tuned model as a GoMLX checkpoint or exporting the fine-tuned weights back to the ONNX model. This can also be used to expand / combine models.
- Inference: use an ONNX file using Go and not having to include ONNX Runtime (or Python)
-- at the cost of including XLA/PJRT (the current only backend for GoMLX). It also allows one to extend the
model with extra ML pre-/post-processing using GoMLX (image transformations, normalization, combining models,
building ensembles, etc.). This may be interesting for large/expensive models, or large throughput on large
batches.
- Notice if you want to simply get a pure Go inference of ONNX models, see
github.com/AdvancedClimateSystems/gonnx or
github.com/oramasearch/onnx-go. They will be slower (~8x based on a SentenceEncoder model, BERT based using
gonnxvsONNXRuntime) than the XLA inference (oronnxruntime) for large projects, but for many use cases it doesn't matter, and they are a much smaller pure Go dependency. Only for CPU (no GPU support).
- Notice if you want to simply get a pure Go inference of ONNX models, see
github.com/AdvancedClimateSystems/gonnx or
github.com/oramasearch/onnx-go. They will be slower (~8x based on a SentenceEncoder model, BERT based using
There are at least some 20 or so models that are working so far, but the list is growing.
- Sentence Encoding all-MiniLM-L6-v2 has been working perfectly, see example below.
- ONNX-GoMLX demo/development notebook: both serve as a functional test and to demo what it can do.
But not all operations ("ops") are converted yet. If you try it and find some operation that is not converted, please let us know (create an "issue") we will be happy to try to convert them. Generally, all the required scaffolding and tooling are already there, and converting ops has been straightforward.
We download (and cache) the all-MiniLM-L6-v2
using github.com/gomlx/go-huggingface.
The tokens in the example are hardcoded for simplicity. See github.com/gomlx/go-huggingface for tokenization for various models.
import (
"github.com/gomlx/onnx-gomlx/onnx"
"github.com/gomlx/go-huggingface/hub"
)
...
// Download and cache the ONNX model from HuggingFace.
hfAuthToken := os.Getenv("HF_TOKEN")
hfModelID := "sentence-transformers/all-MiniLM-L6-v2"
repo := hub.New(modelID).WithAuth(hfAuthToken)
modelPath := must.M1(repo.DownloadFile("onnx/model.onnx"))
// Parse ONNX model.
model := must.M1(onnx.ReadFile(modelPath))
// Convert ONNX variables (model weights) to GoMLX Context -- which stores variables and can be checkpoint (saved):
ctx := context.New()
must.M(model.VariablesToContext(ctx))
// Execute it with GoMLX/XLA:
sentences := []string{
"This is an example sentence",
"Each sentence is converted"}
//... tokenize ...
inputIDs := [][]int64{
{101, 2023, 2003, 2019, 2742, 6251, 102},
{ 101, 2169, 6251, 2003, 4991, 102, 0}}
tokenTypeIDs := [][]int64{
{0, 0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0}}
attentionMask := [][]int64{
{1, 1, 1, 1, 1, 1, 1},
{1, 1, 1, 1, 1, 1, 0}}
var embeddings []*tensors.Tensor
embeddings = context.MustExecOnceN( // Execute a GoMLX computation graph with a context
backends.New(), // GoMLX backend to use (defaults to XLA)
ctx, // Context store the model variables/weights and optional hyperparameters.
func (ctx *context.Context, inputs []*Node) []*Node {
// Convert ONNX model (in `model`) to a GoMLX computation graph. It returns a slice of values (with only one for this model)
return model.CallGraph(ctx, inputs[0].Graph(), map[string]*Node{
"input_ids": inputs[0],
"attention_mask": inputs[1],
"token_type_ids": inputs[2]}, targetOutputs...)
},
inputIDs, attentionMask, tokenTypeIDs) // Inputs to the GoMLX function.
fmt.Printf("Embeddings: %s", embeddings)The output looks like:
Embeddings: [2][7][384]float32{
{{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
{-0.0200, -0.0014, -0.0177, ..., 0.0204, 0.0522, 0.1991},
{-0.0196, -0.0336, -0.0319, ..., 0.0203, 0.0709, 0.0644},
...,
{-0.0253, 0.0408, 0.0125, ..., -0.0270, 0.0377, 0.1133},
{-0.0140, -0.0275, 0.0796, ..., -0.0748, 0.0774, -0.0657},
{0.0318, -0.0032, -0.0210, ..., 0.0387, 0.0191, -0.0059}},
{{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
{0.0304, 0.0531, -0.0238, ..., -0.1011, 0.0218, 0.0473},
{-0.0027, -0.0508, 0.0805, ..., -0.0777, 0.0881, -0.0560},
...,
{0.0928, 0.0165, -0.0976, ..., 0.0449, 0.0390, -0.0182},
{0.0231, 0.0090, -0.0213, ..., 0.0232, 0.0191, -0.0066},
{-0.0213, 0.0019, 0.0043, ..., 0.0561, 0.0170, 0.0256}}}
- Extract the ONNX model's weight to GoMLX
Context: seeModel.VariablesToContext(). - Use
Model.CallGraph()in your GoMLX model function (see example just above). - Train model as usual in GoMLX.
- Depending on how you are going to use the model:
- Save the model as a GoMLX checkpoint, as usual.
- Save the model by updating the ONNX model: after training use
Model.ContextToONNX()to copy the update variable
values from GoMLXContextback to the ONNX model (in-memory), and then useModel.Write()orModel.SaveToFile()to save the updated ONNX model to disk.
We have some GoMLX/XLA and ONNX Runtime (Microsoft) benchmarks in this spreadsheet, tested on the sentence encoder model we were interested in. This was used during development and reflects how it improves performance—the numbers on the bottom of the sheets are the currently accurate.
See docs/benchmarks.md for more information.
Collaboration is very welcome: either in the form of code or simply with ideas with real applicability. Don't hesitate to start a discussion or issue in the repository.
You can find the author and other interested parties in the Slack channel #gomlx (you can join the slack server here).
If you are interested, we have two notebooks we use to compare results. They are a good starting point for anyone curious:
- This project was born from brainstorming with the talented folks at KnightAnalytics. Without their insight and enthusiasm this wouldn't have gotten off the ground.
- ONNX models is such a nice open source standard to communicate models across different implementations.
- OpenXLA/XLA the open-source backend engine by Google that powers this implementation.
- Sources of inspiration: