20190228のTensorFlowに関する記事は1件です。

TensorFlowのsess.run()が遅かった

やりたいこと

色んな画像に対してObject Detection APIで
sess.run()したかったが、一枚5~8秒ほどかかった。

# イメージ
for i in range(10):
    tf.Session(graph=hoge) as sess:
        sess.run(hoge_op, feed_dict={})

実際はクラス化したり、sessionがネストしてたり、
外部から呼び出されたりでもっと複雑だったが、
まぁ雰囲気はこんな感じだった。

対処

要はこう。

# イメージ
tf.Session(graph=hoge) as sess:
    for i in range(10):
        sess.run(hoge_op, feed_dict={})

何故か。グラフとセッションの関係

この2つの概念を知っておけば、自分のような初心者でも何となくTensorflowのコードが理解出来るようになった。

  • グラフ:計算グラフ

情報系の人間には見覚えのあるグラフ。

TensorFlowで書く時は↓みたいな感じ。上の画像とは関係ないです。

add_graph = tf.Graph()
with add_graph.as_default():
    a = tf.placeholder(tf.int32, shape=[], name="a")
    b = tf.placeholder(tf.int32, shape=[], name="b")
    add_op = tf.add(a, b, name="add_op")
  • セッション:

グラフを実行するために存在する。
コードによっては省略されているが、tf.Session(graph=hoge)のように、実行対象のグラフは必ず指定されている。
tf.Session()でグラフを指定したsessを作成出来て、
sess.run()でその中の一部とか全部を出力出来る(最後のoperationとかが引数)。

with tf.Session(graph=add_graph) as sess:
    ret = sess.run(add_op, feed_dict={a:1,b:1})
    print ret
  • 実装の流れ:
  • グラフを作成
  • セッションを作成
  • sess.run()

何故遅かったのか

最初の例だと、毎回セッションの作成を行ってしまうから。
セッションは作成時にメモリ確保とかを行うので、作成にはかなり時間がかかる。

参考サイト

グラフとセッションの部分のコード参考:
http://docs.fabo.io/tensorflow/building_graph/tensorflow_graph_part2.html
グラフとセッションについての説明:
https://arakan-pgm-ai.hatenablog.com/entry/2017/05/04/173031

感想

Chainerとかは直感的だけど、TensorFlowは結構独特で、
ごくごく軽く触る程度の身としては困っている。

勉強(機械学習)のための勉強(Tensorflowへの慣れ)、
になってしまったので、
やはりChainerみたいなdefine by run系がとっつきやすいと思う。

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む