ほろ酔い開発日誌

AI企業のエンジニアのブログです。機械学習、Web開発の技術的お話、ビジネスチックなお話、日常のお役立ち情報など雑多な内容でお送りします。

Tensorflow run() vs eval() と InteractiveSession() vs Session()

はじめに

Tensorflowを使う際にコードによって若干の違いが見られたのでその点を理解しておきたいと思います。

  • run() と eval()
  • InteractiveSession() と Session()

この2点に違いについて説明します。

run() vs eval()

例えば、以下のような簡単なMLPの実装の一部を見て下さい。

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=t, logits=h_fc))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cost)
correct_prediction = tf.equal(tf.argmax(h_fc, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)

n_epochs = 10
batch_size = 100
n_batches = train_X.shape[0] // batch_size
train_X, train_y = shuffle(train_X, train_y)

for epoch in range(n_epochs):
    for i in range(n_batches):
        start = i * batch_size
        end = start + batch_size
        train_step.run(feed_dict={x: train_X[start:end], t: train_y[start:end]})
    train_accuracy = accuracy.eval(feed_dict={x: valid_X, t: valid_y})
    print("EPOCH::%i, training_accuracy %g" % (epoch+1, train_accuracy))

print("test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, t: mnist.test.labels}))
sess.close()

上のコードの中で、例えば、

train_step.run(feed_dict={x: train_X[start:end], t: train_y[start:end]})

の部分では run が使われているのに

accuracy.eval(feed_dict={x: valid_X, t: valid_y})

の部分では eval が使われているじゃないですか。 evalrun って何が違うのでしょうか?

stackoverflow.com

上記のAnswerとして以下のようにあります。

op.run() is a shortcut for calling tf.get_default_session().run(op)
t.eval() is a shortcut for calling tf.get_default_session().run(t)

ここでいう tf.get_default_session().run()

sess = tf.Session()
sess.run()

の sess = tf.get_default_session と考えれば分かりやすいと思います。 じゃ、「結局どっちも同じじゃん」って感じですけど、run は Operation クラスで evalTensor クラスに属するのでオブジェクトに応じてメソッドを変える必要があるということです。これが結論です。

ここで、「あれれ、じゃあ sess.run ってどういうやつだっけ?」ともなっているかもしれません。次で説明します。

InteractiveSession() vs Session()

https://www.tensorflow.org/versions/r0.11/api_docs/python/client/session_management#InteractiveSession

InteractiveSession() がTensorflowの公式に載っていました。Session()に対してInteractiveSession() は何が違うのでしょうか?

A TensorFlow Session for use in interactive contexts, such as a shell.

The only difference with a regular Session is that an InteractiveSession installs itself as the default session on construction. The methods Tensor.eval() and Operation.run() will use that session to run ops.

This is convenient in interactive shells and IPython notebooks, as it avoids having to pass an explicit Session object to run ops.

つまり、 InteractiveSessonを使うとsess = Session() のようにして指定したsessを明示的に指定しなくてもよくなるよ、ということです。IPython notebookで使うときとかに便利だということですね。

examples/faq.md at master · tensorflow/examples · GitHub

以下は上記リンク先のコード例です。わざわざ sess.run() のような記述はいらなくなります。

sess = tf.InteractiveSession()
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# We can just use 'c.eval()' without passing 'sess'
print(c.eval())
sess.close()

ちなみに with 公文を使えば、tf.Session() を使っても同様の記述が出来るようです。こちらも上記リンク先のコードです。

a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
with tf.Session():
  # We can also use 'c.eval()' here.
  print(c.eval())

sess.run を使って書いてみます。

a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
sess = tf.Session()
sess.run(c)
sess.close()

以上で違いが理解出来たのではないでしょうか?

おわりに

IPython notebookを使うときはInteractiveSession が便利のような気もしますが(ちょっと楽)、Sessionrun のメソッドで(eval とごちゃごちゃにならず)統一的に書けるので良いなと思ったりもしました。