动手学深度学习v2——第六章predict_ch6
在QA环节,有位同学问了第六章的predict函数在哪,书中没有给出,使用predict_ch3稍作更改可得。
def predict_ch6(net, test_iter, device, n=6): #@save
"""预测标签(定义见第3章)"""
for X, y in test_iter:
X, y = X.to(device), y.to(device)#放到GPU
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
X = X.cpu().numpy()
y = y.cpu().numpy()#把torch.tensor改为numpy并放在CPU
titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
d2l.show_images(X[0:n].reshape((n, 224, 224)), 1, n, titles=titles[0:n])
predict_ch6(net, test_iter, d2l.try_gpu())