26tensorflow基本使用

tensorflow基本使用

想来想去,在实战前,还是先看一下tensorflow的基本使用。



理解TensorFlow

  • 使用图(graph)来表示计算任务;
  • 在被称之为会话(Session)的上下文(context)中执行图;
  • 使用tensor(张量)表示数据;
  • 通过变量(Variable)维护状态;
  • 使用feed和fetch可以为任意的操作(arbitrary operation)赋值或者从其中获取数据。
  • TensorFlow是一个编程系统,使用图来表示计算任务。图中的节点被称作op(Operation),op可以获得0个或多个tensor,产生0个或多个tensor。每个tensor是一个类型化的多维数组。例如:可以将一组图像集表示成一个四维的浮点数组,四个维度分别是[batch, height, weight, channels]
  • 图(graph)描述了计算的过程。为了进行计算,图必须在会话中启动,会话负责将图中的op分发到CPU或GPU上进行计算,然后将产生的tensor返回。在Python中,tensor就是numpy.ndarray对象。

构建阶段和执行阶段

TensorFlow程序通常被组织成两个阶段:构建阶段执行阶段

  • 构建阶段:op的执行顺序被描述成一个图;
  • 执行阶段:使用会话执行图中的op;
  • 例如:通常在构建阶段创建一个图来表示神经网络,在执行阶段反复执行图中的op训练神经网络。

一个例子:

import tensorflow.compat.v1 as tf       # 导入tensorflow库
# import tensorflow as tf
tf.compat.v1.disable_eager_execution()
mat1 = tf.constant([[3., 3.]])          # 创建一个1x2的矩阵
mat2 = tf.constant([[2.], [2.]])        # 创建一个2x2的矩阵
produce = tf.matmul(mat1, mat2)         # 创建op执行两个矩阵的乘法
# sess = tf.compat.v1.Session()
sess = tf.Session()                     # 启动默认图
res = sess.run(produce)                 # 在默认图中执行op操作
print(res)                              # 输出乘积结果

输出结果:

[[12.]]

交互式会话(InteractiveSession)

为了方便使用Ipython之类的Python交互环境,可以使用交互式会话(InteractiveSession)来代替Session,使用类似Tensor.run()和Operation.eval()来代替Session.run(),避免使用一个变量来持有会话。

一个例子:

import tensorflow as tf  # 导入tensorflow库
# sess = tf.InteractiveSession()
sess = tf.compat.v1.InteractiveSession()  # 创建交互式会话
tf.compat.v1.disable_eager_execution()    # 保证sess.run()能够正常运行
a = tf.Variable([1.0, 2.0])  # 创建常量数组
b = tf.constant([3.0, 4.0])  # 创建常量数组
# sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())  # 变量初始化
res = tf.add(a, b)  # 创建加法操作
print(res.eval())  # 执行操作并输出结果

输出结果:

[4. 6.]

Feed操作

前面的例子中,数据均以变量或常量的形式进行存储。Tensorflow还提供了Feed机制,该机制可以临时替代图中任意操作中的tensor。

最常见的用例是使用tf.placeholder()创建占位符,相当于是作为图中的输入,然后使用Feed机制向图中占位符提供数据进行计算,具体使用方法见接下来的样例。

一个例子:

import tensorflow as tf  # 导入tensorflow库
tf.compat.v1.disable_eager_execution()  # 创建交互式会话
sess = tf.compat.v1.InteractiveSession()
input1 = tf.compat.v1.placeholder(tf.float32)  # 创建占位符
input2 = tf.compat.v1.placeholder(tf.float32)  # 创建占位符
res = tf.multiply(input1, input2)  # 创建乘法操作
# res.eval(feed_dict={input1: [7.], input2: [2.]})  # 求值
print(type(res), res)
print(sess.run(res, feed_dict={input1: [7.], input2: [2.]}))

输出结果:

<class 'tensorflow.python.framework.ops.Tensor'> Tensor("Mul:0", dtype=float32)
[14.]



最后的思考

1、sess = tf.Session()

在运行程序时,用到了sess = tf.Session(),会报错

AttributeError: module 'tensorflow' has no attribute 'Session'

理由:错误的意思是tensortflow模块没有Session属性,后来查阅资料发现,tensorflow2.0版本中的确没有Session这个属性,如果安装的是tensorflow2.0版本又想利用Session属性,可以将tf.Session()更改为:

tf.compat.v1.Session()

这个方法可以解决此类问题,不仅仅适用于Session属性。

再次运行时,程序又报了另一个错误:

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

查阅资料发现,原因是2.0与1.0版本不兼容,在程序开始部分添加以下代码,就可以正常运行了。

tf.compat.v1.disable_eager_execution()

ensorflow的官网对disable_eager_execution()方法是这样解释的:

This function can only be called before any Graphs, Ops, or Tensors have been created. 
It can be used at the beginning of the program for complex migration projects from TensorFlow 1.x to 2.x.

翻译:此函数只能在创建任何图、运算或张量之前调用。它可以用于从TensorFlow 1.x到2.x的复杂迁移项目的程序开头。

找到了一个更简单的方法,在引用tensorflow时,直接用:

import tensorflow.compat.v1 as tf

2、tf.InteractiveSession()

在运行程序时,用到了sess = tf.InteractiveSssion(),会报错:

AttributeError: module 'tensorflow' has no attribute 'InteractiveSession'

理由:在新的Tensorflow 2.0版本中已经移除了Session这一模块,改换运行代码:

sess = tf.compat.v1.InteractiveSession()

3、tf.global_variables_initializer()

在运行程序时,用到了tf.global_variables_initializer(),会报错:

AttributeError: module 'tensorflow' has no attribute 'global_variables_initializer'

错误原因同上,把这一句改成:

sess.run(tf.compat.v1.global_variables_initializer())

然后有有了新的错误:

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

保证sess.run()能够正常运行,新增代码:

sess = tf.compat.v1.InteractiveSession()

4、tf.placeholder()

同理改成

tf.compat.v1.disable_eager_execution()
tf.compat.v1.placeholder

5、multiply

tf.mul已经在新版本中被移除,使用 tf.multiply 代替



我恨版本错误。

posted @ 2021-11-18 21:03  奶酥  阅读(195)  评论(0编辑  收藏  举报