- 投稿日:2020-10-06T17:02:49+09:00
kerasのRNN系APIの引数return_state, return_sequencesについて
kerasのRNN系APIのGRUの引数return_state, return_sequencesについて
大雑把に書きました。
環境
python3使用する擬似データ
B = 1 #バッチサイズ T = 10 #時系列長 N = 1000 #特徴量 data = np.random.randn(B, T, N)使用するRNN系インターフェース
tf.keras.layers.GRUreturn_state=True, return_sequences=True
赤丸 : return_sequencesをTrue時
緑丸 : return_statesをTrue時gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=True) B = 1 T = 10 N = 1000 data = np.random.randn(B, T, N) outputs, states = gru(data) print("赤丸:", outputs.shape) print("緑丸:", states.shape)赤丸: (1, 10, 256) 緑丸: (1, 256)return_state=True, return_sequences=False
赤丸 : return_sequencesをFalse時
緑丸 : return_statesをTrue時gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=False) B = 1 T = 10 N = 1000 data = np.random.randn(B, T, N) outputs, states = gru(data) print("赤丸:", outputs.shape) print("緑丸:", states.shape)赤丸: (1, 256) 緑丸: (1, 256)return_state=False, return_sequences=True
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=True) B = 1 T = 10 N = 1000 data = np.random.randn(B, T, N) outputs = gru(data) print("赤丸:", outputs.shape) print("緑丸なし")赤丸: (1, 10, 256) 緑丸なしreturn_state=False, return_sequences=False
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=False) B = 1 T = 10 N = 1000 data = np.random.randn(B, T, N) outputs = gru(data) print("赤丸:", outputs.shape) print("緑丸なし")赤丸: (1, 256) 緑丸なし



