第一章: PyTorch计算机视觉实战:目标检测、图像处理与深度学习

 

第一章

复制代码
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch


###################  Chapter One #######################################
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
x = np.array([[1,1]])
y = np.array([[0]])

from copy import deepcopy
import numpy as np
def feed_forward(inputs, outputs, weights):
    pre_hidden = np.dot(inputs,weights[0])+ weights[1]
    hidden = 1/(1+np.exp(-pre_hidden))
    out = np.dot(hidden, weights[2]) + weights[3]
    mean_squared_error = np.mean(np.square(out - outputs))
    return mean_squared_error

def update_weights(inputs, outputs, weights, lr):
    original_weights = deepcopy(weights)
    temp_weights = deepcopy(weights)
    updated_weights = deepcopy(weights)
    original_loss = feed_forward(inputs, outputs, original_weights)
    for i, layer in enumerate(original_weights):
        print(i,layer)
        for index, weight in np.ndenumerate(layer):
            print("**",index, weight)
            temp_weights = deepcopy(weights)
            temp_weights[i][index] += 0.0001  # 每次+0.0001方式,改变一个参数,计算新的输出值,与原输出值比较,得到梯度
            print("**@",temp_weights[i][index])
            _loss_plus = feed_forward(inputs, outputs, temp_weights)
            grad = (_loss_plus - original_loss)/(0.0001)
            updated_weights[i][index] -= grad*lr
    return updated_weights, original_loss

W = [
    np.array([[-0.0053, 0.3793],
              [-0.5820, -0.5204],
              [-0.2723, 0.1896]], dtype=np.float32).T,
    np.array([-0.0140, 0.5607, -0.0628], dtype=np.float32),
    np.array([[ 0.1528, -0.1745, -0.1135]], dtype=np.float32).T,
    np.array([-0.5516], dtype=np.float32)
]



losses = []
for epoch in range(100):
    W, loss = update_weights(x,y,W,0.01)
    losses.append(loss)
    print(epoch,loss)
plt.plot(losses)
plt.title('Loss over increasing number of epochs')
plt.show()
########################################################################
复制代码

 



posted @   辛河  阅读(145)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示