TorchScript(torch.jit)でGPUメモリを節約する
こんにちは.今回は,PyTorchでサポートされているTorchScriptへの変換を 行うJITコンパイル機能が,GPUのメモリの節約になる例を紹介したいと思います.
JITコンパイルとは
JITはJust In Timeの略です.
JITコンパイルでは,コンパイルと名がついているように,
機械語への変換を行います.
通常のコンパイルでは,プログラムを実行する前,
つまり”事前コンパイル”を行いますが,
JITコンパイルはプログラム”実行時”にコンパイルを行います.
PyTorchではTorchScriotという中間表現に 変換されるので,Pythonで学習させたモデルをTorchScriptに変換することで C++から呼び出したりデプロイ先からPythonに依存せずに呼び出すことが 可能となります.
Tensorflowのデプロイ機能に対抗したPyTorch独自のデプロイ機能と言えます.
チュートリアルは以下にあります.
ドキュメントも参考にして下さい.
メモリの節約例
本来は組み込み等に使いますが,TorchScriptに変換することで 計算の中間結果をGPUメモリに展開することがなくなり, メモリの節約に繋がります.
物体検出で見られるIoUの計算でメモリの節約例を見ていきます.
import torch
def intersection(boxes1, boxes2):
"""intersection: 領域の共通部分の面積
boxes1に対してすべてのboxes2のペアをつくる
Args:
boxes1 ([type]): N boxes
boxes2 ([type]): M boxes
Returns:
tensor: shape [N, M]
"""
x_min1, y_min1, x_max1, y_max1 = boxes1.chunk(4, dim=1)
x_min2, y_min2, x_max2, y_max2 = boxes2.chunk(4, dim=1)
# 重なり部分の高さ
all_pairs_min_ymax = torch.min(y_max1, y_max2.t())
all_pairs_max_ymin = torch.max(y_min1, y_min2.t())
intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0)
# 重なり部分の幅
all_pairs_min_xmax = torch.min(x_max1, x_max2.t())
all_pairs_max_xmin = torch.max(x_min1, x_min2.t())
intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0)
return intersect_heights * intersect_widths
このコードは2つのBoxの共通面積を計算する関数です.
引数として,N個のBoxとM個のBoxを取り,全ての組み合わせについて
共通領域の面積を求めます.つまり,M x NのTensorが生成されることになります.
このコードをこのまま実行すると,MとNの数が巨大な時に,
all_pairs_min_ymax, all_pairs_max_ymin, intersect_heights
といった計算の中間結果で
GPUメモリを大幅に消費してしまうことになります.
例えば,N=242991, M=500(物体検出的に言うと,anchorが242991個,GT Boxが500個の想定)のとき,
GPUメモリは約460MBも消費されてしまいます.
ここで役立つのがjitです.中間結果を出力しないためGPUメモリを節約できます.
デコレータが用意されているので,@torch.jit.script
を関数の先頭に書くだけで
JITコンパイルが行われます.
@torch.jit.script
def intersection(boxes1, boxes2):
pass
これで中間結果にGPUメモリを取られることがなくなり, 忌々しいCUDA out of memoryを回避できます.
また,以下の通りJITコンパイルに対応していない操作もあるので注意して下さい.