We want to make an image classifier as easily as possible with a custom data set. Transfer learning with Pytorch

Train your Image Classification model on your own dataset

Colab: from the Pytorch tutorial

We want to create an image classification function according to the use case

If we want to tailor our image classification to our use case, train our image classification model with a dataset that fits our use case.
For example, if we want to identify traffic signs in an image, collect images of each sign and train the model.

Let’s try using Pytorch

Let’s learn using Pytorch.
This time, we will use a model that has trained a dataset called ImageNet in advance, and attach an unlearned neural network to the end of the model to train it.
This technique is called Finetune.
Also, at the end of the article, the pre-trained layer is locked so that only the last layer added is trained.
This technique is called transfer learning, whereas the entire model is trained from random weights.
This is a common practice for use cases where large datasets cannot be prepared.


Import of required modules

Prepare the data set

As shown below, prepare the train and val directories in the data directory, prepare the class name directory in each, and put the image.

Make the dataset the data loader format that gives to the model

Display the data in the data loader format as an image and check it.

Set up a training loop

Define model and training settings

Download torchvision’s pre-trained model and add a neural network to train.
Set the loss function and optimizer.

Pre-trained models are provided in torchvision.models as follows:

import torchvision.models as models
resnet18 = models.resnet18 (pretrained = True)
alexnet = models.alexnet (pretrained = True)
squeezenet = models.squeezenet1_0 (pretrained = True)
vgg16 = models.vgg16 (pretrained = True)
densenet = models. densenet161 (pretrained = True)
inception = models.inception_v3 (pretrained = True)
googlenet = models.googlenet (pretrained = True)
shufflenet = models.shufflenet_v2_x1_0 (pretrained = True)
mobilenet_v2 = models.mobilenet_v2 (pretrained = True)
mobilenet_v3_large = models. mobilenet_v3_large (pretrained = True)
mobilenet_v3_small = models.mobilenet_v3_small (pretrained = True)
resnext50_32x4d = models.resnext50_32x4d (pretrained = True)
wide_resnet50_2 = models.wide_resnet50_2 (pretrained = True)
mnasnet = models.mnasnet1_0 (pretrained = True)
efficientnet_b0 = models.efficientnet_b0 (pretrained = True)
efficientnet_b1 = models.efficientnet_b1 (pretrained = True)
efficientnet_b2 = models.efficientnet_b2
efficientnet_b3 = models.efficientnet_b3 (pretrained = True)
efficientnet_b4 = models.efficientnet_b4 (pretrained = True)
efficientnet_b5 = models.efficientnet_b5 (pretrained = True)
efficientnet_b6 = models.efficientnet_b6 (pretrained = True)
efficientnet_b7 = models.efficientnet_b7 (pretrained = True)
regnet_y_400mf = models.regnet_y_400mf (pretrained = True)
regnet_y_800mf = models.regnet_y_800mf (pretrained = True)
regnet_y_1_6gf = models.regnet_y_1_6gf (pretrained = True)
regnet_y_3_2gf = models.regnet_y_3_2gf (pretrained = True)
regnet_y_8gf = models.regnet_y_8gf (
pretrained = True)
regnet_y_16gf = models.
regnet_x_400mf = models.regnet_x_400mf (pretrained = True)
regnet_x_800mf = models.regnet_x_800mf (pretrained = True)
regnet_x_1_6gf = models.regnet_x_1_6gf (
pretrained = True)
regnet_x_3_2gf = models.
regnet_x_16gf = models.regnet_x_16gf (pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf (pretrained = True)
vit_b_16 = models.vit_b_16 (pretrained = True)
vit_b_32 = models.vit_b_32 (pretrained = True)
vit_l_16 = models.vit_l_16 (pretrained = True)
vit_l_32 = models.vit_l_32 (pretrained = True) convnext_tiny =
convnext_small = models.convnext_small (pretrained = True)
convnext_base = models.convnext_base (pretrained = True)
convnext_large = models.convnext_large (pretrained = True)


Best val Acc: 0.934641

I was able to learn about 200 sheets in 2 classes in 33 minutes with Colab’s CPU (without GPU).

Try to infer with a model

Lock all but the last layer (fixed feature extractor)

Learning will be faster.

If the training data is essentially similar to the pre-trained network, training a new classifier on a fixed feature extractor may be sufficient (and arguably the fastest). It’s an approach). However, if the data is quite different (for example, if conv net was trained with a photo while trying to tweak a painting or illustration), it may make sense to tweak the previous layers as well.

Best val Acc: 0.967320

I was able to learn about 200 sheets in 2 classes in 16 minutes with Colab’s CPU (without GPU).

Let’s make a discriminator according to the use case

If you have all the data, you can learn in a short time, so let’s apply it to business.


I’m a freelance engineer.
Work consultation
Please feel free to contact us with a brief development description.

I am making an app that uses Core ML and ARKit.
We send machine learning / AR related information.






Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store