# CcGAN Train

This notebook is used to train CcGAN.

## Step 1 - Import Libraries and Set Arguments

First we import others' libraries.

In [None]:
import copy
import gc
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import timeit
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision
from torchvision.utils import save_image
from tqdm import tqdm

Next import our own models.

In [None]:
from models import *

And then we set the arguments.

In [None]:
# Overall Settings
DATA_PATH = "/home/ubuntu/Desktop/Ra/Ra_128_indexed_binned.h5"
SAVE_OUTPUTS_DIR = "./output/saved_outputs"
SAVE_IMAGES_DIR = "./output/saved_images"
EMBED_MODELS_DIR = "./output/embed_models"
CCGAN_MODELS_DIR = "./output/CcGAN_models"
SEED = 42
NUM_WORKERS = 0

# Dataset
DATA_SPLIT = "train"
MIN_LABEL = 1.5
MAX_LABEL = 4.9
NUM_CHANNELS = 3
IMG_SIZE = 128
MAX_NUM_IMG_PER_LABEL = 25
MAX_NUM_IMG_PER_LABEL_AFTER_REPLICA = 0
SHOW_REAL_IMGS = True
VISUALIZE_FAKE_IMAGES = True

# Embedding Settings
BASE_LR_X2Y = 0.01
BASE_LR_Y2H = 0.01

# GAN Settings
GAN = "CcGAN"
GAN_ARCH = "SAGAN"
NET_EMBED = "ResNet34_embed"
EPOCH_CNN_EMBED = 200
RESUMEEPOCH_CNN_EMBED = 0
EPOCH_NET_Y2H = 500
DIM_EMBED = 128
BATCH_SIZE_EMBED = 256

LOSS_TYPE_GAN = "hinge"
NITERS_GAN = 30000
RESUME_NITERS_GAN = 0
SAVE_NITERS_FREQ = 5000
LR_G = 1e-4
LR_D = 1e-4
DIM_GAN = 128
BATCH_SIZE_DISC = 64
BATCH_SIZE_GENE = 64
NUM_D_STEPS = 2
CGAN_NUM_CLASSES = 20
VISUALIZE_FREQ = 1000

KERNEL_SIGMA = -1.0
THRESHOLD_TYPE = "soft"
KAPPA = -2.0
NONZERO_SOFT_WEIGHT_THRESHOLD = 1e-3

# DiffAugment Settings
GAN_DIFFAUGMENT = True
GAN_DIFFAUGMENT_POLICY = "color,translation,cutout"

# Evaluation Settings
EVAL_MODE = 2
NUM_EVAL_LABELS = -1
SAMP_BATCH_SIZE = 1000
NFAKE_PER_LABEL = 200
NREAL_PER_LABEL = -1
COMP_FID = True
EPOCH_FID_CNN = 200
FID_RADIUS = 0
FID_NUM_CENTERS = -1
DUMP_FAKE_FOR_NIQE = True

## Step 2 - Define Basic Classes and Functions

We need to use some basic classes and functions later, so we define them here.

### Step 2.1 - Function `normalize_images`

Normalize the batch images' pixel values from $[0, 255]$ to $[-1, 1]$.

In [None]:
def normalize_images(batch_images):
    batch_images = batch_images / 255.0
    batch_images = (batch_images - 0.5) / 0.5
    return batch_images

### Step 2.2 - Function `view_dataset`

In [None]:
def _print_hdf5(name, obj):
    indent = "  " * name.count("/")
    if isinstance(obj, h5py.Dataset):
        print(f"{indent}[Dataset] {name} shape={obj.shape} dtype={obj.dtype}")
    elif isinstance(obj, h5py.Group):
        print(f"{indent}[Group]   {name}")


def view_dataset(dataset_path):
    with h5py.File(dataset_path, "r") as f:
        f.visititems(_print_hdf5)

### Step 2.3 - Class `Label_dataset`

`Imgs_dataset` is derived from class `torch.utils.data.Dataset`.

In [None]:
class Label_dataset(torch.utils.data.Dataset):
    def __init__(self, labels):
        super(Label_dataset, self).__init__()
        self.labels = labels
        self.n_samples = len(self.labels)

    def __getitem__(self, index):
        y = self.labels[index]
        return y

    def __len__(self):
        return self.n_samples

### Step 2.4 - Class `Imgs_dataset`

`Imgs_dataset` is derived from class `torch.utils.data.Dataset`.

In [None]:
class Imgs_dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels=None, normalize=False):
        super(Imgs_dataset, self).__init__()
        self.images = images
        self.n_images = len(self.images)
        self.labels = labels
        if labels is not None:
            if len(self.images) != len(self.labels):
                raise Exception(
                    "images ("
                    + str(len(self.images))
                    + ") and labels ("
                    + str(len(self.labels))
                    + ") do not have the same length!!!"
                )
        self.normalize = normalize

    def __getitem__(self, index):
        image = self.images[index]
        if self.normalize:
            image = image / 255.0
            image = (image - 0.5) / 0.5
        if self.labels is not None:
            label = self.labels[index]
            return (image, label)
        else:
            return image

    def __len__(self):
        return self.n_images

