.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_ch6\ch6_9_CelebA.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_ch6_ch6_9_CelebA.py: ============================================================= 6.4.1 Generating complex distributions from celebA dataset ============================================================= We reproduce the Figure 6.9 from the book, generating images from the CelebA dataset. We generate new images from different latent dimension spaces, and compare them to the closest images in the dataset. We define "closest" using a distance metric in the latent space between images. .. GENERATED FROM PYTHON SOURCE LINES 12-14 Necessary Imports ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 14-39 .. code-block:: Python import math import os import matplotlib.pyplot as plt import numpy as np import pandas as pd import skimage.io import skimage.transform from scipy import optimize from codpy import core from codpy.kernel import Kernel, Sampler from codpy.permutation import scipy_lsap, lsap, map_invertion import codpy.conditioning try: current_dir = os.path.dirname(__file__) data_path = os.path.join(current_dir, "data") except NameError: current_dir = os.getcwd() data_path = os.path.join(current_dir, "data") proj_path = data_path .. GENERATED FROM PYTHON SOURCE LINES 40-45 Data Generator ------------------------ The CelebA dataset comes with a .csv file containing attributes for each image, and corresponding file names. The actual image files can be found in a subfolder, where the file names match those in the .csv file. We define a class to handle the loading and processing of this dataset. .. GENERATED FROM PYTHON SOURCE LINES 45-227 .. code-block:: Python class CelebA_data_generator: def __init__(self, **kwargs): self.main_folder = kwargs.get("main_folder", os.path.join(data_path, "celebA")) self.images_folder = os.path.join( self.main_folder, "img_align_celeba/img_align_celeba/" ) self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv") if not (os.path.exists(self.main_folder) and os.path.exists(self.attributes_path) and os.path.exists(self.images_folder)): import kagglehub dataset = kagglehub.dataset_download("jessicali9530/celeba-dataset") self.main_folder = dataset self.images_folder = os.path.join( self.main_folder, "img_align_celeba/img_align_celeba/" ) self.attributes_path = os.path.join(self.main_folder, "list_attr_celeba.csv") self.selected_features = kwargs.get("selected_features", None) self.drop_features = kwargs.get("drop_features", None) self.features_name = [] self.__prepare() def __prepare(self): # Read the CSV, and only keep files matching selected features self.attributes = pd.read_csv(self.attributes_path, sep=",") for feat in self.selected_features: self.attributes = self.attributes.loc[self.attributes[feat] == 1] for feat in self.drop_features: self.attributes = self.attributes.loc[self.attributes[feat] == -1] self.attributes.drop(self.selected_features, axis=1, inplace=True) self.attributes.drop(self.drop_features, axis=1, inplace=True) # Gathering a list of real images paths self.attributes.set_index("image_id", inplace=True) image_path = list(self.attributes.index) PIC_DIR = os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\") image_path = [os.path.join(PIC_DIR, path) for path in image_path] # This is now a DataFrame with the path to each image self.attributes["path"] = image_path self.features_name = list(self.attributes.columns)[:-1] # We remove attributes which only appear once, as they do not provide useful information drop_unique = [] def helper(col): n = len(pd.unique(self.attributes[col])) if n == 1: drop_unique.append(col) [helper(col) for col in self.attributes.columns] self.attributes.drop(drop_unique, axis=1, inplace=True) # self.attributes.reset_index(drop=True, inplace=True) self.num_features = self.attributes.shape[1] def get_images( self, image_ids=None, exclude_image_ids=None, conditionned_features=None, **kwargs, ): Nx = kwargs.get("Nx", self.attributes.shape[0]) # Number of total images if image_ids is None: if Nx < self.attributes.shape[0]: # Sample small portion of the dataset random_state = kwargs.get("seed", 42) if conditionned_features is None: if exclude_image_ids is None: image_ids = self.attributes.sample( Nx, random_state=random_state )["path"] else: image_ids = self.attributes["path"].index image_ids = ~image_ids.isin(exclude_image_ids) image_ids = self.attributes.loc[image_ids] image_ids = image_ids.sample(Nx, random_state=random_state)[ "path" ] else: test = self.attributes.copy() for c in conditionned_features: test = self.attributes.loc[self.attributes[c] == 1] if test.shape[0] < Nx // 2: image_ids = test["path"] else: image_ids = test.sample(Nx // 2, random_state=random_state)[ "path" ] for c in conditionned_features: test = self.attributes.loc[self.attributes[c] == -1] image_ids = pd.concat( [ image_ids, test.sample(Nx - len(image_ids), random_state=random_state)[ "path" ], ] ) else: # Get entire dataset image_ids = self.attributes["path"] PIC_DIR = kwargs.get( "PIC_DIR", os.path.join(self.main_folder, "img_align_celeba\\img_align_celeba\\"), ) paths_images = [os.path.join(PIC_DIR, image_id) for image_id in image_ids] images = get_images_list( list_pics=list(paths_images), index=image_ids.index, **kwargs ) return images def get_data(self, N=0, **kwargs): images = self.get_images(**kwargs) attributes = self.attributes.loc[images.index] return images, images.index, attributes def get_images_list(list_pics, index=None, **kwargs): if isinstance(list_pics, list): out = np.asarray([get_images_list(pic, **kwargs) for pic in list_pics]) out = np.reshape(out, [out.shape[0], -1]) if index is not None: out = pd.DataFrame(out, index=index) return out output_shape = kwargs.get("output_shape", None) pic = skimage.io.imread(fname=list_pics) if pic.dtype != float: pic = pic.astype(float) / np.max(pic) if output_shape is not None: pic = skimage.transform.resize( pic, output_shape=output_shape, anti_aliasing=True ) if kwargs.get("flat", True): pic = pic.flatten() # imshow(pic), return pic def basic_filter(x): x = x.reshape([x.shape[0] // 3, 3]) x -= x.min(0) x /= x.max(0) return x.ravel() def tiles(x, **kwargs): SQ = kwargs.get("N", int(np.sqrt(x.shape[0]))) tile_shape = kwargs.get("tile_shape", [SQ, SQ]) pic_shape = kwargs["pic_shape"] Dz = int(len(x[0]) / (pic_shape[0] * pic_shape[1])) filter = kwargs.get("filter", basic_filter) if filter is not None: for n in range(x.shape[0]): x[n] = filter(x[n]) if Dz > 1: img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz]) else: img = np.zeros([x.shape[0], pic_shape[0], pic_shape[1], Dz]) for j in range(x.shape[0]): pic = x[j] pic = pic.reshape((pic_shape[0], pic_shape[1], Dz)) # imshow(pic), img[j] = pic out = np.zeros([tile_shape[0] * pic_shape[0], tile_shape[1] * pic_shape[1], Dz]) for j in range(tile_shape[0]): for k in range(tile_shape[1]): ind = j * tile_shape[1] + k if ind < img.shape[0]: out[ j * pic_shape[0] : (j + 1) * pic_shape[0], k * pic_shape[1] : (k + 1) * pic_shape[1], ] = img[ind] return out .. GENERATED FROM PYTHON SOURCE LINES 228-232 Configuration ------------------------ We define parameters for the experiment. Input shape is the original image size. Rescale shape is the size to which images will be resized for processing, before being flattened and used in the model. .. GENERATED FROM PYTHON SOURCE LINES 232-432 .. code-block:: Python config = { "input_shape": [218, 178], "rescale_shape": [50, 50], "Nx": 1000, # Number of images to use for training "Nz": 16, # Numer of images to generate "seed": 43, "main_folder": os.path.join(data_path, "celebA"), } def launch_celebA_generator( kwargs=config, selected_features=["Blond_Hair", "Attractive", "Smiling"], drop_features=["Male"], dimensions=[1, 2, 3, 10, 40], ): N = np.sqrt(kwargs["Nz"]) # Width of the tiles assert ( N - int(np.sqrt(kwargs["Nz"])) == 0 ), "Number of generated images (Nz) must be a perfect square for plotting on tiles" N = int(N) # Load images celeba = CelebA_data_generator( selected_features=selected_features, drop_features=drop_features ) x_target, images_index, images_attributes = celeba.get_data(**kwargs) print(f"Loaded {x_target.shape[0]} images") for d in dimensions: params = kwargs.copy() # Get latent dim # We make a sampler mapping to a latent space of dimension d draw from normal distribution Nz = kwargs["Nz"] # Number of images to generate sampler = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0) sampled_latent = np.random.normal(size=(Nz, d)) closest_indices = sampler.dnm(x=sampled_latent).argmin(1) closest_latent = sampler.get_x()[closest_indices] generated_pics = sampler(z=sampled_latent) # (Nz, kwargs['input_shape']) database_pics = sampler(z=closest_latent) # (Nz, kwargs['input_shape']) pic_generated = tiles( generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N] ) pic_database = tiles( database_pics, pic_shape=params["input_shape"], tile_shape=[N, N] ) pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png" plt.imsave( os.path.join(proj_path, "pic_generated_N" + pic_name), pic_generated, vmin=0.0, vmax=1.0, ) plt.show() plt.imsave( os.path.join(proj_path, "pic_database_N" + pic_name), pic_database, vmin=0.0, vmax=1.0, ) plt.show() print("Saved", pic_name) def launch_celebA_Wasserstein_generator( kwargs=config, selected_features=["Blond_Hair", "Attractive", "Smiling"], drop_features=["Male"], dimensions=[1, 2, 3, 10, 40], ): N = np.sqrt(kwargs["Nz"]) # Width of the tiles assert ( N - int(np.sqrt(kwargs["Nz"])) == 0 ), "Number of generated images (Nz) must be a perfect square for plotting on tiles" N = int(N) def celebA_Wasserstein_descent(pics, database_pics, latent_values, generator): """ This function performs a Wasserstein descent on the generated images to find the closest images in the database. It uses the Sampler to generate images and find the closest ones in the database. """ # We compute the distance between generated and database images generated_pics = generator(z=latent_values) # (Nz, kwargs['input_shape']) closest_images_indice = core.KerOp.dnm( x=generated_pics, y=database_pics, distance="norm1" ) closest_images_indice = closest_images_indice.argmin(1) closest_images = database_pics[closest_images_indice] def sgn(x): return math.copysign(1, x) # diff = np.vectorize(sgn)(generated_pics - closest_images) diff = generated_pics - pics nablas = generator.grad(z=latent_values) nablas = -np.einsum("ijk,ik->ij", nablas, diff) # nablas /= np.abs(nablas).mean(1)[:,None] def error(a, b): return np.linalg.norm(a - b) ** 2 return np.linalg.norm(a - b, 1) error_ = error(pics, generated_pics) def f(x): latent = latent_values + x * nablas out = generator(z=latent) out = error(out, pics) return out count = 0 while count < 10: a, b, c = f(0.0), f(1e-8), f(2e-8) fprime = (b - a) / 1e-8 assert fprime <= 0.0 bound = -a / fprime fsec = (a + c - 2 * b) / (1e-16) if fsec >= 0: bound = max(bound, -fprime / fsec) # bound = 1./np.linalg.norm(nablas)**2 xmin, fval, iter, funcalls = optimize.brent( f, brack=(0.0, bound), maxiter=5, full_output=True ) if fval >= error_: break latent_values, error_ = latent_values + xmin * nablas, fval generated_pics = generator(z=latent_values) nablas = generator.grad(z=latent_values) # diff = np.vectorize(sgn)(generated_pics - closest_images) diff = generated_pics - pics nablas = -np.einsum("ijk,ik->ij", nablas, diff) # nablas /= np.abs(nablas).mean(1)[:,None] count = count + 1 pass generated_pics = generator(z=latent_values) return generated_pics, closest_images # Load images celeba = CelebA_data_generator( selected_features=selected_features, drop_features=drop_features ) x_target, images_index, images_attributes = celeba.get_data(**kwargs) print(f"Loaded {x_target.shape[0]} images") for d in dimensions: params = kwargs.copy() # Get latent dim # We make a sampler mapping to a latent space of dimension d draw from normal distribution Nz = kwargs["Nz"] # Number of images to generate new_images = celeba.get_images(Nx=16, exclude_image_ids=images_index) generator = Sampler(x=x_target.values, latent_dim=d, iter=0, reg=0.0) # new_images = generator.get_fx() encoder = Kernel(x=generator.get_fx(), fx=generator.get_x()) latent_values = encoder(new_images) # latent_values = generator.get_x()[:Nz] + np.random.normal(size=(Nz, d))*.001 generated_pics, database_pics = celebA_Wasserstein_descent( new_images, generator.get_fx(), latent_values, generator ) # (Nz, kwargs['input_shape']) # generated_pics = generator(latent_values) # closest_images_indice = core.KerOp.dnm(x=generated_pics,y=generator.get_fx(),distance="norm1") # closest_images_indice = scipy_lsap(closest_images_indice) # database_pics= generator.get_fx()[closest_images_indice] pic_generated = tiles( generated_pics, pic_shape=params["input_shape"], tile_shape=[N, N] ) pic_database = tiles( database_pics, pic_shape=params["input_shape"], tile_shape=[N, N] ) pic_name = str(kwargs["Nx"]) + "D_" + str(d) + ".png" plt.imsave( os.path.join(proj_path, "pic_generated_N" + pic_name), pic_generated, vmin=0.0, vmax=1.0, ) plt.show() plt.imsave( os.path.join(proj_path, "pic_database_N" + pic_name), pic_database, vmin=0.0, vmax=1.0, ) plt.show() print("Saved", pic_name) if __name__ == "__main__": # launch_celebA_generator(dimensions = [40]) launch_celebA_Wasserstein_generator(dimensions=[10]) .. rst-class:: sphx-glr-script-out .. code-block:: none Loaded 1000 images Saved 1000D_10.png .. rst-class:: sphx-glr-timing **Total running time of the script:** (4 minutes 12.934 seconds) .. _sphx_glr_download_auto_ch6_ch6_9_CelebA.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: ch6_9_CelebA.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: ch6_9_CelebA.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: ch6_9_CelebA.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_