单隐藏层的神经网络

导入所需库

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

初始化参数

def init_(n_x,n_h,n_y):
    w1=np.zeros((n_x,n_h))
    w2=np.zeros((n_h,n_y))
    
    b1=np.zeros((n_h,1))
    b2=np.zeros((n_y,1))
    
    return w1,w2,b1,b2

定义传播函数

def propagate(x,y,w1,w2,b1,b2):
    m=x.shape[1]
    
    z1=np.dot(w1.T,x)+b1
    a1=np.tanh(z1)
    z2=np.dot(w2.T,a1)+b2
    a2=np.tanh(z2)
    
    cost = (- 1 / m) * np.sum(y * np.log(a2) + (1 - y) * (np.log(1 - a2)))
    
    dz2= a2 - y
    dw2 = (1 / m) * np.dot(dz2, a1.T)
    db2 = (1 / m) * np.sum(dz2, axis=1, keepdims=True)
    dz1 = np.multiply(np.dot(w2, dz2), 1 - np.power(a1, 2))
    dw1 = (1 / m) * np.dot(dz1, x.T)
    db1 = (1 / m) * np.sum(dz1, axis=1, keepdims=True)

    dw1=dw1.T
    dw2=dw2.T
    
    cost=cost.ravel()
    
    grads={'dw1':dw1,'dw2':dw2,'db1':db1,'db2':db2}
    
    return grads,cost

定义优化函数

def optimizer(x,y,w1,w2,b1,b2,num_iterations,learning_rate,print_cost=False):
    costs=[]
    
    for i in range(num_iterations):
        grads,cost=propagate(x,y,w1,w2,b1,b2)
        dw1=grads['dw1']
        dw2=grads['dw2']
        db1=grads['db1']
        db2=grads['db2']
        
        w1=w1-learning_rate*dw1
        w2=w2-learning_rate*dw2
        b1=b1-learning_rate*db1
        b2=b2-learning_rate*db2
        
        if i%100==0:
            costs.append(cost)
        
        if print_cost and i%100==0:
            print("迭代的次数: %i , 误差值: %f" % (i,cost))
            
        params  = {'w1':w1,'w2':w2,'b1':b1,'b2':b2}
        
    return params,costs

模型预测函数

def predict(x,w1,w2,b1,b2):
    z1=np.dot(w1.T,x)+b1
    a1=np.tanh(z1)
    z2=np.dot(w2.T,a1)+b2
    a2=np.tanh(z2)
    return a2.ravel()

整合模型

def model(x_train,x_test,y_train,y_test,num_iterations,learning_rate,print_cost=False):
    n_x=x_train.shape[0]
    n_y=1
    n_h=4
    w1,w2,b1,b2=init_(n_x=n_x,n_h=4,n_y=n_y)
    
    params,costs=optimizer(x_train,y_train,w1,w2,b1,b2,num_iterations=num_iterations,learning_rate=learning_rate,print_cost=print_cost)
    
    w1=params['w1']
    w2=params['w2']
    b1=params['b1']
    b2=params['b2']
    
    y_pred_train=predict(x_train,w1,w2,b1,b2)
    y_pred_test=predict(x_test,w1,w2,b1,b2)
    
    print("训练集准确性:{}".format(mean_squared_error(y_train.ravel(),y_pred_train)))
    print("测试集准确性:{}".format(mean_squared_error(y_test.ravel(),y_pred_test)))
    
    return y_pred_train,y_pred_test

测试

x=np.random.randint(0,1,size=(1000,5))
y=np.random.rand(1000)
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=0)
x_train=x_train.reshape(5,-1)
x_test=x_test.reshape(5,-1)
y_train=y_train.reshape(1,-1)
y_test=y_test.reshape(1,-1)
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
(5, 800)
(5, 200)
(1, 800)
(1, 200)
y_pred_train,y_pred_test=model(x_train,x_test,y_train,y_test,num_iterations=2000,learning_rate=0.03,print_cost=True)
<ipython-input-3-c1c2a4656491>:9: RuntimeWarning: divide by zero encountered in log
  cost = (- 1 / m) * np.sum(y * np.log(a2) + (1 - y) * (np.log(1 - a2)))


