Tensorflow2.0自定义层
Tensorflow2.0自定义层
tensorflow2.0建议使用tf.keras作为构建神经网络的高级API。 也就是说,大多数TensorFlow API都可用于eager执行模式。
from __future__ import absolute_import, division, print_function, unicode_literals
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf
print(tf.__version__)
## 一、网络层layer的常见操作
通常机器学习模型可以表示为简单网络层的堆叠与组合,而tensorflow就提供了常见的网络层,为编写神经网络程序提供了便利。
TensorFlow2推荐使用tf.keras来构建网络层,tf.keras来自原生keras,用其来构建网络具有更好的可读性和易用性。
如,要构造一个简单的全连接网络,只需要指定网络的神经元个数
layer = tf.keras.layers.Dense(100)
# 也可以添加输入维度限制
layer = tf.keras.layers.Dense(100, input_shape=(None, 20))
可以在[文档](https://www.tensorflow.org/api_docs/python/tf/keras/layers)中查看预先存在的图层的完整列表。 它包括Dense,Conv2D,LSTM,BatchNormalization,Dropout等等。
每个层都可以当作一个函数,然后以输入的数据作为函数的输入
layer(tf.ones([6, 6]))
同时也可以得到网络的变量、权重矩阵、偏置等
print(layer.variables) # 包含了权重和偏置
[<tf.Variable 'dense_1/kernel:0' shape=(6, 100) dtype=float32, numpy=
array([[-0.18237606, 0.16415142, 0.20687856, 0.23396944, 0.09779547,
-0.14794639, -0.10231382, -0.22263053, -0.0950674 , 0.18697281,
0.20488529, -0.04037735, -0.19727436, 0.0979359 , -0.1759503 ,
0.22504129, 0.21929003, -0.1273948 , -0.13652515, 0.02981101,
0.14656503, 0.20608391, 0.14076535, -0.02625689, -0.00161622,
-0.01449171, 0.23303385, 0.14593105, 0.11570902, -0.03970808,
-0.05525994, -0.20392904, -0.10306785, 0.21736331, 0.10087742,
-0.14146385, 0.03447478, 0.01457174, -0.06794603, 0.1030371 ,
-0.15175559, 0.22587933, 0.0804611 , 0.21479838, -0.11029668,
0.22146653, -0.07499251, 0.1368954 , -0.13015983, -0.12019924,
0.21677957, -0.09586674, -0.05949883, 0.22539525, 0.2289827 ,
-0.02051648, 0.01296295, 0.16009761, 0.10034381, 0.12798755,
-0.10539538, 0.11883061, 0.07966466, -0.22101976, 0.12746729,
-0.1093536 , -0.16521278, 0.20071043, 0.16937451, 0.01447372,
0.16793476, -0.13962969, 0.1615852 , -0.10127702, -0.21089599,
-0.03635107, -0.2252161 , -0.02891247, 0.04012387, 0.1437303 ,
-0.14835042, 0.04761215, 0.00950299, -0.23300804, 0.09713729,
0.15262072, -0.00947247, 0.07256009, -0.15564013, -0.23770826,
0.20197298, 0.17501004, 0.16743289, -0.05297002, 0.06925295,
0.13787319, -0.00939476, 0.21161182, -0.14816652, 0.09728603],
[-0.23597604, 0.09226345, -0.21754897, 0.030596 , -0.02821516,
-0.11382222, 0.04664303, 0.03997506, 0.11674343, 0.17904802,
0.09352373, -0.06271012, -0.0995118 , -0.0839863 , 0.19747855,
0.20034815, -0.00912318, -0.07400802, 0.1354406 , -0.10645141,
...
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32)>]
输出被截断。作为可滚动元素查看或在文本编辑器中打开。调整单元格输出设置。
print(layer.kernel, layer.bias) # 也可以分别取出权重和偏置
<tf.Variable 'dense_1/kernel:0' shape=(6, 100) dtype=float32, numpy=
array([[-0.18237606, 0.16415142, 0.20687856, 0.23396944, 0.09779547,
-0.14794639, -0.10231382, -0.22263053, -0.0950674 , 0.18697281,
0.20488529, -0.04037735, -0.19727436, 0.0979359 , -0.1759503 ,
0.22504129, 0.21929003, -0.1273948 , -0.13652515, 0.02981101,
0.14656503, 0.20608391, 0.14076535, -0.02625689, -0.00161622,
-0.01449171, 0.23303385, 0.14593105, 0.11570902, -0.03970808,
-0.05525994, -0.20392904, -0.10306785, 0.21736331, 0.10087742,
-0.14146385, 0.03447478, 0.01457174, -0.06794603, 0.1030371 ,
-0.15175559, 0.22587933, 0.0804611 , 0.21479838, -0.11029668,
0.22146653, -0.07499251, 0.1368954 , -0.13015983, -0.12019924,
0.21677957, -0.09586674, -0.05949883, 0.22539525, 0.2289827 ,
-0.02051648, 0.01296295, 0.16009761, 0.10034381, 0.12798755,
-0.10539538, 0.11883061, 0.07966466, -0.22101976, 0.12746729,
-0.1093536 , -0.16521278, 0.20071043, 0.16937451, 0.01447372,
0.16793476, -0.13962969, 0.1615852 , -0.10127702, -0.21089599,
-0.03635107, -0.2252161 , -0.02891247, 0.04012387, 0.1437303 ,
-0.14835042, 0.04761215, 0.00950299, -0.23300804, 0.09713729,
0.15262072, -0.00947247, 0.07256009, -0.15564013, -0.23770826,
0.20197298, 0.17501004, 0.16743289, -0.05297002, 0.06925295,
0.13787319, -0.00939476, 0.21161182, -0.14816652, 0.09728603],
[-0.23597604, 0.09226345, -0.21754897, 0.030596 , -0.02821516,
-0.11382222, 0.04664303, 0.03997506, 0.11674343, 0.17904802,
0.09352373, -0.06271012, -0.0995118 , -0.0839863 , 0.19747855,
0.20034815, -0.00912318, -0.07400802, 0.1354406 , -0.10645141,
...
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32)>
输出被截断。作为可滚动元素查看或在文本编辑器中打开。调整单元格输出设置。
## 二、实现自定义网络层
实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:
(1)__init__()函数,可以在其中执行所有与输入无关的初始化
(2)build()函数,可以获得输入张量的形状,并可以进行其余的初始化
(3)call()函数,构建网络结构,进行前向传播
实际上,不必等到调用build()来创建网络结构,也可以在__init__()中创建它们。 但是,在build()中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期的网络构建。 另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。
class MyDense(tf.keras.layers.Layer):
def __init__(self, n_outputs):
super(MyDense, self).__init__()
self.n_outputs = n_outputs
def build(self, input_shape):
self.kernel = self.add_variable('kernel',
shape=[int(input_shape[-1]),
self.n_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDense(10)
print(layer(tf.ones([6, 5])))
print(layer.trainable_variables)
tf.Tensor(
[[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]
[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]
[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]
[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]
[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]
[ 1.0200843 -0.42590106 -0.92992705 0.46160045 0.7518406 0.32543844
0.34020287 0.08215448 0.22044104 -0.5337319 ]], shape=(6, 10), dtype=float32)
[<tf.Variable 'my_dense/kernel:0' shape=(5, 10) dtype=float32, numpy=
array([[ 0.54810244, 0.042225 , 0.25634396, 0.1677258 , -0.0361526 ,
0.32831818, 0.17709464, 0.46625894, 0.29662275, -0.32920587],
[ 0.30925363, -0.426274 , -0.49862564, 0.3068235 , 0.29526353,
0.50076336, 0.17321467, 0.21151704, -0.26317668, -0.2006711 ],
[ 0.10354012, -0.3258371 , -0.12274069, -0.33250242, 0.46343058,
-0.45535576, 0.5332853 , -0.37351888, -0.00410944, 0.16418225],
[-0.4515978 , 0.04706419, -0.42583126, -0.19347438, 0.54246336,
0.57910997, 0.01877069, 0.01255274, -0.14176458, -0.6309683 ],
[ 0.5107859 , 0.23692083, -0.13907343, 0.51302797, -0.5131643 ,
-0.6273973 , -0.56216246, -0.23465535, 0.332869 , 0.4629311 ]],
dtype=float32)>]
## 三、网络层组合
机器学习模型中有很多是通过叠加不同的结构层组合而成的,如resnet的每个残差块就是“卷积+批标准化+残差连接”的组合。
在tensorflow2中要创建一个包含多个网络层的的结构,一般继承与tf.keras.Model类。
# 残差块
class ResnetBlock(tf.keras.Model):
def __init__(self, kernel_size, filters):
super(ResnetBlock, self).__init__(name='resnet_block')
# 每个子层卷积核数
filter1, filter2, filter3 = filters
# 三个子层,每层1个卷积加一个批正则化
# 第一个子层, 1*1的卷积
self.conv1 = tf.keras.layers.Conv2D(filter1, (1,1))
self.bn1 = tf.keras.layers.BatchNormalization()
# 第二个子层, 使用特点的kernel_size
self.conv2 = tf.keras.layers.Conv2D(filter2, kernel_size, padding='same')
self.bn2 = tf.keras.layers.BatchNormalization()
# 第三个子层,1*1卷积
self.conv3 = tf.keras.layers.Conv2D(filter3, (1,1))
self.bn3 = tf.keras.layers.BatchNormalization()
def call(self, inputs, training=False):
# 堆叠每个子层
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.conv3(x)
x = self.bn3(x, training=training)
# 残差连接
x += inputs
outputs = tf.nn.relu(x)
return outputs
resnetBlock = ResnetBlock(2, [6,4,9])
# 数据测试
print(resnetBlock(tf.ones([1,3,9,9])))
# 查看网络中的变量名
print([x.name for x in resnetBlock.trainable_variables])
tf.Tensor(
[[[[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.83203167 0.9436392 1.0989372 1.2588525 0.8683256 1.1279813
0.7571581 0.47963202 0.88908756]]
[[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
0.9072118 0.5618475 0.9134829 ]
[0.79764616 1.0550306 0.9386751 1.1079601 0.9402881 0.99479383
...
1.1792164 1.0868194 1.0623009 ]
[0.87889266 0.9541194 0.8929231 0.96703756 1.0905087 1.0646607
0.9235744 0.9829142 1.1302696 ]]]], shape=(1, 3, 9, 9), dtype=float32)
['resnet_block/conv2d_12/kernel:0', 'resnet_block/conv2d_12/bias:0', 'resnet_block/batch_normalization_v2_12/gamma:0', 'resnet_block/batch_normalization_v2_12/beta:0', 'resnet_block/conv2d_13/kernel:0', 'resnet_block/conv2d_13/bias:0', 'resnet_block/batch_normalization_v2_13/gamma:0', 'resnet_block/batch_normalization_v2_13/beta:0', 'resnet_block/conv2d_14/kernel:0', 'resnet_block/conv2d_14/bias:0', 'resnet_block/batch_normalization_v2_14/gamma:0', 'resnet_block/batch_normalization_v2_14/beta:0']
输出被截断。作为可滚动元素查看或在文本编辑器中打开。调整单元格输出设置。
如果模型是线性的,可以直接用tf.keras.Sequential来构造。
seq_model = tf.keras.Sequential(
[
tf.keras.layers.Conv2D(1, 1, input_shape=(None, None, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(2, 1, padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(3, 1),
tf.keras.layers.BatchNormalization(),
])
seq_model(tf.ones([1,2,3,3]))
<tf.Tensor: id=1354, shape=(1, 2, 3, 3), dtype=float32, numpy=
array([[[[-0.36850607, -0.60731524, 1.2792252 ],
[-0.36850607, -0.60731524, 1.2792252 ],
[-0.36850607, -0.60731524, 1.2792252 ]],
[[-0.36850607, -0.60731524, 1.2792252 ],
[-0.36850607, -0.60731524, 1.2792252 ],
[-0.36850607, -0.60731524, 1.2792252 ]]]], dtype=float32)>
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
2024-01-11 电磁频谱参数杂谈
2023-01-11 LLVM 指令与lowering代码结构
2022-01-11 高通为何46亿美元ADAS Veoneer Arriver
2022-01-11 OpenCL,OpenGL编译
2022-01-11 半导体异质集成电路