- 投稿日:2019-02-28T12:24:28+09:00
TensorFlowのsess.run()が遅かった
やりたいこと
色んな画像に対してObject Detection APIで
sess.run()したかったが、一枚5~8秒ほどかかった。# イメージ for i in range(10): tf.Session(graph=hoge) as sess: sess.run(hoge_op, feed_dict={})実際はクラス化したり、sessionがネストしてたり、
外部から呼び出されたりでもっと複雑だったが、
まぁ雰囲気はこんな感じだった。対処
要はこう。
# イメージ tf.Session(graph=hoge) as sess: for i in range(10): sess.run(hoge_op, feed_dict={})何故か。グラフとセッションの関係
この2つの概念を知っておけば、自分のような初心者でも何となくTensorflowのコードが理解出来るようになった。
- グラフ:計算グラフ
情報系の人間には見覚えのあるグラフ。
TensorFlowで書く時は↓みたいな感じ。上の画像とは関係ないです。add_graph = tf.Graph() with add_graph.as_default(): a = tf.placeholder(tf.int32, shape=[], name="a") b = tf.placeholder(tf.int32, shape=[], name="b") add_op = tf.add(a, b, name="add_op")
- セッション:
グラフを実行するために存在する。
コードによっては省略されているが、tf.Session(graph=hoge)のように、実行対象のグラフは必ず指定されている。
tf.Session()でグラフを指定したsessを作成出来て、
sess.run()でその中の一部とか全部を出力出来る(最後のoperationとかが引数)。with tf.Session(graph=add_graph) as sess: ret = sess.run(add_op, feed_dict={a:1,b:1}) print ret
- 実装の流れ:
- グラフを作成
- セッションを作成
- sess.run()
何故遅かったのか
最初の例だと、毎回セッションの作成を行ってしまうから。
セッションは作成時にメモリ確保とかを行うので、作成にはかなり時間がかかる。参考サイト
グラフとセッションの部分のコード参考:
http://docs.fabo.io/tensorflow/building_graph/tensorflow_graph_part2.html
グラフとセッションについての説明:
https://arakan-pgm-ai.hatenablog.com/entry/2017/05/04/173031感想
Chainerとかは直感的だけど、TensorFlowは結構独特で、
ごくごく軽く触る程度の身としては困っている。勉強(機械学習)のための勉強(Tensorflowへの慣れ)、
になってしまったので、
やはりChainerみたいなdefine by run系がとっつきやすいと思う。