【深度学习】TensorFlow会话中的主要方法

77 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第14天,点击查看活动详情

1. 会话中的关键操作

1.1 run

run()方法

  • 通过使用sess.run()来运行operation
  • fetches:可以是单一的operation操作、或者是列表、也可以是元组类型
  • feed_dict:参数允许调用者覆盖图中张量的值,运行时赋值。
  • feed_dict不能单独使用,需要搭配placeholder使用

注意:run()和tf.operation.eval()都可以运行operation操作,但是tf.operation.eval()需要在会话中运行。

代码演示:

步骤1:导入所需要的库

  • 此处我们想用tensorflow1.x版本进行演示,但是我们安装的是tensorflow2.x版本
  • 因此,想使用tensorflow1.x的语法,需要开启兼容模式
  • 使用如下语法调用tensorflow库并开启兼容
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

步骤2:运行operation

  • 只有开启会话我们才能看到定义常量的结果
  • run()中传入c,表示运行c定义的操作,我们想看到c的值
  • 同样,使用.eval方法也可以运行操作
# 创建图
a = tf.constant(1.0)
b = tf.constant(2.0)
c = a * b

# 创建会话
sess = tf.Session()

# 计算C的值
print(sess.run(c))
print(c.eval(session = sess))

代码运行结果如下图所示:

image.png

如果我们想查看所有的值,可以将其都传入run方法中:具体操作如下所示

  • 使用列表将a,b,c都传入其中进行查看
print(sess.run([a,b,c]))

代码运行结果如下图所示:

image.png

当然,也可以使用元组的形式将其传入:

  • 也会返回元组
print(sess.run((a,b,c)))

1.2 feed操作

placeholder提供占位符,run时候通过feed_dict指定参数

  • 有的时候,在定义张量的时候,我们并不确定具体的值是多少,这个时候就可以使用paceholder去定义。
  • 这就相当于先占了一个位置
  • 在开启会话,运行操作的时候就需要传入赋值了,否则就会报错
  • 在会话中传入的值就需要和占位值的信息相符合

代码演示:

  • 通过定义占位符的方式
  • 将占位符传入feed_dict,同时将其赋值
a_ph = tf.placeholder(tf.float32)
b_ph = tf.placeholder(tf.float32)
c_add = tf.add(a_ph, b_ph)
print("a_ph:\n", a_ph)
print("b_ph:\n", a_ph)
print("c_add:\n", a_ph)

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=True)) as sess:
    c_ph_value = sess.run(c_add, feed_dict={a_ph:2.2, b_ph:3.3})
    print("c_ph_value:\n", c_ph_value)

运行结果如下图所示:

image.png