# a deep copy of weights for the best performing model
best_model_wts = copy.deepcopy(model.state_dict())

# initialize best loss to a large value
best_loss=float('inf')

# main loop
	.
	.
	.
	# store best model
	if val_loss < best_loss:
		best_loss = val loss
    	best_model_wts = copy.deepcopy(model.state_dict())
    	# store weights into a local file
    	torch.save(model.state_dict(), path2weights)
    	print("Copied best model weights!")

 

+ Recent posts