SegmentAnything is amazing because it can cleanly segment anything (cut out objects)

MLBoy
5 min readJun 15, 2023

--

Edit segments of any object at will

Segment Anything
You can segment anything.
Because you can easily specify the points and ranges you want to segment.
This makes it easy to segment like image editing software.

For the time being, you can try it on the demo site below.

How to use

install

git clone https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

model instantiation

from segment_anything import sam_model_registry, SamPredictor
import torch

sam_checkpoint = “sam_vit_h_4b8939.pth”
model_type = “vit_h”

device = “cuda”

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

execution

We’ll try it with the image below.


import cv2
import numpy as np

image = cv2.imread('images/kyoto.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image) # Image to embedding

Get the mask using the prompt.
Various prompts are available.

specify a point

Execute by specifying one point with an object on the image.
The star mark attached to the flag in the upper center is the designated point.

label is 1 for foreground and 0 for background.

Three highly reliable masks are obtained.

input_point = np.array([[4200, 2000]])
input_label = np.array([1])

masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)

mask is a numpy.array of True and False with the same shape as the image (1, width, height).

array([[False, False, False, …, False, False, False],
[False, False, False, …, False, False, False],
[False, False, False, … , False, False, False],
…,
[False, False, False, …, False, False, False],
[False, False, False, …, False, False, False],
[False , False, False, …, False, False, False]],

By setting multimask_output=True, you can get 3 highly reliable mask candidates.
This makes the resulting masks a numpy.array of shape (3, width, height).

Specify multiple points

Run by specifying multiple points with objects on the image.
At this time, the result of the previous inference can be used for this inference by giving the high score mask output from the previous inference of one point to the input. This will give you a more accurate mask.

input_point = np.array([[2000, 600], [3000, 400],  [2000, 190], [1000, 1000]])
input_label = np.array([1, 1, 1, 1])

mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask

masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)

By specifying a point in the background, you can narrow the range of the foreground object.

From the mask of the whole cat, you can segment only the tail by specifying the body as the background (label 0) as follows.

input_point = np.array([[390, 1000], [500, 700]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask

masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)

give a box as a prompt

You can prompt the bounding box of an object in the image.

input_box = np.array([150, 400, 360, 580]

masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)

Use a combination of points and boxes

You can mask non-berry cakes using boxes and negative points.

input_box = np.array([150, 400, 360, 580])
input_point = np.array([[290, 550]])
input_label = np.array([0])

masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)

Prompt for multiple boxes

input_boxes = torch.tensor([
[35,90,300,450],
[85,250,220,320],
[90,410,215,550],
[220,360,420,520],
], device=predictor.device)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)

visualization functions

def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()

Combined with inpainting, you can easily create an object deletion feature.

🐣

I’m a freelance engineer.
Work consultation
Please feel free to contact us with a brief development description.
rockyshikoku@gmail.com

I am creating applications using machine learning and AR technology.

I send machine learning / AR related information.

GitHub

Twitter
Medium

--

--