Export Segment Anything neural network to ONNX: the missing parts

andreygermanov

Andrey Germanov

Posted on November 15, 2023

Export Segment Anything neural network to ONNX: the missing parts

Table of Contents

Introduction
What is a problem ?
Diving to the SAM model structure
Export SAM to ONNX - the right way
    Export the image encoder
    Export the mask decoder
Produce image segmentation masks using ONNX
    Preprocess input image
    Generate embeddings from input image
    Encode the prompt
    Run the mask decoder
    Post-process and visualize segmentation mask
Conclusion

Introduction

Hello all!

In this article, I am going to talk about Segment Anything - the neural network for instance segmentation, that can be used to segment any object from an image without knowing its type. However, this is not a tutorial on how to use it, because it already described in official repository and in other articles like this one. Here I will explain how to solve a problem with it, which is not described anywhere - the problem with export to ONNX function.

What is a problem ?

If you try to export the Segment Anything model to ONNX and then deploy it to production, using the guide in the official notebook, you'll see that you can't use only ONNX model that you exported, but you still need to use Segment Anything package with PyTorch to prepare embeddings from input image, and you still need to use a function from this package to encode the prompt.

When I experienced this for the first time, I've asked myself: "Why should I export the model to ONNX if I still need to use the original PyTorch model ?".

One of the main benefits of ONNX is the ability to run the model in environments without Python and PyTorch. However, according to official documentation, I can't do that with Segment Anything. Even with ONNX I need to install the whole PyTorch environment on my production server or device.

I was not alone with this problem, a lot of people asked for solution in forums or in the project GitHub, but there were no clear answers. Finally, I decided to dive to the Segment Anything source code myself and fill this gap.

In this article, I am going to show how to export a complete SAM model and how to segment the image using only ONNX model and without other heavy dependencies.

Diving to the SAM model structure

Before going to ONNX, let's understand the SAM model structure by using its official API.

The Segment Anything has a transformer neural network architecture and contains the following parts: image encoder, prompt encoder and mask decoder.

Image description

This picture from SAM official paper shows the segmentation mask inference process. Now let's see the code, that uses the official API, that implements this flow.

All code examples in this article use the following image, named cat_dog.jpg that you can download here:

Image description



from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import cv2

# 1. Load the image
img = cv2.imread("cat_dog.jpg")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

# 2. Load the Segment anything model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")

# 3. Put the model to the SamPredictor helper object
predictor = SamPredictor(sam)

# 4. Encode the image to embeddings.
predictor.set_image(img)

# 5. Prepare the prompt
input_point = np.array([[321,230]])
input_label = np.array([1])

# 6. Decode masks
masks = predictor.predict(input_point, input_label)


Enter fullscreen mode Exit fullscreen mode

Here is a breakdown of this flow:

  1. First it loads the image as a Numpy array of HWC shape (Height, Width, Channels) using OpenCV. You can do this using any other library like Pillow as well.
  2. Then it loads the SAM model to the sam variable. The sam is an object of Sam class, defined in the sam.py file. This class contains both the image encoder and the mask decoder parts. If you open this file and see the __init__ constructor, you'll find there that the encoder initialized in the image_encoder property and the decoder initialized in the mask_decoder property. Both of them are standard PyTorch neural network modules.
  3. Then, the code initializes the helper SamPredictor object, which used as a wrapper for created Sam model. It contains helper methods to prepare the input image, encode the image to embeddings, encode the prompt and pass both them to the mask_decoder to get segmentation masks.
  4. The most important line of the whole code is predictor.set_image(img). This method used to preprocess input image and run the SAM encoder network with it. Under the hood, it runs the following line with preprocessed image: predictor.features = sam.image_encoder(input_image). This line passes the image through the encoder neural network to get embeddings and saves them to the features property of the SamPredictor object. The official export to ONNX function does not export this neural network, so you still need to run this even if you use the exported ONNX model.
  5. Then, you define the point on the image, that will be used as a prompt to decode segmentation mask and a label for this point: 1 means that the point belongs to the object that you want to extract, 0 means that the point does not belong to that object.
  6. Finally, you executed the predictor.predict(input_point, input_label) method. At this moment, the predictor encoded the prompt and passed both image embeddings, saved in the features property and the encoded prompt to the mask decoder, which is a sam.mask_decoder neural network. Then this method returned the resulting output tensor, which then post-processed to return the masks.

