Achieve (Almost) dependencies-free Kedro Viz Pipeline

python
kedro
Published

July 8, 2024

import ast
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List

from kedro.pipeline.modular_pipeline import pipeline as ModularPipeline
from kedro.pipeline.pipeline import Node, Pipeline

logger = logging.getLogger(__name__)

The Problem

Kedro viz is a visualisation tools for Kedro. It creates an interactive flow chart that visualize Kedro’s pipeline in a web app. One of the issue is that Kedro-Viz requires loading a Kedro project, this creates frictions as Kedro-Viz is often used for onboarding and installing all dependencies correctly could be a big challenge.

(note: It’s available in kedro-viz now as kedro viz --lite)

Solution

If we focus only on the interactive flowchart of Kedro-Viz, it’s possible to get rid of the dependencies, and the key is to use Abstract Syntax Tree (AST) to parse Kedro pipeline instead actually loading the module.

How this work?

Python as an Interpreted Language

Python is often considered as “interpreted” rather than “compiled”. In fact, compilation still happens in Python but it’s a lot simpler compare to other language like C++.

The things that happens is usually

Parsing a text file -> AST -> Bytecode (i.e. the .pyc file) -> Machine code

Before this, we need to understand what AST is and how can we leverage the ast Python library. AST is a data structure that represent your code in a tree-like structure. For example, consider the snippet below:

import time

# A simple time program
start_time = time.time()
time.sleep(1)
now = time.time()
print("Time spent:", now - start_time)
snippet = """import time

# A simple time program
start_time = time.time()
time.sleep(1)
now = time.time()
print("Time spent:", now - start_time)"""
import ast

parsed = ast.parse(snippet)
print(parsed.body)
[<ast.Import object at 0x127acdea0>, <ast.Assign object at 0x127acd960>, <ast.Expr object at 0x127c2ee00>, <ast.Assign object at 0x127c2ed40>, <ast.Expr object at 0x1270ba020>]

We can use a library called ast.dump to visualise the tree better.

print(ast.dump(parsed, indent=4))
Module(
    body=[
        Import(
            names=[
                alias(name='time')]),
        Assign(
            targets=[
                Name(id='start_time', ctx=Store())],
            value=Call(
                func=Attribute(
                    value=Name(id='time', ctx=Load()),
                    attr='time',
                    ctx=Load()),
                args=[],
                keywords=[])),
        Expr(
            value=Call(
                func=Attribute(
                    value=Name(id='time', ctx=Load()),
                    attr='sleep',
                    ctx=Load()),
                args=[
                    Constant(value=1)],
                keywords=[])),
        Assign(
            targets=[
                Name(id='now', ctx=Store())],
            value=Call(
                func=Attribute(
                    value=Name(id='time', ctx=Load()),
                    attr='time',
                    ctx=Load()),
                args=[],
                keywords=[])),
        Expr(
            value=Call(
                func=Name(id='print', ctx=Load()),
                args=[
                    Constant(value='Time spent:'),
                    BinOp(
                        left=Name(id='now', ctx=Load()),
                        op=Sub(),
                        right=Name(id='start_time', ctx=Load()))],
                keywords=[]))],
    type_ignores=[])

For example, this is corresponding to start_time = time.time()

        Assign(
            targets=[
                Name(id='start_time', ctx=Store())],
            value=Call(
                func=Attribute(
                    value=Name(id='time', ctx=Load()),
                    attr='time',
                    ctx=Load()),
                args=[],
                keywords=[]))
    

There is one thing that is missing from the snippet, which is the comment. As the interpreter does not care about this information, so it is thrown away during the process. If you care about preserving comments, you may consider CST or other parser which keep the information.

Problem - Create flowchart with missing dependencies

Consider this pipeline which requires spark as a dependency.

# from nodes.py
import spark

def my_spark_etl_func():
    spark...
# from pipeline.py
from kedro.pipeline import pipeline, node
from .nodes import my_spark_etl_func

def create_pipeline():
    return pipeline(node(my_spark_etl_func,
                         inputs=["dataset_1","dataset_2"],
                         outputs=["output_dataset_1"]
                        )
                   )

Parsing with AST

From kedro viz perspective, this is the problematic part because this will cause a ImportError:

from .nodes import my_spark_etl_func

As Kedro-viz does not execute these function, it would be nice if we can parse the second part out and ignore the rest of the file. This is where ast will be useful.

def create_pipeline():
    return pipeline(node(my_spark_etl_func,
                         inputs=["dataset_1","dataset_2"],
                         outputs=["output_dataset_1"]
                        )
                   )

Implement a KedroPipelineFinder to find the pipeline defintion

kedro_pipeline_text = """from kedro.pipeline import pipeline, node
from .nodes import my_spark_etl_func

def create_pipeline():
    return pipeline(node(my_spark_etl_func,
                         inputs=["dataset_1","dataset_2"],
                         outputs=["output_dataset_1"]
                        )
                   )"""

