머신러닝/Pytorch

visualize the filters of the first CNN layer

조마조마 2021. 1. 6. 09:48

let's get the weight of the first layer:

for w in model_resnet18.parameters():
	w = w.data.cpu()
    print(w.shape)
    break

 

then, normalize the weights:

min_w = torch.min(w)
w1 = (-1 / (2 * min_w)) * w + 0.5
print(torch.min(w1).item(), torch.max(w1).item())

 

next, make a grid and display it:

grid_size = len(w1)
x_grid = [w1[i] for i in range(grid_size)]
x_grid = utils.make_grid(x_grid, nrow=8, padding=1)
print(x_grid.shape)

plt.figure(figsize=(10, 10))
show(x_grid)

 

※ reference: pytorch computer vision codebook