Harukaのnote

Linuxやプログラミング,写真,旅行等の記録帳

AttnGANを動かすメモ

最近というか以前からText2Imageに興味があったので、留学や学会が一段落した今AttnGANを試して見ようと思いました。

だいたいは以下のGitHubの手順に従うと大丈夫です。
github.com


しかし、いくつか実行にあたり、修正した点があったので記載しておこうと思います。

Pytorchのインストール

以下のサイトから自身の環境に適切なものを選びます。
pytorch.org

pip install torch で入るものとは微妙に異なるようで、エラーが出ました。


その他必要ライブラリのインストール

AttnGAN公式には以下のPythonライブラリが必要と書かれています。

python-dateutil
easydict
pandas
torchfile
nltk
scikit-image

加えて、実行時にpyyamlが必要と言われましたので、追加しました。

$ pip install pyyaml

Pytorchのバージョンに合わせた修正

AttnGANのコードは既に消えたPytorchの関数を含んでいるため修正します。

具体的にはゼロ次元テンソルの扱いです。

code/pretrain_DAMSM.py 103~107行目

変更前
s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL
s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL

w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL
w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL
変更後
s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

code/pretrain_DAMSM.py 160, 161行目

変更前
s_cur_loss = s_total_loss[0] / step
w_cur_loss = w_total_loss[0] / step
変更後
s_cur_loss = s_total_loss.item() / step
w_cur_loss = w_total_loss.item() / step

code/miscc/utils.py 35行目

fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)

を自身のTTFフォントのパスに変更します。
これは実行時にIOErrorが出た場合だけでいいと思います。



これでおそらく事前学習までは動くかと思います。

qiita.com