Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ iccv21

# data
data
stats
*.png
_cache_*
figures
Expand Down
52 changes: 3 additions & 49 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# Gradient Normalization for Generative Adversarial Networks

Yi-Lun Wu, Hong-Han Shuai, Zhi-Rui Tam, Hong-Yu Chiu

Paper: [https://arxiv.org/abs/2109.02235](https://arxiv.org/abs/2109.02235)

This is the official implementation of Gradient Normalized GAN (GN-GAN).

## NOTE
THIS IS A FORK OF A FORK TO TEST FOR SN
THE FILES ARE IN SN_TESTING
## Requirements
- Python 3.8.9
- Python packages
Expand Down Expand Up @@ -116,44 +111,3 @@ All the reported values (Inception Score and FID) in our paper are calculated by
--eval \
--save path/to/generated/images
```

## How to integrate Gradient Normalization into your work?
The function `normalize_gradient` is implemented based on `torch.autograd` module, which can easily normalize your forward propagation of discriminator by updating a single line.
```python
from torch.nn import BCEWithLogitsLoss
from models.gradnorm import normalize_gradient

net_D = ... # discriminator
net_G = ... # generator
loss_fn = BCEWithLogitsLoss()

# Update discriminator
x_real = ... # real data
x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data
pred_real = normalize_gradient(net_D, x_real) # net_D(x_real)
pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake)
loss_real = loss_fn(pred_real, torch.ones_like(pred_real))
loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake))
(loss_real + loss_fake).backward() # backward propagation
...

# Update generator
x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data
pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake)
loss_fake = loss_fn(pred_fake, torch.ones_like(pred_fake))
loss.backward() # backward propagation
...

```

## Citation
If you find our work is relevant to your research, please cite:
```
@InProceedings{GNGAN_2021_ICCV,
author = {Yi-Lun Wu, Hong-Han Shuai, Zhi Rui Tam, Hong-Yu Chiu},
title = {Gradient Normalization for Generative Adversarial Networks},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2021}
}
```
2 changes: 1 addition & 1 deletion config/GN-GAN-CR_CIFAR10_BIGGAN.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=125000
--batch_size_D=50
--batch_size_G=50
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0001
--betas=0.0
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN-CR_CIFAR10_CNN.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN-CR_CIFAR10_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=0
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0004
--lr_G=0.0002
--n_dis=5
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN-CR_STL10_CNN.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
Expand Down
25 changes: 25 additions & 0 deletions config/GN-GAN-CR_STL10_CNN_MODULE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
--dataset=stl10.48
--arch=dcgan.48
--loss=bce
--total_steps=200000
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
--z_dim=128
--cr=5
--n_classes=1

--ema_decay=0.9999
--ema_start=0

--sample_step=500
--sample_size=64
--eval_step=5000
--save_step=20000
--num_images=50000
--fid_stats=./stats/stl10.unlabeled.48.npz
--logdir=./logs/GN-GAN-CR_STL10_CNN_0
2 changes: 1 addition & 1 deletion config/GN-GAN-CR_STL10_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=0
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0004
--lr_G=0.0002
--n_dis=5
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN_CELEBAHQ128_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--batch_size_D=64
--batch_size_G=128
--accumulation=1
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=5
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN_CELEBAHQ256_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--batch_size_D=64
--batch_size_G=128
--accumulation=1
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=5
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN_CHURCH256_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--batch_size_D=64
--batch_size_G=128
--accumulation=1
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=5
Expand Down
4 changes: 2 additions & 2 deletions config/GN-GAN_CIFAR10_BIGGAN.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
--dataset=cifar10.32
--arch=biggan.32
--loss=hinge
--total_steps=125000
--lr_decay_start=125000
--total_steps=75000
--lr_decay_start=75000
--batch_size_D=50
--batch_size_G=50
--num_workers=8
Expand Down
27 changes: 27 additions & 0 deletions config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
--dataset=cifar10.32
--arch=biggan.32
--loss=hinge
--total_steps=75000
--lr_decay_start=75000
--batch_size_D=50
--batch_size_G=50
--num_workers=8
--lr_D=0.0002
--lr_G=0.0001
--betas=0.0
--betas=0.999
--n_dis=4
--z_dim=128
--cr=0
--n_classes=10

--ema_decay=0.9999
--ema_start=1000

--sample_step=500
--sample_size=64
--eval_step=5000
--save_step=20000
--num_images=50000
--fid_stats=./stats/cifar10.train.npz
--logdir=./logs/GN-GAN_CIFAR10_BIGGAN_MODULE_0
2 changes: 1 addition & 1 deletion config/GN-GAN_CIFAR10_CNN.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
Expand Down
25 changes: 25 additions & 0 deletions config/GN-GAN_CIFAR10_CNN_MODULE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
--dataset=cifar10.32
--arch=dcgan.32
--loss=bce
--total_steps=200000
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
--z_dim=128
--cr=0
--n_classes=1

--ema_decay=0.9999
--ema_start=0

--sample_step=500
--sample_size=64
--eval_step=5000
--save_step=20000
--num_images=50000
--fid_stats=./stats/cifar10.train.npz
--logdir=./logs/GN-GAN_CIFAR10_CNN_MODULE_0
2 changes: 1 addition & 1 deletion config/GN-GAN_CIFAR10_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=0
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0004
--lr_G=0.0002
--n_dis=5
Expand Down
2 changes: 1 addition & 1 deletion config/GN-GAN_STL10_CNN.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
Expand Down
25 changes: 25 additions & 0 deletions config/GN-GAN_STL10_CNN_MODULE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
--dataset=stl10.48
--arch=dcgan.48
--loss=bce
--total_steps=200000
--lr_decay_start=200000
--batch_size_D=64
--batch_size_G=128
--num_workers=10
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
--z_dim=128
--cr=0
--n_classes=1

--ema_decay=0.9999
--ema_start=0

--sample_step=500
--sample_size=64
--eval_step=5000
--save_step=20000
--num_images=50000
--fid_stats=./stats/stl10.unlabeled.48.npz
--logdir=./logs/GN-GAN_STL10_CNN_MODULE_0
2 changes: 1 addition & 1 deletion config/GN-GAN_STL10_RES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--lr_decay_start=0
--batch_size_D=64
--batch_size_G=128
--num_workers=8
--num_workers=10
--lr_D=0.0004
--lr_G=0.0002
--n_dis=5
Expand Down
1 change: 1 addition & 0 deletions generated/GN-GAN_CIFAR10_CNN_0/Results/results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
IS: 7.683(0.082), FID: 21.932
Loading