20190707のTensorFlowに関する記事は2件です。

【Tensorflow・VGG16】転移学習による画像分類

やること(概要)

  • 1. 画像データの収集
  • 2. データセットの作成(画像データの変換)
  • 3. モデルの作成 & 学習
  • 4. 実行(コマンドライン)

動作環境

  • macOS Catalina 10.15 beta
  • Python 3.6.8
  • flickapi 2.4
  • pillow 6.0.0
  • scikit-learn 0.20.3
  • google colaboratory

実施手順

1. 画像データの収集

・3種類(りんご、トマト、いちご)の画像分類を実施するため、画像ファイルをflickrから取得
・flickrによる画像ファイルの取得方法は前回記事で書いたこちら
・それぞれ300枚の画像ファイルを取得
・検索キーワードは、「apple」、「tomato」、「strawberry」を指定
・flickrからダウンロードした不要なデータ(検索キーワードと関係ない画像ファイル)は目で見て除外しておく

download.py
from flickrapi import FlickrAPI
from urllib.request import urlretrieve
import os, time, sys

# Set your own API Key and Secret Key
key = "XXXXXXXXXX"
secret = "XXXXXXXXXX"
wait_time = 0.5

keyword = sys.argv[1]
savedir = "./data/" + keyword

flickr = FlickrAPI(key, secret, format='parsed-json')
result = flickr.photos.search(
    text = keyword,
    per_page = 300,
    media = 'photos',
    sort = 'relevance',
    safe_search = 1,
    extras = 'url_q, license'
)

photos = result['photos']

for i, photo in enumerate(photos['photo']):
    url_q = photo['url_q']
    filepath = savedir + '/' + photo['id'] + '.jpg'
    if os.path.exists(filepath): continue
    urlretrieve(url_q,filepath)
    time.sleep(wait_time)

2. データセットの作成(画像データの変換)

・取得した画像ファイルをnumpy形式(バイナリファイル -> .npy)で保存
・VGG16のデフォルトサイズの224にresize

generate_data.py
from PIL import Image
import os, glob
import numpy as np
from sklearn import model_selection

classes = ['apple', 'tomato', 'strawberry']
num_classes = len(classes)
IMAGE_SIZE = 224 # Specified size of VGG16 Default input size in VGG16

X = [] # image file
Y = [] # correct label

for index, classlabel in enumerate(classes):
    photo_dir = './data/' + classlabel
    files = glob.glob(photo_dir + '/*.jpg')
    for i, file in enumerate(files):
        image = Image.open(file)
        # standardize to 'RGB'
        image = image.convert('RGB')
        # to make image file all the same size
        image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
        data = np.asarray(image)
        X.append(data)
        Y.append(index)

X = np.array(X)
Y = np.array(Y)

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, Y)
xy = (X_train, X_test, y_train, y_test)
np.save('./image_files.npy', xy)

3. モデルの作成 & 学習

1). Google Colaboratoryの利用

  • トレーニング処理に時間がかかるため、GPUが無料で利用可能なGoogle Colaboratoryを使用(環境構築不要・無料で使えるブラウザ上のPython実行環境)
  • 今回はGoogle Driveに「2.」で作成した「image_files.npy」をGoogle Driveへ格納し、ファイルをGoogle Colabからの読み込み
  • 読み込みするためにGoogle Driveのマウントが必要であるが、方法は下記の通り (Google Colabの詳しい使い方はこちらを参考にした)
マウント方法
from google.colab import drive
drive.mount('/content/gdrive')
# image_files.npyの格納先(My Drive直下に'hoge'フォルダを作成し、そこに格納)
PATH = '/content/gdrive/My Drive/hoge/'

2). データの読み込み & データ変換

  • google driveに格納した「image_files.npy」を読み込み、訓練データとテストデータに分割
  • 正解ラベルをone-hotベクトルへ変換(Ex:0 -> [1,0,0], 1 -> [0,1,0]のようなイメージ)
  • データを標準化(画像データを0~1の範囲に変換。RGB形式なので、(0,0,0)~(255,255,255)の範囲であるため、255で割る)
