ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python. ONNX Script is:
Note however that ONNX Script does not intend to support the entirety of the Python language.
ONNX Script provides a few major capabilities for authoring and debugging ONNX models and functions:
A converter which translates a Python ONNX Script function into an ONNX graph, accomplished by traversing the Python Abstract Syntax Tree to build an ONNX graph equivalent of the function.
A converter that operates inversely, translating ONNX models and functions into ONNX Script. This capability can be used to fully round-trip ONNX Script ↔ ONNX graph.
A runtime shim that allows such functions to be evaluated (in an "eager mode"). This functionality currently relies on ONNX Runtime for executing every ONNX Operator, and there is a Python-only reference runtime for ONNX underway that will also be supported.
Note that the runtime is intended to help understand and debug function definitions. Performance is not a goal here.
pip install --upgrade onnxscript
pip install onnx onnxruntime pytest
git clone https://github.com/microsoft/onnxscript
cd onnxscript
pip install -e .
pytest onnxscript
import onnx
# We use ONNX opset 15 to define the function below.
from onnxscript import FLOAT, script
from onnxscript import opset15 as op
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def onnx_hardmax(X, axis: int):
"""Hardmax is similar to ArgMax, with the result being encoded OneHot style."""
# The type annotation on X indicates that it is a float tensor of
# unknown rank. The type annotation on axis indicates that it will
# be treated as an int attribute in ONNX.
#
# Invoke ONNX opset 15 op ArgMax.
# Use unnamed arguments for ONNX input parameters, and named
# arguments for ONNX attribute parameters.
argmax = op.ArgMax(X, axis=axis, keepdims=False)
xshape = op.Shape(X, start=axis)
# use the Constant operator to create constant tensors
zero = op.Constant(value_ints=[0])
depth = op.GatherElements(xshape, zero)
empty_shape = op.Constant(value_ints=[0])
depth = op.Reshape(depth, empty_shape)
values = op.Constant(value_ints=[0, 1])
cast_values = op.CastLike(values, X)
return op.OneHot(argmax, depth, cast_values, axis=axis)
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
matmul = op.MatMul(X, Wt) + Bias
return onnx_hardmax(matmul, axis=1)
# onnx_model is an in-memory ModelProto
onnx_model = sample_model.to_model_proto()
# Save the ONNX model at a given path
onnx.save(onnx_model, "sample_model.onnx")
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print(f"The model is invalid: {e}")
else:
print("The model is valid!")
The decorator parses the code of the function, converting it into an intermediate representation. If it fails, it produces an error message indicating the line where the error was detected. If it succeeds, the intermediate representation can be converted into an ONNX graph structure of type FunctionProto
:
Hardmax.to_function_proto()
returns a FunctionProto
Eager mode is mostly used to debug and validate that intermediate results are as expected. The function defined above can be called as below, executing in an eager-evaluation mode:
import numpy as np
v = np.array([[0, 1], [2, 3]], dtype=np.float32)
result = Hardmax(v)
More examples can be found in the docs/examples directory.
Every change impacting the converter or the eager evaluation must be unit tested with class OnnxScriptTestCase
to ensure both systems do return the same results with the same inputs.
We use ruff
, black
, isort
, and mypy
etc. to check code formatting and use lintrunner
to run all linters. You can install the dependencies and initialize with
pip install lintrunner lintrunner-adapters
lintrunner init
This will install lintrunner on your system and download all the necessary dependencies to run linters locally. If you want to see what lintrunner init will install, run lintrunner init --dry-run
.
To lint local changes:
lintrunner
To format files:
lintrunner f
To lint all files:
lintrunner --all-files
Use --output oneline
to produce a compact list of lint errors, useful when there are many errors to fix.
See all available options with lintrunner -h
.
To read more about lintrunner, see wiki. To update an existing linting rule or create a new one, modify .lintrunner.toml
or create a new adapter following examples in https://github.com/justinchuby/lintrunner-adapters.
Download Details:
Author: microsoft
Official Github: https://github.com/microsoft/onnxscript
License: MIT