[Pytorch] 모델의 저장

[원본 링크]

파이토치에서 모델을 저장하는 방식에 대해 다뤄본다.




pickle 포맷

파이토치 모델은 pickle이라는 python용 바이너리 직렬화 포맷을 통해 저장된다.
대단한 목적이 있어서는 아니고, 파이썬 코드를 포함해서 이런저런 모델 메타데이터 정보들을 효율적으로 압축해서 저장하기 위한 것이다.

torch.save는 pickle로 파일을 압축해서 저장하고, torch.load는 pickle 파일을 읽어서 압축해제하고 메모리에 로드한다.




state_dict 저장

가장 흔하게 사용하는 방법 중 하나다.
이 방식은 모델의 파라미터만 저장한다. 그러니까, 모델 자체의 타입 정보 등은 저장하지 않는다는 것이다.

그래서 사용할 때도 모델의 타입정보를 가져다가 로드해야 한다.

파일 크기가 작다는 장점이 있다.




전체 저장

그냥 통째로 모든 정보를 저장할 수도 있다.

그러면 모델 타입정보를 알 필요는 없고, 그냥 그대로 로드해서 쓸 수 있다.

이건 모델 클래스 코드를 함께 전달하지 않아도 된다는 장점이 있으나, 그만큼 모델 크기가 좀 커진다.




체크포인트 저장

이건 최종 배포가 아닌 학습 과정에 주로 쓰이는 저장 방식이다.

막 그렇게 대단한 것은 아니고, 학습 과정과 관련된 컨텍스트 정보를 딕셔너리로 구겨넣어서 저장하는 것이다.

그러면 load할때 그대로 꺼내와서 쓸 수 있다.

모델 동작 자체는 잘 된다.

빠른 피드백이나, 학습 중간 과정 진행함에 있어서는 가장 유리한 방식이다.
이게 저장 방식 중에 모델 크기가 가장 크다.



참조
https://tutorials.pytorch.kr/beginner/saving_loading_models.html
https://docs.python.org/3/library/pickle.html