tensorflow的tile使用
当你需要按照矩阵维度复制数据时候,可以使用tensorflow的tile函数
a1 = tf.tile(a, [2, 2]) 表示把a的第一个维度复制两次,第二个维度复制2次。
注意使用tf.nn.softmax(r, axis=0),表示对每一列取softmax,一定要注意维度,axis=0表示对列取softmax,不然数据会出错
1 def tensoflow_test(): 2 # 一个batch有20个样本,每个样本的长度为5,每一个为200维度 3 lstm_outpus = tf.truncated_normal(shape=[2, 5, 4], mean=0, stddev=1) 4 # 变形成二维 5 lstm_o = tf.reshape(lstm_outpus, shape=[-1, 4]) 6 # 经过非线性 7 M = tf.tanh(lstm_o) 8 # 初始化权重信息 9 w = tf.truncated_normal(shape=[4,1], mean=0, stddev=1) 10 # 权重tf.matmul(M, w) 11 r = tf.matmul(M, w) 12 a = tf.nn.softmax(r, axis=0) 13 alpha = tf.tile(a, (1, 4)) 14 # attention_res = lstm_o * alpha 15 16 # M = tf.reshape(t, shape=[-1, 200]) 17 # o = tf.Variable(tf.truncated_normal([1, 200]), name='w', dtype=tf.float32) 18 # a = tf.Variable(tf.truncated_normal([2,3]), dtype=tf.float32) 19 # b = tf.Variable(tf.truncated_normal([2,3]), dtype=tf.float32) 20 # a_b = tf.multiply(a,b) 21 # # a_b = a * b 22 # w = tf.transpose(o) 23 # res = tf.matmul(M, w) 24 # res2 = tf.reshape(res, shape=[-1, 5]) 25 # copy_res = tf.tile(res2, (3,1)) 26 # init_op = tf.global_variables_initializer() 27 28 with tf.Session() as sess: 29 # sess.run(init_op) 30 # print(sess.run(res)) 31 # print(sess.run(res2)) 32 # print(res2) 33 # print(sess.run(copy_res)) 34 # print(copy_res) 35 # print(sess.run(lstm_o)) 36 # print(sess.run(lstm_outpus)) 37 # print(sess.run(w)) 38 print(lstm_outpus) 39 print(lstm_o) 40 print(alpha) 41 # print(sess.run(lstm_outpus)) 42 print(sess.run([a, alpha])) 43 # print(sess.run(alpha)) 44 # print(sess.run(alpha)) 45 # print(sess.run(attention_res))
时刻记着自己要成为什么样的人!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)