High quality background removal on iOS
Installing the original model and CoreMLTools
git clone https://huggingface.co/briaai/RMBG-1.4
cd RMBG-1.4/
pip install -r requirements.txt
pip install coremltools
policy
It takes a normalized input and outputs a grayscale mask image.
Create a wrap model that includes Python’s postprocess, and
perform Core ML conversion that includes preprocess.
Python’s preprocess and postProcess are below.
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
# orig_im_size=im.shape[0:2]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
return image
def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
preprocess normalizes the input to -0.5~0.5.
postprocess sets the output to 0 to 1 and then 0 to 255.
The permute operation (swapping the dimensions of an array) is for operating in cv2, so we will not include it this time.
Create a wrapped model that includes the original model and postprocess
from briarmbg import BriaRMBG
net = BriaRMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
net.to(device)
# wrap class including the postprocess
class CoreMLRMBG(torch.nn.Module):
def __init__(self, net):
super(CoreMLRMBG, self).__init__()
self.net = net
def forward(self, image):
result = self.net(image)[0][0]
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255)
return im_array
model = CoreMLRMBG(net)
Insert the entire original model and add post-processing to the forward.
This means that if you subtract the minimum value from the output and divide it by the difference between the maximum and minimum values, you get 0~1, and when you multiply it by 255, you get 0~255.
Conversion including preprocess
from coremltools.converters.mil.input_types import ColorLayout
ex = torch.randn((1,3,1024,1024)).cuda()
jit_model = torch.jit.trace(model,ex)
import coremltools as ct
coreml_model = ct.convert(
jit_model,
convert_to="mlprogram",
compute_precision=ct.precision.FLOAT32,
compute_units=ct.ComputeUnit.CPU_AND_GPU,
inputs=[
ct.ImageType(name="image",
shape=ex.shape,
bias=[-0.5,-0.5,-0.5],
scale=1/255.0)
],
outputs=[ct.ImageType(name="output",color_layout=ColorLayout.GRAYSCALE)])
coreml_model.save("RMBG.mlpackage")
This means that if you divide the input image from 0 to 255 by 255 and subtract 0.5, it will be normalized to -0.5 to 0.5.
Use with iOS
guard let model = try? VNCoreMLModel(for: RMBG().model) else {
fatalError()
}
let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
guard let results = request.results, let firstResult = results.first as? VNPixelBufferObservation else {
return
}
let ciImage = CIImage(cvPixelBuffer: firstResult.pixelBuffer)
let context = CIContext()
guard let cgImage = context.createCGImage(ciImage, from: ciImage.extent) else { return }
DispatchQueue.main.async {
self?.resultImage = UIImage(cgImage: cgImage)
}
})
request.imageCropAndScaleOption = .scaleFill
let handler = VNImageRequestHandler(ciImage: ciImage, options: [:])
DispatchQueue.global(qos: .userInteractive).async {
do {
try handler.perform([self.request])
} catch {
fatalError()
}
}
I want the whole image as input, so
imageCropAndScaleOption is .scaleFill.
The output is a 1024*1024 square, so resize it to the original image size.
Conversion demo.
Original model.
I’m a freelance engineer.
Work consultation
Please feel free to contact us with a brief development description.
rockyshikoku@gmail.com
I am creating applications using machine learning and AR technology.
I send machine learning / AR related information.