This is how the official API works. The Segment Anything is actually two neural networks: image_encoder and mask_decoder that executed separately one by one. It runs the sam.image_encoder network first to encode image to embeddings, and then it runs sam.mask_decoder network to decode embeddings to masks, using prompt. Prompt also encoded, using the prompt encoder, but in many cases prompt can be encoded without neural network. However, when you export the sam model to ONNX, it exports only the mask_decoder, and you still need to use the official API to prepare the image embeddings for the exported ONNX model and to encode the prompt.

Fortunately, the image_encoder is an ordinary PyTorch neural network module that you can export to ONNX yourself using the standard PyTorch feature, described here. The prompt also can be encoded using only Numpy. I will fill these gaps for you in the next sections.

Export SAM to ONNX - the right way

To use the Segment Anything network independently of PyTorch and/or Python, you need to export two models to ONNX: the image encoder and the mask decoder. Official documentation shows how to export only mask decoder. In this tutorial, I will show you how to export and use both parts and do not depend on PyTorch and SAM official API.

Export the image encoder

To export any PyTorch model to ONNX you need to know the shape of input tensor or tensors, that this model requires. The image encoder model, used in Segment Anything, is a modified encoder part of the ViT transformer neural network. It defined in the ImageEncoderViT class in the image_encoder.py. By analyzing the source code of this file it's easy to understand that this neural network module requires the input tensor in the following shape (1,3,1024,1024), which is a batch of images of 1024x1024 size. So, to pass a single image to the image encoder, you need to encode it to the float tensor of this shape.

This is a full code to export the image encoder to ONNX. I assume that you'll run it in Jupyter Notebook:



!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install onnx
!pip install torch

from segment_anything import sam_model_registry
import torch

# Download SAM model checkpoint
!pip install wget
!python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

# Load SAM model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")

# Export images encoder from SAM model to ONNX
torch.onnx.export(
    f="vit_b_encoder.onnx",
    model=sam.image_encoder,
    args=torch.randn(1, 3, 1024, 1024),
    input_names=["images"],
    output_names=["embeddings"],
    export_params=True
)


Enter fullscreen mode Exit fullscreen mode
  • This code installs and imports all required packages first. Perhaps you already have all them, but I included these lines in case if not.
  • Then it downloads model weights and loads the sam model with them. I used the smallest Vit-B version, but you can replace it with 'Vit-L' or 'Vit-H' and download appropriate weights from here.
  • Finally, the standard torch.onnx.export function used to export the sam.image_encoder to the vit_b_encoder.onnx file. The resulting ONNX model has a single input, named images, which accepts input tensors of (1,3,1024,1024) shape. Also, it will have a single output, named embedddings that will contain embeddings for the provided input image.

Great! After running this you'll have vit_b_encoder.onnx file. The biggest part of export work is done!

Export the mask decoder

In this section I can only repeat the code, that already written in the official notebook. I modified it a little bit for consistency:



!pip3 install git+https://github.com/facebookresearch/segment-anything.git
!pip3 install onnx
!pip3 install torch

from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import torch

# Download SAM model checkpoint
!pip install wget
!python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

# Load SAM model
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")

# Export masks decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
    f="vit_b_decoder.onnx",
    model=onnx_model,
    args=tuple(dummy_inputs.values()),
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_axes={
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"}
    },
    export_params=True,
    opset_version=17,
    do_constant_folding=True
)


