tensorflow常用函数

227 阅读1分钟

tf.gather_nd

tf.gather_nd 是 TensorFlow 中的一个函数,用于根据给定的索引从张量中收集元素。它可以用于从多维张量中选择特定位置的元素。 tf.gather_nd 的基本用法是提供一个索引张量,该索引张量指定了要收集的元素的位置。索引张量的形状决定了输出张量的形状。 下面是一个示例代码,演示如何使用 tf.gather_nd 函数:

import tensorflow as tf
# 创建一个输入张量
input_tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 创建一个索引张量
indices = tf.constant([[0, 0], [1, 2]])
# 使用 tf.gather_nd 收集元素
output_tensor = tf.gather_nd(input_tensor, indices)
with tf.Session() as sess:
  print(sess.run(output_tensor))

在这个示例中,我们首先创建一个输入张量 input_tensor,其形状为 (3, 3)。然后,我们创建一个索引张量 indices,其形状为 (2, 2)。索引张量中的每个元素指定了要从输入张量中收集的元素的位置。 最后,我们使用 tf.gather_nd 函数将输入张量中指定位置的元素收集到输出张量中。输出张量的形状与索引张量的形状相关。 输出将是一个包含收集到的元素的张量。在这个示例中,输出为:

[1 6]