Updated On : Oct-02,2022 Time Investment : ~30 mins

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

What is Image Segmentation?

Image segmentation is the process of segmenting images into segments (also referred to as objects). We detect objects present in images and color them to separate them from each other. It mainly concentrates on detecting boundaries of objects hence they can be easily separated. Many times, we even label each segment/object detected.

Applications of Image Segmentation

  • Face detection
  • Video surveillance
  • Self-driving car uses it to detect objects (road signs, other cars, pedestrian, etc)
  • Detecting objects in satellite images (roads, crops, buildings, etc)
  • Medical imaging to detect tumors.
  • Content-based image retrieval which searches contents of images rather than metadata, name, etc to retrieve data.

Types of Image Segmentation

  • Instance Segmentation - All objects of same type are marked with different colors/labels. Each object has its own color/label. E.g., Individual persons in image will have different colors/labels.
  • Semantic Segmentation - All objects of same type are marked with one color/label. E.g., all people in images will have same color/label.

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

How to perform Image Segmentation on Image?

Over the years many approaches have been developed for solving image segmentation tasks. Some of them use machine learning (deep learning) whereas others use non-ML solutions. The python library scikit-image has an implementation of majority of non-ML methods. We have listed below few famous approaches to solving image segmentation tasks using non-ML approaches.

  • Thresholding-based methods.
  • Clustering-based methods
  • Histogram-based methods.
  • Region-growing methods.
  • edge detection.
  • Watershed transformation.
  • Graph-based methods.

Majority of ML approaches involve use of deep neural networks consisting of layers like convolution, dense, etc. Below, we have listed some of the famous neural networks that solve image segmentation tasks.

  • U-Net
  • Fast-FCN (Fully Convolutional Network)
  • Mask R-CNN
  • DeepLab
  • LRASPP
  • Gates-SCNN

What Can You Learn From This Tutorial?

As a part of this tutorial, we have explained how to use pre-trained PyTorch models available from torchvision module for image segmentation tasks. Torchvision is a computer vision toolkit of PyTorch and provides pre-trained models for many computer vision tasks like image classification, object detection, image segmentation, etc.

We have downloaded few images from the internet and tried pre-trained models on them. We have explained usage of both instance and semantic segmentation models. Torchvision provides models that are trained on datasets COCO and Pascal VOC. It provides an implementation of majority of deep learning models we have listed above.

Below, we have listed essential sections of tutorial to give an overview of material covered.

Important Sections Of Tutorial

  1. Load Images
    • 1.1 Download Images
    • 1.2 Load Images in Memory using Pillow (PIL)
    • 1.3 Convert Pillow Images to PyTorch Tensors
  2. Load Models
    • 2.1 Semantic Segmentation Model
    • 2.2 Instance Segmentation Model
  3. Preprocess Images and Make Predictions
    • 3.1 Semantic Segmentation
    • 3.2 Instance Segmentation
  4. Visualize Results
    • 4.1 Semantic Segmentation
    • 4.2 Instance Segmentation
  5. Try Other Pre-Trained Image Segmentation Models

Install Latest Version of "torchvision"

  • pip install -U torchvision

Below, we have imported necessary Python libraries that we have used in our tutorial. We have also printed the version that we have used.

import torch

print("PyTorch Version : {}".format(torch.__version__))
PyTorch Version : 1.12.1+cu102
import torchvision

print("TorchVision Version : {}".format(torchvision.__version__))
TorchVision Version : 0.13.1+cu102
import gc

1. Load Images

1.1 Download Images

In this section, we have downloaded three images from the internet that we'll use for our tutorial. We'll try image segmentation algorithms on these images.

The images have objects like people, dogs, toys, etc that we'll try to detect using image segmentation algorithms.