## Step 3 - Settings

In this part we would like to make configurations.

In [None]:
# Seeds
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
np.random.seed(SEED)

# Paths
os.makedirs(SAVE_OUTPUTS_DIR, exist_ok=True)
os.makedirs(SAVE_IMAGES_DIR, exist_ok=True)
os.makedirs(EMBED_MODELS_DIR, exist_ok=True)
os.makedirs(CCGAN_MODELS_DIR, exist_ok=True)

## Step 4 - Load and Process Data

In the next code block we will get the following `numpy.ndarray`: 

1. `labels_all`: all the original labels
2. `images_all`: all the original images, shape `(N, 3, H, W)`
3. `index_train`: indexes for training

### Step 4.1 - Load data from the `.h5` dataset

In [None]:
# View dataset structure
view_dataset(DATA_PATH)
print("")

# Load data from h5 file
hf = h5py.File(DATA_PATH, 'r')
labels_all = hf['labels'][:]
images_all = hf['images'][:]
index_train = hf['index_train'][:]
hf.close()

# Change the data type and shape to fit the model
labels_all = labels_all.astype(float)
images_all = images_all.transpose(0, 3, 1, 2)

print(f"`images_all` shape: {images_all.shape}, dtype: {images_all.dtype}")
print(f"`labels_all` shape: {labels_all.shape}, dtype: {labels_all.dtype}")
print(f"`index_train` shape: {index_train.shape}, dtype: {index_train.dtype}")

### Step 4.2 - Split the training dataset

We split the data. Rule:

- If `args.data_split` is `train`, use `index_train` as the set of training indexes.
- Otherwise, use all data as the training set.

After the next block we will get:

1. `images_train`: the training set images (subset of `images_all`)
2. `labels_train_raw`: the training set labels (subset of `labels_all`)

In [None]:
# data split
print(f"Data split: {DATA_SPLIT}")
print("")
if DATA_SPLIT == "train":
    images_train = images_all[index_train]
    labels_train_raw = labels_all[index_train]
else:
    images_train = copy.deepcopy(images_all)
    labels_train_raw = copy.deepcopy(labels_all)

# Print the data split result
print(f"`images_train` shape: {images_train.shape}, dtype: {images_train.dtype}")
print(f"`labels_train_raw` shape: {labels_train_raw.shape}, dtype: {labels_train_raw.dtype}")

### Step 4.3 - Label range limitation

We only take images with label in $\left(q_1, q_2\right)$, where $q_1 = \text{min label}$ and $q_2 = \text{max label}$.

In the following block we set limitation on training data. And if `args.visualize_fake_images` is `True` or `args.comp_FID` is `True`, we will also set limitation on all data.

In this part we updated `images_train` and `labels_train_raw` (training data), as well as `images_all` and `labels_all` (all data).

In [None]:
# Limitation on training data
q1 = MIN_LABEL
q2 = MAX_LABEL
indx = np.where((labels_train_raw>q1)*(labels_train_raw<q2)==True)[0]
labels_train_raw = labels_train_raw[indx]
images_train = images_train[indx]
assert len(labels_train_raw)==len(images_train)
print(f"`images_train` shape: {images_train.shape}, dtype: {images_train.dtype}")
print(f"`labels_train_raw` shape: {labels_train_raw.shape}, dtype: {labels_train_raw.dtype}")
print("")

# Limitation on all data
if VISUALIZE_FAKE_IMAGES or COMP_FID:
    indx = np.where((labels_all>q1)*(labels_all<q2)==True)[0]
    labels_all = labels_all[indx]
    images_all = images_all[indx]
    assert len(labels_all)==len(images_all)
    print(f"`images_all` shape: {images_all.shape}, dtype: {images_all.dtype}")
    print(f"`labels_all` shape: {labels_all.shape}, dtype: {labels_all.dtype}")

### Step 4.4 - Image number limitation

For each label value, we allow at most `args.max_num_img_per_label` images occur in the training set. This step randomly selects the images and labels and update `images_train` and `labels_train_raw`.

This step may only have influences on the training set.

In [None]:
# for each angle, take no more than args.max_num_img_per_label images
image_num_threshold = MAX_NUM_IMG_PER_LABEL
print("Original set has {} images; For each angle, take no more than {} images.".format(len(images_train), image_num_threshold))
unique_labels_tmp = np.sort(np.array(list(set(labels_train_raw))))
for i in tqdm(range(len(unique_labels_tmp))):
    indx_i = np.where(labels_train_raw == unique_labels_tmp[i])[0]
    if len(indx_i)>image_num_threshold:
        np.random.shuffle(indx_i)
        indx_i = indx_i[0:image_num_threshold]
    if i == 0:
        sel_indx = indx_i
    else:
        sel_indx = np.concatenate((sel_indx, indx_i))
