Practice/Wonderful-Tune Section Something 2 (SAM 2) in 60 Traces of Code | by Sagi eppel | Aug, 2024

A step-by-step tutorial for fine-tuning SAM2 for customized segmentation duties

SAM2 (Section Something 2) is a brand new mannequin by Meta aiming to section something in a picture with out being restricted to particular courses or domains. What makes this mannequin distinctive is the dimensions of knowledge on which it was skilled: 11 million pictures, and 11 billion masks. This intensive coaching makes SAM2 a robust start line for coaching on new picture segmentation duties.

The query you would possibly ask is that if SAM can section something why can we even have to retrain it? The reply is that SAM is superb at frequent objects however can carry out quite poorly on uncommon or domain-specific duties.
Nevertheless, even in circumstances the place SAM provides inadequate outcomes, it’s nonetheless potential to considerably enhance the mannequin’s potential by fine-tuning it on new information. In lots of circumstances, it will take much less coaching information and provides higher outcomes then coaching a mannequin from scratch.

This tutorial demonstrates easy methods to fine-tune SAM2 on new information in simply 60 strains of code (excluding feedback and imports).

The complete coaching script of the may be present in:

SAM2 web diagram taken from SAM2 GIT web page

The primary means SAM works is by taking a picture and some extent within the picture and predicting the masks of the section that comprises the purpose. This method allows full picture segmentation with out human intervention and with no limits on the courses or varieties of segments (as mentioned in a earlier put up).

The process for utilizing SAM for full picture segmentation:

  1. Choose a set of factors within the picture
  2. Use SAM to foretell the section containing every level
  3. Mix the ensuing segments right into a single map

Whereas SAM may make the most of different inputs like masks or bounding packing containers, these are primarily related for interactive segmentation involving human enter. For this tutorial, we’ll give attention to absolutely computerized segmentation and can solely contemplate single factors enter.

Extra particulars on the mannequin can be found on the undertaking web site.

The SAM2 may be downloaded from:

If you happen to don’t need to copy the coaching code, it’s also possible to obtain my forked model that already comprises the TRAIN.py script.

Comply with the set up directions on the github repository.

Generally, you want Python >=3.11 and PyTorch.

As well as, we’ll use OpenCV this may be put in utilizing:

pip set up opencv-python

Downloading pre-trained mannequin

You additionally have to obtain the pre-trained mannequin from:

https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints

There are a number of fashions you possibly can select from all suitable with this tutorial. I like to recommend utilizing the small mannequin which is the quickest to coach.

The following step is to obtain the dataset that can be used to fine-tune the mannequin. For this tutorial, we’ll use the LabPics1 dataset for segmenting supplies and liquids. You’ll be able to obtain the dataset from this URL:

https://zenodo.org/information/3697452/information/LabPicsV1.zip?obtain=1

The very first thing we have to write is the info reader. This may learn and put together the info for the web.

The information reader wants to supply:

  1. A picture
  2. Masks of all of the segments within the picture.
  3. And a random level inside every masks

Lets begin by loading dependencies:

import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

Subsequent we listing all the photographs within the dataset:

data_dir=r"LabPicsV1//" # Path to LabPics1 dataset folder
information=[] # listing of information in dataset
for ff, title in enumerate(os.listdir(data_dir+"Easy/Practice/Picture/")): # go over all folder annotation
information.append({"picture":data_dir+"Easy/Practice/Picture/"+title,"annotation":data_dir+"Easy/Practice/Occasion/"+title[:-4]+".png"})

Now for the primary operate that may load the coaching batch. The coaching batch consists of: One random picture, all of the segmentation masks belong to this picture, and a random level in every masks:

def read_batch(information): # learn random picture and its annotaion from  the dataset (LabPics)

# choose picture

ent = information[np.random.randint(len(data))] # select random entry
Img = cv2.imread(ent["image"])[...,::-1] # learn picture
ann_map = cv2.imread(ent["annotation"]) # learn annotation

# resize picture

