머신러닝/Pytorch
visualize the filters of the first CNN layer
조마조마
2021. 1. 6. 09:48
let's get the weight of the first layer:
for w in model_resnet18.parameters():
w = w.data.cpu()
print(w.shape)
break
then, normalize the weights:
min_w = torch.min(w)
w1 = (-1 / (2 * min_w)) * w + 0.5
print(torch.min(w1).item(), torch.max(w1).item())
next, make a grid and display it:
grid_size = len(w1)
x_grid = [w1[i] for i in range(grid_size)]
x_grid = utils.make_grid(x_grid, nrow=8, padding=1)
print(x_grid.shape)
plt.figure(figsize=(10, 10))
show(x_grid)
※ reference: pytorch computer vision codebook