images_train = images_train[sel_indx]
labels_train_raw = labels_train_raw[sel_indx]
print("{} images left and there are {} unique labels".format(len(images_train), len(set(labels_train_raw))))
print("")
print(f"`images_train` shape: {images_train.shape}, dtype: {images_train.dtype}")
print(f"`labels_train_raw` shape: {labels_train_raw.shape}, dtype: {labels_train_raw.dtype}")

### Step 4.5 - Normalization

From the following block we get `labels_train` â€” the normalized labels, ranging from $0$ to $1$.

$$
\text{normalized training labels} = \frac{\text{training labels} - \text{min label}}{\text{max label} - \text{min label}} \in \left[0, 1\right]
$$

In [None]:
# Normalize labels
labels_train = (labels_train_raw - MIN_LABEL) / (MAX_LABEL - MIN_LABEL)    
print("Preset `min_label`: {}, `max_label`: {}".format(MIN_LABEL, MAX_LABEL))
print("Range of `labels_train_raw`: ({},{})".format(np.min(labels_train_raw), np.max(labels_train_raw)))
print("Range of `labels_train`: ({},{})".format(np.min(labels_train), np.max(labels_train)))

### Step 4.6 - Calculate $\sigma$ and $\kappa$

We calculate $\sigma$ and $\kappa$ value of the dataset (if `KERNEL_SIGMA` or `KERNEL_KAPPA` is negative).

We calculate $\sigma$ by

$$
\sigma = \left(\frac{4 \hat{\sigma}_{y^r}^5}{3 N^r}\right)^{\frac{1}{5}}
$$

where $\hat{\sigma}_{y^r}$ is the sample standard deviation of normalized labels in the training set.

In [None]:
# Set `kernel_sigma`
if KERNEL_SIGMA < 0:

    std_label = np.std(labels_train)
    KERNEL_SIGMA = 1.06 * std_label * (len(labels_train)) ** (-1 / 5)

print("`kernel_sigma`: {}".format(KERNEL_SIGMA))

We calculate $\kappa$ as follows: Let
$$
\kappa_{\text{base}} = \max \left(y_{[2]}^r - y_{[1]}^r, y_{[3]}^r - y_{[2]}^r, \cdots, y_{\left[N_{\text{uy}}^r\right]}^r - y_{\left[N_{\text{uy}}^r - 1\right]}^r\right)
$$
where $y_{[l]}^r$ is the $l$-th smallest normalized distinct real label and $N_{\text{uy}}^r$ is the number of normalized distinct labels in the training set.

Then $\kappa$ is set as
$$
\kappa = m_{\kappa} \kappa_{\text{base}}
$$
where $m_{\kappa}$ stands for $50\%$ of the minimum number of neighboring labels used for estimating $p_r\left(x | y\right)$ given a label $y$.

For example, $m_{\kappa} = 1$ implies using $2$ neighboring labels (one on the left and the other one on the right).

In experiments $m_{\kappa}$ is generally set as $1$ or $2$. In some extreme cases when many distinct labels have too few real samples, we may consider increasing $m_{\kappa}$. We also found $\nu = \frac{1}{\kappa^2}$ works well in the experiments.

In [None]:
# Set `kappa`
if KAPPA < 0:
    unique_labels_norm = np.sort(np.array(list(set(labels_train))))
    n_unique = len(unique_labels_norm)

    diff_list = []
    for i in range(1, n_unique):
        diff_list.append(unique_labels_norm[i] - unique_labels_norm[i - 1])
    kappa_base = np.abs(KAPPA) * np.max(np.array(diff_list))

    if THRESHOLD_TYPE == "hard":
        KAPPA = kappa_base
    else:
        KAPPA = 1 / kappa_base**2

print("`kappa`: {}".format(KAPPA))

## Step 5 - Pre-trained CNN and GAN for Label Embedding

Define function `train_net_embed`.

