[Pytorch] TorchScript
Pytorch ๊ธฐ๋ณธ ๋ชจ๋ธ ํ์์ ํ์์ ์ธ ๋จ์ ์ ๋ง์ด ๊ฐ์ง๊ณ ์๋ ํธ์ด๋ค.
ํ์ด์ฌ ์ฝ๋๋ก ์คํ๋๊ณ , ์ฌ์ฉ๋ฒ์ด ๊ฐ๋จํ๊ณ , ํ์ต์ด๋ ๋ฐฐํฌ๋ ๊ฐ๋จํ ๊ฒ์ด ์ฅ์ ์ด๋. ์ต์ ์ ์ฑ๋ฅ์ ๋ด์ง ๋ชปํ๋ค๋๊ฒ ๊ณ ์ง์ ์ธ ๋จ์ ์ด ์์๋ค.
๊ทธ๋์ ๋์จ ๊ฒ์ด Torchscript๋ผ๋ ๋ชจ๋ธ ํ์ผ ํ์์ด๋ค.
torch์์ ์ ๊ณตํ๋ torchscript ๊ธฐ๋ฅ ์ง์์ ํ์ฉํ๋ฉด, ์ปดํ์ผ ๊ณผ์ ์ ๊ฑฐ์ณ์ python&pytorch ์ข
์์ฑ์ ์ ๊ฑฐํ๊ณ , ์ต์ํ์ ์๋ฏธ๋ง์ ๋ด๋ script๋ฅผ ๋ง๋ค ์ ์๋ค.
๋ค๋ง, ์ด ์์ฒด๋ก๋ ์์ฑ๋ ๋ฐ์ด๋๋ฆฌ๊ฐ ์๋์ ์ ์ํ๋ค. ๋ฐ์ดํธ์ฝ๋์ฒ๋ผ ์ค๊ฐ ํํ ์ธ์ด๋ก ๋ณํํ๋ ๊ฒ์ด๋ค.
์ค์ ๋ก torchscript๋ฅผ ๋๋ฆฌ๋ ค๋ฉด ์ด๊ฑธ ๋ค์ torch๋ก ๋ก๋ํด์ ์คํํ๊ฑฐ๋, TensorRT ๊ฐ์ ์ง์ง native ๋ฐ์ด๋๋ฆฌ ํํ๋ก ์ปดํ์ผํด์ ์คํํด์ผ ํ๋ค.
์์ ์ํ
๋จผ์ ์์ ๋ชจ๋ธ์ ํ๋ ์ค๋นํด๋ณด๊ฒ ๋ค.
๊ทธ๋ฅ ๊ธธ์ด 2์ง๋ฆฌ ์
๋ ฅ์ ๋ฐ์์ ๊ธธ์ด 3์ง๋ฆฌ๋ฅผ ๋ฐํํ๋ ๊ฐ๋จํ ๋ชจ๋ธ์ด๋ค.
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 3)
def forward(self, x):
return self.linear(x)from model import SimpleModel
import torch
import torch.nn as nn
import torch.optim as optim
model = SimpleModel()
print(model)
x = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]])
y = torch.tensor([[3.0, 5.0, 7.0], [5.0, 7.0, 9.0], [7.0, 9.0, 11.0], [9.0, 11.0, 13.0]])
criterion = nn.MSELoss() # ํ๊ท ์ ๊ณฑ์ค์ฐจ
optimizer = optim.SGD(model.parameters(), lr=0.01) # ํ๋ฅ ์ ๊ฒฝ์ฌํ๊ฐ๋ฒ
# 10๋ฒ ๋ฐ๋ณตํด์ ํ์ต (epoch)
for epoch in range(1000):
optimizer.zero_grad() # ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ
outputs = model(x) # ๋ชจ๋ธ์ ์
๋ ฅ๊ฐ ์ ๋ฌ
loss = criterion(outputs, y) # ์์ค ๊ณ์ฐ
loss.backward() # ์ญ์ ํ
optimizer.step() # ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ
if (epoch+1) % 100 == 0:
print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')
# ๊ฒฐ๊ณผ ํ์ธ
print("ํ์ต๋ ๊ฐ์ค์น:", model.linear.weight.tolist())
print("ํ์ต๋ ํธํฅ:", model.linear.bias.tolist())
# ํ
์คํธ
print("๊ธฐ๋ํ ์ถ๋ ฅ", y.tolist())
print("์ค์ ์ ์ถ๋ ฅ", model(torch.tensor([[1.0, 1.0]])).tolist())
# 3๊ฐ์ง ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ ์ ์ฅ
torch.save(model.state_dict(), "state_dict.pth")
torch.save(model, "model_full.pth")
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# ... ๊ธฐํ ํ์ํ ์ ๋ณด
}
torch.save(checkpoint, "checkpoint.pth")
๊ทธ๋์ 1.0, 1.0์ ๋ฃ์ผ๋ฉด 3,5,7์ ๊ทผ์ฌํ๋ ๊ฐ์ด ๋์ค๋๋ก ์ธํ
ํด๋๋ค.
์ด๊ฑธ torchscript๋ก ์ปดํ์ผํด๋ณด์.
2๊ฐ์ง ์ปดํ์ผ ๋ฐฉ๋ฒ: trace vs script
torchscript์๋ 2๊ฐ์ง ์ปดํ์ผ ๋ฐฉ๋ฒ์ด ์๋ค.
๊ทผ๋ฐ ์ด๊ฒ ๋ฐฉ๋ฒ๋ง ๋ค๋ฅธ๊ฒ ์๋๋ผ ์ค์ ๊ฒฐ๊ณผ๋ฌผ์ ์์ค๋ ๋ค๋ฅผ ์ ์๊ธฐ ๋๋ฌธ์ ๊ทธ ์ฅ๋จ์ ๊ณผ ํน์ง์ ์ ์๊ณ ์ ํํด์ผ ํ๋ค. ๋ฑ ๋ญ๋ฅผ ์จ์ผํ๋ค๋ ์ผ๋ฐ์ ์ธ ์ง์นจ์ ์กด์ฌํ์ง ์๋ ๊ฒ ๊ฐ๋ค.
์ผ๋จ script๋ ๊ฐ์ฅ ์ ์งํ๊ณ ์ผ๋ฐ์ ์ธ ํํ์ ์ปดํ์ผ ๋ฐฉ๋ฒ์ด๋ค. ๊ทธ๋ฅ python ์ฝ๋๋ฅผ ์ฝ์ด์, ์ ์ ์ผ๋ก ์ปดํ์ผํ๋ ์ ํ์ ์ธ ๋ฐฉ๋ฒ๋ก ์ ์ทจํ๋ค.
๋ฐ๋ฉด์ trace๋ ๋ชจ๋ธ์ ์ค์ ๋ก ์คํํ๋ฉด์ ๊ทธ ์คํ ํ๋ฆ์ ๋ง๊ฒ ์ต์ ํ๋ ์ฝ๋๋ฅผ ๋ง๋ค์ด๋ด๋ ๋ฐฉ์์ ์ทจํ๋ค. ์ง์ง JIT ์ปดํ์ผ์ ๊ฐ๊น์ด ๋ฐฉ๋ฒ์ด๋ผ๊ณ ํ ์ ์๊ฒ ๋ค.
trace ์ปดํ์ผ
trace ์ปดํ์ผ์ ๊ทธ๋ ๊ฒ ์ด๋ ต์ง ์๋ค.
์ค์ ๋ชจ๋ธ ์คํ์ ํ์ํ ์
๋ ฅ๊ฐ์ ๋์ถฉ ํ๋ ๋ง๋ค๊ณ , ๋ฐ์ด๋ฃ์ด์ trace ์ปดํ์ผ์ ๋๋ฆฌ๋ฉด ๋๋ค.
import trace
from model import SimpleModel
import torch
model = SimpleModel()
model.load_state_dict(torch.load("state_dict.pth"))
print(model)
# ์ถ๋ก ๋ชจ๋ ์ ํ
model.eval()
# Tracing ๋ฐฉ์ - ์์ ์
๋ ฅ์ผ๋ก ๋ชจ๋ธ ์คํ ๊ฒฝ๋ก ์ถ์
example_input = torch.tensor([[1.0, 2.0]])
traced_model = torch.jit.trace(model, example_input)
traced_model.save("traced_model.pt")
๊ทธ๋ฌ๊ณ ๋ง๋ค์ด์ง๋ ๋ฐํ๊ฐ traced_model์ด trace ๋ชจ๋๋ก ์ปดํ์ผ๋ ๋ชจ๋ธ ๋ฐ์ดํฐ๋ค.
์ ๊ฑธ ์ ์ฅํด์ ์ฌ์ฉํ๋ฉด ๋๋ค.
์คํํ ๋๋ jit.load๋ฅผ ํตํด์ ๋ก๋ํด์ ์ฌ์ฉํ ์ ์๋ค.
from model import SimpleModel
import torch
loaded_model = torch.jit.load("traced_model.pt")
print(loaded_model)
print(loaded_model(torch.tensor([[1.0, 1.0]])).tolist())
๊ฒฝ์ฐ์ ๋ฐ๋ผ ๊ฐ์ด ์ฝ๊ฐ ์๊ณก๋ ์๋ ์๊ธด ํ๋ฐ, ์ด ๊ฒฝ์ฐ์๋ ๋์ผํ ๊ฐ์ด ๋์๋ค.
script ์ปดํ์ผ
script ์ปดํ์ผ์ ๊ฒฝ์ฐ์๋ ๋น๊ต์ ๋ ๊ฐ๋จํ๋ค.
์ค์ ๋ก ๋ชจ๋ธ์ ๋๋ฆด ํ์๋ ์๊ธฐ ๋๋ฌธ์ด๋ค.
from model import SimpleModel
import torch
loaded_model = torch.jit.load("scripted_model.pt")
print(loaded_model)
print(loaded_model(torch.tensor([[1.0, 1.0]])).tolist())
์ด ๊ฒฝ์ฐ์๋ ์ผ๋จ์ ๋๋ฑํ ๊ฒฐ๊ณผ๊ฐ ๋์๋ค.
trace vs script
๊ทธ๋์, 2๊ฐ์ง ๋ฐฉ๋ฒ ์ค ๋ฌด์์ ์ ํํด์ผํ ๊น?
์ฌ์ค ์ด๊ฒ๋ ๊ฒฝ์ฐ์ ๋ฐ๋ผ ๋ค๋ฅด๋ค๋ ๋ง์ ํ ์๋ฐ์ ์์ ๊ฒ ๊ฐ๋ค.
trace๋ ํด๋น ๊ณ์ฐ ํ๋ฆ์ ๋ง์ถฐ์ ์ต์ ํ๋ ํํ๋ก ์ ์ ๋น๋๋ฅผ ์ํํ๋ค. ํด๋น ์์ ์คํ์ ์ฌ์ฉ๋์ง ์์ ๋ถํ์ํ ๊ทธ๋ํ๋ฅผ ์ ๋ถ ์ ๊ฑฐํ๋ค.
๊ทธ๋์ ์คํ ํ๋ฆ์ด ํ๊ฐ์ง๋ผ๋ฉด ๋ฌด์๋ณด๋ค ๋น ๋ฅด๊ณ ๋ฐ์ด๋ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ ์ ์๋ค.
ํ์ง๋ง ๋ค์ํ ํํ์ ์คํ ํ๋ฆ์ ์ ๊ณตํ๊ฒ๋ ๋ง๋ค๊ธฐ ์ด๋ ต๊ณ , ์
๋ ฅ ํฌ๊ธฐ ๋ํ ํ๊ฐ์ง๋ก ์ ํ๋๋ค๋ ๋จ์ ์ด ์๋ค.
script๋ ์ฝ๋๋ฅผ ์ ์ ์ผ๋ก ๋ถ์ํด์ ์ต์ ํ๋ ํํ๋ก ์ปดํ์ผ์ ์๋ํ๋ค. ์ด๊ฑด ํนํ, ๋ชจ๋ธ ์ฝ๋์ ๋ณต์กํ ์กฐ๊ฑด๋ฌธ์ด๋ ๋ฐ๋ณต๋ฌธ ๊ฐ์ ๊ฒ์ด ์๋ ํจํด์ ๋ํด์๋ ๋์ฑ ์ ํจํ๊ฒ ๋์ํ๋ค.
์๋ฅผ ๋ค๋ฉด, ์ด๋ฐ ์์ ๋ณ์น์ ์ธ ์ฝ๋ ๋ง์ด๋ค.
import torch
import torch.nn as nn
class SampleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 3)
def forward(self, x):
# ๋์ ์ ์ด ํ๋ฆ ์์
if x.sum() > 0:
return self.linear(x)
else:
return torch.zeros_like(x)
ํ์ง๋ง script ์ฝ๋์ ์ ์ ๋ถ์ ์์ค์ ๋ฎ๋ค. ๋ชจ๋ python ์ ํ์ค๋ ์ฝ๋ ํจํด์ ์ ํจํ๊ฒ ๋์ํ๋ ๊ฒ์ด ์๋๊ธฐ ๋๋ฌธ์ด๋ค. ์ต์ ํ๊ฐ ๋๋ ํ๋ ๋ด์์ ๊ฐ๋ฐ์ ์ผ๋ก ์ฝ๋๋ฅผ ์์ฑํด์ผ ํ๋ค๋๊ฒ ์ข ์น๋ช
์ ์ธ ๋ถ๋ถ์ด๋ค.
์์์ ์จ๋ ์๋๊ณ , lamdba๋ฅผ ์จ๋ ์๋๊ณ , union๋ ์๋๊ณ , ๋์ ํ์
๋ ์๋๊ณ , ์ฌ์ค ๋ญ ๋๋๊ฒ ์๋ค.
์ฝ๋ ํ๋ฆฌํฐ์ script ์ปดํ์ผ ๊ฐ๋ฅ์ฑ์ ๋์์ ๊ฐ์ ธ๊ฐ๋๊ฒ ๊ฑฐ์ ๋ถ๊ฐ๋ฅํ๋ค.
trace with script
์์ชฝ ๋ชจ๋์ ์ฅ์ ์ ๊ฐ์ ธ๊ฐ๊ณ ์ถ๋ค๋ฉด, trace์ script๋ฅผ ๋์์ ์์ด์ ์ฌ์ฉํ๋ ๊ฒ๋ ๊ฐ๋ฅํ๋ค.
์๋ฅผ ๋ค๋ฉด, ์ด๋ฐ ์์ด๋ค.
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
์ธ๋ถ ๋ชจ๋ธ ๋์๋ค์ ๋ํด์๋ trace๋ก ์ปดํ์ผ์ ํ๊ณ , ๋ถ๊ธฐ ์ ์ด์ ๋ํด์๋ script ์ปดํ์ผ์ ํ๋๋ก ํ๋ ๊ฒ์ด๋ค.
์ด๋ฌ๋ฉด ์ปดํ์ผ์ ํ ๋ script๊ฐ trace๋ ์คํฌ๋ฆฝํธ๋ฅผ ์ธ๋ผ์ธํด์ ์ ์ ํ ์ฒ๋ฆฌํ๋ค.
์์ ์ ์ฌ์ฉํ ์ ์ฒด ์ฝ๋๋ฒ ์ด์ค๋ ๋ค์ ๋งํฌ์์ ํ์ธํ ์ ์๋ค.
https://github.com/myyrakle/ml_examples/tree/master/pytorch/torchscript-basic
์ฐธ์กฐ
https://tutorials.pytorch.kr/recipes/torchscript_inference.html
https://happy-jihye.github.io/dl/torch-2/
https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/
https://docs.pytorch.org/docs/stable/jit.html
https://docs.pytorch.org/docs/stable/jit_language_reference_v2.html