!wget https://www.luxurytravelmagazine.com/files/593/2/80152/luxury-travel-instagram_bu.jpg
!wget https://www.akc.org/wp-content/uploads/2020/12/training-behavior.jpg
!wget https://images.squarespace-cdn.com/content/v1/519bd105e4b0c8ea540e7b36/1555002210238-V3YQS9DEYD2QLV6UODKL/The-Benefits-Of-Playing-Outside-For-Children.jpg
--2022-08-18 06:00:45--  https://www.luxurytravelmagazine.com/files/593/2/80152/luxury-travel-instagram_bu.jpg
Resolving www.luxurytravelmagazine.com (www.luxurytravelmagazine.com)... 108.61.242.74
Connecting to www.luxurytravelmagazine.com (www.luxurytravelmagazine.com)|108.61.242.74|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 65161 (64K) [image/jpeg]
Saving to: ‘luxury-travel-instagram_bu.jpg’

luxury-travel-insta 100%[===================>]  63.63K   331KB/s    in 0.2s

2022-08-18 06:00:46 (331 KB/s) - ‘luxury-travel-instagram_bu.jpg’ saved [65161/65161]

--2022-08-18 06:00:47--  https://www.akc.org/wp-content/uploads/2020/12/training-behavior.jpg
Resolving www.akc.org (www.akc.org)... 146.75.38.133, 2a04:4e42:78::645
Connecting to www.akc.org (www.akc.org)|146.75.38.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 558739 (546K) [image/jpeg]
Saving to: ‘training-behavior.jpg’

training-behavior.j 100%[===================>] 545.64K  1.52MB/s    in 0.4s

2022-08-18 06:00:48 (1.52 MB/s) - ‘training-behavior.jpg’ saved [558739/558739]

--2022-08-18 06:00:49--  https://images.squarespace-cdn.com/content/v1/519bd105e4b0c8ea540e7b36/1555002210238-V3YQS9DEYD2QLV6UODKL/The-Benefits-Of-Playing-Outside-For-Children.jpg
Resolving images.squarespace-cdn.com (images.squarespace-cdn.com)... 146.75.28.238
Connecting to images.squarespace-cdn.com (images.squarespace-cdn.com)|146.75.28.238|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 136668 (133K) [image/jpeg]
Saving to: ‘The-Benefits-Of-Playing-Outside-For-Children.jpg’

The-Benefits-Of-Pla 100%[===================>] 133.46K   728KB/s    in 0.2s

2022-08-18 06:00:50 (728 KB/s) - ‘The-Benefits-Of-Playing-Outside-For-Children.jpg’ saved [136668/136668]

%ls
The-Benefits-Of-Playing-Outside-For-Children.jpg
__notebook__.ipynb
luxury-travel-instagram_bu.jpg
training-behavior.jpg

1.2 Load Images in Memory using Pillow

After downloading images, we loaded them in memory using the Python Library pillow.

from PIL import Image

holiday = Image.open("luxury-travel-instagram_bu.jpg")

holiday

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

kids_playing = Image.open("The-Benefits-Of-Playing-Outside-For-Children.jpg")
dog_kid_playing = Image.open("training-behavior.jpg")

1.3 Convert Pillow Images to PyTorch Tensors

Below, we have converted all our pillow images to Pytorch tensors using pil_to_tensor() function available from torchvision module. All Pytorch models require input to be tensors.

from torchvision.transforms.functional import pil_to_tensor

holiday_tensor_int = pil_to_tensor(holiday)
kids_playing_tensor_int = pil_to_tensor(kids_playing)
dog_kid_playing_tensor_int = pil_to_tensor(dog_kid_playing)

holiday_tensor_int.shape, kids_playing_tensor_int.shape, dog_kid_playing_tensor_int.shape
(torch.Size([3, 422, 750]),
 torch.Size([3, 667, 1000]),
 torch.Size([3, 486, 729]))
holiday_tensor_int.dtype
torch.uint8

2. Load Models

In this section, we'll load image segmentation models in memory that we'll use on our images. We have loaded one model for explaining semantic segmentation and one for instance segmentation.

2.1 Semantic Segmentation Model