X_train, X_test, y_train, y_test = np.load(PATH + 'image_files.npy', allow_pickle=True)

# convert one-hot vector
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

# normalization
X_train = X_train.astype('float') / 255.0
X_test = X_test.astype('float') / 255.0

3). モデルの作成

  • VGG16を利用
  • 3つのパラメータは下記の通り。
    • include_top : ネットワークの出力層側にある3つの全結合層(Fully Connected層)を含むかどうか。今回はFC層を独自に計算するため、Falseを指定。
    • weights : VGG16の重みの種類を指定する。None(ランダム初期化)か'imagenet' (ImageNetで学習した重み)のどちらか一方
    • input_shape : オプショナルなshapeのタプル。include_topがFalseの場合のみ指定可能 (そうでないときは入力のshapeは(224, 224, 3)。正確に3つの入力チャンネルをもつ必要があり、width とheightは48以上にする必要がある
モデルの作成
vgg16_model = VGG16(
    weights='imagenet',
    include_top=False,
    input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)   
)
  • FC層を構築
  • input_shapeには上記modelのoutputの形状で、1番目以降を指定(0番目は個数が入っている)
FC層の構築
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(num_classes, activation='softmax'))
  • vgg16_modelとtop_modelを結合してモデルを作成
モデルの結合
# combine models
model = Model(
    inputs=vgg16_model.input,
    outputs=top_model(vgg16_model.output)
)
model.summary()
model.summaryの出力結果
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 2)                 6423298   
=================================================================
Total params: 21,137,986
Trainable params: 21,137,986
Non-trainable params: 0
_________________________________________________________________

4). 重みの固定

  • 上記で作成したモデルは下記2つを結合したもの
    • vgg16_model:FC層を除いたVGG16
    • top_model:多層パーセプトロン
  • この内、vgg16_modelの'block4_pool'(model.summary参照)までの重みを固定(VGG16の高い特徴量抽出を継承するため)
重みの固定
for layer in model.layers[:15]:
    layer.trainable = False

5).モデルの学習

  • optimizerはSGDを指定
  • 多クラス分類を指定
モデルの学習
opt = SGD(lr=1e-4, momentum=0.9)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=32, epochs=10)

6).テストデータでの評価

テストデータでの評価
score = model.evaluate(X_test, y_test, batch_size=32)
print('loss: {0} - acc: {1}'.format(score[0], score[1]))

7).モデルの保存

モデルの保存
model.save(PATH + 'vgg16_transfer.h5')

4. 実行(コマンドライン)

  • 作成したモデル(vgg16_transfer.h5)を使って、画像ファイルの推定を行う
predict.py
import numpy as np
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from PIL import Image
import sys

classes = ['apple', 'tomato', 'strawberry']
num_classes = len(classes)
IMAGE_SIZE = 224

# convert data by specifying file from terminal
image = Image.open(sys.argv[1])
image = image.convert('RGB')
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
data = np.asarray(image)
X = []
X.append(data)
X = np.array(X)

# load model
model = load_model('./vgg16_transfer.h5')

# estimated result of the first data (multiple scores will be returned)
result = model.predict([X])[0]
predicted = result.argmax()
percentage = int(result[predicted] * 100)

print(classes[predicted], percentage)
  • 実行は下記の通り(引数に推定する画像ファイル名を指定)
実行
$ python predict.py XXXX.jpeg
結果例
strawberry 100

ソースコード

https://github.com/hiraku00/vgg16_transfer
('image_files.npy'と'vgg16_transfer.h5'は100MB超過のため除外)

参考文献

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Android(JAVA)でTensorFlowLiteを使って画像分類をやってみる

初めに

今回は、Android(Java)でTensorFlowLiteを使って、画像分類をしようと思います!
もし、コード等に間違え、改善点があれば、教えてください!

TensorFlowLiteとは