Enter fullscreen mode Exit fullscreen mode
  • This code installs and imports all required packages first. Perhaps you already have all them, but I included these lines in case if not.
  • Then it downloads model weights and loads the sam model with them. I used the smallest Vit-B version, but you can replace it with 'Vit-L' or 'Vit-H' and download appropriate weights from here.
  • Finally, it uses the standard torch.onnx.export function to export the sam.mask_decoder to the vit_b_decoder.onnx file. The resulting ONNX model has six inputs. Most important of them are: image_embeddings that will receive the output of the vit_b_encoder.onnx model as image embeddings, point_coords and point_masks that will receive the encoded prompt. Also, the decoder model requires orig_im_size which is an original input image size as a Numpy array with two items: [height, width] to correctly scale the resulted masks.

Wonderful! Now you have all parts in a puzzle:

  • vit_b_encoder.onnx - to create the image embeddings
  • vit_b_decoder.onnx - to decode segmentation masks using the embeddings and the prompts.

For your convenience, I put all ONNX export code to the sam_onnx_export.ipynb notebook in the article's repository.

However, using these models without official API is a little bit complicated, because you need to preprocess input image and encode prompt on your own. There are no any documentation about these points. I will show how to do this in the next section.

Produce image segmentation masks using ONNX

To get segmentation masks for interested objects in your image using the ONNX models exported above, you need to do the following:

  • Preprocess the input image
  • Pass the preprocessed image to the vit_b_encoder.onnx model to generate image embeddings
  • Create a prompt and encode it
  • Pass the image embeddings and prompt to the vit_b_decoder.onnx model and receive segmentation mask
  • Post-process the mask and optionally visualize it

In the next sections, I am going to implement these steps one by one. I assume that you will follow my code using Jupyter Notebook and that you have vit_b_encoder.onnx and vit_b_decoder.onnx file in the folder with your notebook. Also, in examples I will use the cat_dog.jpg image, which you can download in the beginning of this article and place in the same folder.

Preprocess input image

As mentioned above, the encoder model requires the input tensor of the (1,3,1024,1024) size. Therefore, you need to correctly resize the input image to 1024x1024 preserving the aspect ratio, convert it to tensor of numbers and normalize this tensor.

Let's load the image first, you will use the Pillow package for this:



!pip install Pillow

from PIL import Image
img = Image.open("cat_dog.jpg")
img = img.convert("RGB")
img.size
orig_width, orig_height = img.size
print(img.size)


Enter fullscreen mode Exit fullscreen mode


(612, 415)


Enter fullscreen mode Exit fullscreen mode

This code loaded the image, converted it to RGB and saves the original size, that you will need later.

Then, you need to resize this image preserving aspect ratio using 1024 as a long side. It means, that you need to set long side to 1024 and then, set short side to maintain aspect ratio. The following code can be used for this:



resized_width, resized_height = img.size

if orig_width > orig_height:
    resized_width = 1024
    resized_height = int(1024 / orig_width * orig_height)
else:
    resized_height = 1024
    resized_width = int(1024 / orig_height * orig_width)

img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
print(img.size)


Enter fullscreen mode Exit fullscreen mode


(1024, 694)


Enter fullscreen mode Exit fullscreen mode

So, this code determined which of the sides is longest and according to this, calculated the new size of shortest side. In this case the longest side is width, the shortest is height, and they scaled to (1024,694) and saved to resized_width and resized_height variables.

Then, you need to convert it to tensor. The Numpy allows doing this in a single line:



!pip install numpy
import numpy as np
input_tensor = np.array(img)
input_tensor.shape


Enter fullscreen mode Exit fullscreen mode


(694, 1024, 3)


Enter fullscreen mode Exit fullscreen mode

The input_tensor contains three matrices of image pixels colors. First matrix contains red color components, second contains green color components and third - blue color components. Each color can be in a range from 0 to 255. However, Segment Anything model requires normalized numbers. To get a normalized number, you need to subtract mean color from each number and then divide it to standard deviation. There are different ways to calculate mean color and standard deviation, but Segment Anything package provides already calculated means and deviations for each color component. You need to initialize them:



mean = np.array([123.675, 116.28, 103.53])
std = np.array([[58.395, 57.12, 57.375]])


Enter fullscreen mode Exit fullscreen mode

