Quick implementation of Kedro DebugRunner

Quick implementation of a debug runner
europython
Published

November 1, 2022

core

Fill in a module description here

::: {.cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *

:::

::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from collections import Counter
from itertools import chain
from typing import Any, Dict, Iterable, List, Set

from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from kedro.runner import SequentialRunner
from kedro.runner.runner import AbstractRunner, run_node
from pluggy import PluginManager


class DebugRunner(SequentialRunner):
    def run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        dataset_names: List[str] = None,
        hook_manager: PluginManager = None,
        session_id: str = None,
    ) -> Dict[str, Any]:
        """Run the ``Pipeline`` using the datasets provided by ``catalog``
        and save results back to the same objects.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            ValueError: Raised when ``Pipeline`` inputs cannot be satisfied.

        Returns:
            Any node outputs that cannot be processed by the ``DataCatalog``.
            These are returned in a dictionary, where the keys are defined
            by the node outputs.

        """
        if not dataset_names:
            dataset_names = []
        hook_manager = hook_manager or _NullPluginManager()
        catalog = catalog.shallow_copy()

        unsatisfied = pipeline.inputs() - set(catalog.list())
        if unsatisfied:
            raise ValueError(
                f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
            )

        free_outputs = (
            pipeline.outputs()
        )  # Return everything regardless if it it's in catalog
        unregistered_ds = pipeline.data_sets() - set(catalog.list())
        for ds_name in unregistered_ds:
            catalog.add(ds_name, self.create_default_data_set(ds_name))

        if self._is_async:
            self._logger.info(
                "Asynchronous mode is enabled for loading and saving data"
            )
        self._run(pipeline, catalog, dataset_names, hook_manager, session_id)

        self._logger.info("Pipeline execution completed successfully.")
        
        free_outputs = free_outputs | set(dataset_names)  # Union

        return {ds_name: catalog.load(ds_name) for ds_name in free_outputs}

    def _run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        dataset_names: List[str],
        hook_manager: PluginManager,
        session_id: str = None,
    ) -> None:
        """The method implementing sequential pipeline running.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            Exception: in case of any downstream node failure.
        """
        nodes = pipeline.nodes
        done_nodes = set()

        load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))

        for exec_index, node in enumerate(nodes):
            try:
                run_node(node, catalog, hook_manager, self._is_async, session_id)
                done_nodes.add(node)
            except Exception:
                self._suggest_resume_scenario(pipeline, done_nodes, catalog)
                raise

            # decrement load counts and release any data sets we've finished with
            for data_set in node.inputs:
                load_counts[data_set] -= 1
                if load_counts[data_set] < 1 and data_set not in pipeline.inputs():
                    if data_set not in dataset_names:
                        catalog.release(data_set)
            for data_set in node.outputs:
                if load_counts[data_set] < 1 and data_set not in pipeline.outputs():
                    if data_set not in dataset_names:
                        catalog.release(data_set)

            self._logger.info(
                "Completed %d out of %d tasks", exec_index + 1, len(nodes)
            )

:::

# `DebugRunner` has to be used in a different way since `session.run` don't support additional argument, so we are going to use a lower level approach and construct `Runner` and `Pipeline` and `DataCatalog` ourselves.

# Testing Kedro Project: https://github.com/noklam/kedro_gallery/tree/master/kedro-debug-runner-demo
%load_ext kedro.ipython
%reload_kedro ~/dev/kedro_gallery/kedro-debug-runner-demo
The kedro.ipython extension is already loaded. To reload it, use:
  %reload_ext kedro.ipython
[10/06/22 14:45:20] INFO     Updated path to Kedro project:       __init__.py:54
                             /Users/Nok_Lam_Chan/dev/kedro_galler               
                             y/kedro-debug-runner-demo                          
[10/06/22 14:45:22] INFO     Kedro project                        __init__.py:77
                             kedro_debug_runner_demo                            
                    INFO     Defined global variable 'context',   __init__.py:78
                             'session', 'catalog' and 'pipelines'               
%reload_kedro ~/dev/kedro_gallery/kedro-debug-runner-demo
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_1 = runner.run(default_pipeline, catalog)
                    INFO     Updated path to Kedro project:       __init__.py:54
                             /Users/Nok_Lam_Chan/dev/kedro_galler               
                             y/kedro-debug-runner-demo                          
