6.4.1 Generating complex distributions from celebA dataset

We reproduce the Figure 6.9 from the book, generating images from the CelebA dataset. We generate new images from different latent dimension spaces, and compare them to the closest images in the dataset. We define “closest” using a distance metric in the latent space between images.

Necessary Imports

import math
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io
import skimage.transform
from scipy import optimize

from codpy import core
from codpy.kernel import Kernel, Sampler
from codpy.permutation import scipy_lsap, lsap, map_invertion
import codpy.conditioning
try:
    current_dir = os.path.dirname(__file__)
    data_path = os.path.join(current_dir, "data")
except NameError:
    current_dir = os.getcwd()
    data_path = os.path.join(current_dir, "data")

proj_path = data_path

Data Generator

The CelebA dataset comes with a .csv file containing attributes for each image, and corresponding file names. The actual image files can be found in a subfolder, where the file names match those in the .csv file. We define a class to handle the loading and processing of this dataset.

class CelebA_data_generator:
    def __init__(self, **kwargs):
        self.main_folder = kwargs.get("main_folder", os.path.join(data_path, "celebA"))
        self.images_folder = os.path.join(
            self.main_folder, "img_align_celeba/img_align_celeba/"
        )
        self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv")

        if not (os.path.exists(self.main_folder) and os.path.exists(self.attributes_path) and os.path.exists(self.images_folder)):
            import kagglehub
            dataset = kagglehub.dataset_download("jessicali9530/celeba-dataset")
            self.main_folder = dataset
            self.images_folder = os.path.join(
                self.main_folder, "img_align_celeba/img_align_celeba/"
            )
            self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv")

        self.selected_features = kwargs.get("selected_features", None)
        self.drop_features = kwargs.get("drop_features", None)
        self.features_name = []
        self.__prepare()

    def __prepare(self):
        # Read the CSV, and only keep files matching selected features
        self.attributes = pd.read_csv(self.attributes_path, sep=",")
        for feat in self.selected_features:
            self.attributes = self.attributes.loc[self.attributes[feat] == 1]
        for feat in self.drop_features:
            self.attributes = self.attributes.loc[self.attributes[feat] == -1]
        self.attributes.drop(self.selected_features, axis=1, inplace=True)
        self.attributes.drop(self.drop_features, axis=1, inplace=True)

        # Gathering a list of real images paths
        self.attributes.set_index("image_id", inplace=True)
        image_path = list(self.attributes.index)
        PIC_DIR = os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\")
        image_path = [os.path.join(PIC_DIR, path) for path in image_path]

        # This is now a DataFrame with the path to each image
        self.attributes["path"] = image_path
        self.features_name = list(self.attributes.columns)[:-1]

        # We remove attributes which only appear once, as they do not provide useful information
        drop_unique = []

        def helper(col):
            n = len(pd.unique(self.attributes[col]))
            if n == 1:
                drop_unique.append(col)

        [helper(col) for col in self.attributes.columns]
        self.attributes.drop(drop_unique, axis=1, inplace=True)
        # self.attributes.reset_index(drop=True, inplace=True)
        self.num_features = self.attributes.shape[1]

    def get_images(
        self,
        image_ids=None,
        exclude_image_ids=None,
        conditionned_features=None,
        **kwargs,
    ):
        Nx = kwargs.get("Nx", self.attributes.shape[0])  # Number of total images
        if image_ids is None:
            if Nx < self.attributes.shape[0]:
                # Sample small portion of the dataset
                random_state = kwargs.get("seed", 42)
                if conditionned_features is None:
                    if exclude_image_ids is None:
                        image_ids = self.attributes.sample(
                            Nx, random_state=random_state
                        )["path"]
                    else:
                        image_ids = self.attributes["path"].index
                        image_ids = ~image_ids.isin(exclude_image_ids)
                        image_ids = self.attributes.loc[image_ids]

                        image_ids = image_ids.sample(Nx, random_state=random_state)[
                            "path"
                        ]
                else:
                    test = self.attributes.copy()
                    for c in conditionned_features:
                        test = self.attributes.loc[self.attributes[c] == 1]
                    if test.shape[0] < Nx // 2:
                        image_ids = test["path"]
                    else:
                        image_ids = test.sample(Nx // 2, random_state=random_state)[
                            "path"
                        ]
                    for c in conditionned_features:
                        test = self.attributes.loc[self.attributes[c] == -1]
                    image_ids = pd.concat(
                        [
                            image_ids,
                            test.sample(Nx - len(image_ids), random_state=random_state)[
                                "path"
                            ],
                        ]
                    )
            else:
                # Get entire dataset
                image_ids = self.attributes["path"]

        PIC_DIR = kwargs.get(
            "PIC_DIR",
            os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\"),
        )
        paths_images = [os.path.join(PIC_DIR, image_id) for image_id in image_ids]
        images = get_images_list(
            list_pics=list(paths_images), index=image_ids.index, **kwargs
        )

        return images

    def get_data(self, N=0, **kwargs):
        images = self.get_images(**kwargs)
        attributes = self.attributes.loc[images.index]
        return images, images.index, attributes


