
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/neural_networks/plot_mnist_filters.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_neural_networks_plot_mnist_filters.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_neural_networks_plot_mnist_filters.py:


=====================================
Visualization of MLP weights on MNIST
=====================================

Sometimes looking at the learned coefficients of a neural network can provide
insight into the learning behavior. For example if weights look unstructured,
maybe some were not used at all, or if very large coefficients exist, maybe
regularization was too low or the learning rate too high.

This example shows how to plot some of the first layer weights in a
MLPClassifier trained on the MNIST dataset.

The input data consists of 28x28 pixel handwritten digits, leading to 784
features in the dataset. Therefore the first layer weight matrix have the shape
(784, hidden_layer_sizes[0]).  We can therefore visualize a single column of
the weight matrix as a 28x28 pixel image.

To make the example run faster, we use very few hidden units, and train only
for a very short time. Training longer would result in weights with a much
smoother spatial appearance. The example will throw a warning because it
doesn't converge, in this case this is what we want because of CI's time
constraints.

.. GENERATED FROM PYTHON SOURCE LINES 25-67


.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/examples/neural_networks/plot_mnist_filters.py", line 36, in <module>
        X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/utils/validation.py", line 72, in inner_f
        return f(**kwargs)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 738, in fetch_openml
        data_info = _get_data_info_by_name(name, version, data_home)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 381, in _get_data_info_by_name
        json_data = _get_json_content_from_openml_api(url, None, False,
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 161, in _get_json_content_from_openml_api
        return _load_json()
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 61, in wrapper
        return f(*args, **kw)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 157, in _load_json
        with closing(_open_openml_url(url, data_home)) as response:
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_openml.py", line 106, in _open_openml_url
        with closing(urlopen(req)) as fsrc:
      File "/usr/lib/python3.10/urllib/request.py", line 216, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.10/urllib/request.py", line 519, in open
        response = self._open(req, data)
      File "/usr/lib/python3.10/urllib/request.py", line 536, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.10/urllib/request.py", line 496, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.10/urllib/request.py", line 1391, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.10/urllib/request.py", line 1351, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>






|

.. code-block:: default


    import warnings

    import matplotlib.pyplot as plt
    from sklearn.datasets import fetch_openml
    from sklearn.exceptions import ConvergenceWarning
    from sklearn.neural_network import MLPClassifier

    print(__doc__)

    # Load data from https://www.openml.org/d/554
    X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
    X = X / 255.

    # rescale the data, use the traditional train/test split
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]

    mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                        solver='sgd', verbose=10, random_state=1,
                        learning_rate_init=.1)

    # this example won't converge because of CI's time constraints, so we catch the
    # warning and are ignore it here
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning,
                                module="sklearn")
        mlp.fit(X_train, y_train)

    print("Training set score: %f" % mlp.score(X_train, y_train))
    print("Test set score: %f" % mlp.score(X_test, y_test))

    fig, axes = plt.subplots(4, 4)
    # use global min / max to ensure all weights are shown on the same scale
    vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
    for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
        ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
                   vmax=.5 * vmax)
        ax.set_xticks(())
        ax.set_yticks(())

    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.026 seconds)


.. _sphx_glr_download_auto_examples_neural_networks_plot_mnist_filters.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_mnist_filters.py <plot_mnist_filters.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_mnist_filters.ipynb <plot_mnist_filters.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
