
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/neighbors/plot_nca_illustration.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_neighbors_plot_nca_illustration.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_neighbors_plot_nca_illustration.py:


=============================================
Neighborhood Components Analysis Illustration
=============================================

This example illustrates a learned distance metric that maximizes
the nearest neighbors classification accuracy. It provides a visual
representation of this metric compared to the original point
space. Please refer to the :ref:`User Guide <nca>` for more information.

.. GENERATED FROM PYTHON SOURCE LINES 11-23

.. code-block:: default


    # License: BSD 3 clause

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_classification
    from sklearn.neighbors import NeighborhoodComponentsAnalysis
    from matplotlib import cm
    from scipy.special import logsumexp

    print(__doc__)








.. GENERATED FROM PYTHON SOURCE LINES 24-30

Original points
---------------
First we create a data set of 9 samples from 3 classes, and plot the points
in the original space. For this example, we focus on the classification of
point no. 3. The thickness of a link between point no. 3 and another point
is proportional to their distance.

.. GENERATED FROM PYTHON SOURCE LINES 30-74

.. code-block:: default


    X, y = make_classification(n_samples=9, n_features=2, n_informative=2,
                               n_redundant=0, n_classes=3, n_clusters_per_class=1,
                               class_sep=1.0, random_state=0)

    plt.figure(1)
    ax = plt.gca()
    for i in range(X.shape[0]):
        ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center')
        ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

    ax.set_title("Original points")
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.axis('equal')  # so that boundaries are displayed correctly as circles


    def link_thickness_i(X, i):
        diff_embedded = X[i] - X
        dist_embedded = np.einsum('ij,ij->i', diff_embedded,
                                  diff_embedded)
        dist_embedded[i] = np.inf

        # compute exponentiated distances (use the log-sum-exp trick to
        # avoid numerical instabilities
        exp_dist_embedded = np.exp(-dist_embedded -
                                   logsumexp(-dist_embedded))
        return exp_dist_embedded


    def relate_point(X, i, ax):
        pt_i = X[i]
        for j, pt_j in enumerate(X):
            thickness = link_thickness_i(X, i)
            if i != j:
                line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
                ax.plot(*line, c=cm.Set1(y[j]),
                        linewidth=5*thickness[j])


    i = 3
    relate_point(X, i, ax)
    plt.show()




.. image-sg:: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_001.png
   :alt: Original points
   :srcset: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 75-80

Learning an embedding
---------------------
We use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an
embedding and plot the points after the transformation. We then take the
embedding and find the nearest neighbors.

.. GENERATED FROM PYTHON SOURCE LINES 80-100

.. code-block:: default


    nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
    nca = nca.fit(X, y)

    plt.figure(2)
    ax2 = plt.gca()
    X_embedded = nca.transform(X)
    relate_point(X_embedded, i, ax2)

    for i in range(len(X)):
        ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i),
                 va='center', ha='center')
        ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]),
                    alpha=0.4)

    ax2.set_title("NCA embedding")
    ax2.axes.get_xaxis().set_visible(False)
    ax2.axes.get_yaxis().set_visible(False)
    ax2.axis('equal')
    plt.show()



.. image-sg:: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_002.png
   :alt: NCA embedding
   :srcset: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_002.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.131 seconds)


.. _sphx_glr_download_auto_examples_neighbors_plot_nca_illustration.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_nca_illustration.py <plot_nca_illustration.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_nca_illustration.ipynb <plot_nca_illustration.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
