tensorflow学习笔记10

神经网络模型架构1

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import input_data

mnist = input_data.read_data_sets('data/',one_hot=True) #one_hot=True编码格式为01编码
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1
weights = {
    'w1':tf.Variable(tf.random.normal([n_input,n_hidden_1],stddev=stddev)),
    'w2':tf.Variable(tf.random.normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random.normal([n_hidden_2,n_classes],stddev=stddev))
}
biases = {
    'bi':tf.Variable(tf.random.normal([n_hidden_1])),
    'b2':tf.Variable(tf.random.normal([n_hidden_2])),
    'out':tf.Variable(tf.random.normal([n_classes]))
}
print("NETWORK READY")

 

posted @ 2021-02-06 21:52  藻类植物  阅读(38)  评论(0编辑  收藏  举报