20200913のTensorFlowに関する記事は1件です。

MLパイプラインのためのTensorflow Data Validation

Tensorflow Data Validation (TFDV)

tensorflow-extended-tfx-logo-social.png
TFDVはTFXのライブラリの一つで、

  • データセットの統計量
  • データセットのスキーマ

の可視化・チェックを行うツールです。
この統計量の可視化の機能がなかなか優秀で、この機能を使ってEDAを行うという記事があったりします。

しかし、機械学習のライフサイクルの中でTFDVが有効なのはデータ探索フェーズだけでなく、本番環境フェーズでも大いに役立ちます。
この記事では本番フェーズ、MLパイプラインでのTFDVの役割とその効果に注目します。
なお、今回実行したソースコードはここにあります。
https://github.com/ken0407/tfdv_practice

TFDV体験

ともあれ、まずはどんな感じのツールなのかを体験したいと思います。

準備

仮想環境を作り、TFDVと、可視化を行うためにJupyter labをインストールします。2020年9月現在、TFDVはpython3.8ではpip installできないので注意してください。

$ python -m venv venv
$ . venv/bin/activate
(venv)$ pip install tensorflow-data-validation jupyterlab

次に、今回実験で使うデータを用意します。

$ mkdir data
$ wget -P data/ https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv

TFDV の基本操作

まずはTFDVの基本的な操作である統計量とスキーマの可視化を行います。以下のコードはnotebookで実行してください。

統計量の算出・可視化

import pandas as pd
import tensorflow_data_validation as tfdv

df = pd.read_csv('data/titanic.csv')
stats = tfdv.generate_statistics_from_dataframe(df)
tfdv.visualize_statistics(stats)

実行すると以下のようになります。
スクリーンショット 2020-09-10 19.33.49.png

スキーマの算出・可視化

TFDVはデータのスキーマを作成・利用することができます。
スキーマは先に算出した統計量を用いて作成することができます。

stats = tfdv.generate_statistics_from_dataframe(df)
schema = tfdv.infer_schema(stats)
tfdv.display_schema(schema)

実行すると以下のようになります。
スクリーンショット 2020-09-10 19.37.34.png

未知のデータに対して最初にこの処理を実行すれば大まかなデータの概要は掴むことはできるのではないでしょうか。

TFDVの役割

これだけだとpandas-profilingなどでもいいのではと思われるかもしれませんが、TFDVが本当に役に立つのはMLパイプラインの中で使用されている時だと思います。

まず前提として、サービスの中で機械学習モデルを使用する場合、データは常に変化し続けていくため、継続的に再学習を行なっていく必要があります。
そういった背景の中で、あえて抽象的な言い方をするとこのデータは使っていいデータなのか否かを常にチェックする必要があります。
TFDVはデータ同士を比較することで異常を検知してこの問題を解決するツールです。

データの異常検知

上であえて"このデータは使っていいデータなのか否か"と抽象的に書きました。
使ってはいけないというのはつまり、

  • データの形式(スキーマ)が変わっている
  • データの分布が変わっている

というような異常が、

  • 別の日のデータセット同士(昨日のデータセットと今日のデータセットなど)
  • 別種のデータセット同士(学習データセットと検証データセットなど)

で発生している状況を指します。
ちなみに、別の日のデータセット同士で異常が発生していた場合をドリフト、別種のデータセット同士で発生していた場合をスキューと呼ぶようです。

ガーベッジイン・ガーベッジアウトという言葉があるように、これら異なるデータセット間でスキーマとデータの分布が共に一貫性が担保されていることが良いモデルを作るには欠かせません。

先に紹介したようにTFDVにはスキーマとデータの統計量を可視化・チェックする機能があるため、データの異常を検知することができます。

準備

train = df.iloc[:600]
valid = df.iloc[600:].reset_index(drop=True)
schema = tfdv.infer_schema(tfdv.generate_statistics_from_dataframe(train))

スキーマの異常検知

カラムが欠けていた場合

validデータセットにPclassが欠けていた場合を想定します。

lack_col_valid = valid.copy().drop('Pclass', axis=1)
anomaly = tfdv.validate_statistics(tfdv.generate_statistics_from_dataframe(lack_col_valid),
                                   schema,
                                  )
tfdv.display_anomalies(anomary)

スクリーンショット 2020-09-12 21.55.53.png

未知の値が入っていた場合

validデータセットのSexカラムに、trainデータセットのSexには含まれていなかった値(dog)が入っていた場合を想定します。

new_val_in_sex_valid = valid.copy()
new_val_in_sex_valid.loc[0, 'Sex'] = 'dog'
anomaly = tfdv.validate_statistics(tfdv.generate_statistics_from_dataframe(new_val_in_sex_valid),
                                   schema,
                                  )
tfdv.display_anomalies(anomaly)

スクリーンショット 2020-09-12 21.58.19.png

分布の異常検知

Ageの分布が学習データと検証データで大きく異なる場合を想定します。

error_distribution_valid = valid.copy()
error_distribution_valid['Age'] = error_distribution_valid['Age'] * 2

tfdv.visualize_statistics(lhs_statistics=tfdv.generate_statistics_from_dataframe(train),
                          rhs_statistics=tfdv.generate_statistics_from_dataframe(error_distribution_valid),
                          lhs_name='TRAIN_DATASET',
                          rhs_name='VALID_DATASET'
)

スクリーンショット 2020-09-13 18.43.17.png

まとめ

TFDVにはEDAに有効な強力なデータの可視化機能以外にも、データの異常を検知し、詳細情報を得ることができます。

紹介した以外にも、

  • 自動で生成したスキーマをカスタマイズする。
  • Dataflowなどを使用して大規模データセットを処理する

などの機能もあります。

また、TFXにはこの他にも便利な機能があるため、試していきたいと思います。

参考

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む