In [None]:
def train_net_embed(
    net,
    net_name,
    trainloader,
    testloader,
    epochs=200,
    resume_epoch=0,
    lr_base=0.01,
    lr_decay_factor=0.1,
    lr_decay_epochs=[80, 140],
    weight_decay=1e-4,
    path_to_ckpt=None,
):

    # Learning rate decay
    def adjust_learning_rate_1(optimizer, epoch):
        lr = lr_base
        num_decays = len(lr_decay_epochs)
        for decay_i in range(num_decays):
            if epoch >= lr_decay_epochs[decay_i]:
                lr = lr * lr_decay_factor
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    # Set up the model
    net = net.cuda()
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(
        net.parameters(), lr=lr_base, momentum=0.9, weight_decay=weight_decay
    )

    # Load checkpoint if `resume_epoch` > 0
    if path_to_ckpt is not None and resume_epoch > 0:
        save_file = os.path.join(path_to_ckpt, "ckpt_embed_x2y", "ckpt_embed_x2y_epoch_{}.pth".format(resume_epoch))
        checkpoint = torch.load(save_file)
        net.load_state_dict(checkpoint["net_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        torch.set_rng_state(checkpoint["rng_state"])

    # Start the timer
    start_tmp = timeit.default_timer()

    # Train the model
    for epoch in range(resume_epoch, epochs):

        # Train mode
        net.train()

        # Init the loss
        train_loss = 0

        # Adjust learning rate
        adjust_learning_rate_1(optimizer, epoch)

        # Iterate through the training data
        for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):

            batch_train_images = batch_train_images.type(torch.float).cuda()
            batch_train_labels = (
                batch_train_labels.type(torch.float).reshape(-1, 1).cuda()
            )

            # Forward pass
            outputs, _ = net(batch_train_images)
            loss = criterion(outputs, batch_train_labels)

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the loss
            train_loss += loss.cpu().item() * batch_train_images.size(0)

        # Calculate the average loss
        train_loss = train_loss / len(trainloader.dataset)

        # Print the training loss
        if testloader is None:
            print(
                "Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f"
                % (epoch + 1, epochs, train_loss, timeit.default_timer() - start_tmp)
            )
        else:
            net.eval()
            with torch.no_grad():
                test_loss = 0
                for batch_test_images, batch_test_labels in testloader:
                    batch_test_images = batch_test_images.type(torch.float).cuda()
                    batch_test_labels = (
                        batch_test_labels.type(torch.float).reshape(-1, 1).cuda()
                    )
                    outputs, _ = net(batch_test_images)
                    loss = criterion(outputs, batch_test_labels)
                    test_loss += loss.cpu().item()
                test_loss = test_loss / len(testloader)

                print(
                    "Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f"
                    % (
                        epoch + 1,
                        epochs,
                        train_loss,
                        test_loss,
                        timeit.default_timer() - start_tmp,
                    )
                )

        # Save checkpoint
        if path_to_ckpt is not None and (
            ((epoch + 1) % 50 == 0) or (epoch + 1 == epochs)
        ):
            save_file = os.path.join(path_to_ckpt, "ckpt_embed_x2y", "ckpt_embed_x2y_epoch_{}.pth".format(epoch + 1))
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    "epoch": epoch,
                    "net_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "rng_state": torch.get_rng_state(),
                },
                save_file,
            )

    # Return the trained model
    return net

Define function `train_net_y2h`.

In [None]:
def train_net_y2h(
    unique_labels_norm,
    net_y2h,
    net_embed,
    epochs=500,
    lr_base=0.01,
    lr_decay_factor=0.1,
    lr_decay_epochs=[150, 250, 350],
    weight_decay=1e-4,
    batch_size=128,
):

    # Learning rate decay function
    def adjust_learning_rate_2(optimizer, epoch):
        lr = lr_base
        num_decays = len(lr_decay_epochs)
        for decay_i in range(num_decays):
            if epoch >= lr_decay_epochs[decay_i]:
                lr = lr * lr_decay_factor
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    assert np.max(unique_labels_norm) <= 1 and np.min(unique_labels_norm) >= 0
    trainset = Label_dataset(unique_labels_norm)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True
    )

    net_embed.eval()
    net_h2y = net_embed.module.h2y  # convert embedding labels to original labels
    optimizer_y2h = torch.optim.SGD(
        net_y2h.parameters(), lr=lr_base, momentum=0.9, weight_decay=weight_decay
    )

    # Set timer
    start_tmp = timeit.default_timer()

    # Start training
    for epoch in range(epochs):

        # Train mode
        net_y2h.train()

        # Init the loss
        train_loss = 0

        # Adjust learning rate
        adjust_learning_rate_2(optimizer_y2h, epoch)

        # Iterate through the training data
        for _, batch_labels in enumerate(trainloader):

            batch_labels = batch_labels.type(torch.float).reshape(-1, 1).cuda()

            # generate noises which will be added to labels
            batch_size_curr = len(batch_labels)
            batch_gamma = np.random.normal(0, 0.2, batch_size_curr)
            batch_gamma = (
                torch.from_numpy(batch_gamma).reshape(-1, 1).type(torch.float).cuda()
            )

            # add noise to labels
            batch_labels_noise = torch.clamp(batch_labels + batch_gamma, 0.0, 1.0)

            # Forward pass
            batch_hiddens_noise = net_y2h(batch_labels_noise)
            batch_rec_labels_noise = net_h2y(batch_hiddens_noise)

            loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)

            # backward pass
            optimizer_y2h.zero_grad()
            loss.backward()
            optimizer_y2h.step()

            # Update the loss
            train_loss += loss.cpu().item() * batch_size_curr

        # Calculate the average loss
        train_loss = train_loss / len(trainloader.dataset)

        print(
            "Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f"
            % (epoch + 1, epochs, train_loss, timeit.default_timer() - start_tmp)
        )

    return net_y2h

