기계는 거짓말하지 않는다

Pytorch model save and load 본문

AI

Pytorch model save and load

KillinTime 2021. 8. 8. 16:32
import torch

model = # 모델 생성

# 모델 저장
save_name = "model_final"
torch.save(model.state_dict(), f"Directory/{save_name}.pth")

# 모델 불러오기
load_name = "model_final"
model.load_state_dict(torch.load(f"Directory/{load_name}.pth")) # 경로, 저장 모델 이름

# 설정 값 세부 저장
checkpoint = {
  'model': model.state_dict(),
  'optimizer': optimizer.state_dict(),
  'lr_scheduler': lr_scheduler.state_dict(),
  'epoch': epoch
}
torch.save(checkpoint, f"Directory/{save_name}.pth")

# 설정 값 불러오기
checkpoint = torch.load(f"Directory/{load_name}.pth"))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
epoch = checkpoint['epoch']
Comments