머신러닝/Pytorch
store best weights
조마조마
2021. 1. 4. 14:34
# a deep copy of weights for the best performing model
best_model_wts = copy.deepcopy(model.state_dict())
# initialize best loss to a large value
best_loss=float('inf')
# main loop
.
.
.
# store best model
if val_loss < best_loss:
best_loss = val loss
best_model_wts = copy.deepcopy(model.state_dict())
# store weights into a local file
torch.save(model.state_dict(), path2weights)
print("Copied best model weights!")