Error code:
fig, axs = plt.subplots(n_filters, n_in_channels, figsize=figsize) print(axs[0, 0])
This is because n_filters = 1 and n_in_channels = 1, and plt.subplots has default value True for parameter squeeze, so the axs is a single subplot rather than an array of subplots.
To get the expected result, add squeeze=False:
fig, axs = plt.subplots(n_filters, n_in_channels, figsize=figsize, squeeze=False) print(axs[0, 0])
Or convert it to numpy array before using index or slice:
axs = np.atleast_2d(axs)
axs = axs.reshape(n_filters, n_in_channels)