So, now you need to subtract 123.765 from each red color component and then divide it by 58.395. Similarly, you need to subtract 116.28 from each component of green color matrix and divide it by 57.12 and so on for blue. You can do all this in a single line of code using Numpy:



input_tensor = (input_tensor - mean) / std


Enter fullscreen mode Exit fullscreen mode

Now you have normalized input tensor, but it has incorrect shape: (694, 1024, 3). You need to change it to the form of (1,color_channels,height,width). In this case it should be (1, 3, 694, 1024):



input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)
input_tensor.shape


Enter fullscreen mode Exit fullscreen mode


(1, 3, 694, 1024)


Enter fullscreen mode Exit fullscreen mode

The final step is to transform it to (1, 3, 1024, 1024). To do this, you need to pad the short side with zeros:



if resized_height < resized_width:
    input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))
else:
    input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))

input_tensor.shape


Enter fullscreen mode Exit fullscreen mode


(1, 3, 1024, 1024)


Enter fullscreen mode Exit fullscreen mode

The np.pad function receives the input tensor that need to pad with zeros and then, for each axis, it receives how many zeros to add before and after existing values. In this case, you need to add 1024-resized_height rows of zeros to the end. If the shortest side was width, then this had to be done for the last axis.

That is it, now you have correct input_tensor for the image encoder model.

Generate embeddings from input image

The first thing that need to do is to import the onnxruntime library and load the vit_b_encoder.onnx model using it:



!pip install onnxruntime
import onnxruntime as ort
encoder = ort.InferenceSession("vit_b_encoder.onnx")


Enter fullscreen mode Exit fullscreen mode

Then, run the model with the input_tensor as input images to generate embeddings:



outputs = encoder.run(None, {"images": input_tensor})
embeddings = outputs[0]
embeddings.shape


Enter fullscreen mode Exit fullscreen mode


(1, 256, 64, 64)


Enter fullscreen mode Exit fullscreen mode

If you remember, when export the image encoder to ONNX you specified that this model should have a single input named "images" and a single output named "embeddings". Here, you've passed the input_tensor as an "images" input. The run method of ONNX model returns outputs as an array, even if the output is single. That is why, the embeddings located in the first item of this array.

Great, now you have embeddings. This is the first input, that you will need for the mask decoder model. The next input is prompt which you also need to prepare.

Encode the prompt

The prompt helps to find segmentation mask of required object correctly. The prompt can be either a single point of image, that belongs to the object, or a bounding box around this object, or several points. To encode all those options, the Segment Anything uses a similar algorithm. Let's start with a single point:



input_point = np.array([[321,230]])
input_label = np.array([1])


Enter fullscreen mode Exit fullscreen mode

In this code, you defined a point with x=321 and y=230. Also, you defined a label for this point, which is 1. This label means that the point belongs to the object. Using this definition, the mask decoder will try to find the segmentation mask for the object, that contains this point. However, you need to encode this point to a format, that mask decoder requires. Use next lines of code for this:



from copy import deepcopy

onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])])[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)

onnx_coord = coords.astype("float32")
onnx_coord


Enter fullscreen mode Exit fullscreen mode


array([[[537.098 , 384.6265],
        [  0.    ,   0.    ]]], dtype=float32)


Enter fullscreen mode Exit fullscreen mode

The SAM mask decoder requires scaling the input point to 1024x1024 image size and convert it to the tensor of floats. Here I used the original_width, original_height, resized_width and resized_height of the image to scale the coordinates.

I won't give detail explanation of each line of this code, because I just reused it from the transform.apply_coords function of the SAM source code with few modifications to make it more simple. It's just a requirement for mask decoder model.

If you need to send bounding box as a prompt, you can use similar code:



input_box = np.array([132, 157, 256, 325]).reshape(2,2)
input_labels = np.array([2,3])

onnx_coord = input_box[None, :, :]
onnx_label = input_labels[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)

onnx_coord = coords.astype("float32")
onnx_coord


Enter fullscreen mode Exit fullscreen mode


array([[[220.86275, 262.5494 ],
        [428.33987, 543.49396]]], dtype=float32)


