20200222のTensorFlowに関する記事は1件です。

TensorFlowのinput_fnを雑に定義したらコケた件

概要

DNNClassifierを試したときにinput_fnでハマったのでメモ。

バージョン

>>> tf.__version__
'2.1.0'

症状

少量のデータで学習できるか試してみていたら、以下のコードでエラーが起きた。

import tensorflow as tf

# データ
features = {'feature': [1, 2, 3]}
labels = [0, 0, 1]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
def train_input_fn():
    return dataset.batch(1)

# DNNClassifier
hidden_units=[32, 16]
feature_columns = [tf.feature_column.numeric_column('feature')]
estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns, hidden_units=hidden_units)

# 学習
estimator.train(input_fn=train_input_fn, steps=10)

次のようなエラーがでた。

RuntimeError: Attempting to capture an EagerTensor without building a function.

解決策

input_fnでdatasetを用意するべきだった模様。

def train_input_fn():
    return tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)

エラーメッセージにピンと来なかったが、今読み返すとそれっぽいこと言ってるかも(適当)。

修正後コード

import tensorflow as tf

# データ
features = {'feature': [1, 2, 3]}
labels = [0, 0, 1]
# dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# def train_input_fn():
#     return dataset.batch(1)
def train_input_fn():
    return tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)

# DNNClassifier
hidden_units=[32, 16]
feature_columns = [tf.feature_column.numeric_column('feature')]
estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns, hidden_units=hidden_units)

# 学習
estimator.train(input_fn=train_input_fn, steps=10)
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む