How to Orchestrate our DBT Batch Jobs to be Resilient to Single Data Source Failures

If I understand your problem correctly, I believe you’re looking for an orchestrator. The most popular ones right now seem to be Airflow, Prefect, and Dagster. We went with Prefect for its ease of use and flexibility. I wrote some code as a proof of concept of parsing a dbt manifest into a Prefect execution graph and executing each model with dbt rpc, but we don’t use this in production yet so use at your own risk. :slight_smile:

Here’s the execution graph that’s generated from our manifest. Each model is run one at a time, and only if its upstream tasks have completed. On the other hand, failures of some parts of the graph don’t keep the whole graph from executing.

import json
import requests
import time
import uuid

from prefect import Flow
from prefect.core.task import Task
from prefect.executors import LocalDaskExecutor
from prefect.utilities.tasks import defaults_from_attrs

class DbtRpcCommand(Task):
    def __init__(
        self,
        method: str,
        host: str = "0.0.0.0",
        port: int = 8580,
        jsonrpc_version: str = "2.0",
        additional_args: dict = {},
        **kwargs
    ):
        super().__init__(**kwargs)
        self.method = method
        self.host = host
        self.port = port
        self.jsonrpc_version = jsonrpc_version
        self.additional_args = additional_args
    
    @property
    def url(self) -> str:
        return 
    
    @defaults_from_attrs(
        "method",
        "host",
        "port",
        "jsonrpc_version",
        "additional_args",
    )
    def _post(
        self,
        *,
        method: str,
        host: str,
        port: int,
        jsonrpc_version: str,
        additional_args: dict,
    ) -> dict:
        headers = requests.utils.default_headers()
        headers["Content-Type"] = "application/json"
        headers["Accept"] = "application/json"
        data = {
            "jsonrpc": jsonrpc_version,
            "method": method,
            "id": str(uuid.uuid1()),
            "params": additional_args,
        }
        response = requests.post(f"http://{host}:{port}/jsonrpc",
                                 headers=headers, data=json.dumps(data))
        response.raise_for_status()
        response_json = response.json()
        if "error" in response_json:
            raise RuntimeError(
                (
                    f"dbt rpc {method} request raised "
                    f"{response_json['error']['data']['type']}: "
                    f"{response_json['error']['message']}. Full message:\n"
                    f"{response_json['error']['data']['message']}"
                )
            )
        return response.json()
    
    @defaults_from_attrs(
        "method",
        "host",
        "port",
        "jsonrpc_version",
        "additional_args",
    )
    def run(
        self,
        *,
        method: str,
        host: str,
        port: int,
        jsonrpc_version: str,
        additional_args: dict,
    ) -> dict:
        response_dict = self._post(
            method=method,
            host=host,
            port=port,
            jsonrpc_version=jsonrpc_version,
            additional_args=additional_args
        )
        if method in ["status", "poll", "ps", "kill"]:
            return response_dict
        else:
            request_token = response_dict["result"]["request_token"]
            while True:
                status_dict = self._post(
                    method="poll",
                    host=host,
                    port=port,
                    jsonrpc_version=jsonrpc_version,
                    additional_args={
                        "request_token": request_token,
                        "logs": True,
                    }
                )
                current_state = status_dict["result"]["state"]
                if current_state != "running":
                    break
                time.sleep(2)
            if current_state != "success":
                logs = "\n".join(
                    [
                        f"{l['timestamp']} [{l['levelname']}] {l['message']}"
                        for l in status_dict["result"]["logs"]
                        if l['levelname'] != 'DEBUG'
                    ]
                )
                raise RuntimeError(
                    (
                        f"dbt rpc {method} request finished with state "
                        f"{current_state}. Logs:\n{logs}"
                    )
                )
            return status_dict

with Flow("dbt rpc flow") as flow:
    compile_task = DbtRpcCommand(method="compile")
    nodes = [r["node"] for r in compile_task.run()["result"]["results"]]
    model_names = {
        node["unique_id"]: node["unique_id"].split(".", 1)[1]
        for node in nodes if node["resource_type"] == "model"
    }
    models = {
        model_name: {
            "run": DbtRpcCommand(
                method="run",
                additional_args={
                    "models": model_name
                },
                name=f"Run {model_name.split('.')[1]}",
            ),
            "test": DbtRpcCommand(
                method="test",
                additional_args={
                    "models": model_name
                },
                name=f"Test {model_name.split('.')[1]}",
            ),
        }
        for model_name in model_names.values()
    }
    print(models)
    for node in nodes:
        if node["resource_type"] == "model":
            model_name = model_names[node['unique_id']]
            # Get the dependencies
            dependencies = node["depends_on"]["nodes"]
            upstream_tasks = [
                models[d.split(".", 1)[1]]["test"] for d in dependencies
                if d in model_names
            ]
            flow.set_dependencies(
                task = models[model_name]["run"],
                upstream_tasks=upstream_tasks,
            )
            flow.set_dependencies(
                task = models[model_name]["test"],
                upstream_tasks=[models[model_name]["run"]],
            )
        elif node["resource_type"] == "test":
            test_nodes = node["depends_on"]["nodes"]
            if len(test_nodes) > 1:
                # Each node's test depends on all the nodes' runs
                test_models = [
                    models[model_names[test_node]] for test_node in test_nodes
                ]
                for test_model in test_models:
                    flow.set_dependencies(
                        task = test_model["test"],
                        upstream_tasks=[model["run"] for model in test_models],
                    )


flow.executor = LocalDaskExecutor(num_workers=4)

flow.register(project_name='test')

Hope this helps!

1 Like