20200730のTensorFlowに関する記事は2件です。

"categorical_crossentropy"と"sparse_categorical_crossentropy"の違い

結論

  • 使用するラベルが違います。違いはそれだけです。"categorical_crossentropy"にはonehot(どこか1つが1で他は全て0)のラベルを使用します。"sparse_categorical_crossentropy"のラベルには整数を使用します。

one-hotと整数表現の違い

例)10分類の場合

one-hot表現 整数表現
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [9]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [2]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [1]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [5]

整数ラベルをone-hotラベルに変換

データセットでは整数ラベルのものが多い印象ですが,損失関数の多くは整数ラベルではなく,one-hotラベルを与えてあげないと動きません.そういう場合は変換する必要があります.(というか,むしろ"sparse_categorical_crossentropy"のように整数ラベルのまま学習できる損失関数が少数派だと感じます.)

以下にコードを記します.

import numpy as np

n_labels = len(np.unique(train_labels))
train_labels_onehot = np.eye(n_labels)[train_labels]

n_labels = len(np.unique(test_labels))
test_labels_onehot = np.eye(n_labels)[test_labels]
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Tensorflow 2.x でもFIDが計算したい!

Tensorflow で FID を計算したい!楽に

Tensorflow を使って FID を計算したいなーって思ったけど色々調べていたけど、以下の点があんまり良くなかったです

  • 一旦画像データを全部メモリに乗せるのが大変
    FID スコアは沢山データがあればあるほど安定します。というかデータ数によってスコアが変わります。うっかり numpy で画像を全部持ってくると、メモリに乗せてしまうので、地獄です。
  • tensorflow データセットと上手く合わせたい
    せっかく tensorflow は色々前処理をさせてくれるんだし、利用できるといいなーと思いました。

コード

そのままファイルにして保存すればいいんじゃないっすかね。

#!/usr/bin/env python3
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm
from typing import Dict

EPS_VAL = 1e-6
LENGTH_FEATURE_VEC = 2048
AUTOTUNE = tf.data.experimental.AUTOTUNE


class FID:
    def __init__(
        self,
        scaling_func: tf.image.ResizeMethod = "nearest",
        batch_size: int = 128,
        num_samples: int = 100000,
    ):
        assert num_samples >= 2048, "invalid sample size"
        self.scaling_func = scaling_func
        self.model = tf.keras.applications.InceptionV3(
            include_top=False, pooling="avg", input_shape=(299, 299, 3)
        )
        self.num_samples = num_samples
        while num_samples % batch_size != 0:
            batch_size = batch_size - 1
        self.batch_size = batch_size

    def rescale_img_size(self, img: tf.Tensor):
        return tf.image.resize(img, size=(299, 299), method=self.scaling_func)

    def gray_to_color(self, img):
        return tf.tile(img, multiples=[1, 1, 1, 3])

    @tf.function
    def calculate_fid(self, feat1, feat2):
        mu1, sigma1 = tf.reduce_mean(feat1, axis=0), tfp.stats.covariance(feat1)
        mu2, sigma2 = tf.reduce_mean(feat2, axis=0), tfp.stats.covariance(feat2)
        ssdiff = tf.reduce_sum(tf.square(mu1 - mu2))

        mu1 = tf.cast(mu1, tf.float64)
        mu2 = tf.cast(mu2, tf.float64)
        sigma1 = tf.cast(sigma1, tf.float64)
        sigma2 = tf.cast(sigma2, tf.float64)
        ssdiff = tf.cast(ssdiff, tf.float64)

        eps = tf.constant(EPS_VAL, dtype=tf.float64)
        offset = tf.eye(LENGTH_FEATURE_VEC, dtype=tf.float64) * eps
        tdot = tf.tensordot(sigma1 + offset, sigma2 + offset, axes=1)
        covmean = tf.linalg.sqrtm(tdot)
        covmean = tf.math.real(covmean)
        fid = ssdiff + tf.linalg.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid

    @tf.function
    def process(self, img: tf.Tensor):
        """
        Args:
            img (tf.Tensor): [B, H, W, C] where each element R[0, 255]
        """
        if img.shape[-1] == 1:
            img = self.gray_to_color(img)
        img = tf.cast(img, tf.float64)
        img = self.rescale_img_size(img)
        img = tf.keras.applications.inception_v3.preprocess_input(img)
        return img

    @tf.function
    def get_feature(self, img: tf.Tensor):
        return self.model(img)

    def calculate_fid_with_ds(self, ds1: tf.data.Dataset, ds2: tf.data.Dataset):

        ds1 = ds1.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)
        ds2 = ds2.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)

        feature1 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
        feature2 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
        for idx, batch1, batch2 in tqdm(
            zip(range(self.num_samples // self.batch_size), ds1, ds2)
        ):
            feat1 = self.get_feature(batch1)
            feat2 = self.get_feature(batch2)
            feature1 = tf.concat([feature1, feat1], axis=0)
            feature2 = tf.concat([feature2, feat2], axis=0)
        fid = self.calculate_fid(feature1, feature2)
        return fid

使い方

if __name__ == "__main__":
    import pathlib
    import tensorflow_datasets as tfds

    @tf.function
    def extract_image(sample: Dict):
        img = tf.cast(sample["image"], tf.float32)
        shapes = tf.shape(img)
        h, w = shapes[-3], shapes[-2]
        small = tf.minimum(h, w)
        img = tf.image.resize_with_crop_or_pad(img, small, small)
        return img

    import matplotlib.pyplot as plt

    ds1 = tfds.load("cifar10")["train"].map(extract_image)
    ds2 = tfds.load("cifar10")["test"].map(extract_image)

    fid = FID(num_samples=10000)
    print(fid.calculate_fid_with_ds(ds1, ds2))

output

tf.Tensor(5.46545230432653, shape=(), dtype=float64)

ちなみに、 num_samples は、FID 計算に用いる画像の数です。データセットに含まれる画像数全部を入れるといいんじゃないっすかね。

なお、サンプルの数とFIDスコアの関係は次のとおりです。

FIDvsSample.png

使ったコードは以下です。

if __name__ == "__main__":
    import pathlib
    import tensorflow_datasets as tfds

    @tf.function
    def extract_image(sample: Dict):
        img = tf.cast(sample["image"], tf.float32)
        shapes = tf.shape(img)
        h, w = shapes[-3], shapes[-2]
        small = tf.minimum(h, w)
        img = tf.image.resize_with_crop_or_pad(img, small, small)
        return img

    import matplotlib.pyplot as plt

    ds1 = tfds.load("cifar10")["train"].map(extract_image)
    ds2 = tfds.load("cifar10")["test"].map(extract_image)

    idxs = []
    fids = []
    for i in range(3, 11):
        fid = FID(num_samples=i * 1000, batch_size=100)
        score = fid.calculate_fid_with_ds(ds1, ds2)
        print("samples {}: {}".format(i * 1000, score))
        fids.append(score.numpy())
        idxs.append(i * 1000)
    plt.plot(idxs, fids, marker="o")
    plt.title("FID Score vs Sample Size")
    plt.ylabel("FID Score")
    plt.xlabel("Sample Size")
    plt.show()
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む