
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/compose/plot_column_transformer.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_compose_plot_column_transformer.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_compose_plot_column_transformer.py:


==================================================
Column Transformer with Heterogeneous Data Sources
==================================================

Datasets can often contain components that require different feature
extraction and processing pipelines. This scenario might occur when:

1. your dataset consists of heterogeneous data types (e.g. raster images and
   text captions),
2. your dataset is stored in a :class:`pandas.DataFrame` and different columns
   require different processing pipelines.

This example demonstrates how to use
:class:`~sklearn.compose.ColumnTransformer` on a dataset containing
different types of features. The choice of features is not particularly
helpful, but serves to illustrate the technique.

.. GENERATED FROM PYTHON SOURCE LINES 19-36

.. code-block:: default


    # Author: Matt Terry <matt.terry@gmail.com>
    #
    # License: BSD 3 clause

    import numpy as np

    from sklearn.preprocessing import FunctionTransformer
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.decomposition import TruncatedSVD
    from sklearn.feature_extraction import DictVectorizer
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics import classification_report
    from sklearn.pipeline import Pipeline
    from sklearn.compose import ColumnTransformer
    from sklearn.svm import LinearSVC








.. GENERATED FROM PYTHON SOURCE LINES 37-45

20 newsgroups dataset
---------------------

We will use the :ref:`20 newsgroups dataset <20newsgroups_dataset>`, which
comprises posts from newsgroups on 20 topics. This dataset is split
into train and test subsets based on messages posted before and after
a specific date. We will only use posts from 2 categories to speed up running
time.

.. GENERATED FROM PYTHON SOURCE LINES 45-58

.. code-block:: default


    categories = ['sci.med', 'sci.space']
    X_train, y_train = fetch_20newsgroups(random_state=1,
                                          subset='train',
                                          categories=categories,
                                          remove=('footers', 'quotes'),
                                          return_X_y=True)
    X_test, y_test = fetch_20newsgroups(random_state=1,
                                        subset='test',
                                        categories=categories,
                                        remove=('footers', 'quotes'),
                                        return_X_y=True)



.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/examples/compose/plot_column_transformer.py", line 47, in <module>
        X_train, y_train = fetch_20newsgroups(random_state=1,
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/utils/validation.py", line 72, in inner_f
        return f(**kwargs)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_twenty_newsgroups.py", line 258, in fetch_20newsgroups
        cache = _download_20newsgroups(target_dir=twenty_home,
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_twenty_newsgroups.py", line 74, in _download_20newsgroups
        archive_path = _fetch_remote(ARCHIVE, dirname=target_dir)
      File "/build/scikit-learn-ZSX7SD/scikit-learn-0.23.2/.pybuild/cpython3_3.10/build/sklearn/datasets/_base.py", line 1181, in _fetch_remote
        urlretrieve(remote.url, file_path)
      File "/usr/lib/python3.10/urllib/request.py", line 241, in urlretrieve
        with contextlib.closing(urlopen(url, data)) as fp:
      File "/usr/lib/python3.10/urllib/request.py", line 216, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.10/urllib/request.py", line 519, in open
        response = self._open(req, data)
      File "/usr/lib/python3.10/urllib/request.py", line 536, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.10/urllib/request.py", line 496, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.10/urllib/request.py", line 1391, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.10/urllib/request.py", line 1351, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>




.. GENERATED FROM PYTHON SOURCE LINES 59-61

Each feature comprises meta information about that post, such as the subject,
and the body of the news post.

.. GENERATED FROM PYTHON SOURCE LINES 61-64

.. code-block:: default


    print(X_train[0])


.. GENERATED FROM PYTHON SOURCE LINES 65-74

Creating transformers
---------------------

First, we would like a transformer that extracts the subject and
body of each post. Since this is a stateless transformation (does not
require state information from training data), we can define a function that
performs the data transformation then use
:class:`~sklearn.preprocessing.FunctionTransformer` to create a scikit-learn
transformer.

.. GENERATED FROM PYTHON SOURCE LINES 74-100

.. code-block:: default



    def subject_body_extractor(posts):
        # construct object dtype array with two columns
        # first column = 'subject' and second column = 'body'
        features = np.empty(shape=(len(posts), 2), dtype=object)
        for i, text in enumerate(posts):
            # temporary variable `_` stores '\n\n'
            headers, _, body = text.partition('\n\n')
            # store body text in second column
            features[i, 1] = body

            prefix = 'Subject:'
            sub = ''
            # save text after 'Subject:' in first column
            for line in headers.split('\n'):
                if line.startswith(prefix):
                    sub = line[len(prefix):]
                    break
            features[i, 0] = sub

        return features


    subject_body_transformer = FunctionTransformer(subject_body_extractor)


.. GENERATED FROM PYTHON SOURCE LINES 101-103

We will also create a transformer that extracts the
length of the text and the number of sentences.

.. GENERATED FROM PYTHON SOURCE LINES 103-113

.. code-block:: default



    def text_stats(posts):
        return [{'length': len(text),
                 'num_sentences': text.count('.')}
                for text in posts]


    text_stats_transformer = FunctionTransformer(text_stats)


.. GENERATED FROM PYTHON SOURCE LINES 114-123

Classification pipeline
-----------------------

The pipeline below extracts the subject and body from each post using
``SubjectBodyExtractor``, producing a (n_samples, 2) array. This array is
then used to compute standard bag-of-words features for the subject and body
as well as text length and number of sentences on the body, using
``ColumnTransformer``. We combine them, with weights, then train a
classifier on the combined set of features.

.. GENERATED FROM PYTHON SOURCE LINES 123-154

.. code-block:: default


    pipeline = Pipeline([
        # Extract subject & body
        ('subjectbody', subject_body_transformer),
        # Use ColumnTransformer to combine the subject and body features
        ('union', ColumnTransformer(
            [
                # bag-of-words for subject (col 0)
                ('subject', TfidfVectorizer(min_df=50), 0),
                # bag-of-words with decomposition for body (col 1)
                ('body_bow', Pipeline([
                    ('tfidf', TfidfVectorizer()),
                    ('best', TruncatedSVD(n_components=50)),
                ]), 1),
                # Pipeline for pulling text stats from post's body
                ('body_stats', Pipeline([
                    ('stats', text_stats_transformer),  # returns a list of dicts
                    ('vect', DictVectorizer()),  # list of dicts -> feature matrix
                ]), 1),
            ],
            # weight above ColumnTransformer features
            transformer_weights={
                'subject': 0.8,
                'body_bow': 0.5,
                'body_stats': 1.0,
            }
        )),
        # Use a SVC classifier on the combined features
        ('svc', LinearSVC(dual=False)),
    ], verbose=True)


.. GENERATED FROM PYTHON SOURCE LINES 155-157

Finally, we fit our pipeline on the training data and use it to predict
topics for ``X_test``. Performance metrics of our pipeline are then printed.

.. GENERATED FROM PYTHON SOURCE LINES 157-163

.. code-block:: default


    pipeline.fit(X_train, y_train)
    y_pred = pipeline.predict(X_test)
    print('Classification report:\n\n{}'.format(
        classification_report(y_test, y_pred))
    )


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.012 seconds)


.. _sphx_glr_download_auto_examples_compose_plot_column_transformer.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_column_transformer.py <plot_column_transformer.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_column_transformer.ipynb <plot_column_transformer.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