TensorFlowLiteのガイドによると、、、

TensorFlowLiteはスマートフォンやIotデバイスなどでTensorFlowモデルを使用するためのツールセットです。

TensorFlow Liteのインタプリタ携帯電話、組み込みLinuxデバイス、およびマイクロコントローラを含む多くの異なるハードウェアの種類、に特別に最適化されたモデルを実行します。

TensorFlowライトコンバータインタプリタによって使用するための効率的な形式にTensorFlowモデルを変換し、バイナリサイズとパフォーマンスを向上させるために最適化を導入することができます。

⇒つまり、PCだけじゃなくて、スマートフォンやIotデバイスなどでも簡単に実行できる、TensorFlowの軽量版的なやつか!
将来的には、スマートフォンだけで学習までできるとか!
すごい!

TensorFlowLiteを使用した開発手順

1.TensorFlowモデルを用意する

TensorFlowで学習済みのモデルを用意します。
今回は、ホストされたモデルを使用するので、割愛します!

2.TensorFlowのモデルを変換する

TensorFlowLiteではTensorFlowのモデルをそのまま使用することができないので、専用の形式(tflite)に変換します。
変換方法等はこちらの記事がわかりやすいので、参考にしてください。

3.組み込む!

今回は、Android(Java)の組み込み方法を解説します!

組み込む!

新規プロジェクトの作成

プロジェクト名等は任意の名前にしてください!
image.png
今回は、AndroidXを使用します。
「Use androidx.* artifacts」にチェックすれば、OKです。
AndroidXの使用については任意なので、使わなくても大丈夫です。

依存関係の追加

appディレクトリ下のbuild.gradleに

build.gradle(app)
dependencies {
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
    implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly'
}

を追加します。
このままだと、すべてのCPUと命令セット用のABIが含まれていますが、「armeab-v7a」と「arm64-v8a」が含まれていれば、ほとんどのAndroidデバイスをカバーできるので、ほかのABIは含めいないように設定します。
含まれていても問題はないですが、アプリのサイズが減るので、おすすめです。

build.gradle(app)
android {
    defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
}

ABIについてはこちらの記事がわかりやすいので参考にしてください。

Androidではassetフォルダーに入れられたものを圧縮してしまうので、モデルをassetフォルダーに入れると、圧縮されて読み込むことができなくなってしまいます。そこで、tfliteファイルを圧縮しないように指定してあげます。

build.gradle(app)
android {
    defaultConfig {
        aaptOptions {
            noCompress "tflite"
        }
    }
}

モデルの設置

モデルとlabel_textをassetフォルダーに設置します。
こちらよりモデルをダウンロードしてください。
image.png

まずは、assetフォルダーフォルダーを作成します。
image.png
解凍したフォルダの中から、ファイルをコピーします。
image.png
コピーしたら、名前を「model.tflite」と「labels.txt」に変更します。
image.png

これでモデルの設置は完了です。

クラスのコピー&カスタマイズ

TensorFlowLiteのAndroidSampleの3つのクラスと、こちらのLogger.javaをコピーします。
image.png
コピーしただけだと、エラーが発生します。
Classifier.javaでLoggerクラスのインポート先を書き換えます。

Classfier.java
import org.tensorflow.lite.Interpreter;
//ここを削除するimport org.tensorflow.lite.examples.classification.env.Logger;
import org.tensorflow.lite.gpu.GpuDelegate;