ast library provides an useful class ast.NodeVisitor, instead of implementing the entire AST, you only need to implement the relevant part that you care. It implement a method called ast.visit, and you only need to implement the relevant part in your class, i.e. visit_<class_name>. You can find the full list of <class_name> in the AST Grammar.

class FunctionDefPrinter(ast.NodeVisitor):
#     def generic_visit(self, node):
#         print(type(node).__name__)
#         super().generic_visit(node)

    def visit_Import(self, node):
        print(node.names)
        print("print everytime something is imported")

#         print(dir(node))
v = FunctionDefPrinter()
parsed = ast.parse(snippet)
v.visit(parsed)
[<ast.alias object at 0x132fec1c0>]
print everytime something is imported

Step 1 - Parsing function that has a name create_pipeline

class KedroPipelineFinder(ast.NodeVisitor):
    def __init__(self):
        self.pipeline_def = []

    def generic_visit(self, node):
        if ast.NodeVisitor.generic_visit(self,node):
            print("Got something!")

    def visit_FunctionDef(self, node):
        if node.name == "create_pipeline":
            print("found a create_pipeline()")
            self.pipeline_def.append(node)
#             return node
kpf = KedroPipelineFinder()
parsed = ast.parse(kedro_pipeline_text)
kpf.visit(parsed)
found a create_pipeline()
create_pipeline_def = parsed.body[2]
print(ast.dump(create_pipeline_def, indent=4))
FunctionDef(
    name='create_pipeline',
    args=arguments(
        posonlyargs=[],
        args=[],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[]),
    body=[
        Return(
            value=Call(
                func=Name(id='pipeline', ctx=Load()),
                args=[
                    Call(
                        func=Name(id='node', ctx=Load()),
                        args=[
                            Name(id='my_spark_etl_func', ctx=Load())],
                        keywords=[
                            keyword(
                                arg='inputs',
                                value=List(
                                    elts=[
                                        Constant(value='dataset_1'),
                                        Constant(value='dataset_2')],
                                    ctx=Load())),
                            keyword(
                                arg='outputs',
                                value=List(
                                    elts=[
                                        Constant(value='output_dataset_1')],
                                    ctx=Load()))])],
                keywords=[]))],
    decorator_list=[])

Step 2 - Build Kedro Pipeline object

class KedroPipelineBuilder(ast.NodeVisitor):
    def __init__(self, pipeline_def: list):
        self.pipeline_def = pipeline_def

    def build(self, node):
        self.generic_visit(node)
        return ...

    def visit_Call(self, node):
        """Assume it is return from a create_pipeline
        def create_pipeline():
           return pipeline(node(...), node(...), node(...)) or

         pipeline object that is imported from other module won't be captured.
       """
create_pipeline_def
<ast.FunctionDef at 0x132fb3b20>
create_pipeline_def.body[0].value
<ast.Call at 0x132fb3bb0>
call = create_pipeline_def.body[0].value
print(call.args)
[<ast.Call object at 0x132fb24a0>]
call_args = call.args
i = 0
call_arg = call_args[i]
fun_name = call_arg.args[0].id
call_arg.args[0]
<ast.Name at 0x132fb2a70>
ast.unparse(create_pipeline_def)
"def create_pipeline():\n    return pipeline(node(my_spark_etl_func, inputs=['dataset_1', 'dataset_2'], outputs=['output_dataset_1']))"
print(ast.dump(call_arg, indent=3))
Call(
   func=Name(id='node', ctx=Load()),
   args=[
      Name(id='my_spark_etl_func', ctx=Load())],
   keywords=[
      keyword(
         arg='inputs',
         value=List(
            elts=[
               Constant(value='dataset_1'),
               Constant(value='dataset_2')],
            ctx=Load())),
      keyword(
         arg='outputs',
         value=List(
            elts=[
               Constant(value='output_dataset_1')],
            ctx=Load()))])
inputs = call_arg.args
print(ast.dump(create_pipeline_def, indent=2))
FunctionDef(
  name='create_pipeline',
  args=arguments(
    posonlyargs=[],
    args=[],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
  body=[
    Return(
      value=Call(
        func=Name(id='pipeline', ctx=Load()),
        args=[
          Call(
            func=Name(id='node', ctx=Load()),
            args=[
              Name(id='my_spark_etl_func', ctx=Load())],
            keywords=[
              keyword(
                arg='inputs',
                value=List(
                  elts=[
                    Constant(value='dataset_1'),
                    Constant(value='dataset_2')],
                  ctx=Load())),
              keyword(
                arg='outputs',
                value=List(
                  elts=[
                    Constant(value='output_dataset_1')],
                  ctx=Load()))])],
        keywords=[]))],
  decorator_list=[])