深度学习--tensorflow中操作张量的高频率api--87

1. 创建张量

tf.constant(value, dtype=None, shape=None, name='Const')
tf.zeros(shape, dtype=tf.float32, name=None)
tf.ones(shape, dtype=tf.float32, name=None)
tf.fill(dims, value, name=None)
tf.random.normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
tf.random.uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)

2. shape操作

tf.reshape(tensor, shape, name=None)
tf.expand_dims(input, axis, name=None)
tf.squeeze(input, axis=None, name=None)
tf.transpose(a, perm=None, conjugate=False, name='transpose')

3. 数学运算

tf.add(x, y, name=None)
tf.subtract(x, y, name=None)
tf.multiply(x, y, name=None)
tf.divide(x, y, name=None)
tf.matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None)
tf.reduce_sum(input_tensor, axis=None, keepdims=False, name=None)
tf.reduce_mean(input_tensor, axis=None, keepdims=False, name=None)
tf.reduce_max(input_tensor, axis=None, keepdims=False, name=None)
tf.reduce_min(input_tensor, axis=None, keepdims=False, name=None)

4 逻辑运算

tf.equal(x, y, name=None)
tf.not_equal(x, y, name=None)
tf.greater(x, y, name=None)
tf.greater_equal(x, y, name=None)
tf.less(x, y, name=None)
tf.less_equal(x, y, name=None)

5. 张量之间的操作

tf.concat(values, axis, name='concat')
tf.split(value, num_or_size_splits, axis=0, num=None, name='split')
tf.stack(values, axis=0, name='stack')
tf.unstack(value, num=None, axis=0, name='unstack')

6. 数据类型的转换

tf.cast(x, dtype, name=None)

7. 聚合(规约)操作

tf.reduce_all(input_tensor, axis=None, keepdims=False, name=None)
tf.reduce_any(input_tensor, axis=None, keepdims=False, name=None)

8 argmax

tf.argmax 经常用于softmax之后 用于判断多分类问题 的输出标签

import tensorflow as tf
import numpy as np

# 假设我们有一个训练好的模型,并且输入数据已经通过 softmax 层进行了处理
# 这里我们手动创建一个 softmax 输出示例
softmax_output = np.array([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1], [0.3, 0.4, 0.3]])

# 使用 tf.argmax 找到每个样本的最大概率对应的类别
predicted_labels = tf.argmax(softmax_output, axis=1)

# 打印预测的标签
print("Predicted labels:", predicted_labels.numpy())

Predicted labels: [2 0 1]

posted @ 2024-06-21 20:44  jack-chen666  阅读(4)  评论(0编辑  收藏  举报