TL;DR
- tf1.xベースの所謂keras-GANベースの実装をtensorflow2.xで実行すると、学習失敗する
- BatchNormalizationの部分を
tf.compat.v1.keras.layers.BatchNormalization
に変更すると成功する - 原因はkeras-GANにおける
trainable
の切り替え。v2のBatchNormalizationはtrainable
で学習/推論モードを切り替えるので、Generatorの学習モデルcompile時点でDescriminatorにtrainable is False
を与えるkeras-GANの書き方はそぐわない。- (…と思われる,誤ってそうならウェルカム指摘)
背景
読書会で以下の本に取り組んでいるのだが、「なんかGoogle Colabで動かすとGANの学習失敗しない?」という話になった。
参加メンバが調査した所、
- tf1.xだと学習はうまくいく
- → マジックコマンド
%tensorlow version 1.x
- → マジックコマンド
- 更に、レイヤを部分的にすげ替えていくと、どうもBathNormalization()をv1にすると上手くいく
- → BatchNormalizationの部分だけを
tf.compat.v1.keras.layers.BatchNormalization
にする
- → BatchNormalizationの部分だけを
という所まで判明した。とはいえ、
(´・ω・`)「いやいやいや、そもそもBathNormの仕様なんて1と2で変わるんかいな???」
個人的にここが引っかかり続けたので調べることにした。結果、結構勉強になったので、本記事はその記録である。結論はTL;DRの通りなので、以下は読み物がてら。
tf1.x→2.xにおけるBatchNormの変更点
まずは順当に、tf1.xと2.xにおけるBatchNormalizationの仕様を確認する。
- 公式ドキュメントはこちら
- GitHubの実装ソースはこちら
- nkmk.me様の記事
一応ソースも見たが、先に白状しておくと、一番分かりやすかったのはnkmk.me様の記事だ。 note.nkmk.me
結論としては、v2のBatchNormはtrainable is False
ならtraining is False
に分岐する(v1は分岐しない)というのがポイント。
tf.compat.v1.keras.layers.BatchNormalization
にを呼ぶと、_USE_V2_BEHAVIOR is False
でBatchNormalizationBase
クラスが呼び出される_USE_V2_BEHAVIOR is False
、つまりv1を選ぶと、- fused(高速化された処理?)のコントロールが行われない
trainable is False
ならtraining is False
への分岐が行われない。- → つまり、v1のBatchNormは
trainable is False
でもtraining is True
になりうるということ
- → つまり、v1のBatchNormは
前提として、BathNormalization()にはtraining
とtrainable
2つのパラメータが有ることを抑えておく必要がある。
- BathNormalization()はNormalizationの単位をパラメータとして学習する
- 訓練時はミニバッチの平均と分散で正規化が行われる
- 推論時は、訓練時に得た正規化パラメータ(平均と分散)を元に、入力データをNormalizeする
- この切替は
BathNormalization()
のメンバ変数?であるtraining
にもっている - それとは別に、compileされたmodelとしては
trainable
をもっていて、これは明示的に設定もできるし、呼び出しメソッド(fit, predict)によっても変更される
keras-GANにおけるtrainableのコントロール
この時点で悪さをしてそうなのは多分trainable
だろうとあたりをつけつつ、じゃあDCGANの実装でtrainable
ってどうコントロールしてるの?と思い実装コードを確認していく。結論から言うと今回の現象の原因は、訓練ループ前にgan
モデルをcompile
する部分にある。
# Keep Discriminator’s parameters constant for Generator training discriminator.trainable = False # Build and compile GAN model with fixed Discriminator to train the Generator gan = build_gan(generator, discriminator) gan.compile(loss='binary_crossentropy', optimizer=Adam())
この辺りはDCGANの学習ループの流れとコードをちゃんと理解していないと混乱する(筆者は最初混乱した)。
学習for文に入る前のDescriminatorのtrainableの値はFalseになっている……のだが、そもそもkeras-GANの実装では、trainable
の変更による重みの凍結はmodel.compile
しないと反映されないという性質を利用している。上記の場合、discriminator.trainable = False
は、直後にcompile
されているgan
の計算グラフにのみ適用される(そしてgan
もといGeneratorの訓練時はDiscriminatorの学習を止める必要があつので、これで正しい)
またこの書き方はGans in Actionだけでなく、KerasのGAN実装で割とポピュラーなKeras-GANで採用されており、KerasのGAN実装では他にも見かけることが有る。
github.com
そして同じ疑問をこのRepoのIssueに挙げてる人も居た。気持ちはよくわかる https://github.com/eriklindernoren/Keras-GAN/issues/73#issuecomment-413105959
そしてBatchNormalization()はtf2.0から、trainable is False
ならtraining is False
に分岐する仕様に変わっている…ということで、tf2.0のままkeras-ganを実行すると、Discriminatorはモデル全体の重みは更新されるが、BathNormalization()だけは推論モードで実行される。
- 初期値がN(0, 1)なので、おそらく入力値に関わらずこの値で正規化される
- Generatorも同様。
- これによる学習のバランスが想定と変わり、学習失敗する。
まとめ
以上。ちなみに「じゃあBathNormalization()をGANで使うとき、Generatorの学習時はどうやってDescriminatorの重みをfreezeするの?」という疑問が残るが、個人的にはtf2.0だとGradiantTapeを使った書き方で割と自然にGとDを個別で計算して損失値をループの中で渡せるので、問題ないんではと思っている。公式実装参照。
なんでまぁ、tf1.xの実装を2.xで動かさないようにしましょうね、というだけの話なのだが、途中でも書いた通り個人的に勉強になったので、経過を残したかった次第。
以上