Enter fullscreen mode Exit fullscreen mode

This code used to encode a prompt to get the mask for object located inside the box with top left corner at x=132,y=157 and bottom right corner at x=256,y=325.

If you want to encode a prompt, that contains both bounding box and point, you can use the following code:



input_box = np.array([132, 157, 256, 325]).reshape(2,2)
box_labels = np.array([2,3])
input_point = np.array([[140, 160]])
input_label = np.array([0])

onnx_coord = np.concatenate([input_point, input_box], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, box_labels], axis=0)[None, :].astype(np.float32)

coords = deepcopy(onnx_coord).astype(float)
coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
coords[..., 1] = coords[..., 1] * (resized_height / orig_height)

onnx_coord = coords.astype("float32")
onnx_coord


Enter fullscreen mode Exit fullscreen mode

This code includes both input_box and input_point and labels for them. Notice that input_label here contains 0, which means that the point (140,160) does not belong to the object, that you want to extract. This prompt will guide the model to segment the object, that located inside the (132,157,256,325) box, but not in (140,160) point.

You can construct very specific prompts to get desired results (just like with ChatGPT ;) ).

So, now you have correctly encoded onnx_coord and onnx_label to pass to the mask decoder. Let's do this right now.

Run the mask decoder

Now when you have the embeddings, onnx_coord and onnx_label, nothing can stop you from running the mask decoder model to get the segmentation mask.

Let's load the model first:



decoder = ort.InferenceSession("vit_b_decoder.onnx")


Enter fullscreen mode Exit fullscreen mode

and pass all encoded data to it:



onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

outputs = decoder.run(None,{
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})
masks = outputs[0]
masks.shape


Enter fullscreen mode Exit fullscreen mode


(1, 1, 415, 612)


Enter fullscreen mode Exit fullscreen mode

This code runs the model with encoded image_embeddings, point_coords and point_labels. Also, I provided dummy masks to mask_input and has_mask_input and original image size to the orig_im_size parameter.

The model returns 3 outputs, and the array of segmentation masks is the first of them. For the input image it returned the tensor of (1, 415, 612) shape which is a single channel segmentation mask.

The only step left is to post process it.

Post-process and visualize segmentation mask

The segmentation mask is an array of pixels, however, each pixel contains not a color but some number. If this number greater than 0, then this pixel belongs to object, otherwise not. So, to convert it to real pixel colors you can run the following code:



mask = masks[0][0]
mask = (mask > 0).astype('uint8')*255


Enter fullscreen mode Exit fullscreen mode

This code extracts the pixel matrix from the mask (415x612), converts all positive values to True and all negatives to False. Then it converts all numbers to 8-bit integers. After this, all True values becomes 1 and all False values become 0. Then, I multiplied the matrix by 255 to convert all True pixels to white color. Finally, you have a single channel black-white image, that can be easily visualized by many image libraries. For example, you can visualize it this way using the Pillow:



img = Image.fromarray(mask,'L')
img


Enter fullscreen mode Exit fullscreen mode

Image description

Hooray! Now you can do Segment Anything image segmentation using only ONNX.

This is the end of our journey. You can find all source code of this section in the sam_onnx_inference.ipynb notebook in the repository.

Conclusion

In this article, I showed how to fill the gap in the official implementation of the Segment Anything Model's ONNX export function. Then I guided you how to do a prompt-based image segmentation using the exported ONNX models.

All source code you can find in this repository: https://github.com/AndreyGermanov/sam_onnx_full_export.

Here I used only Python, but now, with complete ONNX models you can do much more. You can run Segment Anything model on any programming language, supported by ONNX runtime. If you know the algorithm how to pre-process input and post-process output, you can integrate this model to most production systems, written in any programming language. For example, you can embed it to software written on C/C++, Go or Rust, or to websites written on JavaScript.

Thank you and until next time!

Follow me on LinkedIn, Twitter, and Facebook to know first about new articles like this one and other software development news.

💖 💪 🙅 🚩
andreygermanov
Andrey Germanov

Posted on November 15, 2023

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related