
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/linear_model/plot_sparse_logistic_regression_20newsgroups.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_linear_model_plot_sparse_logistic_regression_20newsgroups.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_linear_model_plot_sparse_logistic_regression_20newsgroups.py:


====================================================
Multiclass sparse logistic regression on 20newgroups
====================================================

Comparison of multinomial logistic L1 vs one-versus-rest L1 logistic regression
to classify documents from the newgroups20 dataset. Multinomial logistic
regression yields more accurate results and is faster to train on the larger
scale dataset.

Here we use the l1 sparsity that trims the weights of not informative
features to zero. This is good if the goal is to extract the strongly
discriminative vocabulary of each class. If the goal is to get the best
predictive accuracy, it is better to use the non sparsity-inducing l2 penalty
instead.

A more traditional (and possibly better) way to predict on a sparse subset of
input features would be to use univariate feature selection followed by a
traditional (l2-penalised) logistic regression model.

.. GENERATED FROM PYTHON SOURCE LINES 21-118


.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/examples/linear_model/plot_sparse_logistic_regression_20newsgroups.py", line 45, in <module>
        X, y = fetch_20newsgroups_vectorized(subset='all', return_X_y=True)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/utils/validation.py", line 72, in inner_f
        return f(**kwargs)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_twenty_newsgroups.py", line 419, in fetch_20newsgroups_vectorized
        data_train = fetch_20newsgroups(data_home=data_home,
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/utils/validation.py", line 72, in inner_f
        return f(**kwargs)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_twenty_newsgroups.py", line 258, in fetch_20newsgroups
        cache = _download_20newsgroups(target_dir=twenty_home,
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_twenty_newsgroups.py", line 74, in _download_20newsgroups
        archive_path = _fetch_remote(ARCHIVE, dirname=target_dir)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_base.py", line 1181, in _fetch_remote
        urlretrieve(remote.url, file_path)
      File "/usr/lib/python3.10/urllib/request.py", line 241, in urlretrieve
        with contextlib.closing(urlopen(url, data)) as fp:
      File "/usr/lib/python3.10/urllib/request.py", line 216, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.10/urllib/request.py", line 519, in open
        response = self._open(req, data)
      File "/usr/lib/python3.10/urllib/request.py", line 536, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.10/urllib/request.py", line 496, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.10/urllib/request.py", line 1391, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.10/urllib/request.py", line 1351, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>






|

.. code-block:: default

    import timeit
    import warnings

    import matplotlib.pyplot as plt
    import numpy as np

    from sklearn.datasets import fetch_20newsgroups_vectorized
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split
    from sklearn.exceptions import ConvergenceWarning

    print(__doc__)
    # Author: Arthur Mensch

    warnings.filterwarnings("ignore", category=ConvergenceWarning,
                            module="sklearn")
    t0 = timeit.default_timer()

    # We use SAGA solver
    solver = 'saga'

    # Turn down for faster run time
    n_samples = 10000

    X, y = fetch_20newsgroups_vectorized(subset='all', return_X_y=True)
    X = X[:n_samples]
    y = y[:n_samples]

    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        random_state=42,
                                                        stratify=y,
                                                        test_size=0.1)
    train_samples, n_features = X_train.shape
    n_classes = np.unique(y).shape[0]

    print('Dataset 20newsgroup, train_samples=%i, n_features=%i, n_classes=%i'
          % (train_samples, n_features, n_classes))

    models = {'ovr': {'name': 'One versus Rest', 'iters': [1, 2, 4]},
              'multinomial': {'name': 'Multinomial', 'iters': [1, 3, 7]}}

    for model in models:
        # Add initial chance-level values for plotting purpose
        accuracies = [1 / n_classes]
        times = [0]
        densities = [1]

        model_params = models[model]

        # Small number of epochs for fast runtime
        for this_max_iter in model_params['iters']:
            print('[model=%s, solver=%s] Number of epochs: %s' %
                  (model_params['name'], solver, this_max_iter))
            lr = LogisticRegression(solver=solver,
                                    multi_class=model,
                                    penalty='l1',
                                    max_iter=this_max_iter,
                                    random_state=42,
                                    )
            t1 = timeit.default_timer()
            lr.fit(X_train, y_train)
            train_time = timeit.default_timer() - t1

            y_pred = lr.predict(X_test)
            accuracy = np.sum(y_pred == y_test) / y_test.shape[0]
            density = np.mean(lr.coef_ != 0, axis=1) * 100
            accuracies.append(accuracy)
            densities.append(density)
            times.append(train_time)
        models[model]['times'] = times
        models[model]['densities'] = densities
        models[model]['accuracies'] = accuracies
        print('Test accuracy for model %s: %.4f' % (model, accuracies[-1]))
        print('%% non-zero coefficients for model %s, '
              'per class:\n %s' % (model, densities[-1]))
        print('Run time (%i epochs) for model %s:'
              '%.2f' % (model_params['iters'][-1], model, times[-1]))

    fig = plt.figure()
    ax = fig.add_subplot(111)

    for model in models:
        name = models[model]['name']
        times = models[model]['times']
        accuracies = models[model]['accuracies']
        ax.plot(times, accuracies, marker='o',
                label='Model: %s' % name)
        ax.set_xlabel('Train time (s)')
        ax.set_ylabel('Test accuracy')
    ax.legend()
    fig.suptitle('Multinomial vs One-vs-Rest Logistic L1\n'
                 'Dataset %s' % '20newsgroups')
    fig.tight_layout()
    fig.subplots_adjust(top=0.85)
    run_time = timeit.default_timer() - t0
    print('Example run in %.3f s' % run_time)
    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.017 seconds)


.. _sphx_glr_download_auto_examples_linear_model_plot_sparse_logistic_regression_20newsgroups.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_sparse_logistic_regression_20newsgroups.py <plot_sparse_logistic_regression_20newsgroups.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_sparse_logistic_regression_20newsgroups.ipynb <plot_sparse_logistic_regression_20newsgroups.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