Use the two functions to train `net_embed` and `net_y2h`.

In [None]:
net_embed_filename_ckpt = os.path.join(
    EMBED_MODELS_DIR,
    "embed_x2y_epoch_{}_seed_{}.pth".format(EPOCH_CNN_EMBED, SEED),
)
net_y2h_filename_ckpt = os.path.join(
    EMBED_MODELS_DIR, "y2h_epoch_{}_seed_{}.pth".format(EPOCH_NET_Y2H, SEED)
)

# Prepare the dataset for training
trainset = Imgs_dataset(images_train, labels_train, normalize=True)

# Create the DataLoader for training
trainloader_embed_net = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE_EMBED, shuffle=True, num_workers=NUM_WORKERS
)

# Get net embed type
if NET_EMBED == "ResNet18_embed":
    net_embed = ResNet18_embed(dim_embed=DIM_EMBED)
elif NET_EMBED == "ResNet34_embed":
    net_embed = ResNet34_embed(dim_embed=DIM_EMBED)
elif NET_EMBED == "ResNet50_embed":
    net_embed = ResNet50_embed(dim_embed=DIM_EMBED)
net_embed = net_embed.cuda()
net_embed = nn.DataParallel(net_embed)

# Get net y2h type
net_y2h = model_y2h(dim_embed=DIM_EMBED)
net_y2h = net_y2h.cuda()
net_y2h = nn.DataParallel(net_y2h)

# Train net_embed first: x2h+h2y
if not os.path.isfile(net_embed_filename_ckpt):
    print("Start training CNN for label embedding.")
    net_embed = train_net_embed(
        net=net_embed,
        net_name=NET_EMBED,
        trainloader=trainloader_embed_net,
        testloader=None,
        epochs=EPOCH_CNN_EMBED,
        resume_epoch=RESUMEEPOCH_CNN_EMBED,
        lr_base=BASE_LR_X2Y,
        lr_decay_factor=0.1,
        lr_decay_epochs=[80, 140],
        weight_decay=1e-4,
        path_to_ckpt=EMBED_MODELS_DIR,
    )
    # save model
    torch.save(
        {
            "net_state_dict": net_embed.state_dict(),
        },
        net_embed_filename_ckpt,
    )
else:
    print("`net_embed` ckpt already exists, loading it.")
    print("")
    checkpoint = torch.load(net_embed_filename_ckpt)
    net_embed.load_state_dict(checkpoint["net_state_dict"])

# Train y2h
if not os.path.isfile(net_y2h_filename_ckpt):
    print("Start training `net_y2h`.")
    net_y2h = train_net_y2h(
        unique_labels_norm,
        net_y2h,
        net_embed,
        epochs=EPOCH_NET_Y2H,
        lr_base=BASE_LR_Y2H,
        lr_decay_factor=0.1,
        lr_decay_epochs=[150, 250, 350],
        weight_decay=1e-4,
        batch_size=128,
    )
    # save model
    torch.save(
        {
            "net_state_dict": net_y2h.state_dict(),
        },
        net_y2h_filename_ckpt,
    )
else:
    print("`net_y2h` ckpt already exists, loading it.")
    print("")
    checkpoint = torch.load(net_y2h_filename_ckpt)
    net_y2h.load_state_dict(checkpoint["net_state_dict"])

# Print the model summary
print("Net Embed Summary")
summary(net_embed, input_size=(NUM_CHANNELS, IMG_SIZE, IMG_SIZE))
print("")
print("Net Y2H Summary")
summary(net_y2h, input_size=(1,))

Now that we have `net_embed` and `net_y2h`, we will do some simple tests on them.

In [None]:
# Get `indx_tmp`
indx_tmp = np.arange(len(unique_labels_norm))
np.random.shuffle(indx_tmp)
indx_tmp = indx_tmp[:10]

# Get `labels_tmp`
labels_tmp = unique_labels_norm[indx_tmp].reshape(-1, 1)
labels_tmp = torch.from_numpy(labels_tmp).type(torch.float).cuda()
epsilons_tmp = np.random.normal(0, 0.2, len(labels_tmp))
epsilons_tmp = torch.from_numpy(epsilons_tmp).view(-1, 1).type(torch.float).cuda()
labels_tmp = torch.clamp(labels_tmp + epsilons_tmp, 0.0, 1.0)

# Eval mode
net_embed.eval()
net_h2y = net_embed.module.h2y
net_y2h.eval()
with torch.no_grad():
    labels_rec_tmp = net_h2y(net_y2h(labels_tmp)).cpu().numpy().reshape(-1, 1)
