MJUN Tech Note

Saving GPU Memory with TorchScript (torch.jit)

Hello. This time I’d like to introduce an example where JIT compilation functionality that converts to TorchScript supported by PyTorch helps save GPU memory.

What is JIT Compilation

JIT stands for Just In Time.

In JIT compilation, as the name compilation suggests, conversion to machine language is performed. In normal compilation, compilation is done before program execution, i.e., “ahead-of-time compilation,” but JIT compilation performs compilation during program “execution time.”

In PyTorch, conversion to an intermediate representation called TorchScript occurs, so by converting models trained in Python to TorchScript, it becomes possible to call them from C++ or call them from deployment destinations without depending on Python.

This can be said to be PyTorch’s unique deployment feature competing with TensorFlow’s deployment functionality.

Tutorials are available here:

Please also refer to the documentation:

Memory Saving Example

Originally used for embedded applications, converting to TorchScript eliminates the expansion of intermediate calculation results in GPU memory, leading to memory savings.

Let’s look at a memory saving example with IoU calculation seen in object detection.

import torch

def intersection(boxes1, boxes2):
    """intersection: Area of common region
       Create all pairs of boxes2 for boxes1

    Args:
        boxes1 ([type]): N boxes
        boxes2 ([type]): M boxes

    Returns:
        tensor: shape [N, M]
    """
    x_min1, y_min1, x_max1, y_max1 = boxes1.chunk(4, dim=1)
    x_min2, y_min2, x_max2, y_max2 = boxes2.chunk(4, dim=1)
    # Height of overlap region
    all_pairs_min_ymax = torch.min(y_max1, y_max2.t())
    all_pairs_max_ymin = torch.max(y_min1, y_min2.t())
    intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0)
    # Width of overlap region
    all_pairs_min_xmax = torch.min(x_max1, x_max2.t())
    all_pairs_max_xmin = torch.max(x_min1, x_min2.t())
    intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0)
    return intersect_heights * intersect_widths

This code is a function that calculates the common area of two boxes. It takes N boxes and M boxes as arguments and finds the area of common regions for all combinations. That is, an M x N tensor will be generated.

If this code is executed as-is, when M and N are huge numbers, intermediate calculation results like all_pairs_min_ymax, all_pairs_max_ymin, intersect_heights will consume significant GPU memory. For example, when N=242991, M=500 (in object detection terms, assuming 242991 anchors and 500 GT boxes), GPU memory consumption reaches about 460MB.

This is where jit becomes useful. Since intermediate results aren’t output, GPU memory can be saved.

A decorator is provided, so JIT compilation is performed by simply writing @torch.jit.script at the beginning of the function.

@torch.jit.script
def intersection(boxes1, boxes2):
    pass

This eliminates GPU memory consumption by intermediate results, allowing you to avoid the annoying CUDA out of memory errors.

Also, note that there are operations that don’t support JIT compilation:

References