Below, we have loaded FCN (with RestNet50 backbone) deep neural network model. The model is available through method fcn_resnet50() from segmentation sub-module of torchvision module.

We need to provide weights parameter to load model with default weights. The segmentation module has an attribute named FCN_ResNet50_Weights that let us specify which weights to use. We have asked to load the model with COCO_WITH_VOC_LABELS_V1 weights. These weights are from model trained on COCO dataset.

Currently, the only weight option available with this model is COCO_WITH_VOC_LABELS_V1 which we have used in our tutorial. There are can be different weight options available if a model is trained with different datasets.

After loading model, we set it in evaluation mode by calling eval() method on it.

from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights

fcn_resnet = fcn_resnet50(weights=FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1)

fcn_resnet.eval();
Downloading: "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth" to /root/.cache/torch/hub/checkpoints/fcn_resnet50_coco-1167a1af.pth

2.2 Instance Segmentation Model

Below, we have loaded Mask R-CNN model which is available from detection sub-module of torchvision module. The model is loaded with COCO_V1 weights which are retrieved from model trained on COCO dataset.

from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

maskrcnn_resnet = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1)

maskrcnn_resnet.eval();
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

3. Preprocess Images and Make Predictions

In this section, we'll make predictions on our images using pre-trained models that we loaded in previous section. We'll need to preprocess images before making predictions on them.

3.1 Semantic Segmentation

The weights objects have a method named transforms() which can be used to prepare images for network. We have saved references to that method.

By default ,method resizes image to size 520, rescale in the range [0.0, 1.0] and normalize using mean [0.485, 0.456, 0.406] & standard deviation [0.229, 0.224, 0.225].

In our case, we have prevented resizing by setting that parameter value to None. This is done because we want to overlay segmented image on original image.

preprocess_img = FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.transforms(resize_size=None)

preprocess_img(holiday_tensor_int).unsqueeze(dim=0).shape
torch.Size([1, 3, 422, 750])

Below, we have made predictions using our trained model on the images that we loaded earlier. We are giving network preprocessed images. We have also introduced batch dimension using unsqueeze() because model works on batch of images.

The output of model is a dictionary with two keys.

  • out - Detected objects masks.
  • aux

The out key value has segmented image which we'll visualize in next section.

holiday_preds1 = fcn_resnet(preprocess_img(holiday_tensor_int).unsqueeze(dim=0))
gc.collect();