results = np.concatenate((labels_tmp.cpu().numpy(), labels_rec_tmp), axis=1)

# Print results
results_df = pd.DataFrame(results, columns=["labels", "reconstructed_labels"])
print(results_df)

# Plot the results
plt.figure(figsize=(10, 5))
plt.plot(results_df["labels"], label="Original Labels", marker="o")
plt.plot(results_df["reconstructed_labels"], label="Reconstructed Labels", marker="x")
plt.title("Original vs Reconstructed Labels")
plt.xlabel("Sample Index")
plt.ylabel("Label Value")
plt.legend()
plt.show()

# Finish tests
net_embed = net_embed.cpu()
net_h2y = net_h2y.cpu()
del net_embed, net_h2y
gc.collect()
net_y2h = net_y2h.cpu()

## Step 6 - GAN Training

Define `train_ccgan`. This is a long function.

In [None]:
def train_ccgan(
    kernel_sigma,
    kappa,
    train_images,
    train_labels,
    netG,
    netD,
    net_y2h,
    SAVE_IMAGES_DIR,
    CCGAN_MODELS_DIR=None,
    clip_label=False,
):

    # Define loss dataframe
    loss_df = pd.DataFrame(
        columns=["niter", "d_loss", "g_loss", "real_prob", "fake_prob"]
    )

    # Nets
    netG = netG.cuda()
    netD = netD.cuda()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()

    # Optimizers
    optimizerG = torch.optim.Adam(netG.parameters(), lr=LR_G, betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(), lr=LR_D, betas=(0.5, 0.999))

    # Resume training if needed
    if CCGAN_MODELS_DIR is not None and RESUME_NITERS_GAN > 0:
        save_file = os.path.join(
            CCGAN_MODELS_DIR,
            "CcGAN_{}_nDsteps_{}".format(THRESHOLD_TYPE, NUM_D_STEPS),
            "ckpt_CcGAN_niters_{}.pth".format(RESUME_NITERS_GAN),
        )
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint["netG_state_dict"])
        netD.load_state_dict(checkpoint["netD_state_dict"])
        optimizerG.load_state_dict(checkpoint["optimizerG_state_dict"])
        optimizerD.load_state_dict(checkpoint["optimizerD_state_dict"])
        torch.set_rng_state(checkpoint["rng_state"])
        print("Resuming training from {} iterations.".format(RESUME_NITERS_GAN))
    else:
        print("Training from scratch, no resume.")

    unique_train_labels = np.sort(np.array(list(set(train_labels))))

    # Output parameters
    n_row = 10
    n_col = 10

    z_fixed = torch.randn(n_row * n_col, DIM_GAN, dtype=torch.float).cuda()

    start_label = np.quantile(train_labels, 0.05)
    end_label = np.quantile(train_labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)

    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).reshape(-1, 1).cuda()

    # Start timer
    start_time = timeit.default_timer()

    # Start training
    for niter in range(RESUME_NITERS_GAN, NITERS_GAN):

        # === Train Distriminator ===

        for _ in range(NUM_D_STEPS):

            ## randomly draw batch_size_disc y's from unique_train_labels
            batch_target_labels_in_dataset = np.random.choice(
                unique_train_labels, size=BATCH_SIZE_DISC, replace=True
            )
            ## add Gaussian noise; we estimate image distribution conditional on these labels
            batch_epsilons = np.random.normal(0, kernel_sigma, BATCH_SIZE_DISC)
            batch_target_labels = batch_target_labels_in_dataset + batch_epsilons

            ## find index of real images with labels in the vicinity of batch_target_labels
            ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
            batch_real_indx = np.zeros(
                BATCH_SIZE_DISC, dtype=int
            )  # index of images in the datata; the labels of these images are in the vicinity
            batch_fake_labels = np.zeros(BATCH_SIZE_DISC)

            for j in range(BATCH_SIZE_DISC):
                ## index for real images
                if THRESHOLD_TYPE == "hard":
                    indx_real_in_vicinity = np.where(
                        np.abs(train_labels - batch_target_labels[j]) <= kappa
                    )[0]
                else:
                    # reverse the weight function for SVDL
                    indx_real_in_vicinity = np.where(
                        (train_labels - batch_target_labels[j]) ** 2
                        <= -np.log(NONZERO_SOFT_WEIGHT_THRESHOLD) / kappa
                    )[0]

                ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
                while len(indx_real_in_vicinity) < 1:
                    batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
                    batch_target_labels[j] = (
                        batch_target_labels_in_dataset[j] + batch_epsilons_j
                    )
                    if clip_label:
                        batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
                    ## index for real images
                    if THRESHOLD_TYPE == "hard":
                        indx_real_in_vicinity = np.where(
                            np.abs(train_labels - batch_target_labels[j]) <= kappa
                        )[0]
                    else:
                        # reverse the weight function for SVDL
                        indx_real_in_vicinity = np.where(
                            (train_labels - batch_target_labels[j]) ** 2
                            <= -np.log(NONZERO_SOFT_WEIGHT_THRESHOLD) / kappa
                        )[0]
                # end while len(indx_real_in_vicinity)<1

                assert len(indx_real_in_vicinity) >= 1

                batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]

                ## labels for fake images generation
                if THRESHOLD_TYPE == "hard":
                    lb = batch_target_labels[j] - kappa
                    ub = batch_target_labels[j] + kappa
                else:
                    lb = batch_target_labels[j] - np.sqrt(
                        -np.log(NONZERO_SOFT_WEIGHT_THRESHOLD) / kappa
                    )
                    ub = batch_target_labels[j] + np.sqrt(
                        -np.log(NONZERO_SOFT_WEIGHT_THRESHOLD) / kappa
                    )
                lb = max(0.0, lb)
                ub = min(ub, 1.0)
                assert lb <= ub
                assert lb >= 0 and ub >= 0
                assert lb <= 1 and ub <= 1
                batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]

            ## draw real image/label batch from the training set
            batch_real_images = torch.from_numpy(
                normalize_images(train_images[batch_real_indx])
            )
            batch_real_images = batch_real_images.type(torch.float).cuda()
            batch_real_labels = train_labels[batch_real_indx]
            batch_real_labels = (
                torch.from_numpy(batch_real_labels).type(torch.float).cuda()
            )

            ## generate the fake image batch
            batch_fake_labels = (
                torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
            )
            z = torch.randn(BATCH_SIZE_DISC, DIM_GAN, dtype=torch.float).cuda()
            batch_fake_images = netG(z, net_y2h(batch_fake_labels))

            ## target labels on gpu
            batch_target_labels = (
                torch.from_numpy(batch_target_labels).type(torch.float).cuda()
            )

            ## weight vector
            if THRESHOLD_TYPE == "soft":
                real_weights = torch.exp(
                    -kappa * (batch_real_labels - batch_target_labels) ** 2
                ).cuda()
                fake_weights = torch.exp(
                    -kappa * (batch_fake_labels - batch_target_labels) ** 2
                ).cuda()
            else:
                real_weights = torch.ones(BATCH_SIZE_DISC, dtype=torch.float).cuda()
                fake_weights = torch.ones(BATCH_SIZE_DISC, dtype=torch.float).cuda()
            # end if threshold type

            # forward pass
            if GAN_DIFFAUGMENT:
                real_dis_out = netD(
                    DiffAugment(batch_real_images, policy=GAN_DIFFAUGMENT_POLICY),
                    net_y2h(batch_target_labels),
                )
                fake_dis_out = netD(
                    DiffAugment(
                        batch_fake_images.detach(), policy=GAN_DIFFAUGMENT_POLICY
                    ),
                    net_y2h(batch_target_labels),
                )
            else:
                real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
                fake_dis_out = netD(
                    batch_fake_images.detach(), net_y2h(batch_target_labels)
                )

            if LOSS_TYPE_GAN == "vanilla":
                real_dis_out = torch.nn.Sigmoid()(real_dis_out)
                fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
                d_loss_real = -torch.log(real_dis_out + 1e-20)
                d_loss_fake = -torch.log(1 - fake_dis_out + 1e-20)
            elif LOSS_TYPE_GAN == "hinge":
                d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
                d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
            else:
                raise ValueError("Not supported loss type!!!")

            d_loss = torch.mean(
                real_weights.reshape(-1) * d_loss_real.reshape(-1)
            ) + torch.mean(fake_weights.reshape(-1) * d_loss_fake.reshape(-1))

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

        # === Train Generator ===

        netG.train()

        # Choose target labels for fake images
        batch_target_labels_in_dataset = np.random.choice(
            unique_train_labels, size=BATCH_SIZE_GENE, replace=True
        )

        # Add Gaussian noise
        batch_epsilons = np.random.normal(0, kernel_sigma, BATCH_SIZE_GENE)
        batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
        batch_target_labels = (
            torch.from_numpy(batch_target_labels).type(torch.float).cuda()
        )

        # Add random noise
        z = torch.randn(BATCH_SIZE_GENE, DIM_GAN, dtype=torch.float).cuda()

        # Generate fake images
        batch_fake_images = netG(z, net_y2h(batch_target_labels))

        # Calculate the generator loss
        if GAN_DIFFAUGMENT:
            dis_out = netD(
                DiffAugment(batch_fake_images, policy=GAN_DIFFAUGMENT_POLICY),
                net_y2h(batch_target_labels),
            )
        else:
            dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
        if LOSS_TYPE_GAN == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = -torch.mean(torch.log(dis_out + 1e-20))
        elif LOSS_TYPE_GAN == "hinge":
            g_loss = -dis_out.mean()

        # Backward pass
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # === Logging ===

        # Every 20 iterations, print the loss
        if (niter + 1) % 20 == 0:
            new_row = pd.DataFrame(
                [
                    {
                        "niter": niter + 1,
                        "d_loss": d_loss.item(),
                        "g_loss": g_loss.item(),
                        "real_prob": real_dis_out.mean().item(),
                        "fake_prob": fake_dis_out.mean().item(),
                    }
                ]
            )

            loss_df = pd.concat([loss_df, new_row], ignore_index=True)

            print(
                "CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]"
                % (
                    GAN_ARCH,
                    niter + 1,
                    NITERS_GAN,
                    d_loss.item(),
                    g_loss.item(),
                    real_dis_out.mean().item(),
                    fake_dis_out.mean().item(),
                    timeit.default_timer() - start_time,
                )
            )

        # Every `VISUALIZE_FREQ` iterations, visualize the generated images
        if (niter + 1) % VISUALIZE_FREQ == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, net_y2h(y_fixed))
                gen_imgs = gen_imgs.detach().cpu()
                save_image(
                    gen_imgs.data,
                    os.path.join(SAVE_IMAGES_DIR, "{}.png".format(niter + 1)),
                    nrow=n_row,
                    normalize=True,
                )

        # Every `SAVE_NITERS_FREQ` iterations, save the model
        if CCGAN_MODELS_DIR is not None and (
            (niter + 1) % SAVE_NITERS_FREQ == 0 or (niter + 1) == NITERS_GAN
        ):
            save_file = os.path.join(
                CCGAN_MODELS_DIR,
                "CcGAN_{}_nDsteps_{}".format(THRESHOLD_TYPE, NUM_D_STEPS),
                "ckpt_CcGAN_niters_{}.pth".format(niter + 1),
            )

            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    "netG_state_dict": netG.state_dict(),
                    "netD_state_dict": netD.state_dict(),
                    "optimizerG_state_dict": optimizerG.state_dict(),
                    "optimizerD_state_dict": optimizerD.state_dict(),
                    "rng_state": torch.get_rng_state(),
                },
                save_file,
            )

    # Save the training loss dataframe
    loss_df.to_csv(
        os.path.join(
            SAVE_OUTPUTS_DIR,
            "CcGAN_{}_nDsteps_{}_loss.csv".format(THRESHOLD_TYPE, NUM_D_STEPS),
        ),
        index=False,
    )

    # Return the trained networks
    return netG, netD

