最近というか以前からText2Imageに興味があったので、留学や学会が一段落した今AttnGANを試して見ようと思いました。
だいたいは以下のGitHubの手順に従うと大丈夫です。
github.com
しかし、いくつか実行にあたり、修正した点があったので記載しておこうと思います。
その他必要ライブラリのインストール
AttnGAN公式には以下のPythonライブラリが必要と書かれています。
python-dateutil easydict pandas torchfile nltk scikit-image
加えて、実行時にpyyamlが必要と言われましたので、追加しました。
$ pip install pyyaml
Pytorchのバージョンに合わせた修正
AttnGANのコードは既に消えたPytorchの関数を含んでいるため修正します。
具体的にはゼロ次元テンソルの扱いです。
code/pretrain_DAMSM.py 103~107行目
変更前
s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL
変更後
s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL
code/pretrain_DAMSM.py 160, 161行目
変更前
s_cur_loss = s_total_loss[0] / step w_cur_loss = w_total_loss[0] / step
変更後
s_cur_loss = s_total_loss.item() / step w_cur_loss = w_total_loss.item() / step
code/miscc/utils.py 35行目
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
を自身のTTFフォントのパスに変更します。
これは実行時にIOErrorが出た場合だけでいいと思います。
これでおそらく事前学習までは動くかと思います。