r = np.min([1024 / Img.shape[1], 1024 / Img.form[0]]) # scalling issue
Img = cv2.resize(Img, (int(Img.form[1] * r), int(Img.form[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.form[1] * r), int(ann_map.form[0] * r)),interpolation=cv2.INTER_NEAREST)

# merge vessels and supplies annotations

mat_map = ann_map[:,:,0] # materials annotation map
ves_map = ann_map[:,:,2] # vessel annotaion map
mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map

# Get binary masks and factors

inds = np.distinctive(mat_map)[1:] # load all indices
factors= []
masks = []
for ind in inds:
masks=(mat_map == ind).astype(np.uint8) # make binary masks
masks.append(masks)
coords = np.argwhere(masks > 0) # get all coordinates in masks
yx = np.array(coords[np.random.randint(len(coords))]) # select random level/coordinate
factors.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(factors), np.ones([len(masks),1])

The primary a part of this operate is selecting a random picture and loading it:

ent  = information[np.random.randint(len(data))] # select random entry
Img = cv2.imread(ent["image"])[...,::-1] # learn picture
ann_map = cv2.imread(ent["annotation"]) # learn annotation
Word that OpenCV reads pictures as BGR whereas SAM expects pictures as RGB, utilizing […,::-1] to vary the picture from BGR to RGB.

Word that OpenCV reads pictures as BGR whereas SAM expects RGB pictures. Through the use of […,::-1] we alter the picture from BGR to RGB.

SAM expects the picture measurement to not exceed 1024, so we’re going to resize the picture and the annotation map to this measurement.

r = np.min([1024 / Img.shape[1], 1024 / Img.form[0]]) # scalling issue
Img = cv2.resize(Img, (int(Img.form[1] * r), int(Img.form[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.form[1] * r), int(ann_map.form[0] * r)),interpolation=cv2.INTER_NEAREST)

An necessary level right here is that when resizing the annotation map (ann_map) we use INTER_NEAREST mode (nearest neighbors). Within the annotation map, every pixel worth is the index of the section it belongs to. Consequently, it’s necessary to make use of resizing strategies that don’t introduce new values to the map.

The following block is particular to the format of the LabPics1 dataset. The annotation map (ann_map) comprises a segmentation map for the vessels within the picture in a single channel, and one other map for the supplies annotation in a special channel. We going to merge them right into a single map.

  mat_map = ann_map[:,:,0] # materials annotation map
ves_map = ann_map[:,:,2] # vessel annotaion map
mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map

What this provides us is a a map (mat_map) wherein the worth of every pixel is the index of the section to which it belongs (for instance: all cells with worth 3 belong to section 3). We need to remodel this right into a set of binary masks (0/1) the place every masks corresponds to a special section. As well as, from every masks, we need to extract a single level.

inds = np.distinctive(mat_map)[1:] # listing of all indices in map
factors= [] # listing of all factors (one for every masks)
masks = [] # listing of all masks
for ind in inds:
masks = (mat_map == ind).astype(np.uint8) # make binary masks for index ind
masks.append(masks)
coords = np.argwhere(masks > 0) # get all coordinates in masks
yx = np.array(coords[np.random.randint(len(coords))]) # select random level/coordinate
factors.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(factors), np.ones([len(masks),1])

That is it! We bought the picture (Img), a listing of binary masks similar to segments within the picture (masks), and for every masks the coordinate of a single level contained in the masks (factors).

Instance for a batch of coaching information: 1) An Picture. 2) Listing of segments masks. 3) For every masks a single level contained in the masks (marked purple for visualization solely). Taken from the LabPics dataset.

Now lets load the web:

sam2_checkpoint = "sam2_hiera_small.pt" # path to mannequin weight
model_cfg = "sam2_hiera_s.yaml" # mannequin config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, machine="cuda") # load mannequin
predictor = SAM2ImagePredictor(sam2_model) # load web

First, we set the trail to the mannequin weights in: sam2_checkpoint parameter. We downloaded the weights earlier from right here. “sam2_hiera_small.pt” check with the small mannequin however the code will work for any mannequin you select. Whichever mannequin you select it’s essential set the corresponding config file within the model_cfg parameter. The config information are already positioned within the sub folder sam2_configs/” of the primary repository.

Earlier than setting coaching parameters we have to perceive the essential construction of the SAM mannequin.

SAM consists of three elements:
1) Picture encoder, 2) Immediate encoder, 3) Masks decoder.

The picture encoder is liable for processing the picture and creating the embedding that represents the picture. This half consists of a VIT transformer and is the biggest element of the web. We normally don’t need to prepare it, because it already provides good illustration and coaching will demand a number of sources.

The immediate encoder processes the extra enter to the web, in our case the enter level.

The masks decoder takes the output of the picture encoder and immediate encoder and produces the ultimate segmentation masks. Generally, we need to prepare solely the masks decoder and perhaps the immediate encoder. These elements are light-weight and may be fine-tuned quick with a modest GPU.

We are able to allow the coaching of the masks decoder and immediate encoder by setting:

predictor.mannequin.sam_mask_decoder.prepare(True) # allow coaching of masks decoder 
predictor.mannequin.sam_prompt_encoder.prepare(True) # allow coaching of immediate encoder

Subsequent, we outline the usual adamW optimizer:

optimizer=torch.optim.AdamW(params=predictor.mannequin.parameters(),lr=1e-5,weight_decay=4e-5)

We additionally going to make use of blended precision coaching which is only a extra memory-efficient coaching technique:

