动手学深度学习v2——第六章predict_ch6

在QA环节,有位同学问了第六章的predict函数在哪,书中没有给出,使用predict_ch3稍作更改可得。

def predict_ch6(net, test_iter, device, n=6):  #@save
    """预测标签(定义见第3章)"""

    for X, y in test_iter:
        X, y = X.to(device), y.to(device)#放到GPU
        break

    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    X = X.cpu().numpy()
    y = y.cpu().numpy()#把torch.tensor改为numpy并放在CPU

    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(X[0:n].reshape((n, 224, 224)), 1, n, titles=titles[0:n])

predict_ch6(net, test_iter, d2l.try_gpu())
posted @ 2023-03-16 20:42  胡不归来  阅读(114)  评论(0编辑  收藏  举报