Pytorch provides two main ways of doing this:
- The less recommended way is to save the entire model object as follows:
torch.save(model, PATH_TO_MODEL)
And then, the saved model can be later read as follows:
model = torch.load(PATH_TO_MODEL)
Although this approach looks the most straightforward, this can be problematic in some cases. This is because we are not only saving the model parameters, but also the model classes and directory structure used in our source code. If our class signatures or directory structures change later, loading the model will fail in potentially unfixable ways.
- The second and more recommended way is to only save the model parameters as follows:
torch.save(model.state_dict(), PATH_TO_MODEL)
Later, when we need to restore the model, first we instantiate an empty model object and then load the model parameters into that model object as follows:
model = Net()
model.load_state_dict(torch.load(PATH_TO_MODEL)
We will use the morte recommended way to save the model as shown in the following code:
PATH_TO_MODEL = "./convnet.pth"
torch.save(model.state_dict(), PATH_TO_MODEL)
The convnet.pth file is essentially a pickle file containing model parameters.
※ reference: Mastering Pytorch
'머신러닝 > Pytorch' 카테고리의 다른 글
define colors using random tuples (0) | 2021.01.06 |
---|---|
define the optimizer and the learning rate schedule (0) | 2021.01.06 |
visualize the filters of the first CNN layer (0) | 2021.01.06 |
store best weights (0) | 2021.01.04 |
Storing and loading models (0) | 2020.12.31 |