/** A classifier specialized to label images using TensorFlow Lite. */
public abstract class Classifier {
    private static final Logger LOGGER = new Logger();

image.png
削除すると、AndroidStudioがこんなことを聞いてくるので、「Alt+Enter」を押せば、自動でインポートしてくれます。
image.png
インポートする際に、
image.png

2種類出てくると思いますので、(android.jar)と書かれていないほうを選択します。

これで、エラーがすべて消えたと思います。

モデルの読み込み

ClassifierFloatMobileNet.java
ClassifierQuantizedMobileNet.java
の2つに共通しているモデルの読み込み部分を変更

ClassifierFloatMobileNet.java,ClassifierQuantizedMobileNet.java
  @Override
  protected String getModelPath() {
    // you can download this file from
    // see build.gradle for where to obtain this file. It should be auto
    // downloaded into assets.
    return "mobilenet_v1_1.0_224.tflite";
  }

変更後

ClassifierFloatMobileNet.java,ClassifierQuantizedMobileNet
  @Override
  protected String getModelPath() {
    // you can download this file from
    // see build.gradle for where to obtain this file. It should be auto
    // downloaded into assets.
    return "model.tflite";
  }


Viewの配置

こんな感じにTextView,Button,ImageViewを配置します。
ButtonにはonClickを設定押しておきます。
↑ onClickを設定する方法ってリスナーのほうがいいのかな。詳しい人おしえてくださいな
image.png

activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:orientation="vertical">

        <LinearLayout
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:orientation="horizontal">

            <TextView
                android:id="@+id/textView"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content"
                android:layout_weight="1"
                android:text="TextView" />
        </LinearLayout>

        <LinearLayout
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:orientation="horizontal">

            <Button
                android:id="@+id/button"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content"
                android:layout_weight="1"
                android:onClick="select"
                android:text="画像を選択" />
        </LinearLayout>

        <LinearLayout
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:orientation="horizontal">

