- 投稿日:2020-02-22T13:01:47+09:00
TensorFlowのinput_fnを雑に定義したらコケた件
概要
DNNClassifierを試したときにinput_fnでハマったのでメモ。
バージョン
>>> tf.__version__ '2.1.0'症状
少量のデータで学習できるか試してみていたら、以下のコードでエラーが起きた。
import tensorflow as tf # データ features = {'feature': [1, 2, 3]} labels = [0, 0, 1] dataset = tf.data.Dataset.from_tensor_slices((features, labels)) def train_input_fn(): return dataset.batch(1) # DNNClassifier hidden_units=[32, 16] feature_columns = [tf.feature_column.numeric_column('feature')] estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns, hidden_units=hidden_units) # 学習 estimator.train(input_fn=train_input_fn, steps=10)次のようなエラーがでた。
RuntimeError: Attempting to capture an EagerTensor without building a function.解決策
input_fn
でdatasetを用意するべきだった模様。def train_input_fn(): return tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)エラーメッセージにピンと来なかったが、今読み返すとそれっぽいこと言ってるかも(適当)。
修正後コード
import tensorflow as tf # データ features = {'feature': [1, 2, 3]} labels = [0, 0, 1] # dataset = tf.data.Dataset.from_tensor_slices((features, labels)) # def train_input_fn(): # return dataset.batch(1) def train_input_fn(): return tf.data.Dataset.from_tensor_slices((features, labels)).batch(1) # DNNClassifier hidden_units=[32, 16] feature_columns = [tf.feature_column.numeric_column('feature')] estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns, hidden_units=hidden_units) # 学習 estimator.train(input_fn=train_input_fn, steps=10)