This is the second post of my frontend project, where we execute an ONNX model in React. We use an existing ONNX model and run it inside a React app. Thus, the model runs only on the client-side. We use the app from the first blog post, which you can find here. However, I try to make this post as encapsulated as possible and will only explain how I integrated the model into my current app at the end. You can test the app here.

Setup
Regarding the setup, I expect you to have a ReactJS app running with at least Material-UI (MUI) installed. If you need to set up a new system or have any questions about the setup, look up this post, where I explain the single steps of creating a React app.
In addition to the basic steps, we need something to execute neural networks. Therefore, we use ONNXRuntime, which provides a library for JS. In addition to that, we need JIMP for preprocessing images. Just type yarn add onnxruntime-web jimp
in your top-level folder of the app, and you are good to go.
Catching JIMP import error
The installation should work fine. However, if you import jimp in your app, you will most likely see an error like this: "Can't resolve 'util' in '/.../node_modules/pngjs/lib'"
. This is due to a breaking change from webpack 4 to 5. Thus, we could eliminate the error by overriding the webpack config. Unfourtantetly, this is not straightforward using create-react-app
, and we would use something like craco to rewrite the webpack config. But it is easier to downgrade react-scripts
to major version four: yarn upgrade react-scripts@4.0.3
, which also eliminates the error.
ONNX models in React in one component
Before we dive into the code, let’s look at the structure of the component we write. The best way to add an inference functionality is to create a single component that executes inference when the user triggers it. Therefore, we create a component called <InferenceMenu>
. This component is the only user interface. Besides that, we wrap the inference logic inside a javascript class. We create a base class from which we can inherit a specific model. Every new model should then inherit from that base class, called InferenceBase
.
Besides the main component and class, we need some Data Science utils for pre-and postprocessing. These will all live inside our util folder.
Inference base class
For our models, we use an object-oriented approach with one base class. The class needs three input parameters: modelSrc
, dims,
and executionProvider.
The first one is an URL to the binary of the model, usually a .onnx file. The dims are an array of numbers that define the input Tensor’s dimension. We later implement TinyYoloV2, which has dims of [1, 3, 416, 416]
. The last two numbers represent width and height, respectively, the first one the batch size and the second the number of channels (RGB in this case). The executionProvider
is a specific prop to ONNXRuntime. For ONNXRuntime Web, we can choose between webgl
and wasm
(WebAssembly).
If you like a more function-based style, check out the official tutorial, which I also used as the basis for my code.
WebGL integrates easily with create-react-app. However, WebAssembly needs the “.wasm” files with a specific MIME-Type (application/wasm). Unforutanetly, webpack sends them with the MIME-Type text/plain, which leds to an error during loading. This can be fixed by changing the webpack config, but this is not intended, when you use create-react-app. Again you could solve it whith craco, but for now we stick to WebGL. However, WebGL has not yet all ONNX operators implemented.
Besides these public props, we also have some private ones. We have two booleans that store the loaded
and error
state. Another private property stores the InferenceSession
object itself (session
). The last private property is modelFile
, which stores the binaries of the .onnx
file. With this set, we can look at the constructor of the class:
export default class InferenceBase { modelSrc: string; modelFile: ArrayBuffer; executionProvider: string[]; dims: number[]; session: InferenceSession | undefined; loaded: boolean; error: boolean; constructor(modelSrc: string, dims: number[], executionProvider: string[]) { this.modelSrc = modelSrc; this.dims = dims; this.executionProvider = executionProvider; this.modelFile = new ArrayBuffer(0); this.session = undefined; this.loaded = false; this.error = false; }
Since JavaScript is not a “first-class OOP” language, there is no direct way to create an abstract method. Thus, we implement an abstract function by implementing them and raising an error on invocation. The error looks like this:
export class NotImplementedError extends Error { constructor(message: string) { super(message); this.name = "NotImplementedError"; } }
As abstract methods, we implement a preprocess
and postprocess
:
// Implemented functions preprocess(jimpImage: Jimp): Jimp { throw new NotImplementedError("The fn preprocess needs to be implemented"); } postprocess( output: TypedTensor<"float32">, inputImage: Jimp ): anyAnnoationObject[] { throw new NotImplementedError("The fn postprocess needs to be implemented"); }
The preprocess
function expects a Jimp
image as input and returns one as output. The postprocess
function expects a TypedTensor
as input, which is Float32Array
with some extra information about the tensor’s shape and returns the annotations. A new model should implement these functions as they differ vastly between models.
Class utils
Besides our abstract methods, we need some util functions. The loadModel
function fetches the .onnx
model from the server and instantiates the ONNXRuntime InferenceSession:
async loadModel() { let modelResponse = await fetch(this.modelSrc); this.modelFile = await modelResponse.arrayBuffer(); console.log("Loading model file:"); console.log(this.modelFile); this.session = await InferenceSession.create(this.modelFile, { executionProviders: this.executionProvider, graphOptimizationLevel: "all", }); }
The following function is called loadImageToJimp
, and it takes an image URL as an input and loads it into and JIMP image:
async loadImageToJimp(src: string) { var imageData = await Jimp.default.read(src).then((imageBuffer: Jimp) => { return imageBuffer; }); return imageData; }
The last function converts a given Jimp
image to an ort
(ONNXRuntime) Tensor by converting each color channel of the image to one array and converting these arrays to a Float32Array
with a util function. We use the Float32Array
as input for our ort.Tensor
constructor:
imageToTensor(imageData: Jimp) { const [redArray, greenArray, blueArray] = jimpToImageArrays(imageData); const float32Data = arraysToFloat32Data( redArray, greenArray, blueArray, this.dims ); const inputTensor: TypedTensor<"float32"> = new Tensor( "float32", float32Data, this.dims ); return inputTensor; }
The functions jimpToImageArrays
and arraysToFloat32Data
are imported from ./src/util/imageProcess.ts. The first one creates three empty arrays and fills them incrementally by looping over the Jimp.bitmap.data
object, which contains the pixel values. This function is needed to load the image and split each color channel into a separate array in preparation for constructing an ort.Tensor
:
export const jimpToImageArrays = (image: Jimp) => { var imageBufferData = image.bitmap.data; const [redArray, greenArray, blueArray] = [ new Array<number>(), new Array<number>(), new Array<number>(), ]; for (let i = 0; i < imageBufferData.length; i += 4) { redArray.push(imageBufferData[i]); greenArray.push(imageBufferData[i + 1]); blueArray.push(imageBufferData[i + 2]); } return [redArray, greenArray, blueArray]; };
The second function (arraysToFloat32Data
) concatenates the three arrays and feeds every value inside a Float32Array
to convert the data type. Furthermore, the array is filled correctly to comply with the input format for an ort.Tensor
:
export const arraysToFloat32Data = ( redArray: number[], blueArray: number[], greenArray: number[], dims: number[] ) => { const transposedData = redArray.concat(greenArray).concat(blueArray); let i, l = transposedData.length; const float32Data = new Float32Array(dims[1] * dims[2] * dims[3]); for (i = 0; i < l; i++) { float32Data[i] = transposedData[i]; } return float32Data; };
The inference function
The function inference
is the main entry point of the class. On invocation, it resets the error state and creates the variable processedOutput
. This is needed at the top level because we assign it later in a try-catch-clause and return it at the top level. Afterward, we check if the model is loaded already. If not, the function throws this error:
export class ModelStillLoadingError extends Error { constructor(message: string) { super(message); this.name = "ModelStillLoading"; } }
Next, the function enters a try-block
, which executes the model. The first three lines use the above functions and convert an image URL to an ort.Tensor
. After that, we create the object feeds
, which is used as input for the ONNX model. This object has a string key, which corresponds to the input name of the ONNX model and as a value an ort.Tensor
. We insert our created tensor in this feeds
and retrieve the name from our inferenceSession object with this.session.inputNames[0]
. The default computer vision model has only one input. Thus, we can retrieve the name by index. The way you use onnxRuntime is very similar to the python API, which I think is great for development.
To execute inference
, we use the run
method from the InferenceSession
object and provide the feeds
as the input. This function is async, and thus we await. The output is again an ort.Tensor
, which is processed by the postprocess
function and returned as a Promise. The promise contains a list of annotations if it succeeds.
// Standard functions async inference(src: string) { this.error = false; let processedOutput: anyAnnoationObject[]; if (this.session === undefined) { throw new ModelStillLoadingError( `Model ${this.modelSrc} is still loading` ); } try { const inputImage = await this.loadImageToJimp(src); const preprocessedImage = this.preprocess(inputImage.clone()); const inputTensor = this.imageToTensor(preprocessedImage); const feeds: Record<string, TypedTensor<"float32">> = {}; feeds[this.session.inputNames[0]] = inputTensor; console.log(`Loaded this feeds:`); console.log(feeds); const outputData = await this.session.run(feeds); const output: any = outputData[this.session.outputNames[0]]; processedOutput = this.postprocess(output, inputImage); console.log("Inference result"); console.log(processedOutput); } catch (e) { console.log(e); this.error = true; } return new Promise<anyAnnoationObject[]>((resolve, reject) => { this.error ? reject([]) : resolve(processedOutput); }); }
TinyYoloV2
Now that we know our base inference class, we can add new models by implementing the pre-and postprocess method and the constructor. We use the tinyYoloV2 ONNX model for our react app. This section will be more complicated as we need to execute mathematical postprocessing in pure JS because JS has no lightweight libraries for computer vision. Although we could use TensorflowJS for this task, I wanted to do this by hand to reduce the dependencies in this app.
I will not explain how YoloV2 works in great detail because this would be a single post for itself. If you are not familiar with Yolo check out the original publication. I mostly used this post to give you some hints for implementation in JS. The code for the tinyYoloV2 model lives here.
Constructor
The constructor adds four new properties to the class. Three of them (classes
, classesColors
and anchors
) are simply lookups for postprocessing. The fourth property is the model name, which is added to each annotation later. The anchors
are special for YOLO and depict the initial sizes of a bounding box. For modelSrc
, we save the model to our public folder and fetch it from there. For a development process, this is good enough.
class TinyYoloV2 extends InferenceBase { classes: string[]; anchors: number[]; modelName: string; classesColor: string[]; constructor() { const modelSrc = process.env.PUBLIC_URL + "/tinyyolov2-8.onnx"; const dims = [1, 3, 416, 416]; const executionProvider = ["webgl"]; super(modelSrc, dims, executionProvider); this.modelName = "Tiny Yolo V2"; this.classes = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ]; this.classesColor = [ "#e6194b", "#3cb44b", "#ffe119", "#4363d8", "#f58231", "#911eb4", "#46f0f0", "#f032e6", "#bcf60c", "#fabebe", "#008080", "#e6beff", "#9a6324", "#fffac8", "#800000", "#aaffc3", "#808000", "#ffd8b1", "#000075", "#109101", "#ffffff", "#000000", ]; this.anchors = [ 1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52, ]; }
Preporcess
Most computer vision networks work best when the aspect ratio is preserved. Thus resizing and padding is the default way to resize images. It takes the longest side (height or width) and resizes it to the target dimension. Afterward, we pad the shortest size to comply with the target dimension. Luckily, Jimp comes with resizing method contain
, which does exactly the resize and pad (called letterbox in Jimp). For the postprocess, we need to calculate the paddings manually, which is not a big deal. And this concludes the preprocessing. TinyYoloV2 expects the original values and does not require normalization.
preprocess(jimpImage: Jimp): Jimp { return jimpImage.contain(this.dims[2], this.dims[3]); }
Postprocess
Postprocessing is much more complicated for TinyYoloV2 than preprocessing because many high-level math and tensor operations are unavailable in JS. Thus, we go through the code in more detail and explain how ONNXRuntime works and how you can work with it. Here is the code for that:
postprocess( output: TypedTensor<"float32">, inputImage: Jimp ): anyAnnoationObject[] { // Expected dims of output: [1, 125, 13, 13] console.time("Postprocess time"); // Prepare variables for loops and processing var [offset, gridX, gridY, stepId, bboxIdx] = [0, 0, 0, 0, 0]; const stepSize = 13 * 13; const paddings = calcPaddingAfterResize( [this.dims[2], this.dims[3]], [inputImage.getWidth(), inputImage.getHeight()] ); var outputBoxes: bboxAnnotationObject[] = []; // We will use the indexes to splice the FloatArray and extract all bboxes for (gridY; gridY < 13; gridY++) { gridX = 0; for (gridX; gridX < 13; gridX++) { var currentBboxList: number[] = []; stepId = 0; for (stepId; stepId < 125; stepId++) { currentBboxList.push(output.data[stepId * stepSize + offset]); } bboxIdx = 0; for (bboxIdx; bboxIdx < 5; bboxIdx++) { var currentBbox = currentBboxList.splice(0, 25); const currentC = sigmoid(currentBbox[4]); if (currentC > 0.3) { // Actual processing of bbox const softmaxOut = softmax( currentBbox.slice(5, currentBbox.length) ); const className = this.classes[softmaxOut.indexOf(Math.max(...softmaxOut))]; const classColor = this.classesColor[softmaxOut.indexOf(Math.max(...softmaxOut))]; const width = //@ts-ignore Math.exp(currentBbox[2]) * this.anchors[bboxIdx * 2] * 32; const height = //@ts-ignore Math.exp(currentBbox[3]) * this.anchors[bboxIdx * 2 + 1] * 32; var bbox = [ ((sigmoid(currentBbox[0]) + gridX) * 32 - width / 2) / this.dims[2], ((sigmoid(currentBbox[1]) + gridY) * 32 + height / 2) / this.dims[3], width / (this.dims[2] - paddings[0] - paddings[1]), height / (this.dims[3] - paddings[2] - paddings[3]), ]; bbox = [ Math.min(Math.max(bbox[0], 0), 1), Math.min(Math.max(bbox[1], 0), 1), Math.min(Math.max(bbox[2], 0), 1), Math.min(Math.max(bbox[3], 0), 1), ]; outputBoxes.push({ className: className, color: classColor, id: uuidv4(), score: currentC, model: this.modelName, type: "bbox", box: bbox, }); } } offset++; } } console.timeEnd("Postprocess time"); return nonMaximumSupression(outputBoxes); }
TinyYoloV2 output
The postprocess receives an ort.Tensor
with the shape of [1, 125, 13, 13]
. From this tensor, we can derive the bounding boxes. The first dimension represents the batch size, which is always one in our model. The third and fourth dimensions define the number grids in each image dimension. Yolo splits an image into a grid of size of 32×32 pixels and predicts bounding boxes for each grid according to its anchors.
The second dimension represents five bounding boxes. One bounding box is defined by x1
, y1
, width
, height
, score
, and the probability distribution over all classes. Thus the formula for the number of data points per bounding box is 5 + number_classes
. In our case, we have 20 classes resulting in an array with a length of 25. The number of anchors (in our case 5) defines the total number of possible bounding boxes per grid. Thus, we have five bounding boxes per grid, which means that we can calculate the number of values per grid as follows:25 values per bounding box * 5 anchors = 125 values
.
Accessing the correct values of the tensor
In Python, we could now select the array of 125 (5 bounding boxes) by indexing a NumPy array. For example, for the first grid cell: resutlArray[0, :, 0, 0]
. In JS we have to do this manually because the tensor is stored in a single array.
Before we process the output tensor from Yolo, we first need to define some variables. An ort.Tensor
is stored in an array with only meta-information about the tensor shape. If we want to access the specific parts of the tensor, we need to loop over the array to create subsets of the original array. Luckily the java API docs for ONNXRuntime specify the reshape order stating that the “array is stored in n-dimensional row-major order“. So it is very likely, that this is also true for the JS API. If you want to know more about row-major order, look up the wiki article, which explains it very well. In our case, that means that we loop from the last dimension to the first. We want to extract the 13*13 subsets of the array, which have a length of 125. Thus, we need to select every 169th (13*13) value of the main array. Therefore, we create the mutable variables offset
, gridX
, gridY
, stepId
and bboxIdx
, which are used in for-loops later and are zero at initialization. Besides this, we define stepSize
, which is simply 13*13 and corresponds to the fact that we want to select every 169th value. At last, we need to calculate the paddings in the image, which we use later to correct the bounding box. Therefore, we use the util function from ./src/util/imageProcess.ts called calcPaddingAfterResize
:
export const calcPaddingAfterResize = ( targetDims: number[], originalDims: number[] ) => { const resizeFactorX = targetDims[0] / originalDims[0]; const resizeFactorY = targetDims[1] / originalDims[1]; var paddings: number[]; if (resizeFactorX < resizeFactorY) { const fullPaddings = targetDims[0] - resizeFactorX * originalDims[1]; const addOdd = Math.floor(fullPaddings) % 2 === 0 ? 0 : 1; paddings = [ 0, 0, Math.floor(fullPaddings / 2), Math.floor(fullPaddings / 2) + addOdd, ]; } else { const fullPaddings = targetDims[1] - resizeFactorY * originalDims[0]; const addOdd = Math.floor(fullPaddings) % 2 === 0 ? 0 : 1; paddings = [ Math.floor(fullPaddings / 2), Math.floor(fullPaddings / 2) + addOdd, 0, 0, ]; } return paddings; };
The function determines the longest size of the image and calculates the rescale factor accordingly. Then we apply the rescale factor to the shorter size and calculate the difference to the target dimension, the paddings.
Extracting the bounding boxes
With all the static variables and the correct workflow in mind for extracting the bounding boxes, we can dive into the code in more detail. At first, we iterate over the number of grids on each dimension and reset the gridX
value for at the start of each iteration cycle of gridY
. We use gridX
and gridY
later for deriving the position of the bounding box :
for (gridY; gridY < 13; gridY++) { gridX = 0; for (gridX; gridX < 13; gridX++) { ... } }
Inside these two loops we create the variable currentBboxList
, which contains the 125 data points that define the five bounding boxes per grid. Furthermore, we set the variable stepId
to zero and start looping from zero to 124 to obtain all data points for currentBboxList
:
for (stepId; stepId < 125; stepId++) { currentBboxList.push(output.data[stepId * stepSize + offset]); }
The offset
is a running variable that we incremented at the end of the loop. It is basically the offset for indexing the main array (output.data
).
After successfully constructing the subset of length 125, we iterated over the bboxId
(which scales from 0 to 5). As mentioned before, each bounding box has a length of 25. thus we splice the first 25 items of the array out and therefore have our single bounding box:
bboxIdx = 0; for (bboxIdx; bboxIdx < 5; bboxIdx++) { var currentBbox = currentBboxList.splice(0, 25); ... }
Convert the bounding box to coordinates
Until now, all steps were necessary to access the correct subset of our output tensor. Now, we focus on converting YoloV2 output to bounding box coordinates in the image. Yolo is so fast because it predicts several bounding boxes and sorts out overlapping ones and bounding boxes with a low probability. This process is called non-maximum suppression. The default for the minimum probability is .3. Thus, we calculate the probability with softmax and only process bounding boxes that have a larger probability than .3:
const currentC = sigmoid(currentBbox[4]); if (currentC > 0.3) { ... }
We compute the softmax over the class distribution inside the if-clause to obtain the predicted class. Softmax is standard for data science. You can find my JS implementation under ./src/util/math.ts. We can identify the corresponding class name and color from the softmax output:
const softmaxOut = softmax( currentBbox.slice(5, currentBbox.length) ); const className = this.classes[softmaxOut.indexOf(Math.max(...softmaxOut))]; const classColor = this.classesColor[softmaxOut.indexOf(Math.max(...softmaxOut))];
Lastly, we can calculate the bounding box. Therefore, we multiply the bounding box height and width (after applying the exponential function) with its corresponding anchor and the grid’s width (32). If you read the preceding post about the React app, I stated that I like the FiftyOne approach to annotations, where you store points in relative values from 0 to 1. Thus, we must divide the x1, y1, height, and width by the corresponding image dimensions minus the paddings applied to the image.
To obtain the x1 and y1 values for a bounding box, we need to apply a sigmoid function to the output, add the current grid position, and multiply it with the grid height/width. Because these points depict the center of the box, we need to subtract half of the width/height of the bounding box.
const width = //@ts-ignore Math.exp(currentBbox[2]) * this.anchors[bboxIdx * 2] * 32; const height = //@ts-ignore Math.exp(currentBbox[3]) * this.anchors[bboxIdx * 2 + 1] * 32; var bbox = [ ((sigmoid(currentBbox[0]) + gridX) * 32 - width / 2) / this.dims[2], ((sigmoid(currentBbox[1]) + gridY) * 32 + height / 2) / this.dims[3], width / (this.dims[2] - paddings[0] - paddings[1]), height / (this.dims[3] - paddings[2] - paddings[3]), ];
The last step is to clip all values between 0 and 1. Then we construct an appropriate annotation object that contains all values to describe and display the annotation and execute non-maximum suppression.
Inference component for ONNX models in React
At last, we need a React component that enables a user to trigger our ONNX model. We call the component <InferenceMenu>
and it lives under ./src/components/InferenceMenue.tsx. As imports, we only use some basic MUI components, the useState
hook, types, the theme, and our TinyYoloV2 executor:
//MUI import Button from "@mui/material/Button"; import LinearProgress from "@mui/material/LinearProgress"; import Box from "@mui/material/Box"; import { useState } from "react"; //Custom import { tinyYoloV2Executor } from "../models/tinyYoloV2"; import { theme } from "../App"; import { anyAnnoationObject } from "../util/types";
As input for the component, we need a source (src
) and a function that receives the annotations and updates any state accordingly:
interface InferenceMenuProps { src: string; updateAnnotation: (annotations: anyAnnoationObject[]) => void; } export default function InferenceMenu({ src, updateAnnotation, }: InferenceMenuProps) {
The component receives three states: loading
, success
and inferenceTime
. All are self-explanatory and are used to make the inference process visual for the user. Besides the states, the component defines a function called inference
, which measures the time by setting a variable called startTime
, sets loading to true, and executes the promise from the tinyYoloV2Executor
. If the promise returns successful, loading is set false, the inference time is calculated, and the result is pushed to the provided function. In case of an error, we set loading and success to false. Obviously, it would be better for the user if he receives a <Snackbar>
message. However, for the scope of this tutorial, we skip this part.
const [loading, setLoading] = useState(false); const [success, setSuccess] = useState(false); const [inferenceTime, setInferenceTime] = useState<null | number>(null); const inference = () => { console.log("Start inference");. setLoading(true); const startTime = new Date() tinyYoloV2Executor .inference(src) .then((result) => { updateAnnotation(result); setInferenceTime(Math.round((new Date().getTime() - startTime.getTime()) / 100) / 10); setLoading(false); setSuccess(true); }) .catch((e) => { setLoading(false); setSuccess(false); }); };
Lastly, we define the UI. Therefore, we render a button. If the loading state is true
, we disable it to avoid several function calls of inference
at once. After a successful inference, we change the color of the button to green and write the inference time in seconds on the button as feedback. During loading, we also render an <LinearProgress>
MUI component that visualizes the loading state:
return ( <Box> <Button sx={{ background: success ? theme.palette.success.dark. : theme.palette.primary.main, }} variant="contained" disabled={loading} onClick={inference} > {inferenceTime ? `Inference (${inferenceTime} Sec)` : "Inference"} </Button> {loading && <LinearProgress />} </Box> );
Outlook
We successfully executed an ONNX model in React. Nice! However, we experienced some pitfalls of JS for computer vision along the way.
[…] ONNX model in React – Run AI in the browser […]