Loading

Pytorch版 机器学习玩FizzBuzz游戏

FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz。当遇到5的倍数时,说buzz。当遇到十五的倍数时,就说fizzbuzz。其他情况下面正常数数。

可以写一个简单的小程序来决定要返回正常数值还是fizz、buzz、fizzbuzz

def fizz_buzz_encode(i):
    if i%15==0: return 3
    elif i%5==0: return 2
    elif i%3==0: return 1
    else: return 0
    
def fizz_buzz_decode(i,prediction):
    return [str(i), "fizz", "buzz","fizzbuzz" ][prediction]

def helper(i):
    print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
    
    
for i in range(1,15):
    helper(i)
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14

我们首先定义模型的输入与输出(训练数据)

import numpy as np
import torch

NUM_DIGITS = 10

def binary_encode(i, num_digits):
    return np.array([i>>d & 1 for d in range(num_digits)][::-1])

trX = torch.Tensor([binary_encode(i,NUM_DIGITS) for i in range(101, 2**NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range (101, 2**NUM_DIGITS)])

然后我们用Pytorch定义模型

NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)
  • 为了让我么们的模型学会FizzBuzz这个游戏,我们需要定义一个损失函数和一个优化算法
  • 这个优化算法会不断优化损失函数,使得模型的在该任务上取得尽可能低的损失值
  • 损失值低往往表示我们的模型表现好,损失值高表示模型表现差
  • 由于FizzBuzz本质上是个分类问题,我们选用Cross Entropy Loss函数
  • 优化函数我们采用Stochastic Gradient Descent
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

以下是模型的训练代码

BATCH_SIZE=128
for epoch in range(10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]
        
        y_pred = model(batchX)        # forwrad
        loss = loss_fn(y_pred, batchY)
        
        print("Epoch", epoch, loss.item())
        
        optimizer.zero_grad()
        loss.backward()          # backpass
        optimizer.step()        # gradient descent
        
Epoch 0 0.026707472279667854
Epoch 0 0.7354339957237244
Epoch 0 0.26077598333358765
Epoch 0 0.5795876383781433
Epoch 0 0.17743852734565735
Epoch 0 0.19959229230880737
Epoch 0 0.3708484172821045
Epoch 0 0.1135769709944725
Epoch 1 0.06608207523822784
Epoch 1 0.8781152367591858
Epoch 1 0.4952445328235626
...         # 中间太多了 略过
Epoch 9998 4.164749498158926e-06
Epoch 9998 4.6815180212433916e-06
Epoch 9998 1.7969510963666835e-06
Epoch 9999 3.5864200071955565e-06
Epoch 9999 6.42869508737931e-06
Epoch 9999 1.4137273183223442e-06
Epoch 9999 3.3005555906129302e-06
Epoch 9999 1.8169919258070877e-06
Epoch 9999 4.157300281804055e-06
Epoch 9999 4.6666186790389474e-06
Epoch 9999 1.7925359543369268e-06

最后我们用训练好的模型尝试在1到100这些数字上面玩FizzBuzz游戏

testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])

with torch.no_grad():
    testY = model(testX)
    
predictions = zip(range(1,101), testY.max(1)[1].cpu().data.tolist())
print([fizz_buzz_decode(i, x) for i,x in predictions])
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', '33', 'buzz', 'buzz', 'fizz', '37', 'buzz', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', 'buzz', 'fizz', 'buzz']

From

https://www.bilibili.com/video/BV12741177Cu?p=2

posted @ 2020-10-24 15:33  coderge  阅读(473)  评论(0编辑  收藏  举报