Note
Go to the end to download the full example code.
6.4.1 Generating complex distributions from celebA dataset
We reproduce the Figure 6.9 from the book, generating images from the CelebA dataset. We generate new images from different latent dimension spaces, and compare them to the closest images in the dataset. We define “closest” using a distance metric in the latent space between images.
Necessary Imports
import math
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io
import skimage.transform
from scipy import optimize
from codpy import core
from codpy.kernel import Kernel, Sampler
from codpy.permutation import scipy_lsap, lsap, map_invertion
import codpy.conditioning
try:
current_dir = os.path.dirname(__file__)
data_path = os.path.join(current_dir, "data")
except NameError:
current_dir = os.getcwd()
data_path = os.path.join(current_dir, "data")
proj_path = data_path
Data Generator
The CelebA dataset comes with a .csv file containing attributes for each image, and corresponding file names. The actual image files can be found in a subfolder, where the file names match those in the .csv file. We define a class to handle the loading and processing of this dataset.
class CelebA_data_generator:
def __init__(self, **kwargs):
self.main_folder = kwargs.get("main_folder", os.path.join(data_path, "celebA"))
self.images_folder = os.path.join(
self.main_folder, "img_align_celeba/img_align_celeba/"
)
self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv")
if not (os.path.exists(self.main_folder) and os.path.exists(self.attributes_path) and os.path.exists(self.images_folder)):
import kagglehub
dataset = kagglehub.dataset_download("jessicali9530/celeba-dataset")
self.main_folder = dataset
self.images_folder = os.path.join(
self.main_folder, "img_align_celeba/img_align_celeba/"
)
self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv")
self.selected_features = kwargs.get("selected_features", None)
self.drop_features = kwargs.get("drop_features", None)
self.features_name = []
self.__prepare()
def __prepare(self):
# Read the CSV, and only keep files matching selected features
self.attributes = pd.read_csv(self.attributes_path, sep=",")
for feat in self.selected_features:
self.attributes = self.attributes.loc[self.attributes[feat] == 1]
for feat in self.drop_features:
self.attributes = self.attributes.loc[self.attributes[feat] == -1]
self.attributes.drop(self.selected_features, axis=1, inplace=True)
self.attributes.drop(self.drop_features, axis=1, inplace=True)
# Gathering a list of real images paths
self.attributes.set_index("image_id", inplace=True)
image_path = list(self.attributes.index)
PIC_DIR = os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\")
image_path = [os.path.join(PIC_DIR, path) for path in image_path]
# This is now a DataFrame with the path to each image
self.attributes["path"] = image_path
self.features_name = list(self.attributes.columns)[:-1]
# We remove attributes which only appear once, as they do not provide useful information
drop_unique = []
def helper(col):
n = len(pd.unique(self.attributes[col]))
if n == 1:
drop_unique.append(col)
[helper(col) for col in self.attributes.columns]
self.attributes.drop(drop_unique, axis=1, inplace=True)
# self.attributes.reset_index(drop=True, inplace=True)
self.num_features = self.attributes.shape[1]
def get_images(
self,
image_ids=None,
exclude_image_ids=None,
conditionned_features=None,
**kwargs,
):
Nx = kwargs.get("Nx", self.attributes.shape[0]) # Number of total images
if image_ids is None:
if Nx < self.attributes.shape[0]:
# Sample small portion of the dataset
random_state = kwargs.get("seed", 42)
if conditionned_features is None:
if exclude_image_ids is None:
image_ids = self.attributes.sample(
Nx, random_state=random_state
)["path"]
else:
image_ids = self.attributes["path"].index
image_ids = ~image_ids.isin(exclude_image_ids)
image_ids = self.attributes.loc[image_ids]
image_ids = image_ids.sample(Nx, random_state=random_state)[
"path"
]
else:
test = self.attributes.copy()
for c in conditionned_features:
test = self.attributes.loc[self.attributes[c] == 1]
if test.shape[0] < Nx // 2:
image_ids = test["path"]
else:
image_ids = test.sample(Nx // 2, random_state=random_state)[
"path"
]
for c in conditionned_features:
test = self.attributes.loc[self.attributes[c] == -1]
image_ids = pd.concat(
[
image_ids,
test.sample(Nx - len(image_ids), random_state=random_state)[
"path"
],
]
)
else:
# Get entire dataset
image_ids = self.attributes["path"]
PIC_DIR = kwargs.get(
"PIC_DIR",
os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\"),
)
paths_images = [os.path.join(PIC_DIR, image_id) for image_id in image_ids]
images = get_images_list(
list_pics=list(paths_images), index=image_ids.index, **kwargs
)
return images
def get_data(self, N=0, **kwargs):
images = self.get_images(**kwargs)
attributes = self.attributes.loc[images.index]
return images, images.index, attributes
def get_images_list(list_pics, index=None, **kwargs):
if isinstance(list_pics, list):
out = np.asarray([get_images_list(pic, **kwargs) for pic in list_pics])
out = np.reshape(out, [out.shape[0], -1])
if index is not None:
out = pd.DataFrame(out, index=index)
return out
output_shape = kwargs.get("output_shape", None)
pic = skimage.io.imread(fname=list_pics)
if pic.dtype != float:
pic = pic.astype(float) / np.max(pic)
if output_shape is not None:
pic = skimage.transform.resize(
pic, output_shape=output_shape, anti_aliasing=True
)
if kwargs.get("flat", True):
pic = pic.flatten()
# imshow(pic),
return pic
def basic_filter(x):
x = x.reshape([x.shape[0] // 3, 3])
x -= x.min(0)
x /= x.max(0)
return x.ravel()
def tiles(x, **kwargs):
SQ = kwargs.get("N", int(np.sqrt(x.shape[0])))
tile_shape = kwargs.get("tile_shape", [SQ, SQ])
pic_shape = kwargs["pic_shape"]
Dz = int(len(x[0]) / (pic_shape[0] * pic_shape[1]))
filter = kwargs.get("filter", basic_filter)
if filter is not None:
for n in range(x.shape[0]):
x[n] = filter(x[n])
if Dz > 1:
img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz])
else:
img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz])
for j in range(x.shape[0]):
pic = x[j]
pic = pic.reshape((pic_shape[0], pic_shape[1], Dz))
# imshow(pic),
img[j] = pic
out = np.zeros([tile_shape[0] * pic_shape[0], tile_shape[1] * pic_shape[1], Dz])
for j in range(tile_shape[0]):
for k in range(tile_shape[1]):
ind = j * tile_shape[1] + k
if ind < img.shape[0]:
out[
j * pic_shape[0] : (j + 1) * pic_shape[0],
k * pic_shape[1] : (k + 1) * pic_shape[1],
] = img[ind]
return out
Configuration
We define parameters for the experiment. Input shape is the original image size. Rescale shape is the size to which images will be resized for processing, before being flattened and used in the model.
config = {
"input_shape": [218, 178],
"rescale_shape": [50, 50],
"Nx": 1000, # Number of images to use for training
"Nz": 16, # Numer of images to generate
"seed": 43,
"main_folder": os.path.join(data_path, "celebA"),
}
def launch_celebA_generator(
kwargs=config,
selected_features=["Blond_Hair", "Attractive", "Smiling"],
drop_features=["Male"],
dimensions=[1, 2, 3, 10, 40],
):
N = np.sqrt(kwargs["Nz"]) # Width of the tiles
assert (
N - int(np.sqrt(kwargs["Nz"])) == 0
), "Number of generated images (Nz) must be a perfect square for plotting on tiles"
N = int(N)
# Load images
celeba = CelebA_data_generator(
selected_features=selected_features, drop_features=drop_features
)
x_target, images_index, images_attributes = celeba.get_data(**kwargs)
print(f"Loaded {x_target.shape[0]} images")
for d in dimensions:
params = kwargs.copy()
# Get latent dim
# We make a sampler mapping to a latent space of dimension d draw from normal distribution
Nz = kwargs["Nz"] # Number of images to generate
sampler = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0)
sampled_latent = np.random.normal(size=(Nz, d))
closest_indices = sampler.dnm(x=sampled_latent).argmin(1)
closest_latent = sampler.get_x()[closest_indices]
generated_pics = sampler(z=sampled_latent) # (Nz, kwargs['input_shape'])
database_pics = sampler(z=closest_latent) # (Nz, kwargs['input_shape'])
pic_generated = tiles(
generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
)
pic_database = tiles(
database_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
)
pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png"
plt.imsave(
os.path.join(proj_path, "pic_generated_N" + pic_name),
pic_generated,
vmin=0.0,
vmax=1.0,
)
plt.show()
plt.imsave(
os.path.join(proj_path, "pic_database_N" + pic_name),
pic_database,
vmin=0.0,
vmax=1.0,
)
plt.show()
print("Saved", pic_name)
def launch_celebA_Wasserstein_generator(
kwargs=config,
selected_features=["Blond_Hair", "Attractive", "Smiling"],
drop_features=["Male"],
dimensions=[1, 2, 3, 10, 40],
):
N = np.sqrt(kwargs["Nz"]) # Width of the tiles
assert (
N - int(np.sqrt(kwargs["Nz"])) == 0
), "Number of generated images (Nz) must be a perfect square for plotting on tiles"
N = int(N)
def celebA_Wasserstein_descent(pics, database_pics, latent_values, generator):
"""
This function performs a Wasserstein descent on the generated images to find the closest images in the database.
It uses the Sampler to generate images and find the closest ones in the database.
"""
# We compute the distance between generated and database images
generated_pics = generator(z=latent_values) # (Nz, kwargs['input_shape'])
closest_images_indice = core.KerOp.dnm(
x=generated_pics, y=database_pics, distance="norm1"
)
closest_images_indice = closest_images_indice.argmin(1)
closest_images = database_pics[closest_images_indice]
def sgn(x):
return math.copysign(1, x)
# diff = np.vectorize(sgn)(generated_pics - closest_images)
diff = generated_pics - pics
nablas = generator.grad(z=latent_values)
nablas = -np.einsum("ijk,ik->ij", nablas, diff)
# nablas /= np.abs(nablas).mean(1)[:,None]
def error(a, b):
return np.linalg.norm(a - b) ** 2
return np.linalg.norm(a - b, 1)
error_ = error(pics, generated_pics)
def f(x):
latent = latent_values + x * nablas
out = generator(z=latent)
out = error(out, pics)
return out
count = 0
while count < 10:
a, b, c = f(0.0), f(1e-8), f(2e-8)
fprime = (b - a) / 1e-8
assert fprime <= 0.0
bound = -a / fprime
fsec = (a + c - 2 * b) / (1e-16)
if fsec >= 0:
bound = max(bound, -fprime / fsec)
# bound = 1./np.linalg.norm(nablas)**2
xmin, fval, iter, funcalls = optimize.brent(
f, brack=(0.0, bound), maxiter=5, full_output=True
)
if fval >= error_:
break
latent_values, error_ = latent_values + xmin * nablas, fval
generated_pics = generator(z=latent_values)
nablas = generator.grad(z=latent_values)
# diff = np.vectorize(sgn)(generated_pics - closest_images)
diff = generated_pics - pics
nablas = -np.einsum("ijk,ik->ij", nablas, diff)
# nablas /= np.abs(nablas).mean(1)[:,None]
count = count + 1
pass
generated_pics = generator(z=latent_values)
return generated_pics, closest_images
# Load images
celeba = CelebA_data_generator(
selected_features=selected_features, drop_features=drop_features
)
x_target, images_index, images_attributes = celeba.get_data(**kwargs)
print(f"Loaded {x_target.shape[0]} images")
for d in dimensions:
params = kwargs.copy()
# Get latent dim
# We make a sampler mapping to a latent space of dimension d draw from normal distribution
Nz = kwargs["Nz"] # Number of images to generate
new_images = celeba.get_images(Nx=16, exclude_image_ids=images_index)
generator = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0)
# new_images = generator.get_fx()
encoder = Kernel(x=generator.get_fx(), fx=generator.get_x())
latent_values = encoder(new_images)
# latent_values = generator.get_x()[:Nz] + np.random.normal(size=(Nz, d))*.001
generated_pics, database_pics = celebA_Wasserstein_descent(
new_images, generator.get_fx(), latent_values, generator
) # (Nz, kwargs['input_shape'])
# generated_pics = generator(latent_values)
# closest_images_indice = core.KerOp.dnm(x=generated_pics,y=generator.get_fx(),distance="norm1")
# closest_images_indice = scipy_lsap(closest_images_indice)
# database_pics= generator.get_fx()[closest_images_indice]
pic_generated = tiles(
generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
)
pic_database = tiles(
database_pics, pic_shape=params["input_shape"], tile_shape=[N, N]
)
pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png"
plt.imsave(
os.path.join(proj_path, "pic_generated_N" + pic_name),
pic_generated,
vmin=0.0,
vmax=1.0,
)
plt.show()
plt.imsave(
os.path.join(proj_path, "pic_database_N" + pic_name),
pic_database,
vmin=0.0,
vmax=1.0,
)
plt.show()
print("Saved", pic_name)
if __name__ == "__main__":
# launch_celebA_generator(dimensions = [40])
launch_celebA_Wasserstein_generator(dimensions=[10])
Loaded 1000 images
Saved 1000D_10.png
Total running time of the script: (4 minutes 12.934 seconds)