TensorFlow 控制函数

139 阅读3分钟

目录

  • tf.constant
  • tf.cast
  • tf.reduce_sum
  • tf.reduce_mean
  • tf.argmax
  • tf.cond
  • tf.equal
  • tf.where
  • tf.concat
  • tf.one_hot
import tensorflow as tf

tf.constant

numpy.array有着异曲同工之妙,创建一个张量

tf.constant(
    value, shape=None, dtype=None
) 
  • value:定义数据
  • shape:定义形状
  • dtype:定义类型
# 创建一个 int32 的张量
tensor = tf.constant(value=[1, 2, 3, 4, 5, 6, 7, 8, 9], shape=(3, 3), dtype='int32')
print(tensor)
==============================
输出:
tf.Tensor([[1 2 3]
          [4 5 6]
          [7 8 9]], shape=(3, 3), dtype=int32)

tf.cast

转变张量的类型

tf.cast(
    x, dtype
)
  • x:指定张量
  • dtype:目标类型
# 把 int32 转为 float32
tensor = tf.constant(value=[1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int32')
tensor = tf.cast(tensor, dtype='float32')
print(tensor)
==============================
输出:
tf.Tensor([1. 2. 3. 4. 5. 6. 7. 8. 9.], shape=(9,), dtype=float32)

tf.reduce_sum

numpy.sum有异曲同工之妙

tf.reduce_sum,
tf.math.reduce_sum (
    input_tensor, axis=0, keepdims=False
)
  • input_tensor:要计算的张量
  • axis:要计算的轴线
  • keepdims:是否保留一个维度
mean_1 = tf.reduce_sum(tensor, axis=0, keepdims=True)
mean_2 = tf.reduce_sum(tensor, axis=0, keepdims=False)

# 注意看 shape
print('sum_1:', mean_1)
print('sum_2:', mean_2)
==============================
输出:
sum_1: tf.Tensor([45.], shape=(1,), dtype=float32)
sum_2: tf.Tensor(45.0, shape=(), dtype=float32)

tf.reduce_mean

numpy.mean有异曲同工之妙

tf.reduce_mean,
tf.math.reduce_mean(
    input_tensor, axis=0, keepdims=False
)
  • input_tensor:要计算的张量
  • axis:要计算的轴线
  • keepdims:是否保留一个维度
mean_1 = tf.reduce_mean(tensor, axis=0, keepdims=True)
mean_2 = tf.reduce_mean(tensor, axis=0, keepdims=False)

# 注意看 shape
print('mean_1:', mean_1)
print('mean_2:', mean_2)
==============================
输出:
mean_1: tf.Tensor([5.], shape=(1,), dtype=float32)
mean_2: tf.Tensor(5.0, shape=(), dtype=float32)

tf.argmax

numpy.argmax有异曲同工之妙,返回张量中最大值的索引

tf.math.argmax(
    input,
    axis=0,
    output_type=tf.dtypes.int64,
)
  • input:输入张量
  • axis:要计算的轴
  • output_type:结果输出的类型
tensor = tf.constant([1, 2, 4, 3, 5, 6, 8, 7, 9, 10], shape=(5, 2))
print(tensor)
new_tensor = tf.argmax(tensor, axis=-1, output_type='int32')
print(new_tensor)
==============================
输出:
tf.Tensor([[1  2]
          [4  3]
          [5  6]
          [8  7]
          [9 10]], shape=(5, 2), dtype=int32) tf.Tensor([1 0 1 0 1], shape=(5,), dtype=int32)

tf.cond

相当于python中的if语句,只是相当于,在tensorflow内部运行的时候还是需要用到tf.cond

tf.cond(
    pred, true_fn=None, false_fn=None
)
  • pred:要判断的条件
  • true_fn:若predTrue,运行true_fn函数
  • false_fn:若predFalse,运行false_fn函数
pred = 1 != 2
true_fn = lambda: print('is True')
false_fn = lambda: print('is False')
tf.cond(pred=pred, true_fn=true_fn, false_fn=false_fn)

# 与上面的等效
if 1!=2:
    print('is True')
else:
    print('is False')
==============================
输出:
is True
is True

tf.equal

numpy.equal有异曲同工之妙,就判断俩张量的元素是否相同

tf.equal,
tf.math.equal(
    x, y
  • x:用作比较的一个张量
  • y:用作比较的一个张量
x = [[1, 2],
     [3, 4]]
y = [[1, 1],
     [1, 4]]
# True 就是元素相同,False 反之
print(tf.equal(x=x, y=y))
==============================
输出:
tf.Tensor([[True False]
          [False True]], shape=(2, 2), dtype=bool)

tf.where

numpy.where有异曲同工之妙,根据condition,将x中对应为False的部分替换为y中的相应部分

tf.where(
    condition, x=None, y=None
)
# 将 x 中的所有 2 替换为 1
x = [1, 2, 3]
y = [6, 6, 6]

# 将 2 替换为 6
print(tf.where([True, False, True], x, y))

x = [1, 2, 2]
# 将不等于 2 的部分替换为 1
print(tf.where(tf.equal(x, 2), x, 1))
==============================
输出:
tf.Tensor([1 6 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 2], shape=(3,), dtype=int32)

tf.concat

沿指定轴连接一堆张量

tf.concat(
    values, axis
)
  • values:一堆张量
  • axis:指定轴
a = [[1, 2, 3], [4, 5, 6]]
b = [[7, 8, 9], [10, 11, 12]]
print(tf.concat([a, b], axis=-1))
==============================
输出:
tf.Tensor([[ 1 2 3 7 8 9]
          [ 4 5 6 10 11 12]], shape=(2, 6), dtype=int32)

tf.one_hot

进行 One-Hot 处理,一般在生成标签的时候用到

tf.one_hot(
    indices,
    depth,
    dtype=None,
)
  • indices:标签的类别(得是int or float)
  • depth:标签的个数
  • dtype:返回的类型
labels = [0, 1, 2, 3]
num = 4
labels = tf.one_hot(indices=labels, depth=num, dtype='int32')
print(labels)
==============================
输出:
tf.Tensor([[1 0 0 0]
          [0 1 0 0]
          [0 0 1 0]
          [0 0 0 1]], shape=(4, 4), dtype=int32)