Fine tune Segment Anything (SAM) for images with multiple masks
TL;DR
If you want to jump straight into the code I’ll provide a notebook in colab here. However, I advise you to read this blog post first.
WHAT TO EXPECT
If you have a dataset with annotated images (multiple objects possible) you can fine-tune SAM. For my use case, finetuning vastly improved the performance of SAM and it even worked better than Detectron2 or U-net. Note I fine-tuned images with multiple masks, but the masks were all from the same category. This article comes with a Jupyter Notebook to follow along. The article gives a conceptual overview, and the notebook all the details
Let’s start!
REQUIREMENTS
To follow along you need:
- a dataset of images (jpg, png, etc.)
- annotations the objects in COCO format (annotations.json)
- and of course, all the necessary Python packages installed (covered in the notebook)
SAM ARCHITECTURE OVERVIEW
Segment Anything is an image segmentation model from Facebook (meta) research. Its object is to be a foundational model for image segmentation (similar to GPT for text). In fact, its architecture is even inspired by Chat-GPT, in such that you can input prompts (points, bounding boxes, and in the future also text) that will help the model to segment an object. The Facebook team built a data engine and trained the model with millions of images (see paper). Figure 2 shows the basic idea. The module image encoder creates image embeddings. These encodings get processed by a neural net. The mask encoder module outputs binary masks (or RLE-formated masks). If the user provides prompts the prompt encoder module tells the model where it is most likely to find masks. We go over this later in a code example.
Three Levels of usage:
- Noob level: Online Tool: SAM has an online version where you can upload images and click on an object in your image and SAM will find the segmentation. Of course, for us computer scientists this is no operating mode. Luckily Facebook open-sourced the code and provided interfaces on how to use the model.
- Mortal level: Facebook provides code snippets on how to use the model. You need to download the model checkpoint and provide an image, read the image (with cv2.imread for example, and start). You can read all this on the GitHub page of SAM (see), so we won’t go into detail here.
- God level: The special snowflakes we are, we want to use SAM on images it has never seen before, thus do not come from the training distribution. To achieve this we need to understand a few things about the SAM model architecture and make some adjustments.
FINE-TUNE SAM
To understand how to fine-tune SAM, we check which tools SAM provides out of the box and why we cannot fine-tune these tools directly. If we want to generate masks with SAM, we can simply run the following code:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# this is classical PyTorch mode
sam = sam_model_registry["vit_h"](checkpoint="../model.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
# read in some image file ...
mask = mask_generator.generate(sample_img_prep)
# OUTPUT ----------------------------------------------------------
#LIST of dicts where each dict represents one mask
# dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
We can just merge all the single masks to one output mask if we want to and visualize:
from typing import List, Dict, Any
def build_totalmask(pred: List[Dict[str, Any]]) -> np.ndarray:
"""Builds a total mask from a list of segmentations
ARGS:
pred (list): list of dicts with keys 'segmentation' and others
RETURNS:
total_mask (np.ndarray): total mask
"""
total_mask = np.zeros(pred[0]['segmentation'].shape, dtype=np.uint8)
for seg in pred:
total_mask += seg['segmentation']
# use cv2 to make image black and white
_, total_mask = cv2.threshold(total_mask, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
return total_mask
tmask = build_totalmask(mask_prep)
What’s the problem?
It seems very easy to use SAM and get all the masks with the SamAutomaticMaskGenerator class, so why not use this class and fine-tune it, like so:
# PSEUDO CODE TO ILLUSTRATE:
sam = sam_model_registry["vit_h"](checkpoint="../model.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
optimizer = torch.optim.Adam(sam.parameters(), lr=0.01)
for img, groundtruth_mask in dataloader:
sam.train()
optimizer.zero_grad()
pred_mask = mask_generator.generate(img)
loss = criterion(pred_mask, groundtruth_mask)
loss.backwards()
optimizer.step()
# and so on ....
This does not work at all, because the SamAutomaticMaskGenerator() class does a ton of work under the hood and also runs without a gradient and its output is binary (not differentiable).
What’s the solution?
Solution 1: Adjust the SAM code itself, so we can use SamAutomaticMaskGenerator() directly.
- I’ve tried this and I don’t recommend this step. There is a big fuckup potential and it’s clearly not the idea of the creators to use this. If you decide to give it a try, you can follow this protocol:
# ---------------------------------------------------------------
1. Clone the repo and install locally
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
2. Manually add gradient in these two files:
https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/predictor.py
3. Modify SamAutomaticMaskGenerator(), so that you
get a mask probablity map with gradients directly.
SamAutomaticMaskGenerator()
4. re-install locally with pip install -e . and use train (see concept section before)
Why would you even try such a hacky solution? The SamAutomaticMaskGenerator() provides a lot of useful functionality that we need to provide ourselves otherwise. As mentioned before, SAM gives us only one mask (three masks if you select multitask output). The masks will be the most prominent object in the image or the object that the user prompted via points or bounding boxes. When you want to have all masks, however, SamAutomaticMaskGenerator() does the following:
- lays a point grid (serves as a prompt) over each image (see Figure 5)
- predicts a mask for each point
- removes duplicate and low-quality masks
- removes small masks
- takes care of output format (binary mask or RLE, instead of a probability map)
- preprocesses the images for use (we need to do this manually for our next approach)
Solution 2:
For this clean solution, we will work with the SAM model directly and consequently update the model weights directly. To achieve this we need to perform the following steps:
- Preprocess our image data. SAM expects quadratic 1024x1024 images. We can find guidance in the SAM code directly (see). You can also find the code in my notebook at the end.
- Build or own PyTorch wrapper model in which we connect the image_encoder, prompt_encoder and mask_decoder in our forward pass.
- Understand the output structure of the mask_decoder. This is a (1024,1024) probability map.
- Define a mask probability threshold and loss functions. See:
## Loss Functions
Quote from Sam Paper
```
Losses. We supervise mask prediction with a linear combination of focal loss [65] and dice loss [73] in a 20:1 ratio of
focal loss to dice loss, following [20, 14]. Unlike [20, 14],
we observe that auxiliary deep supervision after each decoder layer is unhelpful. The IoU prediction head is trained
with mean-square-error loss between the IoU prediction and
the predicted mask’s IoU with the ground truth mask. It is
added to the mask loss with a constant scaling factor of 1.0.
```
We won't train the IoU prediction head
- We are lazy and don’t want to add a point grid as a prompt and deal with all the postprocessing stuff, so we just feed the whole ground truth mask (including all masks in one image) and let the loss function take care of it. This means forcing the model to output high probabilities for the location where the masks are.
These are the conceptual steps to fine-tune SAM. Because Medium is not the ideal place to share code, you’ll find the notebook with code here. When seeing the code and reading the article this should all become much more clear.
ACKNOWLEDGEMENTS
I mainly used these three GitHub repos as sources:
- Original SAM Code (see)
- Finetune anything (see)
- Lightning-Sam (see). Special thanks to Luca Medeiros from whom I used quite some code.
- The images come from our lab partner Dr. Daniel Freund (TU Dresden)
Unfortunately, a Video of DigitalSreeni on how to fine-tune SAM came out after I struggled with my implementation. So I could not use it, but I’ll leave the link here. I did not watch the video yet, but his videos are usually very good.
Notebook:
Colab here