def get_images_list(list_pics, index=None, **kwargs):
    if isinstance(list_pics, list):
        out = np.asarray([get_images_list(pic, **kwargs) for pic in list_pics])
        out = np.reshape(out, [out.shape[0], -1])
        if index is not None:
            out = pd.DataFrame(out, index=index)
        return out
    output_shape = kwargs.get("output_shape", None)
    pic = skimage.io.imread(fname=list_pics)
    if pic.dtype != float:
        pic = pic.astype(float) / np.max(pic)
    if output_shape is not None:
        pic = skimage.transform.resize(
            pic, output_shape=output_shape, anti_aliasing=True
        )
        if kwargs.get("flat", True):
            pic = pic.flatten()
    # imshow(pic),
    return pic


def basic_filter(x):
    x = x.reshape([x.shape[0] // 3, 3])
    x -= x.min(0)
    x /= x.max(0)
    return x.ravel()


def tiles(x, **kwargs):
    SQ = kwargs.get("N", int(np.sqrt(x.shape[0])))
    tile_shape = kwargs.get("tile_shape", [SQ, SQ])
    pic_shape = kwargs["pic_shape"]
    Dz = int(len(x[0]) / (pic_shape[0] * pic_shape[1]))
    filter = kwargs.get("filter", basic_filter)
    if filter is not None:
        for n in range(x.shape[0]):
            x[n] = filter(x[n])
    if Dz > 1:
        img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz])
    else:
        img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz])
    for j in range(x.shape[0]):
        pic = x[j]
        pic = pic.reshape((pic_shape[0], pic_shape[1], Dz))
        # imshow(pic),
        img[j] = pic

    out = np.zeros([tile_shape[0] * pic_shape[0], tile_shape[1] * pic_shape[1], Dz])
    for j in range(tile_shape[0]):
        for k in range(tile_shape[1]):
            ind = j * tile_shape[1] + k
            if ind < img.shape[0]:
                out[
                    j * pic_shape[0] : (j + 1) * pic_shape[0],
                    k * pic_shape[1] : (k + 1) * pic_shape[1],
                ] = img[ind]

    return out

Configuration

We define parameters for the experiment. Input shape is the original image size. Rescale shape is the size to which images will be resized for processing, before being flattened and used in the model.

config = {
    "input_shape": [218, 178],
    "rescale_shape": [50, 50],
    "Nx": 1000,  # Number of images to use for training
    "Nz": 16,  # Numer of images to generate
    "seed": 43,
    "main_folder": os.path.join(data_path, "celebA"),
}


def launch_celebA_generator(
    kwargs=config,
    selected_features=["Blond_Hair", "Attractive", "Smiling"],
    drop_features=["Male"],
    dimensions=[1, 2, 3, 10, 40],
):
    N = np.sqrt(kwargs["Nz"])  # Width of the tiles
    assert (
        N - int(np.sqrt(kwargs["Nz"])) == 0
    ), "Number of generated images (Nz) must be a perfect square for plotting on tiles"
    N = int(N)

    # Load images
    celeba = CelebA_data_generator(
        selected_features=selected_features, drop_features=drop_features
    )
    x_target, images_index, images_attributes = celeba.get_data(**kwargs)
    print(f"Loaded {x_target.shape[0]} images")

    for d in dimensions:
        params = kwargs.copy()

        # Get latent dim
        # We make a sampler mapping to a latent space of dimension d draw from normal distribution
        Nz = kwargs["Nz"]  # Number of images to generate
        sampler = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0)

        sampled_latent = np.random.normal(size=(Nz, d))
        closest_indices = sampler.dnm(x=sampled_latent).argmin(1)
        closest_latent = sampler.get_x()[closest_indices]

        generated_pics = sampler(z=sampled_latent)  # (Nz, kwargs['input_shape'])
        database_pics = sampler(z=closest_latent)  # (Nz, kwargs['input_shape'])

        pic_generated = tiles(
            generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
        )
        pic_database = tiles(
            database_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
        )

        pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png"
        plt.imsave(
            os.path.join(proj_path, "pic_generated_N" + pic_name),
            pic_generated,
            vmin=0.0,
            vmax=1.0,
        )
        plt.show()
        plt.imsave(
            os.path.join(proj_path, "pic_database_N" + pic_name),
            pic_database,
            vmin=0.0,
            vmax=1.0,
        )
        plt.show()
        print("Saved", pic_name)


