MJUN Tech Note

TorchScript(torch.jit)でGPUメモリを節約する

こんにちは.今回は,PyTorchでサポートされているTorchScriptへの変換を 行うJITコンパイル機能が,GPUのメモリの節約になる例を紹介したいと思います.

JITコンパイルとは

JITはJust In Timeの略です.

JITコンパイルでは,コンパイルと名がついているように, 機械語への変換を行います.
通常のコンパイルでは,プログラムを実行する前, つまり”事前コンパイル”を行いますが, JITコンパイルはプログラム”実行時”にコンパイルを行います.

PyTorchではTorchScriotという中間表現に 変換されるので,Pythonで学習させたモデルをTorchScriptに変換することで C++から呼び出したりデプロイ先からPythonに依存せずに呼び出すことが 可能となります.

Tensorflowのデプロイ機能に対抗したPyTorch独自のデプロイ機能と言えます.

チュートリアルは以下にあります.

Introduction to TorchScript — PyTorch Tutorials 2.5.0+cu124 documentation
Introduction to TorchScript — PyTorch Tutorials 2.5.0+cu124 documentation favicon https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

ドキュメントも参考にして下さい.

TorchScript — PyTorch 2.5 documentation
TorchScript — PyTorch 2.5 documentation favicon https://pytorch.org/docs/stable/jit.html

メモリの節約例

本来は組み込み等に使いますが,TorchScriptに変換することで 計算の中間結果をGPUメモリに展開することがなくなり, メモリの節約に繋がります.

物体検出で見られるIoUの計算でメモリの節約例を見ていきます.

import torch

def intersection(boxes1, boxes2):
    """intersection: 領域の共通部分の面積
       boxes1に対してすべてのboxes2のペアをつくる

    Args:
        boxes1 ([type]): N boxes
        boxes2 ([type]): M boxes

    Returns:
        tensor: shape [N, M]
    """
    x_min1, y_min1, x_max1, y_max1 = boxes1.chunk(4, dim=1)
    x_min2, y_min2, x_max2, y_max2 = boxes2.chunk(4, dim=1)
    # 重なり部分の高さ
    all_pairs_min_ymax = torch.min(y_max1, y_max2.t())
    all_pairs_max_ymin = torch.max(y_min1, y_min2.t())
    intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0)
    # 重なり部分の幅
    all_pairs_min_xmax = torch.min(x_max1, x_max2.t())
    all_pairs_max_xmin = torch.max(x_min1, x_min2.t())
    intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0)
    return intersect_heights * intersect_widths

このコードは2つのBoxの共通面積を計算する関数です.
引数として,N個のBoxとM個のBoxを取り,全ての組み合わせについて 共通領域の面積を求めます.つまり,M x NのTensorが生成されることになります.

このコードをこのまま実行すると,MとNの数が巨大な時に, all_pairs_min_ymax, all_pairs_max_ymin, intersect_heightsといった計算の中間結果で GPUメモリを大幅に消費してしまうことになります.
例えば,N=242991, M=500(物体検出的に言うと,anchorが242991個,GT Boxが500個の想定)のとき, GPUメモリは約460MBも消費されてしまいます.

ここで役立つのがjitです.中間結果を出力しないためGPUメモリを節約できます.

デコレータが用意されているので,@torch.jit.scriptを関数の先頭に書くだけで JITコンパイルが行われます.

@torch.jit.script
def intersection(boxes1, boxes2):
    pass

これで中間結果にGPUメモリを取られることがなくなり, 忌々しいCUDA out of memoryを回避できます.

また,以下の通りJITコンパイルに対応していない操作もあるので注意して下さい.

TorchScript Unsupported PyTorch Constructs — PyTorch 2.5 documentation
TorchScript Unsupported PyTorch Constructs — PyTorch 2.5 documentation favicon https://pytorch.org/docs/stable/jit_unsupported.html#jit-unsupported

参考

[Feature request] Make IoU computation more memory efficient · Issue #18 · facebookresearch/maskrcnn-benchmark
❓ Questions and Help I'm experiencing high GPU memory usage. I made my own COCO dataset and started training with 2 separate models: e2e_faster_rcnn_R_50_FPN_1x.yaml, e2e_faster_rcnn_R_101_FPN_1x.y...
[Feature request] Make IoU computation more memory efficient · Issue #18 · facebookresearch/maskrcnn-benchmark favicon https://github.com/facebookresearch/maskrcnn-benchmark/issues/18
[Feature request] Make IoU computation more memory efficient · Issue #18 · facebookresearch/maskrcnn-benchmark