BUG | ValueError: Shape mismatch: The shape of labels (received (320,)) should equal the shape of logits except for the last dimension (received (64, 5)).

1 TensorFlow报错

报错信息:

image

2 报错原因

字面原因:

这个问题是由于输出层的类别数和训练数据shape不同导致。

底层原因:

Step1 : 代码中,我通过ImageDataGenerator函数获取的图像生成器,会自动将图像label转为one-hot编码格式

train_image_generator = ImageDataGenerator(rescale=1./255, horizontal_flip=True)
val_image_generator = ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(directory = train_dir,
                                                           batch_size = batch_size,
                                                           shuffle=True,
                                                           target_size = (im_height, im_width),
                                                           class_mode=’categorical’)
val_data_gen = val_image_generator.flow_from_directory(directory = val_dir,
                                                       batch_size = batch_size,
                                                       shuffle=False,
                                                       target_size = (im_height, im_width),
                                                       class_mode=’categorical’)

train_imgs_batch, train_labels_batch = next(train_data_gen)
print(train_labels_batch[:5])

输出:

[[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]]

Step2 : 而在构造模型的loss函数和accuracy计算方法时,分别采用了SparseCategoricalCrossentropySparseCategoricalAccuracy

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name=’train_accuracy’)

而在TensorFlow官方文档有关tf.keras.losses.CategoricalCrossentropy函数中有说明:

image

accuracy也有类似说明:

image

image

输入的label经过了one hot编码,但是loss和accuracy却调错,使用了不采用one-hot编码的SparseCategoricalCrossentropy和SparseCategoricalAccuracy。

3 解决方法

直接改成对应的loss函数CategoricalCrossentropy和CategoricalAccuracy即可。

img

posted @ 2022-03-12 11:42  就良同学  阅读(135)  评论(0编辑  收藏  举报