def launch_celebA_Wasserstein_generator(
    kwargs=config,
    selected_features=["Blond_Hair", "Attractive", "Smiling"],
    drop_features=["Male"],
    dimensions=[1, 2, 3, 10, 40],
):
    N = np.sqrt(kwargs["Nz"])  # Width of the tiles
    assert (
        N - int(np.sqrt(kwargs["Nz"])) == 0
    ), "Number of generated images (Nz) must be a perfect square for plotting on tiles"
    N = int(N)

    def celebA_Wasserstein_descent(pics, database_pics, latent_values, generator):
        """
        This function performs a Wasserstein descent on the generated images to find the closest images in the database.
        It uses the Sampler to generate images and find the closest ones in the database.
        """
        # We compute the distance between generated and database images
        generated_pics = generator(z=latent_values)  # (Nz, kwargs['input_shape'])
        closest_images_indice = core.KerOp.dnm(
            x=generated_pics, y=database_pics, distance="norm1"
        )
        closest_images_indice = closest_images_indice.argmin(1)
        closest_images = database_pics[closest_images_indice]

        def sgn(x):
            return math.copysign(1, x)

        # diff = np.vectorize(sgn)(generated_pics - closest_images)
        diff = generated_pics - pics
        nablas = generator.grad(z=latent_values)
        nablas = -np.einsum("ijk,ik->ij", nablas, diff)

        # nablas /= np.abs(nablas).mean(1)[:,None]
        def error(a, b):
            return np.linalg.norm(a - b) ** 2
            return np.linalg.norm(a - b, 1)

        error_ = error(pics, generated_pics)

        def f(x):
            latent = latent_values + x * nablas
            out = generator(z=latent)
            out = error(out, pics)
            return out

        count = 0
        while count < 10:
            a, b, c = f(0.0), f(1e-8), f(2e-8)
            fprime = (b - a) / 1e-8
            assert fprime <= 0.0
            bound = -a / fprime
            fsec = (a + c - 2 * b) / (1e-16)
            if fsec >= 0:
                bound = max(bound, -fprime / fsec)
            # bound = 1./np.linalg.norm(nablas)**2
            xmin, fval, iter, funcalls = optimize.brent(
                f, brack=(0.0, bound), maxiter=5, full_output=True
            )
            if fval >= error_:
                break
            latent_values, error_ = latent_values + xmin * nablas, fval
            generated_pics = generator(z=latent_values)
            nablas = generator.grad(z=latent_values)
            # diff = np.vectorize(sgn)(generated_pics - closest_images)
            diff = generated_pics - pics
            nablas = -np.einsum("ijk,ik->ij", nablas, diff)
            # nablas /= np.abs(nablas).mean(1)[:,None]
            count = count + 1
            pass

        generated_pics = generator(z=latent_values)

        return generated_pics, closest_images

    # Load images
    celeba = CelebA_data_generator(
        selected_features=selected_features, drop_features=drop_features
    )
    x_target, images_index, images_attributes = celeba.get_data(**kwargs)
    print(f"Loaded {x_target.shape[0]} images")

    for d in dimensions:
        params = kwargs.copy()

        # Get latent dim
        # We make a sampler mapping to a latent space of dimension d draw from normal distribution
        Nz = kwargs["Nz"]  # Number of images to generate
        new_images = celeba.get_images(Nx=16, exclude_image_ids=images_index)
        generator = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0)
        # new_images = generator.get_fx()
        encoder = Kernel(x=generator.get_fx(), fx=generator.get_x())
        latent_values = encoder(new_images)
        # latent_values = generator.get_x()[:Nz] + np.random.normal(size=(Nz, d))*.001
        generated_pics, database_pics = celebA_Wasserstein_descent(
            new_images, generator.get_fx(), latent_values, generator
        )  # (Nz, kwargs['input_shape'])
        # generated_pics = generator(latent_values)
        # closest_images_indice = core.KerOp.dnm(x=generated_pics,y=generator.get_fx(),distance="norm1")
        # closest_images_indice = scipy_lsap(closest_images_indice)
        # database_pics= generator.get_fx()[closest_images_indice]

        pic_generated = tiles(
            generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
        )
        pic_database = tiles(
            database_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
        )

        pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png"
        plt.imsave(
            os.path.join(proj_path, "pic_generated_N" + pic_name),
            pic_generated,
            vmin=0.0,
            vmax=1.0,
        )
        plt.show()
        plt.imsave(
            os.path.join(proj_path, "pic_database_N" + pic_name),
            pic_database,
            vmin=0.0,
            vmax=1.0,
        )
        plt.show()
        print("Saved", pic_name)


if __name__ == "__main__":
    # launch_celebA_generator(dimensions = [40])
    launch_celebA_Wasserstein_generator(dimensions=[10])
Loaded 1000 images
Saved 1000D_10.png

Total running time of the script: (4 minutes 12.934 seconds)

Gallery generated by Sphinx-Gallery