简单实例Torch7如何建立神经网络
require('nn') local function createQNetwork() local mlp = nn.Sequential() mlp:add(nn.Reshape(10)) mlp:add(nn.Linear(10, 32)) mlp:add(nn.Sigmoid()) mlp:add(nn.Linear(32, 1)) return mlp end local function qinput(obs_table) local obs = torch.Tensor(#obs_table):fill(0) for k = 1, #obs_table do obs[k] = obs_table[k] end obs = obs:view(#obs_table, -1) return obs end qnn = createQNetwork() for i = 1, 30 do obs_table = {} if i <10 then for j = 1, 10 do obs_table[j] = 2 end else for j = 1, 10 do obs_table[j] = 3 end end obs = qinput(obs_table) print(obs) q = qnn:forward(obs) print("q",q) bs = torch.Tensor(1):fill(5) cri = nn.MSECriterion() qloss = cri:forward(q, bs) dloss_dpredict = cri:backward(q, bs) qnn:zeroGradParameters() qgradInput = qnn:backward(obs, dloss_dpredict) qnn:updateParameters(0.3) end
posted on 2017-10-30 11:46 WegZumHimmel 阅读(149) 评论(0) 编辑 收藏 举报