holiday_preds1.keys()
odict_keys(['out', 'aux'])
kids_playing_preds1 = fcn_resnet(preprocess_img(kids_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
dog_kid_playing_preds1 = fcn_resnet(preprocess_img(dog_kid_playing_tensor_int).unsqueeze(dim=0))
gc.collect();

3.2 Instance Segmentation

Below, we have first retrieved image preprocessing function from weights object.

Then, we made predictions on our images using Mask R-CNN model that we loaded earlier. We have given network preprocessed images for making predictions.

The prediction of network is a dictionary with 4 keys.

  • boxes - Bounding boxes around objects
  • labels - Labels of detected objects
  • scores - Predicted Probability of object presence.
  • masks - Detected object masks

As this network is part of detection module, it returns bounding boxes around detected objects in images.

preprocess_img = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.transforms()

preprocess_img(holiday_tensor_int).shape
torch.Size([3, 422, 750])
holiday_preds2 = maskrcnn_resnet(preprocess_img(holiday_tensor_int).unsqueeze(dim=0))
gc.collect();

holiday_preds2[0].keys()
dict_keys(['boxes', 'labels', 'scores', 'masks'])
kids_playing_preds2 = maskrcnn_resnet(preprocess_img(kids_playing_tensor_int).unsqueeze(dim=0))
gc.collect();
dog_kid_playing_preds2 = maskrcnn_resnet(preprocess_img(dog_kid_playing_tensor_int).unsqueeze(dim=0))
gc.collect();

4. Visualize Results

In this section, we'll visualize predictions made by our image segmentation models

4.1 Semantic Segmentation

In this section, we'll visualize predictions made by semantic segmentation model.

In order to do that, we have first created a dictionary that maps object names to their index. We'll be using this mapping to retrieve segmentation results for a particular object.

class_to_idx = {cls: idx for (idx, cls) in enumerate(FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.meta["categories"])}

class_to_idx
{'__background__': 0,
 'aeroplane': 1,
 'bicycle': 2,
 'bird': 3,
 'boat': 4,
 'bottle': 5,
 'bus': 6,
 'car': 7,
 'cat': 8,
 'chair': 9,
 'cow': 10,
 'diningtable': 11,
 'dog': 12,
 'horse': 13,
 'motorbike': 14,
 'person': 15,
 'pottedplant': 16,
 'sheep': 17,
 'sofa': 18,
 'train': 19,
 'tvmonitor': 20}

Below, we have first retrieved predictions made on our holiday image in a variable. The prediction is present in out key of prediction as we mentioned earlier.

The shape of prediction present in out key is (#class, height, width) where '#class' is a number of objects that model is trained to detect. In our case, it is 20 which we can see from dictionary loaded in previous cell.

After loading prediction, we have normalized masks using softmax() function.

Then, we have retrieved mask present at person index and visualized it. We can notice from the image that it is able to detect both persons present in an image.

from torchvision.transforms.functional import to_pil_image

prediction = holiday_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]

to_pil_image(normalized_masks[class_to_idx['person']])

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

The torchvision module provides a method named draw_segmentation_masks() that let us overlay detected objects on original image.

First, we have created a boolean mask from normalized mask by setting condition which sets probabilities above 0.7 to True.

Then, we have called draw_segmentation_masks() method with an original image and person mask to overlay detected person objects on original image.

from torchvision.utils import draw_segmentation_masks

masks = normalized_masks > 0.7

out = draw_segmentation_masks(holiday_tensor_int, masks[class_to_idx['person']])

to_pil_image(out)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have overlaid background on an original image using same process as previous cell.

from torchvision.utils import draw_segmentation_masks

masks = normalized_masks > 0.7

background = draw_segmentation_masks(holiday_tensor_int, masks[class_to_idx['__background__']])

to_pil_image(background)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have retrieved persons object present in our section image and visualized it. We can notice that model is correctly identifying all kids present in an image.

from torchvision.transforms.functional import to_pil_image

prediction = kids_playing_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]

to_pil_image(normalized_masks[class_to_idx['person']])

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have overlaid detected kids on original image.

from torchvision.utils import draw_segmentation_masks

masks = normalized_masks > 0.7

out = draw_segmentation_masks(kids_playing_tensor_int, masks[class_to_idx['person']])

to_pil_image(out)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have retrieved person object for third image and visualized it.

In the next cell, we have retrieved a dog object from same image and visualized it as well.

On later cells, we have overlaid kid and dog object on original image.

from torchvision.transforms.functional import to_pil_image

prediction = dog_kid_playing_preds1['out']
normalized_masks = prediction.softmax(dim=1)[0]

to_pil_image(normalized_masks[class_to_idx['person']])

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

to_pil_image(normalized_masks[class_to_idx['dog']])

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

from torchvision.utils import draw_segmentation_masks

masks = normalized_masks > 0.7

person = draw_segmentation_masks(dog_kid_playing_tensor_int, masks[class_to_idx['person']])

to_pil_image(person)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

from torchvision.utils import draw_segmentation_masks

masks = normalized_masks > 0.1

dog = draw_segmentation_masks(dog_kid_playing_tensor_int, masks=masks[class_to_idx['dog']])

to_pil_image(dog)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

4.2 Instance Segmentation

In this section, we'll visualize predictions made by our Mask R-CNN instance segmentation model.

Below, we have first retrieved mapping for labels present through MaskRCNN_ResNet50_FPN_Weights.COCO_V1 object. The dictionary has a mapping from label index to category name. We'll use this dictionary to convert labels predicted by model to category names.

class_to_idx = {cls: idx for (idx, cls) in enumerate(MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"])}

len(class_to_idx)
82

Below, we have retrieved predicted masks and their labels from prediction dictionary. We have then retrieved category names from predicted label indexes. We have also predicted object categories.

from torchvision.utils import draw_segmentation_masks

mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]

masks = holiday_preds2[0]['masks'].squeeze()
labels = holiday_preds2[0]['labels']
categories = [mapping[label] for label in labels[:15]]

print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
Detected Objects : ['person', 'person', 'kite', 'kite', 'kite', 'backpack', 'kite', 'kite', 'kite', 'kite', 'kite', 'kite', 'kite', 'kite', 'kite']
Unique Objects : ['backpack', 'kite', 'person']

Below, we have overlaid predicted objects on an original image using draw_segmentation_masks() method. We have provided method original image and predicted masks. We have not given all masks to method but only first 15 as we want to highlight few important objects.

We have also given mapping to colors parameter to color objects.

We can notice from the results how model is detecting objects like persons, backpack, etc.

from torchvision.utils import draw_segmentation_masks

color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange"}

colors = [color_mapping[mapping[label]] for label in labels[:15]]

output = draw_segmentation_masks(holiday_tensor_int, masks=masks[:15].to(torch.bool), colors=colors)

to_pil_image(output)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have performed same process to detect and visualize objects present in our second image involving kids.

from torchvision.utils import draw_segmentation_masks

mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]

masks = kids_playing_preds2[0]['masks'].squeeze()
labels = kids_playing_preds2[0]['labels']
categories = [mapping[label] for label in labels[:5]]

print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
Detected Objects : ['person', 'person', 'person', 'person', 'sports ball']
Unique Objects : ['sports ball', 'person']
from torchvision.utils import draw_segmentation_masks

color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange", "frisbee": "pink"}

colors = [color_mapping[mapping[label]] for label in labels[:5]]

output = draw_segmentation_masks(kids_playing_tensor_int, masks=masks[:5].to(torch.bool), colors=colors)

to_pil_image(output)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

Below, we have performed same process to detect and visualize objects present in our third image where a kid and a dog are playing.

from torchvision.utils import draw_segmentation_masks

mapping = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"]

masks = dog_kid_playing_preds2[0]['masks'].squeeze()
labels = dog_kid_playing_preds2[0]['labels']
categories = [mapping[label] for label in labels]

print("Detected Objects : {}".format(categories))
print("Unique Objects : {}".format(list(set(categories))))
Detected Objects : ['person', 'sports ball', 'dog', 'baseball glove', 'dog']
Unique Objects : ['sports ball', 'person', 'baseball glove', 'dog']
from torchvision.utils import draw_segmentation_masks

color_mapping = {"person": "tomato", "kite": "dodgerblue", "backpack": "yellow", "sports ball": "green", "dog": "orange", "frisbee": "pink", "baseball glove": "grey"}

colors = [color_mapping[mapping[label]] for label in labels[:3]]

output = draw_segmentation_masks(dog_kid_playing_tensor_int, masks=masks[:3].to(torch.bool), colors=colors)

to_pil_image(output)

PyTorch: Image Segmentation using Pre-Trained Models (torchvision)

5. Try Other Pre-Trained Image Segmentation Models

The torchvision module has other pre-trained models that are available for image segmentation tasks which can be tried to check how they perform. Below, we have listed them.

  • Semantic Segmentation Models
    • fcn_resnet101
    • deeplabv3_mobilenet_v3_large
    • deeplabv3_resnet50
    • deeplabv3_resnet101
    • lraspp_mobilenet_v3_large
  • Instance Segmentation Models
    • maskrcnn_resnet50_fpn
    • maskrcnn_resnet50_fpn_v2

This ends our small tutorial explaining how we can use pre-trained pytorch models for image segmentation tasks.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription