- 投稿日:2019-07-07T19:04:26+09:00
【Tensorflow・VGG16】転移学習による画像分類
やること(概要)
- 1. 画像データの収集
- 2. データセットの作成(画像データの変換)
- 3. モデルの作成 & 学習
- 4. 実行(コマンドライン)
動作環境
- macOS Catalina 10.15 beta
- Python 3.6.8
- flickapi 2.4
- pillow 6.0.0
- scikit-learn 0.20.3
- google colaboratory
実施手順
1. 画像データの収集
・3種類(りんご、トマト、いちご)の画像分類を実施するため、画像ファイルをflickrから取得
・flickrによる画像ファイルの取得方法は前回記事で書いたこちら
・それぞれ300枚の画像ファイルを取得
・検索キーワードは、「apple」、「tomato」、「strawberry」を指定
・flickrからダウンロードした不要なデータ(検索キーワードと関係ない画像ファイル)は目で見て除外しておくdownload.pyfrom flickrapi import FlickrAPI from urllib.request import urlretrieve import os, time, sys # Set your own API Key and Secret Key key = "XXXXXXXXXX" secret = "XXXXXXXXXX" wait_time = 0.5 keyword = sys.argv[1] savedir = "./data/" + keyword flickr = FlickrAPI(key, secret, format='parsed-json') result = flickr.photos.search( text = keyword, per_page = 300, media = 'photos', sort = 'relevance', safe_search = 1, extras = 'url_q, license' ) photos = result['photos'] for i, photo in enumerate(photos['photo']): url_q = photo['url_q'] filepath = savedir + '/' + photo['id'] + '.jpg' if os.path.exists(filepath): continue urlretrieve(url_q,filepath) time.sleep(wait_time)2. データセットの作成(画像データの変換)
・取得した画像ファイルをnumpy形式(バイナリファイル -> .npy)で保存
・VGG16のデフォルトサイズの224にresizegenerate_data.pyfrom PIL import Image import os, glob import numpy as np from sklearn import model_selection classes = ['apple', 'tomato', 'strawberry'] num_classes = len(classes) IMAGE_SIZE = 224 # Specified size of VGG16 Default input size in VGG16 X = [] # image file Y = [] # correct label for index, classlabel in enumerate(classes): photo_dir = './data/' + classlabel files = glob.glob(photo_dir + '/*.jpg') for i, file in enumerate(files): image = Image.open(file) # standardize to 'RGB' image = image.convert('RGB') # to make image file all the same size image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) data = np.asarray(image) X.append(data) Y.append(index) X = np.array(X) Y = np.array(Y) X_train, X_test, y_train, y_test = model_selection.train_test_split(X, Y) xy = (X_train, X_test, y_train, y_test) np.save('./image_files.npy', xy)3. モデルの作成 & 学習
1). Google Colaboratoryの利用
- トレーニング処理に時間がかかるため、GPUが無料で利用可能なGoogle Colaboratoryを使用(環境構築不要・無料で使えるブラウザ上のPython実行環境)
- 今回はGoogle Driveに「2.」で作成した「image_files.npy」をGoogle Driveへ格納し、ファイルをGoogle Colabからの読み込み
- 読み込みするためにGoogle Driveのマウントが必要であるが、方法は下記の通り (Google Colabの詳しい使い方はこちらを参考にした)
マウント方法from google.colab import drive drive.mount('/content/gdrive') # image_files.npyの格納先(My Drive直下に'hoge'フォルダを作成し、そこに格納) PATH = '/content/gdrive/My Drive/hoge/'2). データの読み込み & データ変換
- google driveに格納した「image_files.npy」を読み込み、訓練データとテストデータに分割
- 正解ラベルをone-hotベクトルへ変換(Ex:0 -> [1,0,0], 1 -> [0,1,0]のようなイメージ)
- データを標準化(画像データを0~1の範囲に変換。RGB形式なので、(0,0,0)~(255,255,255)の範囲であるため、255で割る)
X_train, X_test, y_train, y_test = np.load(PATH + 'image_files.npy', allow_pickle=True) # convert one-hot vector y_train = np_utils.to_categorical(y_train, num_classes) y_test = np_utils.to_categorical(y_test, num_classes) # normalization X_train = X_train.astype('float') / 255.0 X_test = X_test.astype('float') / 255.03). モデルの作成
- VGG16を利用
- 3つのパラメータは下記の通り。
include_top: ネットワークの出力層側にある3つの全結合層(Fully Connected層)を含むかどうか。今回はFC層を独自に計算するため、Falseを指定。weights: VGG16の重みの種類を指定する。None(ランダム初期化)か'imagenet' (ImageNetで学習した重み)のどちらか一方input_shape: オプショナルなshapeのタプル。include_topがFalseの場合のみ指定可能 (そうでないときは入力のshapeは(224, 224, 3)。正確に3つの入力チャンネルをもつ必要があり、width とheightは48以上にする必要があるモデルの作成vgg16_model = VGG16( weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3) )
- FC層を構築
- input_shapeには上記modelのoutputの形状で、1番目以降を指定(0番目は個数が入っている)
FC層の構築top_model = Sequential() top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:])) top_model.add(Dense(256, activation='relu')) top_model.add(Dropout(0.5)) top_model.add(Dense(num_classes, activation='softmax'))
- vgg16_modelとtop_modelを結合してモデルを作成
モデルの結合# combine models model = Model( inputs=vgg16_model.input, outputs=top_model(vgg16_model.output) ) model.summary()model.summaryの出力結果_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ sequential_1 (Sequential) (None, 2) 6423298 ================================================================= Total params: 21,137,986 Trainable params: 21,137,986 Non-trainable params: 0 _________________________________________________________________4). 重みの固定
- 上記で作成したモデルは下記2つを結合したもの
- vgg16_model:FC層を除いたVGG16
- top_model:多層パーセプトロン
- この内、vgg16_modelの'block4_pool'(model.summary参照)までの重みを固定(VGG16の高い特徴量抽出を継承するため)
重みの固定for layer in model.layers[:15]: layer.trainable = False5).モデルの学習
- optimizerはSGDを指定
- 多クラス分類を指定
モデルの学習opt = SGD(lr=1e-4, momentum=0.9) model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) model.fit(X_train, y_train, batch_size=32, epochs=10)6).テストデータでの評価
テストデータでの評価score = model.evaluate(X_test, y_test, batch_size=32) print('loss: {0} - acc: {1}'.format(score[0], score[1]))7).モデルの保存
モデルの保存model.save(PATH + 'vgg16_transfer.h5')4. 実行(コマンドライン)
- 作成したモデル(vgg16_transfer.h5)を使って、画像ファイルの推定を行う
predict.pyimport numpy as np from tensorflow import keras from tensorflow.keras.models import Sequential, Model, load_model from PIL import Image import sys classes = ['apple', 'tomato', 'strawberry'] num_classes = len(classes) IMAGE_SIZE = 224 # convert data by specifying file from terminal image = Image.open(sys.argv[1]) image = image.convert('RGB') image = image.resize((IMAGE_SIZE, IMAGE_SIZE)) data = np.asarray(image) X = [] X.append(data) X = np.array(X) # load model model = load_model('./vgg16_transfer.h5') # estimated result of the first data (multiple scores will be returned) result = model.predict([X])[0] predicted = result.argmax() percentage = int(result[predicted] * 100) print(classes[predicted], percentage)
- 実行は下記の通り(引数に推定する画像ファイル名を指定)
実行$ python predict.py XXXX.jpeg結果例strawberry 100ソースコード
https://github.com/hiraku00/vgg16_transfer
('image_files.npy'と'vgg16_transfer.h5'は100MB超過のため除外)参考文献
- 投稿日:2019-07-07T13:49:57+09:00
Android(JAVA)でTensorFlowLiteを使って画像分類をやってみる
初めに
今回は、Android(Java)でTensorFlowLiteを使って、画像分類をしようと思います!
もし、コード等に間違え、改善点があれば、教えてください!TensorFlowLiteとは
TensorFlowLiteのガイドによると、、、
TensorFlowLiteはスマートフォンやIotデバイスなどでTensorFlowモデルを使用するためのツールセットです。
TensorFlow Liteのインタプリタ携帯電話、組み込みLinuxデバイス、およびマイクロコントローラを含む多くの異なるハードウェアの種類、に特別に最適化されたモデルを実行します。
TensorFlowライトコンバータインタプリタによって使用するための効率的な形式にTensorFlowモデルを変換し、バイナリサイズとパフォーマンスを向上させるために最適化を導入することができます。
⇒つまり、PCだけじゃなくて、スマートフォンやIotデバイスなどでも簡単に実行できる、TensorFlowの軽量版的なやつか!
将来的には、スマートフォンだけで学習までできるとか!
すごい!TensorFlowLiteを使用した開発手順
1.TensorFlowモデルを用意する
TensorFlowで学習済みのモデルを用意します。
今回は、ホストされたモデルを使用するので、割愛します!2.TensorFlowのモデルを変換する
TensorFlowLiteではTensorFlowのモデルをそのまま使用することができないので、専用の形式(tflite)に変換します。
変換方法等はこちらの記事がわかりやすいので、参考にしてください。3.組み込む!
今回は、Android(Java)の組み込み方法を解説します!
組み込む!
新規プロジェクトの作成
プロジェクト名等は任意の名前にしてください!
今回は、AndroidXを使用します。
「Use androidx.* artifacts」にチェックすれば、OKです。
AndroidXの使用については任意なので、使わなくても大丈夫です。依存関係の追加
appディレクトリ下のbuild.gradleに
build.gradle(app)dependencies { implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly' }を追加します。
このままだと、すべてのCPUと命令セット用のABIが含まれていますが、「armeab-v7a」と「arm64-v8a」が含まれていれば、ほとんどのAndroidデバイスをカバーできるので、ほかのABIは含めいないように設定します。
含まれていても問題はないですが、アプリのサイズが減るので、おすすめです。build.gradle(app)android { defaultConfig { ndk { abiFilters 'armeabi-v7a', 'arm64-v8a' } } }ABIについてはこちらの記事がわかりやすいので参考にしてください。
Androidではassetフォルダーに入れられたものを圧縮してしまうので、モデルをassetフォルダーに入れると、圧縮されて読み込むことができなくなってしまいます。そこで、tfliteファイルを圧縮しないように指定してあげます。
build.gradle(app)android { defaultConfig { aaptOptions { noCompress "tflite" } } }モデルの設置
モデルとlabel_textをassetフォルダーに設置します。
こちらよりモデルをダウンロードしてください。
まずは、assetフォルダーフォルダーを作成します。
解凍したフォルダの中から、ファイルをコピーします。
コピーしたら、名前を「model.tflite」と「labels.txt」に変更します。
これでモデルの設置は完了です。
クラスのコピー&カスタマイズ
TensorFlowLiteのAndroidSampleの3つのクラスと、こちらのLogger.javaをコピーします。
コピーしただけだと、エラーが発生します。
Classifier.javaでLoggerクラスのインポート先を書き換えます。Classfier.javaimport org.tensorflow.lite.Interpreter; //ここを削除するimport org.tensorflow.lite.examples.classification.env.Logger; import org.tensorflow.lite.gpu.GpuDelegate; /** A classifier specialized to label images using TensorFlow Lite. */ public abstract class Classifier { private static final Logger LOGGER = new Logger();
削除すると、AndroidStudioがこんなことを聞いてくるので、「Alt+Enter」を押せば、自動でインポートしてくれます。
インポートする際に、
2種類出てくると思いますので、(android.jar)と書かれていないほうを選択します。
これで、エラーがすべて消えたと思います。
モデルの読み込み
ClassifierFloatMobileNet.java
ClassifierQuantizedMobileNet.java
の2つに共通しているモデルの読み込み部分を変更元
ClassifierFloatMobileNet.java,ClassifierQuantizedMobileNet.java@Override protected String getModelPath() { // you can download this file from // see build.gradle for where to obtain this file. It should be auto // downloaded into assets. return "mobilenet_v1_1.0_224.tflite"; }変更後
ClassifierFloatMobileNet.java,ClassifierQuantizedMobileNet@Override protected String getModelPath() { // you can download this file from // see build.gradle for where to obtain this file. It should be auto // downloaded into assets. return "model.tflite"; }Viewの配置
こんな感じにTextView,Button,ImageViewを配置します。
ButtonにはonClickを設定押しておきます。
↑ onClickを設定する方法ってリスナーのほうがいいのかな。詳しい人おしえてくださいな
activity_main.xml<?xml version="1.0" encoding="utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" tools:context=".MainActivity"> <LinearLayout android:layout_width="match_parent" android:layout_height="match_parent" android:orientation="vertical"> <LinearLayout android:layout_width="match_parent" android:layout_height="wrap_content" android:orientation="horizontal"> <TextView android:id="@+id/textView" android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_weight="1" android:text="TextView" /> </LinearLayout> <LinearLayout android:layout_width="match_parent" android:layout_height="wrap_content" android:orientation="horizontal"> <Button android:id="@+id/button" android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_weight="1" android:onClick="select" android:text="画像を選択" /> </LinearLayout> <LinearLayout android:layout_width="match_parent" android:layout_height="wrap_content" android:orientation="horizontal"> <ImageView android:id="@+id/imageView" android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_weight="1" tools:srcCompat="@tools:sample/avatars" /> </LinearLayout> </LinearLayout> </androidx.constraintlayout.widget.ConstraintLayout>コードを書く
まず、使用する変数を宣言しましょう!
MainActivity.javaImageView imageView; TextView textView; Classifier classifier; private static final int RESULT_IMAGEFILE = 1001; //画像取得時に使用するリクエストコードonCreate内でtextview,ImageViewの紐づけを行います。
MainActivity.javaimageView = findViewById(R.id.imageView); textView = findViewById(R.id.textView);次に、Classfierの呼び出しを行います。
MainActivity.javatry { classifier = Classifier.create(this,QUANTIZED,CPU,2); } catch (IOException e) { e.printStackTrace(); }引数は、Acritivy、Modelの種類、演算に使用するデバイスの指定、使用するスレッド数を指定します。
基本はこの設定で動くと思いますが、臨機応変に変更しましょう。ボタンの動作を書く
ボタンを押したら、ギャラリーを開いて画像が選択できるようにIntentを飛ばします。
MainAcritivy.javapublic void image(View V){ Intent intent = new Intent(Intent.ACTION_OPEN_DOCUMENT); intent.addCategory(Intent.CATEGORY_OPENABLE); intent.setType("image/*"); startActivityForResult(intent, RESULT_IMAGEFILE); }これについて詳しくはこちら
ギャラリーから戻ってきてからの処理
ギャラリーから戻て来たら、画像を取得して、処理します。
MainAcritivty.java@Override public void onActivityResult(int requestCode, int resultCode, Intent resultData) { super.onActivityResult(requestCode, resultCode, resultData); if (requestCode == RESULT_IMAGEFILE && resultCode == Activity.RESULT_OK) { if (resultData.getData() != null) { ParcelFileDescriptor pfDescriptor = null; try { Uri uri = resultData.getData(); pfDescriptor = getContentResolver().openFileDescriptor(uri, "r"); if (pfDescriptor != null) { FileDescriptor fileDescriptor = pfDescriptor.getFileDescriptor(); Bitmap bmp = BitmapFactory.decodeFileDescriptor(fileDescriptor); pfDescriptor.close(); int height = bmp.getHeight(); int width = bmp.getWidth(); while (true) { int i = 2; if (width < 500 && height < 500) { break; } else { if (width > 500 || height > 500) { width = width / i; height = height / i; } else { break; } i++; } } Bitmap croppedBitmap = Bitmap.createScaledBitmap(bmp, width, height, false); imageView.setImageBitmap(croppedBitmap); List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap,classfier); String text; for (Classifier.Recognition result : results) { RectF location = result.getLocation(); Float conf = result.getConfidence(); String title = result.getTitle(); text += title + "\n"; } textView.setText(text); } } catch (IOException e) { e.printStackTrace(); } finally { try { if (pfDescriptor != null) { pfDescriptor.close(); } } catch (Exception e) { e.printStackTrace(); } } } } }長いので、区切って説明します。
このコードは、アクティビティに戻ってきたときに呼ばれ、戻ってきたのが、ギャラリーからのものかを判定しています。
MainAcrivity.java@Override public void onActivityResult(int requestCode, int resultCode, Intent resultData) { super.onActivityResult(requestCode, resultCode, resultData); if (requestCode == RESULT_IMAGEFILE && resultCode == Activity.RESULT_OK) { } }このコードは、戻り値からURIを取得し、ParceFileDescriptorでファイルデータをとっています。
こんなURI「content://com.android.providers.media.documents/document/image%3A325268」が取得できるので、ここから画像を取得しています。MainAcrivity.javaif (resultData.getData() != null) { ParcelFileDescriptor pfDescriptor = null; try { Uri uri = resultData.getData(); pfDescriptor = getContentResolver().openFileDescriptor(uri, "r"); if (pfDescriptor != null) { FileDescriptor fileDescriptor = pfDescriptor.getFileDescriptor();このコードは先ほど取得した画像をbitmapに変換し、画像のサイズを300より小さくなるようにしています。
300よりでかい画像だと、正常に判定することができず、エラーで落ちてしまいます。
Caused by: java.lang.ArrayIndexOutOfBoundsException
そのため、縦横比を維持しつつ、縦横が300以内に収まるようにしています。MainAcrivity.javaBitmap bmp = BitmapFactory.decodeFileDescriptor(fileDescriptor); pfDescriptor.close(); if (!bmp.isMutable()) { bmp = bmp.copy(Bitmap.Config.ARGB_8888, true); } int height = bmp.getHeight(); int width = bmp.getWidth(); while (true) { int i = 2; if (width < 300 && height < 300) { break; } else { if (width > 300 || height > 300) { width = width / i; height = height / i; } else { break; } i++; } } Bitmap croppedBitmap = Bitmap.createScaledBitmap(bmp, width, height, false);いよいよ判別です。
このコードでは、加工した画像で判別をし、独自のリストで受け取っています。
そして、リストをforで回して、結果を取得し、textViewに表示させています。
今回は、判別された品目名のみ出力していますが、品目である可能性がどれくらいかなども取得することができます。MainAcrivity.javaList<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap); String text=""; for (Classifier.Recognition result : results) { /* RectF location = result.getLocation(); Float conf = result.getConfidence(); */ String title = result.getTitle(); text += title + "\n"; } textView.setText(text);以上で完成です!!
実際に動かしてみる
それでは、実際に動かしてみたいと思います!
まずは、犬の画像
公園のベンチ、、
爪、、、
アメリカンカメレオン、、、
んー
精度は微妙ですね
次は、美しい景色の画像。
デルフトの街並みですウィンドウスクリーン、、、
ドアマット、、、
ブラインド、、、
んー
ダメやん!まとめ
精度は微妙でしたが、うまく?画像を分類することができました!
今度は、リアルタイムで分類をしてみたいと思います!
ではでは!











