Export Segment Anything neural network to ONNX: the missing parts
Andrey Germanov
Posted on November 15, 2023
Table of Contents
Introduction
What is a problem ?
Diving to the SAM model structure
Export SAM to ONNX - the right way
Export the image encoder
Export the mask decoder
Produce image segmentation masks using ONNX
Preprocess input image
Generate embeddings from input image
Encode the prompt
Run the mask decoder
Post-process and visualize segmentation mask
Conclusion
Introduction
Hello all!
In this article, I am going to talk about Segment Anything - the neural network for instance segmentation, that can be used to segment any object from an image without knowing its type. However, this is not a tutorial on how to use it, because it already described in official repository and in other articles like this one. Here I will explain how to solve a problem with it, which is not described anywhere - the problem with export to ONNX function.
What is a problem ?
If you try to export the Segment Anything model to ONNX and then deploy it to production, using the guide in the official notebook, you'll see that you can't use only ONNX model that you exported, but you still need to use Segment Anything package with PyTorch to prepare embeddings from input image, and you still need to use a function from this package to encode the prompt.
When I experienced this for the first time, I've asked myself: "Why should I export the model to ONNX if I still need to use the original PyTorch model ?".
One of the main benefits of ONNX is the ability to run the model in environments without Python and PyTorch. However, according to official documentation, I can't do that with Segment Anything. Even with ONNX I need to install the whole PyTorch environment on my production server or device.
I was not alone with this problem, a lot of people asked for solution in forums or in the project GitHub, but there were no clear answers. Finally, I decided to dive to the Segment Anything source code myself and fill this gap.
In this article, I am going to show how to export a complete SAM model and how to segment the image using only ONNX model and without other heavy dependencies.
Diving to the SAM model structure
Before going to ONNX, let's understand the SAM model structure by using its official API.
The Segment Anything has a transformer neural network architecture and contains the following parts: image encoder, prompt encoder and mask decoder.
This picture from SAM official paper shows the segmentation mask inference process. Now let's see the code, that uses the official API, that implements this flow.
All code examples in this article use the following image, named cat_dog.jpg
that you can download here:
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import cv2
# 1. Load the image
img = cv2.imread("cat_dog.jpg")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
# 2. Load the Segment anything model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
# 3. Put the model to the SamPredictor helper object
predictor = SamPredictor(sam)
# 4. Encode the image to embeddings.
predictor.set_image(img)
# 5. Prepare the prompt
input_point = np.array([[321,230]])
input_label = np.array([1])
# 6. Decode masks
masks = predictor.predict(input_point, input_label)
Here is a breakdown of this flow:
- First it loads the image as a Numpy array of HWC shape (Height, Width, Channels) using OpenCV. You can do this using any other library like Pillow as well.
- Then it loads the SAM model to the
sam
variable. Thesam
is an object ofSam
class, defined in the sam.py file. This class contains both the image encoder and the mask decoder parts. If you open this file and see the__init__
constructor, you'll find there that the encoder initialized in theimage_encoder
property and the decoder initialized in themask_decoder
property. Both of them are standard PyTorch neural network modules. - Then, the code initializes the helper
SamPredictor
object, which used as a wrapper for createdSam
model. It contains helper methods to prepare the input image, encode the image to embeddings, encode the prompt and pass both them to themask_decoder
to get segmentation masks. - The most important line of the whole code is
predictor.set_image(img)
. This method used to preprocess input image and run the SAM encoder network with it. Under the hood, it runs the following line with preprocessed image:predictor.features = sam.image_encoder(input_image)
. This line passes the image through the encoder neural network to get embeddings and saves them to thefeatures
property of theSamPredictor
object. The official export to ONNX function does not export this neural network, so you still need to run this even if you use the exported ONNX model. - Then, you define the point on the image, that will be used as a prompt to decode segmentation mask and a label for this point: 1 means that the point belongs to the object that you want to extract, 0 means that the point does not belong to that object.
- Finally, you executed the
predictor.predict(input_point, input_label)
method. At this moment, the predictor encoded the prompt and passed both image embeddings, saved in thefeatures
property and the encodedprompt
to the mask decoder, which is asam.mask_decoder
neural network. Then this method returned the resulting output tensor, which then post-processed to return the masks.
This is how the official API works. The Segment Anything is actually two neural networks: image_encoder
and mask_decoder
that executed separately one by one. It runs the sam.image_encoder
network first to encode image to embeddings, and then it runs sam.mask_decoder
network to decode embeddings to masks, using prompt. Prompt also encoded, using the prompt encoder, but in many cases prompt can be encoded without neural network. However, when you export the sam
model to ONNX, it exports only the mask_decoder, and you still need to use the official API to prepare the image embeddings for the exported ONNX model and to encode the prompt.
Fortunately, the image_encoder is an ordinary PyTorch neural network module that you can export to ONNX yourself using the standard PyTorch feature, described here. The prompt also can be encoded using only Numpy. I will fill these gaps for you in the next sections.
Export SAM to ONNX - the right way
To use the Segment Anything network independently of PyTorch and/or Python, you need to export two models to ONNX: the image encoder and the mask decoder. Official documentation shows how to export only mask decoder. In this tutorial, I will show you how to export and use both parts and do not depend on PyTorch and SAM official API.
Export the image encoder
To export any PyTorch model to ONNX you need to know the shape of input tensor or tensors, that this model requires. The image encoder model, used in Segment Anything, is a modified encoder part of the ViT transformer neural network. It defined in the ImageEncoderViT
class in the image_encoder.py. By analyzing the source code of this file it's easy to understand that this neural network module requires the input tensor in the following shape (1,3,1024,1024), which is a batch of images of 1024x1024 size. So, to pass a single image to the image encoder, you need to encode it to the float tensor of this shape.
This is a full code to export the image encoder to ONNX. I assume that you'll run it in Jupyter Notebook:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install onnx
!pip install torch
from segment_anything import sam_model_registry
import torch
# Download SAM model checkpoint
!pip install wget
!python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
# Load SAM model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
# Export images encoder from SAM model to ONNX
torch.onnx.export(
f="vit_b_encoder.onnx",
model=sam.image_encoder,
args=torch.randn(1, 3, 1024, 1024),
input_names=["images"],
output_names=["embeddings"],
export_params=True
)
- This code installs and imports all required packages first. Perhaps you already have all them, but I included these lines in case if not.
- Then it downloads model weights and loads the
sam
model with them. I used the smallestVit-B
version, but you can replace it with 'Vit-L' or 'Vit-H' and download appropriate weights from here. - Finally, the standard
torch.onnx.export
function used to export thesam.image_encoder
to thevit_b_encoder.onnx
file. The resulting ONNX model has a single input, namedimages
, which accepts input tensors of (1,3,1024,1024) shape. Also, it will have a single output, namedembedddings
that will contain embeddings for the provided input image.
Great! After running this you'll have vit_b_encoder.onnx
file. The biggest part of export work is done!
Export the mask decoder
In this section I can only repeat the code, that already written in the official notebook. I modified it a little bit for consistency:
!pip3 install git+https://github.com/facebookresearch/segment-anything.git
!pip3 install onnx
!pip3 install torch
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import torch
# Download SAM model checkpoint
!pip install wget
!python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
# Load SAM model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
# Export masks decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
f="vit_b_decoder.onnx",
model=onnx_model,
args=tuple(dummy_inputs.values()),
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes={
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"}
},
export_params=True,
opset_version=17,
do_constant_folding=True
)
- This code installs and imports all required packages first. Perhaps you already have all them, but I included these lines in case if not.
- Then it downloads model weights and loads the
sam
model with them. I used the smallestVit-B
version, but you can replace it with 'Vit-L' or 'Vit-H' and download appropriate weights from here. - Finally, it uses the standard
torch.onnx.export
function to export thesam.mask_decoder
to thevit_b_decoder.onnx
file. The resulting ONNX model has six inputs. Most important of them are:image_embeddings
that will receive the output of thevit_b_encoder.onnx
model as image embeddings,point_coords
andpoint_masks
that will receive the encoded prompt. Also, the decoder model requiresorig_im_size
which is an original input image size as a Numpy array with two items:[height, width]
to correctly scale the resulted masks.
Wonderful! Now you have all parts in a puzzle:
-
vit_b_encoder.onnx
- to create the image embeddings -
vit_b_decoder.onnx
- to decode segmentation masks using the embeddings and the prompts.
For your convenience, I put all ONNX export code to the sam_onnx_export.ipynb notebook in the article's repository.
However, using these models without official API is a little bit complicated, because you need to preprocess input image and encode prompt on your own. There are no any documentation about these points. I will show how to do this in the next section.
Produce image segmentation masks using ONNX
To get segmentation masks for interested objects in your image using the ONNX models exported above, you need to do the following:
- Preprocess the input image
- Pass the preprocessed image to the
vit_b_encoder.onnx
model to generate image embeddings - Create a prompt and encode it
- Pass the image embeddings and prompt to the
vit_b_decoder.onnx
model and receive segmentation mask - Post-process the mask and optionally visualize it
In the next sections, I am going to implement these steps one by one. I assume that you will follow my code using Jupyter Notebook and that you have vit_b_encoder.onnx
and vit_b_decoder.onnx
file in the folder with your notebook. Also, in examples I will use the cat_dog.jpg
image, which you can download in the beginning of this article and place in the same folder.
Preprocess input image
As mentioned above, the encoder model requires the input tensor of the (1,3,1024,1024) size. Therefore, you need to correctly resize the input image to 1024x1024 preserving the aspect ratio, convert it to tensor of numbers and normalize this tensor.
Let's load the image first, you will use the Pillow
package for this:
!pip install Pillow
from PIL import Image
img = Image.open("cat_dog.jpg")
img = img.convert("RGB")
img.size
orig_width, orig_height = img.size
print(img.size)
(612, 415)
This code loaded the image, converted it to RGB and saves the original size, that you will need later.
Then, you need to resize this image preserving aspect ratio using 1024 as a long side. It means, that you need to set long side to 1024 and then, set short side to maintain aspect ratio. The following code can be used for this:
resized_width, resized_height = img.size
if orig_width > orig_height:
resized_width = 1024
resized_height = int(1024 / orig_width * orig_height)
else:
resized_height = 1024
resized_width = int(1024 / orig_height * orig_width)
img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
print(img.size)
(1024, 694)
So, this code determined which of the sides is longest and according to this, calculated the new size of shortest side. In this case the longest side is width, the shortest is height, and they scaled to (1024,694) and saved to resized_width
and resized_height
variables.
Then, you need to convert it to tensor. The Numpy allows doing this in a single line:
!pip install numpy
import numpy as np
input_tensor = np.array(img)
input_tensor.shape
(694, 1024, 3)
The input_tensor
contains three matrices of image pixels colors. First matrix contains red color components, second contains green color components and third - blue color components. Each color can be in a range from 0 to 255. However, Segment Anything model requires normalized numbers. To get a normalized number, you need to subtract mean color from each number and then divide it to standard deviation. There are different ways to calculate mean color and standard deviation, but Segment Anything package provides already calculated means and deviations for each color component. You need to initialize them:
mean = np.array([123.675, 116.28, 103.53])
std = np.array([[58.395, 57.12, 57.375]])
So, now you need to subtract 123.765 from each red color component and then divide it by 58.395. Similarly, you need to subtract 116.28 from each component of green color matrix and divide it by 57.12 and so on for blue. You can do all this in a single line of code using Numpy:
input_tensor = (input_tensor - mean) / std
Now you have normalized input tensor, but it has incorrect shape: (694, 1024, 3). You need to change it to the form of (1,color_channels,height,width)
. In this case it should be (1, 3, 694, 1024):
input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)
input_tensor.shape
(1, 3, 694, 1024)
The final step is to transform it to (1, 3, 1024, 1024). To do this, you need to pad the short side with zeros:
if resized_height < resized_width:
input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))
else:
input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))
input_tensor.shape
(1, 3, 1024, 1024)
The np.pad function receives the input tensor that need to pad with zeros and then, for each axis, it receives how many zeros to add before and after existing values. In this case, you need to add 1024-resized_height
rows of zeros to the end. If the shortest side was width, then this had to be done for the last axis.
That is it, now you have correct input_tensor
for the image encoder model.
Generate embeddings from input image
The first thing that need to do is to import the onnxruntime
library and load the vit_b_encoder.onnx
model using it:
!pip install onnxruntime
import onnxruntime as ort
encoder = ort.InferenceSession("vit_b_encoder.onnx")
Then, run the model with the input_tensor
as input images to generate embeddings:
outputs = encoder.run(None, {"images": input_tensor})
embeddings = outputs[0]
embeddings.shape
(1, 256, 64, 64)
If you remember, when export the image encoder to ONNX you specified that this model should have a single input named "images" and a single output named "embeddings". Here, you've passed the input_tensor as an "images" input. The run
method of ONNX model returns outputs as an array, even if the output is single. That is why, the embeddings
located in the first item of this array.
Great, now you have embeddings. This is the first input, that you will need for the mask decoder
model. The next input is prompt
which you also need to prepare.
Encode the prompt
The prompt helps to find segmentation mask of required object correctly. The prompt can be either a single point of image, that belongs to the object, or a bounding box around this object, or several points. To encode all those options, the Segment Anything uses a similar algorithm. Let's start with a single point:
input_point = np.array([[321,230]])
input_label = np.array([1])
In this code, you defined a point with x=321 and y=230. Also, you defined a label for this point, which is 1. This label means that the point belongs to the object. Using this definition, the mask decoder will try to find the segmentation mask for the object, that contains this point. However, you need to encode this point to a format, that mask decoder requires. Use next lines of code for this:
from copy import deepcopy
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])])[None, :].astype(np.float32)
coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
onnx_coord = coords.astype("float32")
onnx_coord
array([[[537.098 , 384.6265],
[ 0. , 0. ]]], dtype=float32)
The SAM mask decoder requires scaling the input point to 1024x1024 image size and convert it to the tensor of floats. Here I used the original_width
, original_height
, resized_width
and resized_height
of the image to scale the coordinates.
I won't give detail explanation of each line of this code, because I just reused it from the transform.apply_coords
function of the SAM source code with few modifications to make it more simple. It's just a requirement for mask decoder model.
If you need to send bounding box as a prompt, you can use similar code:
input_box = np.array([132, 157, 256, 325]).reshape(2,2)
input_labels = np.array([2,3])
onnx_coord = input_box[None, :, :]
onnx_label = input_labels[None, :].astype(np.float32)
coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
onnx_coord = coords.astype("float32")
onnx_coord
array([[[220.86275, 262.5494 ],
[428.33987, 543.49396]]], dtype=float32)
This code used to encode a prompt to get the mask for object located inside the box with top left corner at x=132,y=157 and bottom right corner at x=256,y=325.
If you want to encode a prompt, that contains both bounding box and point, you can use the following code:
input_box = np.array([132, 157, 256, 325]).reshape(2,2)
box_labels = np.array([2,3])
input_point = np.array([[140, 160]])
input_label = np.array([0])
onnx_coord = np.concatenate([input_point, input_box], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, box_labels], axis=0)[None, :].astype(np.float32)
coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
onnx_coord = coords.astype("float32")
onnx_coord
This code includes both input_box
and input_point
and labels for them. Notice that input_label
here contains 0, which means that the point (140,160) does not belong to the object, that you want to extract. This prompt will guide the model to segment the object, that located inside the (132,157,256,325) box, but not in (140,160) point.
You can construct very specific prompts to get desired results (just like with ChatGPT ;) ).
So, now you have correctly encoded onnx_coord
and onnx_label
to pass to the mask decoder. Let's do this right now.
Run the mask decoder
Now when you have the embeddings
, onnx_coord
and onnx_label
, nothing can stop you from running the mask decoder model to get the segmentation mask.
Let's load the model first:
decoder = ort.InferenceSession("vit_b_decoder.onnx")
and pass all encoded data to it:
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
outputs = decoder.run(None,{
"image_embeddings": embeddings,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})
masks = outputs[0]
masks.shape
(1, 1, 415, 612)
This code runs the model with encoded image_embeddings
, point_coords
and point_labels
. Also, I provided dummy masks to mask_input
and has_mask_input
and original image size to the orig_im_size
parameter.
The model returns 3 outputs, and the array of segmentation masks is the first of them. For the input image it returned the tensor of (1, 415, 612) shape which is a single channel segmentation mask.
The only step left is to post process it.
Post-process and visualize segmentation mask
The segmentation mask is an array of pixels, however, each pixel contains not a color but some number. If this number greater than 0, then this pixel belongs to object, otherwise not. So, to convert it to real pixel colors you can run the following code:
mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255
This code extracts the pixel matrix from the mask (415x612), converts all positive values to True and all negatives to False. Then it converts all numbers to 8-bit integers. After this, all True
values becomes 1 and all False
values become 0. Then, I multiplied the matrix by 255 to convert all True pixels to white color. Finally, you have a single channel black-white image, that can be easily visualized by many image libraries. For example, you can visualize it this way using the Pillow:
img = Image.fromarray(mask,'L')
img
Hooray! Now you can do Segment Anything image segmentation using only ONNX.
This is the end of our journey. You can find all source code of this section in the sam_onnx_inference.ipynb notebook in the repository.
Conclusion
In this article, I showed how to fill the gap in the official implementation of the Segment Anything Model's ONNX export function. Then I guided you how to do a prompt-based image segmentation using the exported ONNX models.
All source code you can find in this repository: https://github.com/AndreyGermanov/sam_onnx_full_export.
Here I used only Python, but now, with complete ONNX models you can do much more. You can run Segment Anything model on any programming language, supported by ONNX runtime. If you know the algorithm how to pre-process input and post-process output, you can integrate this model to most production systems, written in any programming language. For example, you can embed it to software written on C/C++, Go or Rust, or to websites written on JavaScript.
Thank you and until next time!
Follow me on LinkedIn, Twitter, and Facebook to know first about new articles like this one and other software development news.
Posted on November 15, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.