scaler = torch.cuda.amp.GradScaler() # set blended precision

Now lets construct the primary coaching loop. The primary half is studying and getting ready the info:

for itr in vary(100000):
with torch.cuda.amp.autocast(): # solid to combine precision
picture,masks,input_point, input_label = read_batch(information) # load information batch
if masks.form[0]==0: proceed # ignore empty batches
predictor.set_image(picture) # apply SAM picture encoder to the picture

First we solid the info to combine precision for environment friendly coaching:

with torch.cuda.amp.autocast():

Subsequent, we use the reader operate we created earlier to learn coaching information:

picture,masks,input_point, input_label = read_batch(information)

We take the picture we loaded and go it by way of the picture encoder (the primary a part of the web):

predictor.set_image(picture)

Subsequent, we course of the enter factors utilizing the web immediate encoder:

  mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, field=None, mask_logits=None, normalize_coords=True)
sparse_embeddings, dense_embeddings = predictor.mannequin.sam_prompt_encoder(factors=(unnorm_coords, labels),packing containers=None,masks=None,)

Word that on this half we are able to additionally enter packing containers or masks however we’re not going to make use of these choices.

Now that we encoded each the immediate (factors) and the picture we are able to lastly predict the segmentation masks:

batched_mode = unnorm_coords.form[0] > 1 # multi masks prediction
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.mannequin.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.mannequin.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the unique picture decision

The primary half on this code is the mannequin.sam_mask_decoder which runs the mask_decoder a part of the web and generates the segmentation masks (low_res_masks) and their scores (prd_scores).

These masks are in decrease decision than the unique enter picture and are resized to the unique enter measurement within the postprocess_masks operate.

This offers us the ultimate prediction of the web: 3 segmentation masks (prd_masks) for every enter level we used and the masks scores (prd_scores). prd_masks comprises 3 predicted masks for every enter level however we solely going to make use of the primary masks for every level. prd_scores comprises a rating of how good the web thinks every masks is (or how certain it’s within the prediction).

Segmentation loss

Now we’ve got the web predictions we are able to calculate the loss. First, we calculate the segmentation loss, which implies how good the expected masks is in comparison with the bottom true masks. For this, we use the usual cross entropy loss.

First we have to convert prediction masks (prd_mask) from logits into possibilities utilizing the sigmoid operate:

prd_mask = torch.sigmoid(prd_masks[:, 0])# Flip logit map to chance map

Subsequent we convert the bottom reality masks right into a torch tensor:

prd_mask = torch.sigmoid(prd_masks[:, 0])# Flip logit map to chance map

Lastly, we calculate the cross entropy loss (seg_loss) manually utilizing the bottom reality (gt_mask) and predicted chance maps (prd_mask):

seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).imply() # cross entropy loss 

(we add 0.0001 to stop the log operate from exploding for zero values).

Rating loss (non-obligatory)

Along with the masks, the web additionally predicts the rating for the way good every predicted masks is. Coaching this half is much less necessary however may be helpful . To coach this half we have to first know what’s the true rating of every predicted masks. That means, how good the expected masks really is. We’re going to do it by evaluating the GT masks and the corresponding predicted masks utilizing intersection over union (IOU) metrics. IOU is solely the overlap between the 2 masks, divided by the mixed space of the 2 masks. First, we calculate the intersection between the expected and GT masks (the realm wherein they overlap):

inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)

We use threshold (prd_mask > 0.5) to show the prediction masks from chance to binary masks.

Subsequent, we get the IOU by dividing the intersection by the mixed space (union) of the expected and gt masks:

iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

We going to make use of the IOU because the true rating for every masks, and get the rating loss as absolutely the distinction between the expected scores and the IOU we simply calculated.

score_loss = torch.abs(prd_scores[:, 0] - iou).imply()

Lastly, we merge the segmentation loss and rating loss (giving a lot larger weight to the primary):

loss = seg_loss+score_loss*0.05  # combine losses

As soon as we get the loss every part is totally normal. We calculate backpropogation and replace weights utilizing the optimizer we made earlier:

predictor.mannequin.zero_grad() # empty gradient
scaler.scale(loss).backward() # Backpropogate
scaler.step(optimizer)
scaler.replace() # Combine precision

We additionally need to save the skilled mannequin as soon as each 1000 steps:

if itrpercent1000==0: torch.save(predictor.mannequin.state_dict(), "mannequin.torch") # save mannequin 

Since we already calculated the IOU we are able to show it as a transferring common to see how nicely the mannequin prediction are bettering over time:


if itr==0: mean_iou=0
mean_iou = mean_iou * 0.99 + 0.01 * np.imply(iou.cpu().detach().numpy())
print("step)",itr, "Accuracy(IOU)=",mean_iou)

