第五次作业 训练一个逻辑与门和逻辑或门

作业五:训练一个逻辑与门和逻辑或门

项目 内容
这个作业属于的课程 人工智能实战2019(北京航空航天大学)
这个作业的要求 训练一个逻辑与门和逻辑或门
我在这个课程的目标是 学习算法,积累项目经验,锻炼coding能力
这个作业在哪个具体方面帮助我实现目标 非线性模型的优势
作业正文 见下文
其他参考文献 微软示例代码
  1. 训练数据
Example 1 2 3 4
x 0 0 1 1
y 0 1 0 1
逻辑与门 0 0 0 1
逻辑或门 0 1 1 1
  1. 检测数据
# test AND gate
input number one:1
input number two:1
[[0.99672156]]
True
# test OR gate
input number one:1
input number two:0
[[0.99822654]]
True
  1. 代码
  • gate.py
import numpy as np  
import matplotlib.pyplot as plt  
from base import *  
  
# x1=0,x2=0,y=0  
# x1=0,x2=1,y=0  
# x1=1,x2=0,y=0  
# x1=1,x2=1,y=1  
def Read_AND_Data(gate):  
    X = np.array([0, 0, 1, 1, 0, 1, 0, 1]).reshape(2, 4)  
    if gate == 'and':  
        Y = np.array([0, 0, 0, 1]).reshape(1, 4)  
    elif gate == 'or':  
        Y = np.array([0, 1, 1, 1]).reshape(1, 4)  
    return X,Y  
  
def Test(W,B):  
    n1 = input("input number one:")  
    x1 = float(n1)  
    n2 = input("input number two:")  
    x2 = float(n2)  
    a = ForwardCalculationBatch(W, B, np.array([x1,x2]).reshape(2,1))  
    print(a)  
    y = x1 or x2  
    if np.abs(a-y) < 1e-2:  
        print("True")  
    else:  
        print("False")  
  
  
if __name__ == '__main__':  
    # SGD, MiniBatch, FullBatch  
 # read data  X,Y = Read_AND_Data('or')  
    W, B = train(X, Y, ForwardCalculationBatch, CheckLoss)  
  
    print("w=",W)  
    print("b=",B)  
    ShowResult(W,B,X,Y,"AND")  
    # test  
  while True:  
        Test(W,B)
  • base.py
import numpy as np  
import matplotlib.pyplot as plt  
  
def Sigmoid(x):  
  s = 1 / (1 + np.exp(-x))  
  return s  
 
# 前向计算  
def ForwardCalculationBatch(W, B, batch_X):  
  Z = np.dot(W, batch_X) + B  
  A = Sigmoid(Z)  
  return A  


# 反向计算  
def BackPropagationBatch(batch_X, batch_Y, A):  
  m = batch_X.shape[1]  
  dZ = A - batch_Y  
  # dZ列相加,即一行内的所有元素相加  
dB = dZ.sum(axis=1, keepdims=True) / m  
  dW = np.dot(dZ, batch_X.T) / m  
  return dW, dB  


# 更新权重参数  
def UpdateWeights(W, B, dW, dB, eta):  
  W = W - eta * dW  
  B = B - eta * dB  
  return W, B  


# 计算损失函数值  
def CheckLoss(W, B, X, Y):  
  m = X.shape[1]  
  A = ForwardCalculationBatch(W, B, X)  

  p4 = np.multiply(1 - Y, np.log(1 - A))  
  p5 = np.multiply(Y, np.log(A))  

  LOSS = np.sum(-(p4 + p5))  # binary classification  
loss = LOSS / m  
  return loss  


# 初始化权重值  
def InitialWeights(num_input, num_output, method):  
   W = np.zeros((num_output, num_input))  
   B = np.zeros((num_output, 1))  
   return W, B  


def train(X, Y, ForwardCalculationBatch, CheckLoss):  
  num_example = X.shape[1]  
  num_feature = X.shape[0]  
  num_category = Y.shape[0]  
  # hyper parameters  
  eta = 0.5  
max_epoch = 10000  
# W(num_category, num_feature), B(num_category, 1)  
W, B = InitialWeights(num_feature, num_category, "zero")  
  # calculate loss to decide the stop condition  
loss = 5 # initialize loss (larger than 0)  
error = 2e-3 # stop condition  

# if num_example=200, batch_size=10, then iteration=200/10=20  for epoch in range(max_epoch):  
      for i in range(num_example):  
          # get x and y value for one sample  
x = X[:, i].reshape(num_feature, 1)  
          y = Y[:, i].reshape(1, 1)  
          # get z from x,y  
batch_a = ForwardCalculationBatch(W, B, x)  
          # calculate gradient of w and b  
dW, dB = BackPropagationBatch(x, y, batch_a)  
          # update w,b  
W, B = UpdateWeights(W, B, dW, dB, eta)  
          # end if  
# end for # calculate loss for this batch  loss = CheckLoss(W, B, X, Y)  
      print(epoch, i, loss, W, B)  
      # end if  
if loss < error:  
          break  
# end for  

return W, B  


def ShowResult(W, B, X, Y, title):  
  w = -W[0, 0] / W[0, 1]  
  b = -B[0, 0] / W[0, 1]  
  x = np.array([0, 1])  
  y = w * x + b  
  plt.plot(x, y)  

  for i in range(X.shape[1]):  
      if Y[0, i] == 0:  
          plt.scatter(X[0, i], X[1, i], marker="o", c='b', s=64)  
      else:  
          plt.scatter(X[0, i], X[1, i], marker="^", c='r', s=64)  
  plt.axis([-0.1, 1.1, -0.1, 1.1])  
  plt.title(title)  
  plt.show()
  1. 结果
  • 与门
w= [[11.76694002 11.76546912]]
b= [[-17.81530488]]
  • 或门
 w= [[11.74573383 11.74749036]]
 b= [[-5.41268583]]
posted @ 2019-04-07 19:28  知也遇兮  阅读(256)  评论(0编辑  收藏  举报