PyTorchでDCGANを実装する
はじめに
この記事ではPyTorchを使ってDCGANの解説および実装を行います。 今回はMNISTのデータセットを利用して、手書き数字の0~2の画像生成を行います。 DCGANの解説には元論文とDCGANの解説が非常にわかりやすい以下のスライドを使用します。 DCGANの実装には書籍『PyTorchで作る発展ディープラーニング』を使用しています。なお本記事でのコードの紹介はモデルの定義と訓練の課程のみとさせていただきます。詳細はGitHubをご確認ください。
www.slideshare.net
本記事で紹介するコードの全体 github.com
参考にした本の著者のGitHub github.com
DCGANの概要
DCGANは2016年に出たGANの派生系です。オリジナルのGANが多層パーセプトロンでモデルを作成していたのに対してDCGANでは画像認識タスクで有効とされているDeepなCNN構造をGANに適応することでGANの表現力をあげたGANの歴史の中でも非常に重要な存在です。 ただしGANの構造自体はそれほど変わりません。生成モデルであるGeneratorは多次元ノイズを入力とし、画像を生成し出力します。識別モデルであるDiscriminatorは、Generatorによって生成された画像と本物の画像を入力とし、各画像の識別結果を出力します。GeneratorはDiscriminatorを騙すために本物そっくりの画像を生成するように、DiscriminatorはGeneratorが生成した画像と本物画像をしっかり識別するように、2つのモデルが互いに敵対しながら学習するのがGANの基本的な構造です。
話をDCGANに戻します。GANにCNNを適応しようという取り組みはDCGAN以前にも行われていましたが、GANの学習における不安定性がCNNの適応を困難なものにしていたようです。学習の安定化のためにDCGANでは以下の事項に変更が加えられています。
- GANにPooling層のないCNNを導入
- Batch Normalizationを導入
- 隠れ層では全結合層を使用しない
- Generator/Discriminatorの活性化関数にはReLU/LeakyReLUをそれぞれ使用(ただしGeneratorの出力層はTanh)
以上がDCGANの概要となります。
Generator と Disicriminatorの解説と実装
それではGeneratorとDiscriminatorの解説をしていきます。
Generator(生成モデル)
こちらの図がGeneratorの構造を示したものです。
Generatorは100次元のノイズzを入力し4層のCNNによって64*64の画像G(z)を生成します。
CNNの1つの層は①転置畳み込み層 + ②Batch Norm + ③ReLUの3つで構成されています。
転置畳み込み層はCNNで一般的に用いられる畳み込み層の逆っぽい操作をします。
こちらのサイトが畳み込み層と転置畳み込み層をアニメーションで示しており非常にわかりやすいです。
(上記サイトより引用)
このアニメーションではpadding=1, stride=1, kernel=(4*4)の条件で転置畳み込みを行なっていますが、今回のGeneratorで用いる転置畳み込み層はstride=2です。
以下、Generatorの実装になります。
class Generator(nn.Module): # nzは入力ノイズの次元数 def __init__(self, nz, image_size): super(Generator, self).__init__() self.layer1 = nn.Sequential( nn.ConvTranspose2d(nz, image_size*8, kernel_size=4, stride=1), nn.BatchNorm2d(image_size*8), nn.ReLU(inplace=True)) self.layer2 = nn.Sequential( nn.ConvTranspose2d(image_size*8, image_size*4, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(image_size * 4), nn.ReLU(inplace=True)) self.layer3 = nn.Sequential( nn.ConvTranspose2d(image_size*4, image_size*2, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(image_size * 2), nn.ReLU(inplace=True)) self.layer4 = nn.Sequential( nn.ConvTranspose2d(image_size*2, image_size, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(image_size), nn.ReLU(inplace=True)) # MNISTなので出力のチャンネルは1 self.last = nn.Sequential( nn.ConvTranspose2d(image_size, 1, kernel_size=4, stride=2, padding=1), nn.Tanh()) # zは乱数で生成するノイズ def forward(self, z): out = self.layer1(z) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.last(out) return out
zを入力としlayer1から出力層まで順に流していく実装になります。 4つのlayerを通ったあと、活性化関数Tanhを使用して-1から1の出力結果になるようにします。
Discriminator(識別モデル)
こちらの図は元論文には掲載されていなかったのですがこちらのサイトより引用したDiscriminatorの構造を示した図です。
Discriminatorは一般的に画像分類タスクに使用されるCNNと同様のものと考えてもらって構いません。
異なる点としては各layerの活性化関数ReLUの代わりにleakyReLUを使用していることです。
以下、 Discriminatorの実装です。
class Discriminator(nn.Module): def __init__(self, nz, image_size): super(Discriminator, self).__init__() # MNISTなので入力チャンネルは1 self.layer1 = nn.Sequential( nn.Conv2d(1, image_size, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.1, inplace=True)) self.layer2 = nn.Sequential( nn.Conv2d(image_size, image_size*2, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.1, inplace=True)) self.layer3 = nn.Sequential( nn.Conv2d(image_size*2, image_size*4, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.1, inplace=True)) self.layer4 = nn.Sequential( nn.Conv2d(image_size*4, image_size*8, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.1, inplace=True)) self.last = nn.Conv2d(image_size*8, 1, kernel_size=4, stride=1) # 入力xは本物画像or生成画像 def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.last(out) return out
損失関数の解説と実装
ここではDCGANの損失関数について解説していきます。なお実装では訓練の過程やデータの受け渡しも行っておりますが、解説は損失関数のみに絞らさせていただきます。 はじめにDiscriminatorの損失関数について考えていきます。 画像を入力としたときで表せます。また正しいラベルはGが生成した画像を、訓練データセット画像をと定義します。 そうした場合Discriminatorの出力は
で表せます。実際のDiscriminatorの出力はバッチサイズ個分の同時確率となるので、これの対数をとって
この式を最大化するように最適化することがDiscriminatorの学習です。 ちなみにこの式を具体的に考えたとき、Discriminatorの学習とは のラベルを持つ本物画像を入力したとき、を最大化することであり、 のラベルを持つ偽物画像を入力したとき、を最大化することです。 これは本物画像を与えたときは1(本物ラベルと同じ)を出力し、偽物画像を入力したときは0(偽物ラベルと同じ)を出力するように学習することを示しており、直感的にもわかりやすくDiscriminatorの学習を示しています。
次にGeneratorの損失関数について考えていきます。 GeneratorはDiscriminatorをだましたいのでDiscriminatorが最大化しようとしている式を最小化すればよいです。 つまりDiscriminatorが最大化しようとしていた以下の式、
この式を最小化することがGeneratorの学習となります。
Generatorの損失関数を考える際は入力はGeneratorが生成した画像のみで良いので、の時を考えます。 このとき
となります。ただしこの式はGeneratorの学習が進みずらいことが分かっています。そこで要はが1と判定してくれればよいだろうと考えてDCGANではGの損失関数を
としています。
両者の損失関数は2値の交差エントロピーを示しているため、PyTorchの実装ではnn.BCEWithLoss()を使用することで簡単に記述できます。 以下損失関数を含めた訓練課程の実装です。 なおここではDatasetやDataloaderの定義の実装コードは掲載しておらず、あくまで訓練の流れを理解するために訓練課程のコードを載せています。
def train_model(G, D, dataloader, num_epochs, nz, mini_batch_size, device): # 最適化手法の設定 G_lr, D_lr = 0.0001, 0.0004 beta1, beta2 = 0.0, 0.9 optimizerG = torch.optim.Adam(G.parameters(), G_lr, [beta1, beta2]) optimizerD = torch.optim.Adam(D.parameters(), D_lr, [beta1, beta2]) # 誤差関数の定義 criterion = nn.BCEWithLogitsLoss(reduction="mean") # ネットワークをGPUへ G.to(device) D.to(device) # 訓練モードへ切り替え G.train() D.train() num_train_imgs = len(dataloader.dataset) batch_size = dataloader.batch_size iteration = 1 # 各epochでの損失を記録 G_losses = [] D_losses = [] print("Start training!") for epoch in tqdm(range(num_epochs)): for data in dataloader: # -------------------- # 1. Update D network # -------------------- # ミニバッチが1だとBatchNormでエラーが出るので回避 if data.size()[0] == 1: continue # GPU使えるならGPUにデータを送る data = data.to(device) # ラベルの作成 mini_batch_size = data.size()[0] # smoothing label を使って学習の安定化を図る real_label = torch.full((mini_batch_size,), 0.8).to(device) fake_label = torch.full((mini_batch_size,), 0).to(device) # 真の画像を判定 D_real_output = D(data) # 偽画像を生成して判定 z = torch.randn(mini_batch_size, nz).to(device) z = z.view(z.size(0), z.size(1), 1, 1) fake_imgs = G(z) D_fake_output = D(fake_imgs) # 誤差を計算 lossD_real = criterion(D_real_output.view(-1), real_label) lossD_fake = criterion(D_fake_output.view(-1), fake_label) lossD = lossD_real + lossD_fake # 誤差逆伝播 optimizerG.zero_grad() optimizerD.zero_grad() lossD.backward() optimizerD.step() # -------------------- # 2. Update G network # -------------------- # 偽画像を生成して判定 z = torch.randn(mini_batch_size, nz).to(device) z = z.view(z.size(0), z.size(1), 1, 1) fake_imgs = G(z) D_fake_output = D(fake_imgs) # 誤差を計算 lossG = criterion(D_fake_output.view(-1), real_label) # 誤差逆伝播 optimizerG.zero_grad() optimizerD.zero_grad() lossG.backward() optimizerG.step() # ----------- # 3. 記録 # ----------- D_losses.append(lossD.item()) G_losses.append(lossG.item()) iteration += 1 # 画像生成 if epoch % 20 == 0: # 画像表示用の自作関数 generate_img(G, dataloader, epoch, batch_size=16, nz=nz, device=device) return G, D, G_losses, D_losses
流れとしてはDataloaderからデータをバッチサイズごとに読み込みDiscriminatorの学習、Generatorの学習を行い、これを200epoch繰り返します。
画像は64px四方の0~2のMNISTの画像を各クラス200枚ずつ用意しています。最適化手法はAdamです。
訓練過程の実装で分かりにくいのはGeneratorの学習部分の損失関数の定義のところです。 Generatorの損失関数を先程のように少し変形しているためコードに起こす際はこのように記述します。
lossG = criterion(D_fake_output.view(-1), real_label)
画像生成結果
それでは学習をしてみた結果を示します。 20epochごとに画像生成したものがこちらです。 画像は学習が進むにつれて綺麗に数字を表現しています。しかし数字の2を訓練データに含んでいたにも関わらず0と1しか生成されていません。 これがモード崩壊と呼ばれる現象で、簡単に生成できる0と1のみを学習してしまった過学習のような状態です。
こちらは学習の損失関数の様子をプロットしたものです。 学習が上手くいっていると2つの損失関数が競合しながら減少していきます。このバランスが崩れるとモード崩壊となって全く同じ画像しか生成しなくなったり、綺麗な画像が生成されなかったりします。今回の結果を見ると途中からGeneratorの損失が増加しているのがみて取れます。このような場合は、パラメータを調整したりGANの構造に変更を加えることでDiscriminatorを弱くする必要があります。少し昔ではありますがDCGANなどの基本的な構造のGANで学習を安定化させる方法はこちらの論文が参考になります。 arxiv.org
おわりに
本記事ではDCGANの構造の解説とPyTorchを用いて簡単に実装の説明を行いました。 実際に自分で画像生成をしてみるとGANの不安定性が明らかになったように感じます。オリジナルのGANよりも安定性が向上したDCGANにおいても, 学習の安定性への工夫は重要そうです。