[10/06/22 14:45:24] INFO     Kedro project                        __init__.py:77
                             kedro_debug_runner_demo                            
                    INFO     Defined global variable 'context',   __init__.py:78
                             'session', 'catalog' and 'pipelines'               
                    INFO     Loading data from               data_catalog.py:343
                             'example_iris_data'                                
                             (CSVDataSet)...                                    
                    INFO     Loading data from 'parameters'  data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: split:                    node.py:327
                             split_data([example_iris_data,parameter            
                             s]) -> [X_train,X_test,y_train,y_test]             
                    INFO     Saving data to 'X_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'X_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: make_predictions:         node.py:327
                             make_predictions([X_train,X_test,y_trai            
                             n]) -> [y_pred]                                    
                    INFO     Saving data to 'y_pred'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_pred'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: report_accuracy:          node.py:327
                             report_accuracy([y_pred,y_test]) ->                
                             None                                               
                    INFO     Model has accuracy of 0.933 on test     nodes.py:74
                             data.                                              
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_2 = runner.run(default_pipeline, catalog, dataset_names=["example_iris_data"])
[10/06/22 14:45:27] INFO     Loading data from               data_catalog.py:343
                             'example_iris_data'                                
                             (CSVDataSet)...                                    
                    INFO     Loading data from 'parameters'  data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: split:                    node.py:327
                             split_data([example_iris_data,parameter            
                             s]) -> [X_train,X_test,y_train,y_test]             
                    INFO     Saving data to 'X_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'X_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: make_predictions:         node.py:327
                             make_predictions([X_train,X_test,y_trai            
                             n]) -> [y_pred]                                    
                    INFO     Saving data to 'y_pred'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_pred'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: report_accuracy:          node.py:327
                             report_accuracy([y_pred,y_test]) ->                
                             None                                               
                    INFO     Model has accuracy of 0.933 on test     nodes.py:74
                             data.                                              
                    INFO     Loading data from               data_catalog.py:343
                             'example_iris_data'                                
                             (CSVDataSet)...                                    
runner = DebugRunner()
default_pipeline = pipelines["__default__"]
run_3 = runner.run(default_pipeline, catalog, dataset_names=["X_train"]) # Input datasets
[10/06/22 14:46:01] INFO     Loading data from               data_catalog.py:343
                             'example_iris_data'                                
                             (CSVDataSet)...                                    
                    INFO     Loading data from 'parameters'  data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: split:                    node.py:327
                             split_data([example_iris_data,parameter            
                             s]) -> [X_train,X_test,y_train,y_test]             
                    INFO     Saving data to 'X_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'X_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_train'        data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Saving data to 'y_test'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'X_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: make_predictions:         node.py:327
                             make_predictions([X_train,X_test,y_trai            
                             n]) -> [y_pred]                                    
                    INFO     Saving data to 'y_pred'         data_catalog.py:382
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_pred'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Loading data from 'y_test'      data_catalog.py:343
                             (MemoryDataSet)...                                 
                    INFO     Running node: report_accuracy:          node.py:327
                             report_accuracy([y_pred,y_test]) ->                
                             None                                               
                    INFO     Model has accuracy of 0.933 on test     nodes.py:74
                             data.                                              
                    INFO     Loading data from 'X_train'     data_catalog.py:343
                             (MemoryDataSet)...                                 
run_1
{}
run_2
{'example_iris_data':      sepal_length  sepal_width  petal_length  petal_width    species
 0             5.1          3.5           1.4          0.2     setosa
 1             4.9          3.0           1.4          0.2     setosa
 2             4.7          3.2           1.3          0.2     setosa
 3             4.6          3.1           1.5          0.2     setosa
 4             5.0          3.6           1.4          0.2     setosa
 ..            ...          ...           ...          ...        ...
 145           6.7          3.0           5.2          2.3  virginica
 146           6.3          2.5           5.0          1.9  virginica
 147           6.5          3.0           5.2          2.0  virginica
 148           6.2          3.4           5.4          2.3  virginica
 149           5.9          3.0           5.1          1.8  virginica
 
 [150 rows x 5 columns]}
run_3
{'X_train':      sepal_length  sepal_width  petal_length  petal_width
 47            4.6          3.2           1.4          0.2
 3             4.6          3.1           1.5          0.2
 31            5.4          3.4           1.5          0.4
 25            5.0          3.0           1.6          0.2
 15            5.7          4.4           1.5          0.4
 ..            ...          ...           ...          ...
 28            5.2          3.4           1.4          0.2
 78            6.0          2.9           4.5          1.5
 146           6.3          2.5           5.0          1.9
 49            5.0          3.3           1.4          0.2
 94            5.6          2.7           4.2          1.3
 
 [120 rows x 4 columns]}

::: {.cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class GreedySequentialRunner(SequentialRunner):
    def run(
        self,
        pipeline: Pipeline,
        catalog: DataCatalog,
        hook_manager: PluginManager = None,
        session_id: str = None,
    ) -> Dict[str, Any]:
        """Run the ``Pipeline`` using the datasets provided by ``catalog``
        and save results back to the same objects.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: The ``DataCatalog`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            ValueError: Raised when ``Pipeline`` inputs cannot be satisfied.

        Returns:
            Any node outputs that cannot be processed by the ``DataCatalog``.
            These are returned in a dictionary, where the keys are defined
            by the node outputs.

        """

        hook_manager = hook_manager or _NullPluginManager()
        catalog = catalog.shallow_copy()

        unsatisfied = pipeline.inputs() - set(catalog.list())
        if unsatisfied:
            raise ValueError(
                f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
            )

        free_outputs = pipeline.outputs() # Return everything regardless if it it's in catalog
        unregistered_ds = pipeline.data_sets() - set(catalog.list())
        for ds_name in unregistered_ds:
            catalog.add(ds_name, self.create_default_data_set(ds_name))

        if self._is_async:
            self._logger.info(
                "Asynchronous mode is enabled for loading and saving data"
            )
        self._run(pipeline, catalog, hook_manager, session_id)

        self._logger.info("Pipeline execution completed successfully.")

        return {ds_name: catalog.load(ds_name) for ds_name in free_outputs}

:::

::: {.cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

import nbdev; nbdev.nbdev_export()

:::