迭代的次数: 0 , 误差值: inf
迭代的次数: 100 , 误差值: 0.695073
迭代的次数: 200 , 误差值: 0.692820
迭代的次数: 300 , 误差值: 0.692800
迭代的次数: 400 , 误差值: 0.692800
迭代的次数: 500 , 误差值: 0.692800
迭代的次数: 600 , 误差值: 0.692800
迭代的次数: 700 , 误差值: 0.692800
迭代的次数: 800 , 误差值: 0.692800
迭代的次数: 900 , 误差值: 0.692800
迭代的次数: 1000 , 误差值: 0.692800
迭代的次数: 1100 , 误差值: 0.692800
迭代的次数: 1200 , 误差值: 0.692800
迭代的次数: 1300 , 误差值: 0.692800
迭代的次数: 1400 , 误差值: 0.692800
迭代的次数: 1500 , 误差值: 0.692800
迭代的次数: 1600 , 误差值: 0.692800
迭代的次数: 1700 , 误差值: 0.692800
迭代的次数: 1800 , 误差值: 0.692800
迭代的次数: 1900 , 误差值: 0.692800
训练集准确性:0.08246015676106996
测试集准确性:0.08054889557439511
list(zip(y_test.ravel(),y_pred_test))
[(0.37107927115563777, 0.486817796537678),
 (0.8688615427200882, 0.486817796537678),
 (0.5355109993899763, 0.486817796537678),
 (0.22114449525092184, 0.486817796537678),
 (0.26770938312157333, 0.486817796537678),
 (0.729257582032334, 0.486817796537678),
 (0.3399104324910738, 0.486817796537678),
 (0.4711199761843007, 0.486817796537678),
 (0.06615983012845861, 0.486817796537678),
 (0.5447015532351007, 0.486817796537678),
 (0.6973357234497007, 0.486817796537678),
 (0.6030496566499163, 0.486817796537678),
 (0.823599574490197, 0.486817796537678),
 (0.11234537770443376, 0.486817796537678),
 (0.05722470908774224, 0.486817796537678),
 (0.8200842216707958, 0.486817796537678),
 (0.15658994420620342, 0.486817796537678),
 (0.5837259386549575, 0.486817796537678),
 (0.22226594923860277, 0.486817796537678),
 (0.1590071505357048, 0.486817796537678),
 (0.5032566392866032, 0.486817796537678),
 (0.022920835271538098, 0.486817796537678),
 (0.5817983022707389, 0.486817796537678),
 (0.24772095988927478, 0.486817796537678),
 (0.948484760737989, 0.486817796537678),
 (0.6038365220782025, 0.486817796537678),
 (0.2969715598717865, 0.486817796537678),
 (0.9083963402080432, 0.486817796537678),
 (0.5716514701181704, 0.486817796537678),
 (0.8426959561467403, 0.486817796537678),
 (0.5719467667417913, 0.486817796537678),
 (0.432785502639644, 0.486817796537678),
 (0.054162520180720986, 0.486817796537678),
 (0.8462259582215704, 0.486817796537678),
 (0.12254820452367743, 0.486817796537678),
 (0.8306126159438691, 0.486817796537678),
 (0.5303435769090119, 0.486817796537678),
 (0.3665551637533727, 0.486817796537678),
 (0.8154728543227319, 0.486817796537678),
 (0.23877360233181177, 0.486817796537678),
 (0.5987897273701399, 0.486817796537678),
 (0.937093383506485, 0.486817796537678),
 (0.343082759167032, 0.486817796537678),
 (0.8660848972610833, 0.486817796537678),
 (0.621727294931208, 0.486817796537678),
 (0.7953558777194885, 0.486817796537678),
 (0.09709280358412864, 0.486817796537678),
 (0.5593649343624151, 0.486817796537678),
 (0.33142794793072106, 0.486817796537678),
 (0.00839373657122755, 0.486817796537678),
 (0.03617147560847633, 0.486817796537678),
 (0.3680468805868833, 0.486817796537678),
 (0.9397360645836482, 0.486817796537678),
 (0.4087808550731754, 0.486817796537678),
 (0.35553328569321296, 0.486817796537678),
 (0.9559196604232578, 0.486817796537678),
 (0.8937777695186071, 0.486817796537678),
 (0.3497703412892025, 0.486817796537678),
 (0.9332290994696136, 0.486817796537678),
 (0.09683341570423865, 0.486817796537678),
 (0.1160521878423636, 0.486817796537678),
 (0.34766029690928413, 0.486817796537678),
 (0.29791012247268034, 0.486817796537678),
 (0.32215786512257505, 0.486817796537678),
 (0.8485553107775969, 0.486817796537678),
 (0.9094723558550617, 0.486817796537678),
 (0.27606115477105164, 0.486817796537678),
 (0.6316130851600054, 0.486817796537678),
 (0.5454360230711929, 0.486817796537678),
 (0.989002309786144, 0.486817796537678),
 (0.6002858926834858, 0.486817796537678),
 (0.2375103136622274, 0.486817796537678),
 (0.3743602103782735, 0.486817796537678),
 (0.15928910962328247, 0.486817796537678),
 (0.9219885632426765, 0.486817796537678),
 (0.08437546071257096, 0.486817796537678),
 (0.30670178038288465, 0.486817796537678),
 (0.8048662589077502, 0.486817796537678),
 (0.5457548497866417, 0.486817796537678),
 (0.30634658223679534, 0.486817796537678),
 (0.3790268482589678, 0.486817796537678),
 (0.4870117088575727, 0.486817796537678),
 (0.14409653622537577, 0.486817796537678),
 (0.310644945659476, 0.486817796537678),
 (0.31234232719538957, 0.486817796537678),
 (0.10547311697709494, 0.486817796537678),
 (0.8722719445052738, 0.486817796537678),
 (0.2532889319668499, 0.486817796537678),
 (0.26993929821784346, 0.486817796537678),
 (0.9980576488545955, 0.486817796537678),
 (0.01987983156128814, 0.486817796537678),
 (0.7558953099047069, 0.486817796537678),
 (0.685052431367328, 0.486817796537678),
 (0.7784070808137151, 0.486817796537678),
 (0.3518866988820638, 0.486817796537678),
 (0.12148718773717548, 0.486817796537678),
 (0.5546076694142167, 0.486817796537678),
 (0.8531629268755221, 0.486817796537678),
 (0.5045478267345075, 0.486817796537678),
 (0.7110396619474573, 0.486817796537678),
 (0.39430550246622353, 0.486817796537678),
 (0.11500693577012833, 0.486817796537678),
 (0.3286727340036747, 0.486817796537678),
 (0.4768145916020252, 0.486817796537678),
 (0.5173012338725859, 0.486817796537678),
 (0.9245222662330027, 0.486817796537678),
 (0.19300983737644895, 0.486817796537678),
 (0.9698389777147209, 0.486817796537678),
 (0.7682304091915088, 0.486817796537678),
 (0.6347136837607619, 0.486817796537678),
 (0.7878830998606767, 0.486817796537678),
 (0.7222472113736585, 0.486817796537678),
 (0.8843379925588427, 0.486817796537678),
 (0.2553769007094344, 0.486817796537678),
 (0.9410669563186016, 0.486817796537678),
 (0.8739863220898394, 0.486817796537678),
 (0.2530325387266127, 0.486817796537678),
 (0.033190536066682874, 0.486817796537678),
 (0.7632052863904151, 0.486817796537678),
 (0.9320420108405786, 0.486817796537678),
 (0.7514422243535841, 0.486817796537678),
 (0.9259538006000805, 0.486817796537678),
 (0.28283320421315006, 0.486817796537678),
 (0.8220450668071817, 0.486817796537678),
 (0.7208351481514301, 0.486817796537678),
 (0.35547171517475473, 0.486817796537678),
 (0.06263577662903907, 0.486817796537678),
 (0.4892355806070845, 0.486817796537678),
 (0.03299834177133876, 0.486817796537678),
 (0.40807109978170764, 0.486817796537678),
 (0.3295219454649767, 0.486817796537678),
 (0.37113760665654294, 0.486817796537678),
 (0.47156490478659696, 0.486817796537678),
 (0.42643675945737836, 0.486817796537678),
 (0.05485765651289365, 0.486817796537678),
 (0.42395603612987276, 0.486817796537678),
 (0.917130805303434, 0.486817796537678),
 (0.3031430642558812, 0.486817796537678),
 (0.30122689807624325, 0.486817796537678),
 (0.7994275567548964, 0.486817796537678),
 (0.4181152755610468, 0.486817796537678),
 (0.16240123491920444, 0.486817796537678),
 (0.8982835305876111, 0.486817796537678),
 (0.6875516598895547, 0.486817796537678),
 (0.533950713108464, 0.486817796537678),
 (0.5356945009516196, 0.486817796537678),
 (0.5414664083813571, 0.486817796537678),
 (0.7864817097747936, 0.486817796537678),
 (0.5882623634886606, 0.486817796537678),
 (0.777476594771057, 0.486817796537678),
 (0.32119964385782773, 0.486817796537678),
 (0.35564169999837714, 0.486817796537678),
 (0.07864004919390899, 0.486817796537678),
 (0.533136283882996, 0.486817796537678),
 (0.7895210322611858, 0.486817796537678),
 (0.8066665313247496, 0.486817796537678),
 (0.6060819175051647, 0.486817796537678),
 (0.7104770517758254, 0.486817796537678),
 (0.6325847553602532, 0.486817796537678),
 (0.8723694063368937, 0.486817796537678),
 (0.8979560742204824, 0.486817796537678),
 (0.3967862959350066, 0.486817796537678),
 (0.5981850800512435, 0.486817796537678),
 (0.4370388272555601, 0.486817796537678),
 (0.3310897196957263, 0.486817796537678),
 (0.48842174292915386, 0.486817796537678),
 (0.7249534839625978, 0.486817796537678),
 (0.8266621271109336, 0.486817796537678),
 (0.1407653413443799, 0.486817796537678),
 (0.33556718986486367, 0.486817796537678),
 (0.3815093436113274, 0.486817796537678),
 (0.7947442612614635, 0.486817796537678),
 (0.9346676672122953, 0.486817796537678),
 (0.1121682091626316, 0.486817796537678),
 (0.4382760104245176, 0.486817796537678),
 (0.06628542155083139, 0.486817796537678),
 (0.5985140090708447, 0.486817796537678),
 (0.32963880693608594, 0.486817796537678),
 (0.9019410468840745, 0.486817796537678),
 (0.11618929919767906, 0.486817796537678),
 (0.10113849620853255, 0.486817796537678),
 (0.6094595038892973, 0.486817796537678),
 (0.3234052377040215, 0.486817796537678),
 (0.05758385452393955, 0.486817796537678),
 (0.186128460883762, 0.486817796537678),
 (0.48858950348521324, 0.486817796537678),
 (0.1517620245279041, 0.486817796537678),
 (0.9517112861726339, 0.486817796537678),
 (0.5317301288423155, 0.486817796537678),
 (0.6032375408059465, 0.486817796537678),
 (0.6735230350543311, 0.486817796537678),
 (0.5062604862610698, 0.486817796537678),
 (0.1753443285148224, 0.486817796537678),
 (0.5588264751846398, 0.486817796537678),
 (0.6672728166482843, 0.486817796537678),
 (0.24322423640884572, 0.486817796537678),
 (0.038997951437352074, 0.486817796537678),
 (0.7403158320562031, 0.486817796537678),
 (0.021050147295514243, 0.486817796537678),
 (0.7964427279087964, 0.486817796537678)]
posted @ 2021-04-24 20:24  魏宝航  阅读(97)  评论(0编辑  收藏  举报