第三讲 神经网络八股--自定义class model 分类iris
1 import tensorflow as tf 2 from tensorflow.keras.layers import Dense 3 from tensorflow.keras import Model 4 from sklearn import datasets 5 import numpy as np 6 7 8 x_train = datasets.load_iris().data 9 y_train = datasets.load_iris().target 10 11 12 np.random.seed(116) 13 np.random.shuffle(x_train) 14 np.random.seed(116) 15 np.random.shuffle(y_train) 16 tf.random.set_seed(116) 17 18 19 20 class IrisModel(Model): 21 def __init__(self): 22 super(IrisModel, self).__init__() 23 self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) 24 25 def call(self, x): 26 y = self.d1(x) 27 return y 28 29 model = IrisModel() 30 31 32 model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1), 33 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 34 metrics=['sparse_categorical_accuracy']) 35 36 model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20) 37 38 model.summary()