Get the `save_images_in_train_folder` dir and output some basic info.

In [None]:
print("CcGAN: {}, {}, Sigma is {}, Kappa is {}.".format(GAN_ARCH, THRESHOLD_TYPE, KERNEL_SIGMA, KAPPA))
save_images_in_train_folder = os.path.join(SAVE_IMAGES_DIR, "{}_{:.3f}_{:.3f}_in_train".format(THRESHOLD_TYPE, KERNEL_SIGMA, KAPPA))
os.makedirs(save_images_in_train_folder, exist_ok=True)

Start the timer.

In [None]:
start = timeit.default_timer()
print("Begin Training: %s" % GAN)

The following code is used for training CcGAN.

In [None]:
def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x

Filename_GAN = os.path.join(CCGAN_MODELS_DIR, 'ckpt_CcGAN_niters_{}_nDsteps_{}_{}_{:.3f}_{:.3f}_seed_{}.pth'.format(NITERS_GAN, NUM_D_STEPS, THRESHOLD_TYPE, KERNEL_SIGMA, KAPPA, SEED))

if not os.path.isfile(Filename_GAN):
    netG = CcGAN_SAGAN_Generator(dim_z=DIM_GAN, dim_embed=DIM_EMBED)
    netD = CcGAN_SAGAN_Discriminator(dim_embed=DIM_EMBED)
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)

    # Start training
    netG, netD = train_ccgan(KERNEL_SIGMA, KAPPA, images_train, labels_train, netG, netD, net_y2h, SAVE_IMAGES_DIR=save_images_in_train_folder, CCGAN_MODELS_DIR = CCGAN_MODELS_DIR)

    # Store model
    torch.save({
        'netG_state_dict': netG.state_dict(),
    }, Filename_GAN)

else:
    print("Loading pre-trained generator.")
    checkpoint = torch.load(Filename_GAN)
    netG = CcGAN_SAGAN_Generator(dim_z=DIM_GAN, dim_embed=DIM_EMBED).cuda()
    netG = nn.DataParallel(netG)
    netG.load_state_dict(checkpoint['netG_state_dict'])

End the timer.

In [None]:
stop = timeit.default_timer()
print("GAN training finished; Time elapses: {}s".format(stop - start))