And that it, we’ve got skilled/ fine-tuned the Section-Something 2 in lower than 60 strains of code (not together with feedback and imports). After about 25,000 steps you must see main enchancment .

The mannequin can be saved to “mannequin.torch”.

You will discover the total coaching code at:

https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code/blob/essential/TRAIN.py

To see easy methods to load and use the mannequin we simply skilled test the following part.

Now that the mannequin as been fine-tuned, let’s use it to section a picture.

We going to do that utilizing the next steps:

  1. Load the mannequin we simply skilled.
  2. Give the mannequin a picture and a bunch of random factors. For every level the web will predict the section masks that include this level and a rating.
  3. Take these masks and sew them collectively into one segmentation map.

The complete code for doing that’s obtainable at:

First, we load the dependencies and solid the weights to float16 this makes the mannequin a lot sooner to run (solely potential for inference).

import numpy as np
import torch
import cv2
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# use bfloat16 for the complete script (reminiscence environment friendly)
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

Subsequent, we load a pattern picture and a masks of the picture area we need to section (obtain picture/masks):

image_path = r"sample_image.jpg" # path to picture
mask_path = r"sample_mask.png" # path to masks, the masks will outline the picture area to section
def read_image(image_path, mask_path): # learn and resize picture and masks
img = cv2.imread(image_path)[...,::-1] # learn picture as rgb
masks = cv2.imread(mask_path,0) # masks of the area we need to section

# Resize picture to most measurement of 1024

r = np.min([1024 / img.shape[1], 1024 / img.form[0]])
img = cv2.resize(img, (int(img.form[1] * r), int(img.form[0] * r)))
masks = cv2.resize(masks, (int(masks.form[1] * r), int(masks.form[0] * r)),interpolation=cv2.INTER_NEAREST)
return img, masks
picture,masks = read_image(image_path, mask_path)

Pattern 30 random factors contained in the area we need to section:

num_samples = 30 # variety of factors/section to pattern
def get_points(masks,num_points): # Pattern factors contained in the enter masks
factors=[]
for i in vary(num_points):
coords = np.argwhere(masks > 0)
yx = np.array(coords[np.random.randint(len(coords))])
factors.append([[yx[1], yx[0]]])
return np.array(factors)
input_points = get_points(masks,num_samples)

Load the usual SAM mannequin (identical as in coaching)

# Load mannequin it's essential have pretrained mannequin already made
sam2_checkpoint = "sam2_hiera_small.pt"
model_cfg = "sam2_hiera_s.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, machine="cuda")
predictor = SAM2ImagePredictor(sam2_model)

Subsequent, Load the weights of the mannequin we simply skilled (mannequin.torch):

predictor.mannequin.load_state_dict(torch.load("mannequin.torch"))

Run the fine-tuned mannequin to foretell a segmentation masks for each level we chosen earlier:

with torch.no_grad(): # stop the web from caclulate gradient (extra environment friendly inference)
predictor.set_image(picture) # picture encoder
masks, scores, logits = predictor.predict( # immediate encoder + masks decoder
point_coords=input_points,
point_labels=np.ones([input_points.shape[0],1])
)

Now we’ve got a listing of predicted masks and their scores. We need to someway sew them right into a single constant segmentation map. Nevertheless, most of the masks overlap and may be inconsistent with one another.
The method to sewing is straightforward:

First we’ll kind the expected masks based on their predicted scores:

masks=masks[:,0].astype(bool)
shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool)

Now lets create an empty segmentation map and occupancy map:

seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)

Subsequent, we add the masks one after the other (from excessive to low rating) to the segmentation map. We solely add a masks if it’s according to the masks that had been beforehand added, which implies provided that the masks we need to add has lower than 15% overlap with already occupied areas.

for i in vary(shorted_masks.form[0]):
masks = shorted_masks[i]
if (masks*occupancy_mask).sum()/masks.sum()>0.15: proceed
masks[occupancy_mask]=0
seg_map[mask]=i+1
occupancy_mask[mask]=1

And that is it.

seg_mask now comprises the expected segmentation map with totally different values for every section and 0 for the background.

We are able to flip this right into a colour map utilizing:

rgb_image = np.zeros((seg_map.form[0], seg_map.form[1], 3), dtype=np.uint8)
for id_class in vary(1,seg_map.max()+1):
rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]

And show:

cv2.imshow("annotation",rgb_image)
cv2.imshow("combine",(rgb_image/2+picture/2).astype(np.uint8))
cv2.imshow("picture",picture)
cv2.waitKey()
Instance for segmentation outcomes utilizing fine-tuned SAM2. Picture from the LabPics dataset.

The complete inference code is accessible at: