[Pytorch] TorchScript

[์›๋ณธ ๋งํฌ]

Pytorch ๊ธฐ๋ณธ ๋ชจ๋ธ ํ˜•์‹์€ ํƒœ์ƒ์ ์ธ ๋‹จ์ ์„ ๋งŽ์ด ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ํŽธ์ด๋‹ค.
ํŒŒ์ด์ฌ ์ฝ”๋“œ๋กœ ์‹คํ–‰๋˜๊ณ , ์‚ฌ์šฉ๋ฒ•์ด ๊ฐ„๋‹จํ•˜๊ณ , ํ•™์Šต์ด๋‚˜ ๋ฐฐํฌ๋„ ๊ฐ„๋‹จํ•œ ๊ฒƒ์ด ์žฅ์ ์ด๋‚˜. ์ตœ์ ์˜ ์„ฑ๋Šฅ์„ ๋‚ด์ง€ ๋ชปํ•œ๋‹ค๋Š”๊ฒŒ ๊ณ ์งˆ์ ์ธ ๋‹จ์ ์ด ์žˆ์—ˆ๋‹ค.

๊ทธ๋ž˜์„œ ๋‚˜์˜จ ๊ฒƒ์ด Torchscript๋ผ๋Š” ๋ชจ๋ธ ํŒŒ์ผ ํ˜•์‹์ด๋‹ค.

torch์—์„œ ์ œ๊ณตํ•˜๋Š” torchscript ๊ธฐ๋Šฅ ์ง€์›์„ ํ™œ์šฉํ•˜๋ฉด, ์ปดํŒŒ์ผ ๊ณผ์ •์„ ๊ฑฐ์ณ์„œ python&pytorch ์ข…์†์„ฑ์„ ์ œ๊ฑฐํ•˜๊ณ , ์ตœ์†Œํ•œ์˜ ์˜๋ฏธ๋งŒ์„ ๋‹ด๋Š” script๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค.
๋‹ค๋งŒ, ์ด ์ž์ฒด๋กœ๋Š” ์™„์„ฑ๋œ ๋ฐ”์ด๋„ˆ๋ฆฌ๊ฐ€ ์•„๋‹˜์— ์œ ์˜ํ•œ๋‹ค. ๋ฐ”์ดํŠธ์ฝ”๋“œ์ฒ˜๋Ÿผ ์ค‘๊ฐ„ ํ‘œํ˜„ ์–ธ์–ด๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

์‹ค์ œ๋กœ torchscript๋ฅผ ๋Œ๋ฆฌ๋ ค๋ฉด ์ด๊ฑธ ๋‹ค์‹œ torch๋กœ ๋กœ๋“œํ•ด์„œ ์‹คํ–‰ํ•˜๊ฑฐ๋‚˜, TensorRT ๊ฐ™์€ ์ง„์งœ native ๋ฐ”์ด๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ์ปดํŒŒ์ผํ•ด์„œ ์‹คํ–‰ํ•ด์•ผ ํ•œ๋‹ค.




์˜ˆ์ œ ์ƒ˜ํ”Œ

๋จผ์ € ์˜ˆ์ œ ๋ชจ๋ธ์„ ํ•˜๋‚˜ ์ค€๋น„ํ•ด๋ณด๊ฒ ๋‹ค.
๊ทธ๋ƒฅ ๊ธธ์ด 2์งœ๋ฆฌ ์ž…๋ ฅ์„ ๋ฐ›์•„์„œ ๊ธธ์ด 3์งœ๋ฆฌ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ์ด๋‹ค.

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 3)

    def forward(self, x):
        return self.linear(x)
from model import SimpleModel
import torch
import torch.nn as nn
import torch.optim as optim

model = SimpleModel()

print(model)

x = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]])
y = torch.tensor([[3.0, 5.0, 7.0], [5.0, 7.0, 9.0], [7.0, 9.0, 11.0], [9.0, 11.0, 13.0]])

criterion = nn.MSELoss()  # ํ‰๊ท ์ œ๊ณฑ์˜ค์ฐจ
optimizer = optim.SGD(model.parameters(), lr=0.01)  # ํ™•๋ฅ ์  ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•

# 10๋ฒˆ ๋ฐ˜๋ณตํ•ด์„œ ํ•™์Šต (epoch)
for epoch in range(1000):
    optimizer.zero_grad()         # ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™”
    outputs = model(x)            # ๋ชจ๋ธ์— ์ž…๋ ฅ๊ฐ’ ์ „๋‹ฌ
    loss = criterion(outputs, y)  # ์†์‹ค ๊ณ„์‚ฐ
    loss.backward()               # ์—ญ์ „ํŒŒ
    optimizer.step()              # ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')

# ๊ฒฐ๊ณผ ํ™•์ธ
print("ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜:", model.linear.weight.tolist())
print("ํ•™์Šต๋œ ํŽธํ–ฅ:", model.linear.bias.tolist())

# ํ…Œ์ŠคํŠธ 
print("๊ธฐ๋Œ€ํ•œ ์ถœ๋ ฅ", y.tolist())
print("์‹ค์ œ์˜ ์ถœ๋ ฅ",  model(torch.tensor([[1.0, 1.0]])).tolist())

# 3๊ฐ€์ง€ ๋ฐฉ์‹์œผ๋กœ ๋ชจ๋ธ ์ €์žฅ
torch.save(model.state_dict(), "state_dict.pth")
torch.save(model, "model_full.pth")
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # ... ๊ธฐํƒ€ ํ•„์š”ํ•œ ์ •๋ณด
}
torch.save(checkpoint, "checkpoint.pth")

๊ทธ๋ž˜์„œ 1.0, 1.0์„ ๋„ฃ์œผ๋ฉด 3,5,7์— ๊ทผ์‚ฌํ•˜๋Š” ๊ฐ’์ด ๋‚˜์˜ค๋„๋ก ์„ธํŒ…ํ•ด๋’€๋‹ค.

์ด๊ฑธ torchscript๋กœ ์ปดํŒŒ์ผํ•ด๋ณด์ž.




2๊ฐ€์ง€ ์ปดํŒŒ์ผ ๋ฐฉ๋ฒ•: trace vs script

torchscript์—๋Š” 2๊ฐ€์ง€ ์ปดํŒŒ์ผ ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค.
๊ทผ๋ฐ ์ด๊ฒŒ ๋ฐฉ๋ฒ•๋งŒ ๋‹ค๋ฅธ๊ฒŒ ์•„๋‹ˆ๋ผ ์‹ค์ œ ๊ฒฐ๊ณผ๋ฌผ์˜ ์ˆ˜์ค€๋„ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ ์žฅ๋‹จ์ ๊ณผ ํŠน์ง•์„ ์ž˜ ์•Œ๊ณ  ์„ ํƒํ•ด์•ผ ํ•œ๋‹ค. ๋”ฑ ๋ญ๋ฅผ ์จ์•ผํ•œ๋‹ค๋Š” ์ผ๋ฐ˜์ ์ธ ์ง€์นจ์€ ์กด์žฌํ•˜์ง€ ์•Š๋Š” ๊ฒƒ ๊ฐ™๋‹ค.

์ผ๋‹จ script๋Š” ๊ฐ€์žฅ ์ •์งํ•˜๊ณ  ์ผ๋ฐ˜์ ์ธ ํ˜•ํƒœ์˜ ์ปดํŒŒ์ผ ๋ฐฉ๋ฒ•์ด๋‹ค. ๊ทธ๋ƒฅ python ์ฝ”๋“œ๋ฅผ ์ฝ์–ด์„œ, ์ •์ ์œผ๋กœ ์ปดํŒŒ์ผํ•˜๋Š” ์ „ํ˜•์ ์ธ ๋ฐฉ๋ฒ•๋ก ์„ ์ทจํ•œ๋‹ค.

๋ฐ˜๋ฉด์— trace๋Š” ๋ชจ๋ธ์„ ์‹ค์ œ๋กœ ์‹คํ–‰ํ•˜๋ฉด์„œ ๊ทธ ์‹คํ–‰ ํ๋ฆ„์— ๋งž๊ฒŒ ์ตœ์ ํ™”๋œ ์ฝ”๋“œ๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๋ฐฉ์‹์„ ์ทจํ•œ๋‹ค. ์ง„์งœ JIT ์ปดํŒŒ์ผ์— ๊ฐ€๊นŒ์šด ๋ฐฉ๋ฒ•์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ๊ฒ ๋‹ค.




trace ์ปดํŒŒ์ผ

trace ์ปดํŒŒ์ผ์€ ๊ทธ๋ ‡๊ฒŒ ์–ด๋ ต์ง„ ์•Š๋‹ค.
์‹ค์ œ ๋ชจ๋ธ ์‹คํ–‰์— ํ•„์š”ํ•œ ์ž…๋ ฅ๊ฐ’์„ ๋Œ€์ถฉ ํ•˜๋‚˜ ๋งŒ๋“ค๊ณ , ๋ฐ€์–ด๋„ฃ์–ด์„œ trace ์ปดํŒŒ์ผ์„ ๋Œ๋ฆฌ๋ฉด ๋œ๋‹ค.

import trace
from model import SimpleModel
import torch

model = SimpleModel()

model.load_state_dict(torch.load("state_dict.pth"))

print(model)

# ์ถ”๋ก  ๋ชจ๋“œ ์ „ํ™˜
model.eval()

# Tracing ๋ฐฉ์‹ - ์˜ˆ์ œ ์ž…๋ ฅ์œผ๋กœ ๋ชจ๋ธ ์‹คํ–‰ ๊ฒฝ๋กœ ์ถ”์ 
example_input = torch.tensor([[1.0, 2.0]])
traced_model = torch.jit.trace(model, example_input)

traced_model.save("traced_model.pt")

๊ทธ๋Ÿฌ๊ณ  ๋งŒ๋“ค์–ด์ง€๋Š” ๋ฐ˜ํ™˜๊ฐ’ traced_model์ด trace ๋ชจ๋“œ๋กœ ์ปดํŒŒ์ผ๋œ ๋ชจ๋ธ ๋ฐ์ดํ„ฐ๋‹ค.
์ €๊ฑธ ์ €์žฅํ•ด์„œ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

์‹คํ–‰ํ•  ๋•Œ๋Š” jit.load๋ฅผ ํ†ตํ•ด์„œ ๋กœ๋“œํ•ด์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

from model import SimpleModel
import torch

loaded_model = torch.jit.load("traced_model.pt")

print(loaded_model)

print(loaded_model(torch.tensor([[1.0, 1.0]])).tolist())

๊ฒฝ์šฐ์— ๋”ฐ๋ผ ๊ฐ’์ด ์•ฝ๊ฐ„ ์™œ๊ณก๋  ์ˆ˜๋„ ์žˆ๊ธด ํ•œ๋ฐ, ์ด ๊ฒฝ์šฐ์—๋Š” ๋™์ผํ•œ ๊ฐ’์ด ๋‚˜์™”๋‹ค.




script ์ปดํŒŒ์ผ

script ์ปดํŒŒ์ผ์˜ ๊ฒฝ์šฐ์—๋Š” ๋น„๊ต์  ๋” ๊ฐ„๋‹จํ•˜๋‹ค.
์‹ค์ œ๋กœ ๋ชจ๋ธ์„ ๋Œ๋ฆด ํ•„์š”๋„ ์—†๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

from model import SimpleModel
import torch

loaded_model = torch.jit.load("scripted_model.pt")

print(loaded_model)

print(loaded_model(torch.tensor([[1.0, 1.0]])).tolist())

์ด ๊ฒฝ์šฐ์—๋„ ์ผ๋‹จ์€ ๋™๋“ฑํ•œ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์™”๋‹ค.




trace vs script

๊ทธ๋ž˜์„œ, 2๊ฐ€์ง€ ๋ฐฉ๋ฒ• ์ค‘ ๋ฌด์—‡์„ ์„ ํƒํ•ด์•ผํ• ๊นŒ?
์‚ฌ์‹ค ์ด๊ฒƒ๋„ ๊ฒฝ์šฐ์— ๋”ฐ๋ผ ๋‹ค๋ฅด๋‹ค๋Š” ๋ง์„ ํ•  ์ˆ˜๋ฐ–์— ์—†์„ ๊ฒƒ ๊ฐ™๋‹ค.

trace๋Š” ํ•ด๋‹น ๊ณ„์‚ฐ ํ๋ฆ„์— ๋งž์ถฐ์„œ ์ตœ์ ํ™”๋œ ํ˜•ํƒœ๋กœ ์ •์  ๋นŒ๋“œ๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค. ํ•ด๋‹น ์˜ˆ์ œ ์‹คํ–‰์— ์‚ฌ์šฉ๋˜์ง€ ์•Š์€ ๋ถˆํ•„์š”ํ•œ ๊ทธ๋ž˜ํ”„๋ฅผ ์ „๋ถ€ ์ œ๊ฑฐํ•œ๋‹ค.
๊ทธ๋ž˜์„œ ์‹คํ–‰ ํ๋ฆ„์ด ํ•œ๊ฐ€์ง€๋ผ๋ฉด ๋ฌด์—‡๋ณด๋‹ค ๋น ๋ฅด๊ณ  ๋›ฐ์–ด๋‚œ ๊ฒฐ๊ณผ๋ฅผ ๋„์ถœํ•  ์ˆ˜ ์žˆ๋‹ค.
ํ•˜์ง€๋งŒ ๋‹ค์–‘ํ•œ ํ˜•ํƒœ์˜ ์‹คํ–‰ ํ๋ฆ„์„ ์ œ๊ณตํ•˜๊ฒŒ๋” ๋งŒ๋“ค๊ธฐ ์–ด๋ ต๊ณ , ์ž…๋ ฅ ํฌ๊ธฐ ๋˜ํ•œ ํ•œ๊ฐ€์ง€๋กœ ์ œํ•œ๋œ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.

