.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_ch6\ch6_iris.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_ch6_ch6_iris.py: =============================== 6.7. Conditional Sampling =============================== In this tutorial, we introduce the problem of conditional sampling, i.e., generating samples from a distribution \( p(y|x) \) using kernel-based models. We’ll use the classic Iris dataset to demonstrate two approaches: - **Nadaraya-Watson Estimator** - **Optimal Transport-based Conditioner** .. GENERATED FROM PYTHON SOURCE LINES 12-118 .. image-sg:: /auto_ch6/images/sphx_glr_ch6_iris_001.png :alt: sep.len | pet.wid., sep.wid. | pet.wid., pet.len | pet.wid., sep.len | pet.wid., sep.wid. | pet.wid., pet.len | pet.wid. :srcset: /auto_ch6/images/sphx_glr_ch6_iris_001.png :class: sphx-glr-single-img .. code-block:: Python import numpy as np import pandas as pd import seaborn as sns from sklearn.datasets import load_iris from codpy import conditioning import codpy.core from codpy.core import get_matrix from codpy.plot_utils import multi_plot import matplotlib.pyplot as plt def load_iris_data(threshold=0.2): """ Loads the Iris dataset and prepares conditional data: using 'petal width' to condition other features. """ data = load_iris() feature_names = ["sep.len", "sep.wid.", "pet.len", "pet.wid."] df = pd.DataFrame(data.data, columns=feature_names) x_cols = ["pet.wid."] y_cols = [col for col in df.columns if col not in x_cols] x = df[x_cols] y = df[y_cols] # Identify reference distribution: near the mean of x threshold = np.power(threshold, 1.0 / len(x_cols)) mean_vals = df.mean() var_vals = df.var() mask = np.ones(len(df), dtype=bool) for col in x_cols: mask &= np.abs(df[col] - mean_vals[col]) < var_vals[col] * threshold ref_dist = df.loc[mask, y_cols] return x, y, ref_dist, get_matrix(mean_vals[x_cols]).T def sample_and_plot(samplers, N=500): """ Sample conditionally using provided samplers and plot marginals. """ x, y, ref_dist, cond_point = load_iris_data() results = [] for SamplerClass in samplers: # Initialize and sample model = SamplerClass(x=x, y=y) sampled = model.sample(x=cond_point, n=N,iter=0).squeeze() # Prepare DataFrame df_sampled = pd.DataFrame(sampled, columns=y.columns) df_sampled["Type"] = SamplerClass.__name__ # Combine with reference df_ref = ref_dist.copy() df_ref["Type"] = "Reference" results.append(pd.concat([df_sampled, df_ref])) # Plotting plot_marginals(results, condition_col=x.columns[0]) def plot_marginals(results, condition_col): """ Plot cumulative histograms for each feature conditioned on the same input. """ plots = [] for df in results: for col in df.columns: if col != "Type": plots.append({"data": df, "col": col, "cond_col": condition_col}) def plot_func(params, **kwargs): ax = params["ax"] data, col, cond = params["data"], params["col"], params["cond_col"] sns.histplot( data=data, x=col, hue="Type", bins=100, cumulative=True, ax=ax, stat="density", common_norm=False, fill=False, element="step", ) ax.set_title(f"{col} | {cond}") multi_plot(plots, fun_plot=plot_func, mp_ncols=3, mp_nrows=len(results)) codpy.core.KerInterface.set_verbose() sample_and_plot( samplers=[ conditioning.ConditionerKernel, conditioning.NadarayaWatsonKernel, ] ) plt.show() pass .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.270 seconds) .. _sphx_glr_download_auto_ch6_ch6_iris.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: ch6_iris.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: ch6_iris.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: ch6_iris.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_