Converting Cycle GAN to CoreML model.

MLBoy
2 min readJun 23, 2020

--

horse2zebra. You can try whatever image2image like this.

In this story, we use CycleGAN tutorials model in TensorFlow Core. At first, train the tutorial model in Colaboratory.

## run all cells in colab to this line.for epoch in range(EPOCHS):
start = time.time()

n = 0
for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('.', end='')
n+=1

clear_output(wait=True)
# Using a consistent image (sample_horse) so that the progress of the model
# is clearly visible.
generate_images(generator_g, sample_horse)

if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))

print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))

Then, insert new cells and run converter.

1, Install CoreMLTools and TFCoreML.

!pip install --upgrade coremltools
!pip install --upgrade tfcoreml

2, Restore checkpoints.

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')

3, Save generator_g(this is the “horse2zebra” generator) as a “Saved Model” format for temporary.

generator_g.save('./savedmodel')

4, Run converter.

import tfcoremlinput_name = generator.inputs[0].name.split(':')[0]
print(input_name) #Check input_name.
keras_output_node_name = generator_g.outputs[0].name.split(':')[0]graph_output_node_name = keras_output_node_name.split('/')[-1]mlmodel = tfcoreml.convert('./savedmodel',
input_name_shape_dict={input_name: (1, 256, 256, 3)},
output_feature_names=[graph_output_node_name],
minimum_ios_deployment_target='13',
image_input_names=input_name,
image_scale=2/ 255.0,
red_bias=-1,
green_bias=-1,
blue_bias=-1
)
mlmodel.save('./cyclegan.mlmodel')

Now, you can use CycleGAN in your iOS project.

import Vision
lazy var coreMLRequest:VNCoreMLRequest = {
let model = try! VNCoreMLModel(for: pix2pix().model)
let request = VNCoreMLRequest(model: model, completionHandler: self.coreMLCompletionHandler0)
return request
}()
let handler = VNImageRequestHandler(ciImage: ciimage,options: [:])
DispatchQueue.global(qos: .userInitiated).async {
try? handler.perform([coreMLRequest])
}

For visualizing multiArray as image, Mr. Hollance’s “CoreML Helpers” are very convenient.

Please follow my Twitter. https://twitter.com/JackdeS11 And please clap your hands 👏.

--

--