script๋Š” ์ฝ”๋“œ๋ฅผ ์ •์ ์œผ๋กœ ๋ถ„์„ํ•ด์„œ ์ตœ์ ํ™”๋œ ํ˜•ํƒœ๋กœ ์ปดํŒŒ์ผ์„ ์‹œ๋„ํ•œ๋‹ค. ์ด๊ฑด ํŠนํžˆ, ๋ชจ๋ธ ์ฝ”๋“œ์— ๋ณต์žกํ•œ ์กฐ๊ฑด๋ฌธ์ด๋‚˜ ๋ฐ˜๋ณต๋ฌธ ๊ฐ™์€ ๊ฒƒ์ด ์žˆ๋Š” ํŒจํ„ด์— ๋Œ€ํ•ด์„œ๋Š” ๋”์šฑ ์œ ํšจํ•˜๊ฒŒ ๋™์ž‘ํ•œ๋‹ค.
์˜ˆ๋ฅผ ๋“ค๋ฉด, ์ด๋Ÿฐ ์‹์˜ ๋ณ€์น™์ ์ธ ์ฝ”๋“œ ๋ง์ด๋‹ค.

import torch
import torch.nn as nn

class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 3)

    def forward(self, x):
        # ๋™์  ์ œ์–ด ํ๋ฆ„ ์˜ˆ์‹œ
        if x.sum() > 0:
            return self.linear(x)
        else:
            return torch.zeros_like(x)

ํ•˜์ง€๋งŒ script ์ฝ”๋“œ์˜ ์ •์  ๋ถ„์„ ์ˆ˜์ค€์€ ๋‚ฎ๋‹ค. ๋ชจ๋“  python ์‹ ํƒ์Šค๋‚˜ ์ฝ”๋“œ ํŒจํ„ด์— ์œ ํšจํ•˜๊ฒŒ ๋™์ž‘ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ตœ์ ํ™”๊ฐ€ ๋˜๋Š” ํ•œ๋„ ๋‚ด์—์„œ ๊ฐ•๋ฐ•์ ์œผ๋กœ ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด์•ผ ํ•œ๋‹ค๋Š”๊ฒŒ ์ข€ ์น˜๋ช…์ ์ธ ๋ถ€๋ถ„์ด๋‹ค.
์ƒ์†์„ ์จ๋„ ์•ˆ๋˜๊ณ , lamdba๋ฅผ ์จ๋„ ์•ˆ๋˜๊ณ , union๋„ ์•ˆ๋˜๊ณ , ๋™์  ํƒ€์ž…๋„ ์•ˆ๋˜๊ณ , ์‚ฌ์‹ค ๋ญ ๋˜๋Š”๊ฒŒ ์—†๋‹ค.
์ฝ”๋“œ ํ€„๋ฆฌํ‹ฐ์™€ script ์ปดํŒŒ์ผ ๊ฐ€๋Šฅ์„ฑ์„ ๋™์‹œ์— ๊ฐ€์ ธ๊ฐ€๋Š”๊ฒŒ ๊ฑฐ์˜ ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค.





trace with script

์–‘์ชฝ ๋ชจ๋‘์˜ ์žฅ์ ์„ ๊ฐ€์ ธ๊ฐ€๊ณ  ์‹ถ๋‹ค๋ฉด, trace์™€ script๋ฅผ ๋™์‹œ์— ์„ž์–ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ๋„ ๊ฐ€๋Šฅํ•˜๋‹ค.
์˜ˆ๋ฅผ ๋“ค๋ฉด, ์ด๋Ÿฐ ์‹์ด๋‹ค.

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

์„ธ๋ถ€ ๋ชจ๋ธ ๋™์ž‘๋“ค์— ๋Œ€ํ•ด์„œ๋Š” trace๋กœ ์ปดํŒŒ์ผ์„ ํ•˜๊ณ , ๋ถ„๊ธฐ ์ œ์–ด์— ๋Œ€ํ•ด์„œ๋Š” script ์ปดํŒŒ์ผ์„ ํ•˜๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
์ด๋Ÿฌ๋ฉด ์ปดํŒŒ์ผ์„ ํ• ๋•Œ script๊ฐ€ trace๋œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ธ๋ผ์ธํ•ด์„œ ์ ์ ˆํžˆ ์ฒ˜๋ฆฌํ•œ๋‹ค.


์˜ˆ์ œ์— ์‚ฌ์šฉํ•œ ์ „์ฒด ์ฝ”๋“œ๋ฒ ์ด์Šค๋Š” ๋‹ค์Œ ๋งํฌ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
https://github.com/myyrakle/ml_examples/tree/master/pytorch/torchscript-basic



์ฐธ์กฐ
https://tutorials.pytorch.kr/recipes/torchscript_inference.html
https://happy-jihye.github.io/dl/torch-2/
https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/
https://docs.pytorch.org/docs/stable/jit.html
https://docs.pytorch.org/docs/stable/jit_language_reference_v2.html