shine-Notes

ゆるふわ思考ダンプ

tensorflow v2 のBatchNormalizationを使うと参考書のGAN訓練(tf1.x)が失敗した話

TL;DR

  • tf1.xベースの所謂keras-GANベースの実装をtensorflow2.xで実行すると、学習失敗する
  • BatchNormalizationの部分をtf.compat.v1.keras.layers.BatchNormalizationに変更すると成功する
  • 原因はkeras-GANにおけるtrainableの切り替え。v2のBatchNormalizationはtrainableで学習/推論モードを切り替えるので、Generatorの学習モデルcompile時点でDescriminatorにtrainable is Falseを与えるkeras-GANの書き方はそぐわない。
    • (…と思われる,誤ってそうならウェルカム指摘)

背景

読書会で以下の本に取り組んでいるのだが、「なんかGoogle Colabで動かすとGANの学習失敗しない?」という話になった。

参加メンバが調査した所、

  • tf1.xだと学習はうまくいく
    • → マジックコマンド%tensorlow version 1.x
  • 更に、レイヤを部分的にすげ替えていくと、どうもBathNormalization()をv1にすると上手くいく
    • → BatchNormalizationの部分だけをtf.compat.v1.keras.layers.BatchNormalizationにする

という所まで判明した。とはいえ、

(´・ω・`)「いやいやいや、そもそもBathNormの仕様なんて1と2で変わるんかいな???」

個人的にここが引っかかり続けたので調べることにした。結果、結構勉強になったので、本記事はその記録である。結論はTL;DRの通りなので、以下は読み物がてら。

tf1.x→2.xにおけるBatchNormの変更点

まずは順当に、tf1.xと2.xにおけるBatchNormalizationの仕様を確認する。

一応ソースも見たが、先に白状しておくと、一番分かりやすかったのはnkmk.me様の記事だ。 note.nkmk.me



結論としては、v2のBatchNormはtrainable is Falseならtraining is Falseに分岐する(v1は分岐しない)というのがポイント

  • tf.compat.v1.keras.layers.BatchNormalizationにを呼ぶと、_USE_V2_BEHAVIOR is FalseBatchNormalizationBaseクラスが呼び出される
    • _USE_V2_BEHAVIOR is False、つまりv1を選ぶと、
      • fused(高速化された処理?)のコントロールが行われない
      • trainable is Falseならtraining is Falseへの分岐が行われない。
        • → つまり、v1のBatchNormはtrainable is Falseでもtraining is Trueになりうるということ

前提として、BathNormalization()にはtrainingtrainable2つのパラメータが有ることを抑えておく必要がある。

  • BathNormalization()はNormalizationの単位をパラメータとして学習する
    • 訓練時はミニバッチの平均と分散で正規化が行われる
    • 推論時は、訓練時に得た正規化パラメータ(平均と分散)を元に、入力データをNormalizeする
  • この切替はBathNormalization()のメンバ変数?であるtrainingにもっている
  • それとは別に、compileされたmodelとしてはtrainableをもっていて、これは明示的に設定もできるし、呼び出しメソッド(fit, predict)によっても変更される

keras-GANにおけるtrainableのコントロール

この時点で悪さをしてそうなのは多分trainableだろうとあたりをつけつつ、じゃあDCGANの実装でtrainableってどうコントロールしてるの?と思い実装コードを確認していく。結論から言うと今回の現象の原因は、訓練ループ前にganモデルをcompileする部分にある。

# Keep Discriminator’s parameters constant for Generator training   
discriminator.trainable = False

# Build and compile GAN model with fixed Discriminator to train the Generator
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

この辺りはDCGANの学習ループの流れとコードをちゃんと理解していないと混乱する(筆者は最初混乱した)。

学習for文に入る前のDescriminatorのtrainableの値はFalseになっている……のだが、そもそもkeras-GANの実装では、trainableの変更による重みの凍結はmodel.compileしないと反映されないという性質を利用している。上記の場合、discriminator.trainable = Falseは、直後にcompileされているganの計算グラフにのみ適用される(そしてganもといGeneratorの訓練時はDiscriminatorの学習を止める必要があつので、これで正しい)

またこの書き方はGans in Actionだけでなく、KerasのGAN実装で割とポピュラーなKeras-GANで採用されており、KerasのGAN実装では他にも見かけることが有る。 github.com

  • そして同じ疑問をこのRepoのIssueに挙げてる人も居た。気持ちはよくわかる https://github.com/eriklindernoren/Keras-GAN/issues/73#issuecomment-413105959

    そしてBatchNormalization()はtf2.0から、trainable is Falseならtraining is Falseに分岐する仕様に変わっている…ということで、tf2.0のままkeras-ganを実行すると、

  • Discriminatorはモデル全体の重みは更新されるが、BathNormalization()だけは推論モードで実行される。

    • 初期値がN(0, 1)なので、おそらく入力値に関わらずこの値で正規化される
  • Generatorも同様。
  • これによる学習のバランスが想定と変わり、学習失敗する。

まとめ

以上。ちなみに「じゃあBathNormalization()をGANで使うとき、Generatorの学習時はどうやってDescriminatorの重みをfreezeするの?」という疑問が残るが、個人的にはtf2.0だとGradiantTapeを使った書き方で割と自然にGとDを個別で計算して損失値をループの中で渡せるので、問題ないんではと思っている。公式実装参照。

www.tensorflow.org

なんでまぁ、tf1.xの実装を2.xで動かさないようにしましょうね、というだけの話なのだが、途中でも書いた通り個人的に勉強になったので、経過を残したかった次第。


以上