            <ImageView
                android:id="@+id/imageView"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content"
                android:layout_weight="1"
                tools:srcCompat="@tools:sample/avatars" />
        </LinearLayout>
    </LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

コードを書く

まず、使用する変数を宣言しましょう!

MainActivity.java
    ImageView imageView;
    TextView textView;
    Classifier classifier;
    private static final int RESULT_IMAGEFILE = 1001;  //画像取得時に使用するリクエストコード

onCreate内でtextview,ImageViewの紐づけを行います。

MainActivity.java
        imageView = findViewById(R.id.imageView);
        textView = findViewById(R.id.textView);

次に、Classfierの呼び出しを行います。

MainActivity.java
        try {
            classifier = Classifier.create(this,QUANTIZED,CPU,2);
        } catch (IOException e) {
            e.printStackTrace();
        }

引数は、Acritivy、Modelの種類、演算に使用するデバイスの指定、使用するスレッド数を指定します。
基本はこの設定で動くと思いますが、臨機応変に変更しましょう。

ボタンの動作を書く

ボタンを押したら、ギャラリーを開いて画像が選択できるようにIntentを飛ばします。

MainAcritivy.java
public void image(View V){
        Intent intent = new Intent(Intent.ACTION_OPEN_DOCUMENT);
        intent.addCategory(Intent.CATEGORY_OPENABLE);
        intent.setType("image/*");
        startActivityForResult(intent, RESULT_IMAGEFILE);
}

これについて詳しくはこちら

ギャラリーから戻ってきてからの処理

ギャラリーから戻て来たら、画像を取得して、処理します。

MainAcritivty.java
    @Override
    public void onActivityResult(int requestCode, int resultCode, Intent resultData) {
        super.onActivityResult(requestCode, resultCode, resultData);
        if (requestCode == RESULT_IMAGEFILE && resultCode == Activity.RESULT_OK) {
            if (resultData.getData() != null) {
                ParcelFileDescriptor pfDescriptor = null;
                try {
                    Uri uri = resultData.getData();
                    pfDescriptor = getContentResolver().openFileDescriptor(uri, "r");
                    if (pfDescriptor != null) {
                        FileDescriptor fileDescriptor = pfDescriptor.getFileDescriptor();
                        Bitmap bmp = BitmapFactory.decodeFileDescriptor(fileDescriptor);
                        pfDescriptor.close();
                        int height = bmp.getHeight();
                        int width = bmp.getWidth();
                        while (true) {
                            int i = 2;
                            if (width < 500 && height < 500) {
                                break;
                            } else {
                                if (width > 500 || height > 500) {
                                    width = width / i;
                                    height = height / i;
                                } else {
                                    break;
                                }
                                i++;
                            }
                        }

                        Bitmap croppedBitmap = Bitmap.createScaledBitmap(bmp, width, height, false);
                        imageView.setImageBitmap(croppedBitmap);
                        List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap,classfier);
                        String text;
                        for (Classifier.Recognition result : results) {
                            RectF location = result.getLocation();
                            Float conf = result.getConfidence();
                            String title = result.getTitle();
                            text += title + "\n";
                        }
                         textView.setText(text);
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                } finally {
                    try {
                        if (pfDescriptor != null) {
                            pfDescriptor.close();
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }

            }
        }
    }

長いので、区切って説明します。

このコードは、アクティビティに戻ってきたときに呼ばれ、戻ってきたのが、ギャラリーからのものかを判定しています。

MainAcrivity.java
   @Override
    public void onActivityResult(int requestCode, int resultCode, Intent resultData) {
        super.onActivityResult(requestCode, resultCode, resultData);
        if (requestCode == RESULT_IMAGEFILE && resultCode == Activity.RESULT_OK) {
        }
    }

このコードは、戻り値からURIを取得し、ParceFileDescriptorでファイルデータをとっています。
こんなURI「content://com.android.providers.media.documents/document/image%3A325268」が取得できるので、ここから画像を取得しています。

MainAcrivity.java
            if (resultData.getData() != null) {
                ParcelFileDescriptor pfDescriptor = null;
                try {
                    Uri uri = resultData.getData();
                    pfDescriptor = getContentResolver().openFileDescriptor(uri, "r");
                    if (pfDescriptor != null) {
                        FileDescriptor fileDescriptor = pfDescriptor.getFileDescriptor();

このコードは先ほど取得した画像をbitmapに変換し、画像のサイズを300より小さくなるようにしています。
300よりでかい画像だと、正常に判定することができず、エラーで落ちてしまいます。
Caused by: java.lang.ArrayIndexOutOfBoundsException
そのため、縦横比を維持しつつ、縦横が300以内に収まるようにしています。

MainAcrivity.java
                        Bitmap bmp = BitmapFactory.decodeFileDescriptor(fileDescriptor);
                        pfDescriptor.close();

                        if (!bmp.isMutable()) {
                            bmp = bmp.copy(Bitmap.Config.ARGB_8888, true);
                        }
                        int height = bmp.getHeight();
                        int width = bmp.getWidth();
                        while (true) {
                            int i = 2;
                            if (width < 300 && height < 300) {
                                break;
                            } else {
                                if (width > 300 || height > 300) {
                                    width = width / i;
                                    height = height / i;
                                } else {
                                    break;
                                }
                                i++;
                            }
                        }
                        Bitmap croppedBitmap = Bitmap.createScaledBitmap(bmp, width, height, false);

いよいよ判別です。
このコードでは、加工した画像で判別をし、独自のリストで受け取っています。
そして、リストをforで回して、結果を取得し、textViewに表示させています。
今回は、判別された品目名のみ出力していますが、品目である可能性がどれくらいかなども取得することができます。

MainAcrivity.java
                        List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
                        String text="";
                        for (Classifier.Recognition result : results) {
                            /*
                            RectF location = result.getLocation();
                            Float conf = result.getConfidence();
                            */
                            String title = result.getTitle();
                            text += title + "\n";
                        } 
                        textView.setText(text);

以上で完成です!!

実際に動かしてみる

それでは、実際に動かしてみたいと思います!
まずは、犬の画像
image.png
公園のベンチ、、
爪、、、
アメリカンカメレオン、、、
んー
精度は微妙ですね
次は、美しい景色の画像。
デルフトの街並みです

image.png

ウィンドウスクリーン、、、
ドアマット、、、
ブラインド、、、
んー
ダメやん!

まとめ

精度は微妙でしたが、うまく?画像を分類することができました!
今度は、リアルタイムで分類をしてみたいと思います!
ではでは!

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む