思考の本棚

機械学習のことや読んだ本の感想を整理するところ

PyTorch-Lightningでの再現性

はじめに

PyTorchで再現性を持たせたい場合、lightningのseed_everything関数や以下のような自作関数を使うケースが多いと思います。

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

以前まではこれでseedを固定できていたのですが、PyTorch-Lightningに切り替えた場合、同じコードを実行してもlossや予測値がブレる現象が起きたのでその原因と対策を記録しておきます。

結論

pytorch_lightning.Trainerクラスの引数でdeterministic=Trueとする。

pytorch_lightning.Trainerはデフォルトではdeterministic=Falseとなっています。 これにより上記関数でtorch.backends.cudnn.deterministic = Trueと設定していてもTrainerによりFalseに上書きされていました。 厳密な再現性を望まない場合は不要な設定かもしれませんが、比較実験を行う場合はこの設定をしておく必要があるかと思います。 pytorch-lightning.readthedocs.io

その他

今回はlightning周りで再現性がとれない現象が起きていましたが、Pytorch側に問題がある場合は公式の以下のリンクをまず確認した方が良さそうです。

pytorch.org