- 投稿日:2021-08-27T19:54:56+09:00
tf.data、TFRecordsを使った画像読み込み (tensorflow 2.6.0~)
チュートリアルとかにある全てnumpy配列に格納するタイプではなく、大規模データなどで使われる、バッチ毎に読み込む方法。 あまり最近のが無かったので、Qiitaで記事にした。 tensorflowは2.6.0を前提。 import glob import math import os from typing import List, Tuple, Union import albumentations as albu # 必要に応じて import tensorflow as tf import numpy as np # 必要に応じて AUTOTUNE = tf.data.experimental.AUTOTUNE tf.data 画像の読み込み tf.dataを使う場合、まずは読み込む用の関数を用意する。 簡単に書けば、こういうものになる。 @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) return img 高速化したいので、tf.functionでラッパーする。 (tf.dataを使うからargは絶対Tensorになるんだけど)、experimental_follow_type_hints=Trueを入れてtype hintsにtf.Tensorを入れておくと、たとえpythonのstringが入ってきても自動的にTensor型に変換してくれるので、再トレーシングとかが気にしなくても良くなる。これは今回のようなtf.dataに限らず、いろんな場面で使えるので、覚えておいて損はない。 tf.decode_imageは自動的にファイル形式を識別して読み込んでくれるので、decode_jpegとかdecode_pngとかよりも使いやすい。expand_animations=Falseでgifとか入ってきてもちゃんと3次元で返ってくれるようになる。 前処理 やり方は主に2種類。KerasのPreprocessing Layerも入るかもしれない。 1. tensorflow公式が準備している前処理関数を使う(tf.image、tensorflow addonsのtfa.image ) 2. 他モジュールやImageDataGeneratorを、tf.py_function経由で前処理をする(※) バリエーションが多い順に並べると後者が強い。ただし、TPUやマルチGPU環境ではうまく動作しないらしい。 1の例で、ランダムクロップと正規化を反映させたい場合、こんな感じ。 def preprocess(img: tf.Tensor) -> tf.Tensor: crop_size = tf.constant((128,128, img.shape[-1])) cropped = tf.image.random_crop(img, crop_size) normalized = tf.cast(cropped, tf.float32) / 255. return normalized @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) return preprocess(img) 他にも、いろいろな前処理があるので、公式ドキュメントや画像の水増し方法をTensorFlowのコードから学ぶなどを参照。 2の場合、albumentationsを使ってみるとこんな感じ。transformsはだいぶ昔に使ってたものを引っ張ってきた。 # callの際に、適当な前処理をどれか1つ選んで行う。 transforms = albu.OneOf([ albu.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90), albu.GaussNoise(), albu.Equalize(), albu.ElasticTransform(), albu.GaussianBlur(), albu.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20), albu.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5), albu.HorizontalFlip()]) def preprocess(img) -> tf.Tensor: augmented = transforms(image=img.numpy())['image'] return tf.convert_to_tensor(augmented, tf.float32) @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) img = tf.py_function(preprocess, [img], [tf.float32]) return preprocess(img) Datasetへ ここは公式のチュートリアルの方が良いので、公式を見るべし。 自分の場合、こういう形をよく作っている。 def generate_dataset(image_list: List[str], batch_size: int = 32, name: str = 'train'): """ Parameters -------- image_list: list (str, ...) 読み込む画像のパスが含まれるリスト batch_size: int, default = 32 バッチサイズ name: str, default = `train` データセットの名前。trainだと、シャッフルされる。 """ ds = tf.data.Dataset.from_tensor_slices(image_list) ds = ds.map(read_and_preprocess, num_parallel_calls=AUTOTUNE) if name == 'train': ds = ds.shuffle(2048) ds = ds.batch(batch_size) # check size size = ds.cardinality().numpy() if size > 0: print(f'{name} steps per epoch: {size}') ds = ds.prefetch(AUTOTUNE) return ds 例: p = glob.glob('samples/*') # 適当なフォルダを指定 dataset = generate_dataset(image_list = p, batch_size = 2) # データセット生成 x = iter(dataset) print(next(x)) 出力: train steps per epoch: 15 <tf.Tensor: shape=(2, 128, 128, 3), dtype=float32, numpy= array([[[[0.9490196 , 0.93333334, 0.92941177], [0.9529412 , 0.9372549 , 0.93333334], [0.95686275, 0.9411765 , 0.9372549 ], ..., [0.9137255 , 0.91764706, 0.89411765], [0.90588236, 0.9098039 , 0.8862745 ], [0.9019608 , 0.90588236, 0.88235295]]]], dtype=float32)> これでデータセットができた。あとはkerasのmodel.fitに入れるなり、train loopで使うなり。 TFRecord こちらは、「ネットワーク経由で、画像を1枚1枚読み込むのは非効率なため、バッチ毎にデータを取得して読み込みたい」といったことをしたいときに使える(ストレージサーバからのデータ取得等)。分散学習(Tensorflow Federated)だったり、TPUだと良く使われるらしい。 データの書き出し 前処理とシリアライズ まずは前処理とシリアライズ用の関数を準備。今回はSemantic Segmentationなどで使われることを想定して書く(画像とマスクがセットになる)。 以下は例で、前処理をカスタマイズしたい場合は、preprocess_imageとpreprocess_maskを好きに変える。学習はしないので、型はnumpyでもok。よってalubumentationやcv2等が使える。 # 学習はしないので、tf.functionはつけなくて良い def as_tensor(func): def function(*args, **kwargs) -> tf.Tensor: ret = func(*args, **kwargs) if not isinstance(ret, tf.Tensor): return tf.convert_to_tensor(ret) return ret return function @as_tensor def preprocess_image(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)): """ 画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) img = tf.image.resize(img, resize_shape) # ここに水増し用の関数を入れる # e.g. # img = transforms(image=img.numpy())['image'] img = tf.convert_to_tensor(img, tf.float32) img /= 255. # 水増し用の関数で正規化した場合、ここは不要 return img @as_tensor def preprocess_mask(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)): """ マスク画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) img = tf.image.resize(img, resize_shape) img = tf.cast(img, tf.float32) img /= 255. return img def serialize(base_path: tf.Tensor, mask_path: tf.Tensor): # 前処理とシリアライズ base_img = preprocess_image(base_path) base_img = tf.io.serialize_tensor(base_img) mask_img = preprocess_mask(mask_path) mask_img = tf.io.serialize_tensor(mask_img) bytes_base = tf.train.BytesList(value=[base_img.numpy()]) bytes_mask = tf.train.BytesList(value=[mask_img.numpy()]) features = tf.train.Features( feature={ 'base': tf.train.Feature(bytes_list = bytes_base), 'mask': tf.train.Feature(bytes_list = bytes_mask) } ) proto = tf.train.Example(features=features) return proto.SerializeToString() 特に前処理をしないなら、こういうのでok. def simply_load(img_path: Union[tf.Tensor, str]) -> tf.Tensor: """ マスク画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) return img def serialize(base_path: tf.Tensor, mask_path: tf.Tensor): # 前処理とシリアライズ base_img = simply_load(base_path) base_img = tf.io.serialize_tensor(base_img) mask_img = simply_load(mask_path) mask_img = tf.io.serialize_tensor(mask_img) ... 書き出し バッチサイズを指定して、(全ファイル数/バッチサイズ)個のファイルを作成する。 公式チュートリアルでも使われているtf.data.experimental.TFRecordWriterは2.6.0以降Deprecatedになったようで、tf.io.TFRecordWriterか、tf.data.experimental.save/tf.data.experimental.loadを使ってやってくれとのこと。後者は仕組みがよくわかっていないので、今回は前者で行う。 def split_path(filepathes: List[str], batch_size: int) -> List[List[str]]: total = math.ceil(len(filepathes)/batch_size) ret = [filepathes[i*batch_size:(i+1)*batch_size] for i in range(total)] if len(ret[-1]) == 0: ret = ret[:-1] return ret def write_to_tfrecord( image_list: List[str], mask_list: List[str], batch_size: int = 32, output_path: str = 'tfr_outputs', output_path_exist_ok: bool=False): """ Parameters -------- image_list: list (str, ...) 読み込む画像のパスが含まれるリスト mask_list: list (str, ...) 読み込む画像(マスク)のパスが含まれるリスト。image_listと同じ長さでなければならない。 batch_size: int, default = 32 バッチサイズ output_path: str, default = `tfr_outputs/` tfrecordファイルの保存先 output_path_exist_ok: bool, default = `False` output_pathの重複作成を許可するかどうか。Falseかつすでにフォルダが存在する場合。エラーが出る。 """ if len(image_list) != len(mask_list): raise ValueError('The sizes of image_list and mask_list do not match({} vs {}).'.format(len(image_list), len(mask_list))) if len(image_list) == 0: raise ValueError('Empty dataset.') os.makedirs(output_path, exist_ok=output_path_exist_ok) filepathes = list(zip(image_list, mask_list)) split_fs = split_path(filepathes, batch_size) size = len(split_fs) adjust_size = len(list(str(size))) for i, fs in enumerate(split_fs): output = os.path.join(output_path, 'batch_{}.tfrecord'.format(str(i+1).rjust(adjust_size, '0'))) print('\rWriting {}/{} to {}...'.format(str(i+1).rjust(adjust_size), size, output), end='') with tf.io.TFRecordWriter(output) as writer: for targets in fs: writer.write(serialize(*targets)) print('\nDone.') 例: write_to_tfrecord(image_list = p, mask_list = p, batch_size = 2) 出力: Writing 15/15 to tfr_outputs/batch_15.tfrecord... Done. データの読み込み tf.recordのファイルを読み込んでdeserializeする。 ここで前処理したいなら、deserializeの最後で前処理する関数などを入れる。 def deserialize(proto): parsed = tf.io.parse_example( proto, { 'base': tf.io.FixedLenFeature([], tf.string), 'mask': tf.io.FixedLenFeature([], tf.string) }) base_img = tf.io.parse_tensor(parsed['base'], out_type=tf.float32) mask_img = tf.io.parse_tensor(parsed['mask'], out_type=tf.float32) # 前処理したいならここにその操作を挿入 return base_img, mask_img def generate_dataset_from_tfrecord(path: str, batch_size: int = 32, name: str = 'train'): """ Parameters -------- path: str 読み込むtfrecordがあるフォルダ batch_size: int, default = 32 バッチサイズ name: str, default = `train` データセットの名前。trainだと、シャッフルされる。 """ file_list = [p for p in glob.glob(path + '/*') if os.path.isfile(p) and os.path.splitext(p)[1] == '.tfrecord'] if len(file_list) == 0: raise ValueError('Empty Dataset.') else: print(f'{name} steps per epoch: {len(file_list)}') ds = tf.data.TFRecordDataset(file_list) ds = ds.map(deserialize, num_parallel_calls=AUTOTUNE).batch(batch_size) if name == 'train': ds = ds.shuffle(2048) ds = ds.prefetch(AUTOTUNE) return ds 例: ds = generate_dataset_from_tfrecord('tfr_outputs', batch_size=2) x = next(iter(ds)) print(len(x)) print(x[0]) 出力: train steps per epoch: 15 2 tf.Tensor( [[[[0.775337 0.802788 0.8263174 ] [0.75114125 0.7785922 0.80212164] [0.7315334 0.7589844 0.7825138 ] ... [0.84117645 0.8372549 0.8215686 ] [0.84117645 0.8372549 0.8215686 ] [0.84432447 0.8404029 0.8247166 ]]]], shape=(2, 256, 256, 3), dtype=float32) ちゃんと型が復元されていることが分かる。
- 投稿日:2021-08-27T19:54:56+09:00
tf.data、TFRecordsを使った画像読み込み
チュートリアルとかにある全てnumpy配列に格納するタイプではなく、大規模データなどで使われる、バッチ毎に読み込む方法。 あまり最近のが無かったので、Qiitaで記事にした。 tensorflowは2.6.0を前提。 import glob import math import os from typing import List, Tuple, Union import albumentations as albu # 必要に応じて import tensorflow as tf import numpy as np # 必要に応じて AUTOTUNE = tf.data.experimental.AUTOTUNE tf.data 画像の読み込み tf.dataを使う場合、まずは読み込む用の関数を用意する。 簡単に書けば、こういうものになる。 @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) return img 高速化したいので、tf.functionでラッパーする。 (tf.dataを使うからargは絶対Tensorになるんだけど)、experimental_follow_type_hints=Trueを入れてtype hintsにtf.Tensorを入れておくと、たとえpythonのstringが入ってきても自動的にTensor型に変換してくれるので、再トレーシングとかが気にしなくても良くなる。これは今回のようなtf.dataに限らず、いろんな場面で使えるので、覚えておいて損はない。 tf.decode_imageは自動的にファイル形式を識別して読み込んでくれるので、decode_jpegとかdecode_pngとかよりも使いやすい。expand_animations=Falseでgifとか入ってきてもちゃんと3次元で返ってくれるようになる。ただし、jpegは注意点があるので、こちらを参照。 前処理 やり方は主に2種類。KerasのPreprocessing Layerも入るかもしれない。 1. tensorflow公式が準備している前処理関数を使う(tf.image、tensorflow addonsのtfa.image ) 2. 他モジュールやImageDataGeneratorを、tf.py_function経由で前処理をする(※) バリエーションが多い順に並べると後者が強い。ただし、TPUやマルチGPU環境ではうまく動作しないらしい。 1の例で、ランダムクロップと正規化を反映させたい場合、こんな感じ。 def preprocess(img: tf.Tensor) -> tf.Tensor: crop_size = tf.constant((128,128, img.shape[-1])) cropped = tf.image.random_crop(img, crop_size) normalized = tf.cast(cropped, tf.float32) / 255. return normalized @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) return preprocess(img) 他にも、いろいろな前処理があるので、公式ドキュメントや画像の水増し方法をTensorFlowのコードから学ぶなどを参照。 2の場合、albumentationsを使ってみるとこんな感じ。transformsはだいぶ昔に使ってたものを引っ張ってきた。 # callの際に、適当な前処理をどれか1つ選んで行う。 transforms = albu.OneOf([ albu.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90), albu.GaussNoise(), albu.Equalize(), albu.ElasticTransform(), albu.GaussianBlur(), albu.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20), albu.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5), albu.HorizontalFlip()]) def preprocess(img) -> tf.Tensor: augmented = transforms(image=img.numpy())['image'] return tf.convert_to_tensor(augmented, tf.float32) @tf.function(experimental_follow_type_hints=True) def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor: # tf.ioでファイルを読み込んでから画像にする _read = tf.io.read_file(img_path) img = tf.image.decode_image(_read, channels=3, expand_animations=False) img = tf.py_function(preprocess, [img], [tf.float32]) return preprocess(img) Datasetへ ここは公式のチュートリアルの方が良いので、公式を見るべし。 自分の場合、こういう形をよく作っている。 def generate_dataset(image_list: List[str], batch_size: int = 32, name: str = 'train'): """ Parameters -------- image_list: list (str, ...) 読み込む画像のパスが含まれるリスト batch_size: int, default = 32 バッチサイズ name: str, default = `train` データセットの名前。trainだと、シャッフルされる。 """ ds = tf.data.Dataset.from_tensor_slices(image_list) ds = ds.map(read_and_preprocess, num_parallel_calls=AUTOTUNE) if name == 'train': ds = ds.shuffle(2048) ds = ds.batch(batch_size) # check size size = ds.cardinality().numpy() if size > 0: print(f'{name} steps per epoch: {size}') ds = ds.prefetch(AUTOTUNE) return ds 例: p = glob.glob('samples/*') # 適当なフォルダを指定 dataset = generate_dataset(image_list = p, batch_size = 2) # データセット生成 x = iter(dataset) print(next(x)) 出力: train steps per epoch: 15 <tf.Tensor: shape=(2, 128, 128, 3), dtype=float32, numpy= array([[[[0.9490196 , 0.93333334, 0.92941177], [0.9529412 , 0.9372549 , 0.93333334], [0.95686275, 0.9411765 , 0.9372549 ], ..., [0.9137255 , 0.91764706, 0.89411765], [0.90588236, 0.9098039 , 0.8862745 ], [0.9019608 , 0.90588236, 0.88235295]]]], dtype=float32)> これでデータセットができた。あとはkerasのmodel.fitに入れるなり、train loopで使うなり。 TFRecord こちらは、「ネットワーク経由で、画像を1枚1枚読み込むのは非効率なため、バッチ毎にデータを取得して読み込みたい」といったことをしたいときに使える(ストレージサーバからのデータ取得等)。分散学習(Tensorflow Federated)だったり、TPUだと良く使われるらしい。 データの書き出し 前処理とシリアライズ まずは前処理とシリアライズ用の関数を準備。今回はSemantic Segmentationなどで使われることを想定して書く(画像とマスクがセットになる)。 以下は例で、前処理をカスタマイズしたい場合は、preprocess_imageとpreprocess_maskを好きに変える。学習はしないので、型はnumpyでもok。よってalubumentationやcv2等が使える。 # 学習はしないので、tf.functionはつけなくて良い def as_tensor(func): def function(*args, **kwargs) -> tf.Tensor: ret = func(*args, **kwargs) if not isinstance(ret, tf.Tensor): return tf.convert_to_tensor(ret) return ret return function @as_tensor def preprocess_image(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)): """ 画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) img = tf.image.resize(img, resize_shape) # ここに水増し用の関数を入れる # e.g. # img = transforms(image=img.numpy())['image'] img = tf.convert_to_tensor(img, tf.float32) img /= 255. # 水増し用の関数で正規化した場合、ここは不要 return img @as_tensor def preprocess_mask(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)): """ マスク画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) img = tf.image.resize(img, resize_shape) img = tf.cast(img, tf.float32) img /= 255. return img def serialize(base_path: tf.Tensor, mask_path: tf.Tensor): # 前処理とシリアライズ base_img = preprocess_image(base_path) base_img = tf.io.serialize_tensor(base_img) mask_img = preprocess_mask(mask_path) mask_img = tf.io.serialize_tensor(mask_img) bytes_base = tf.train.BytesList(value=[base_img.numpy()]) bytes_mask = tf.train.BytesList(value=[mask_img.numpy()]) features = tf.train.Features( feature={ 'base': tf.train.Feature(bytes_list = bytes_base), 'mask': tf.train.Feature(bytes_list = bytes_mask) } ) proto = tf.train.Example(features=features) return proto.SerializeToString() 特に前処理をしないなら、こういうのでok. def simply_load(img_path: Union[tf.Tensor, str]) -> tf.Tensor: """ マスク画像の前処理関数 """ img = tf.io.read_file(img_path) img = tf.image.decode_image(img, channels=3, expand_animations=False) return img def serialize(base_path: tf.Tensor, mask_path: tf.Tensor): # 前処理とシリアライズ base_img = simply_load(base_path) base_img = tf.io.serialize_tensor(base_img) mask_img = simply_load(mask_path) mask_img = tf.io.serialize_tensor(mask_img) ... 書き出し バッチサイズを指定して、(全ファイル数/バッチサイズ)個のファイルを作成する。 公式チュートリアルでも使われているtf.data.experimental.TFRecordWriterは2.6.0以降Deprecatedになったようで、tf.io.TFRecordWriterか、tf.data.experimental.save/tf.data.experimental.loadを使ってやってくれとのこと。後者は仕組みがよくわかっていないので、今回は前者で行う。 def split_path(filepathes: List[str], batch_size: int) -> List[List[str]]: total = math.ceil(len(filepathes)/batch_size) ret = [filepathes[i*batch_size:(i+1)*batch_size] for i in range(total)] if len(ret[-1]) == 0: ret = ret[:-1] return ret def write_to_tfrecord( image_list: List[str], mask_list: List[str], batch_size: int = 32, output_path: str = 'tfr_outputs', output_path_exist_ok: bool=False): """ Parameters -------- image_list: list (str, ...) 読み込む画像のパスが含まれるリスト mask_list: list (str, ...) 読み込む画像(マスク)のパスが含まれるリスト。image_listと同じ長さでなければならない。 batch_size: int, default = 32 バッチサイズ output_path: str, default = `tfr_outputs/` tfrecordファイルの保存先 output_path_exist_ok: bool, default = `False` output_pathの重複作成を許可するかどうか。Falseかつすでにフォルダが存在する場合。エラーが出る。 """ if len(image_list) != len(mask_list): raise ValueError('The sizes of image_list and mask_list do not match({} vs {}).'.format(len(image_list), len(mask_list))) if len(image_list) == 0: raise ValueError('Empty dataset.') os.makedirs(output_path, exist_ok=output_path_exist_ok) filepathes = list(zip(image_list, mask_list)) split_fs = split_path(filepathes, batch_size) size = len(split_fs) adjust_size = len(list(str(size))) for i, fs in enumerate(split_fs): output = os.path.join(output_path, 'batch_{}.tfrecord'.format(str(i+1).rjust(adjust_size, '0'))) print('\rWriting {}/{} to {}...'.format(str(i+1).rjust(adjust_size), size, output), end='') with tf.io.TFRecordWriter(output) as writer: for targets in fs: writer.write(serialize(*targets)) print('\nDone.') 例: write_to_tfrecord(image_list = p, mask_list = p, batch_size = 2) 出力: Writing 15/15 to tfr_outputs/batch_15.tfrecord... Done. データの読み込み tf.recordのファイルを読み込んでdeserializeする。 ここで前処理したいなら、deserializeの最後で前処理する関数などを入れる。 def deserialize(proto): parsed = tf.io.parse_example( proto, { 'base': tf.io.FixedLenFeature([], tf.string), 'mask': tf.io.FixedLenFeature([], tf.string) }) base_img = tf.io.parse_tensor(parsed['base'], out_type=tf.float32) mask_img = tf.io.parse_tensor(parsed['mask'], out_type=tf.float32) # 前処理したいならここにその操作を挿入 return base_img, mask_img def generate_dataset_from_tfrecord(path: str, batch_size: int = 32, name: str = 'train'): """ Parameters -------- path: str 読み込むtfrecordがあるフォルダ batch_size: int, default = 32 バッチサイズ name: str, default = `train` データセットの名前。trainだと、シャッフルされる。 """ file_list = [p for p in glob.glob(path + '/*') if os.path.isfile(p) and os.path.splitext(p)[1] == '.tfrecord'] if len(file_list) == 0: raise ValueError('Empty Dataset.') else: print(f'{name} steps per epoch: {len(file_list)}') ds = tf.data.TFRecordDataset(file_list) ds = ds.map(deserialize, num_parallel_calls=AUTOTUNE).batch(batch_size) if name == 'train': ds = ds.shuffle(2048) ds = ds.prefetch(AUTOTUNE) return ds 例: ds = generate_dataset_from_tfrecord('tfr_outputs', batch_size=2) x = next(iter(ds)) print(len(x)) print(x[0]) 出力: train steps per epoch: 15 2 tf.Tensor( [[[[0.775337 0.802788 0.8263174 ] [0.75114125 0.7785922 0.80212164] [0.7315334 0.7589844 0.7825138 ] ... [0.84117645 0.8372549 0.8215686 ] [0.84117645 0.8372549 0.8215686 ] [0.84432447 0.8404029 0.8247166 ]]]], shape=(2, 256, 256, 3), dtype=float32) ちゃんと型が復元されていることが分かる。