Convert RMBG1.4 to Core ML model

MLBoy
3 min readFeb 18, 2024

--

High quality background removal on iOS

Installing the original model and CoreMLTools

git clone https://huggingface.co/briaai/RMBG-1.4
cd RMBG-1.4/
pip install -r requirements.txt
pip install coremltools

policy

It takes a normalized input and outputs a grayscale mask image.
Create a wrap model that includes Python’s postprocess, and
perform Core ML conversion that includes preprocess.

Python’s preprocess and postProcess are below.

def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
# orig_im_size=im.shape[0:2]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
return image

def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array

preprocess normalizes the input to -0.5~0.5.
postprocess sets the output to 0 to 1 and then 0 to 255.
The permute operation (swapping the dimensions of an array) is for operating in cv2, so we will not include it this time.

Create a wrapped model that includes the original model and postprocess

from briarmbg import BriaRMBG

net = BriaRMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
net.to(device)

# wrap class including the postprocess
class CoreMLRMBG(torch.nn.Module):
def __init__(self, net):
super(CoreMLRMBG, self).__init__()
self.net = net

def forward(self, image):
result = self.net(image)[0][0]
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255)
return im_array

model = CoreMLRMBG(net)

Insert the entire original model and add post-processing to the forward.
This means that if you subtract the minimum value from the output and divide it by the difference between the maximum and minimum values, you get 0~1, and when you multiply it by 255, you get 0~255.

Conversion including preprocess

from coremltools.converters.mil.input_types import ColorLayout
ex = torch.randn((1,3,1024,1024)).cuda()
jit_model = torch.jit.trace(model,ex)
import coremltools as ct

coreml_model = ct.convert(
jit_model,
convert_to="mlprogram",
compute_precision=ct.precision.FLOAT32,
compute_units=ct.ComputeUnit.CPU_AND_GPU,
inputs=[
ct.ImageType(name="image",
shape=ex.shape,
bias=[-0.5,-0.5,-0.5],
scale=1/255.0)
],
outputs=[ct.ImageType(name="output",color_layout=ColorLayout.GRAYSCALE)])
coreml_model.save("RMBG.mlpackage")

This means that if you divide the input image from 0 to 255 by 255 and subtract 0.5, it will be normalized to -0.5 to 0.5.

Use with iOS

guard let model = try? VNCoreMLModel(for: RMBG().model) else {
fatalError()
}

let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
guard let results = request.results, let firstResult = results.first as? VNPixelBufferObservation else {
return
}
let ciImage = CIImage(cvPixelBuffer: firstResult.pixelBuffer)
let context = CIContext()
guard let cgImage = context.createCGImage(ciImage, from: ciImage.extent) else { return }
DispatchQueue.main.async {
self?.resultImage = UIImage(cgImage: cgImage)
}
})

request.imageCropAndScaleOption = .scaleFill

let handler = VNImageRequestHandler(ciImage: ciImage, options: [:])
DispatchQueue.global(qos: .userInteractive).async {
do {
try handler.perform([self.request])
} catch {
fatalError()
}
}

I want the whole image as input, so
imageCropAndScaleOption is .scaleFill.
The output is a 1024*1024 square, so resize it to the original image size.

Conversion demo.

Original model.

I’m a freelance engineer.
Work consultation
Please feel free to contact us with a brief development description.
rockyshikoku@gmail.com

I am creating applications using machine learning and AR technology.

I send machine learning / AR related information.

GitHub

Twitter
Medium

--

--