Using Python to update model_schema.yml files and create staging files (BigQuery Specific)

Howdy folks, this is a long post–I didn’t want to make multiple posts. Also, happy for input on the Python, in case you think it could be improved :slight_smile:

1. Updating model_schema.yml Files Using Python

Background

Analysts are required to be on top of updating model_schema.yml files. I wrote a Python script that will update the model_schema.yml for the analyst, after they have modified & run a query. This is very helpful, especially for documenting queries that are dynamically generated using Jinja loops or similar, where one can end up with a large number of relatively-dynamic columns. I have used ruamel.yaml–I found that the common PyYAML to not be sufficient for my purposes, as I found the ruamel round-trip formatting and export to be much superior.

This codes uses ruamel v0.17.21, found here: ruamel.yaml · PyPI

  1. pip install -U pip setuptools wheel
  2. pip install ruamel.yaml==0.17.21

Some Important Notes

A few notes about the limitations of this code:

  1. This code requires the analyst to have run the query recently–it operates off of the columns it finds in the database from the analyst’s development run environment.
  2. This code is specific to BigQuery,
  3. This code assumes that all model_schema.yml files follow the same naming structure, ex: models_schema__invoices.yml, models_schema__customers.yml, …
  4. This code preserves any existing description or other parameters present in YAML entries
  5. It assumes that all description fields are double-quoted ex. description: "Hello world!"

How the code works

  1. Prompt user for the name of query they modified
  2. Search across model_schema.yml files across the project
  3. Identify where the entry exists: Note: there must be an entry for the query, and at least 1 column defined
  4. Query against the analyst’s development output, to identify columns that will be seen in production
  5. Check if it should keep, drop, or add existing column entries
  6. Write the updated query+column fields to the appropriate YAML file

As a side note, updating YAML in-place was a massive pain, but it didn’t seem to be a common thing on the internet. This is my hacky solution–happy to hear if folks have a better way of approaching this.

At the moment, this code must be manually triggered by the analyst. In a nice future state, this would be triggered by git-commit actions, and it would automagically search for modified queries, update the relevant yaml, and prompt the user to update relevant documentation parameters in the yaml.

Minimum amount of yml necessary for the code to successfully execute

This code works for both net-new and existing yml entries. However, it needs the analyst to start the net-new entry–the intent of this is that it focuses more on dynamically updating existing entries, so the analyst doesn’t have to go digging through folder structure as often.

version: 2

models:
  - name: query_you_modified
    columns:
      - name: need_at_least_1_field
        description: "If net-new, just put in a fake field name, as it will be dropped and replaced with the proper fields."

The Python code:

Any variables specific to your use case are flagged as ### UPDATE ME!!! in the code

import collections
import os
import pandas
import pathlib
import sys

from google.api_core.exceptions import BadRequest
from google.cloud import bigquery
from ruamel.yaml import YAML
from ruamel.yaml.scalarstring import DoubleQuotedScalarString # Used in our production environment, but not used in this demo code

########## USER DEFINED CUSTOM CONFIGS FOR USE IN SCRIPT ##########
### NOTES ###
# This script requires that dbt be installed in the top-level folder of your OS ex. ~/dbt, and that your profiles.yml lives in ~/.dbt/profiles.yml
# This script requires that your model_schema.yml files are all quoted for descriptions
# This script works for BigQuery connections--not sure how it works with other databases
# This script requires that you already have the query defined in the model_schema.yml, with at least 1 column defined
### END NOTES ###

# Define your dataset name--current query works with Google BigQuery
client = bigquery.Client(location="US")
project_name = "your-project-name"  ### UPDATE ME!!!

# Requires that all of your model_schema.yml files to start with the same naming schema
yaml_file_prefix = "models_descr"  ### UPDATE ME!!!
# If you are adding new fields, what are the default dbt column fields that are always added. This ensures that the fields are always added in order
additional_default_fields = [("description", "No description set"), ("allowed_field", True), ("other_field", "yay")]  ### UPDATE ME!!! This is a dictionary of default fields that you expect for every column entry. 
additional_default_fields = collections.OrderedDict(additional_default_fields)

########## END USER DEFINED CONFIGS ##########

# RUAMEL.YAML CONFIGS
# ruamel.yaml v.0.17.21
yaml = YAML()
# control indentation of hierarchy in yaml
yaml.indent(sequence=4, offset=2)
# Preserve quotes in outputs
yaml.preserve_quotes = True
# Keep descriptions from line wrapping
yaml.width = 4096

# Pull table and column data from current dev output
def table_pull(project_name, dbt_dev_folder, query_name):
    query = f"""
        SELECT
        cols.column_name
        , cols.description AS column_description
        , tables.option_value AS table_description

        FROM `{project_name}.{dbt_dev_folder}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` AS cols

        LEFT JOIN `{project_name}.{dbt_dev_folder}.INFORMATION_SCHEMA.TABLE_OPTIONS` AS tables
        ON cols.table_name = tables.table_name
        AND tables.option_name = "description"

        WHERE cols.table_name = "{query_name}"

        ORDER BY 1
        """

    print(f"Querying {query_name} \n")

    query_job = client.query(query)

    try:
        query_job.result()
    except BadRequest:
        for e in query_job.errors:
            print(f"QUERY ERROR: {e['message']} \n")
        sys.exit(f"Beep boop query failed, exiting script... \n")

    columns = query_job.to_dataframe()
    print(f"The query used was: \n\n{query} \n")
    print(f"Tables loaded into dataframe \n")

    if columns.empty:
        sys.exit("Your query returned no result, exiting script \n")

    return columns

def main():
    # Define paths for use across script
    current_path = pathlib.Path().resolve()
    home_path = pathlib.Path.home()

    dbt_path = f"{home_path}/dbt/"
    dbt_profile_path = f"{home_path}/.dbt/profiles.yml"
    dbt_yml_locations_path = f"{home_path}/dbt/models/"

    actual_path = os.path.relpath(home_path, current_path)

    contyn = False

    # User must have run the modified query
    print("You need to have run the query you are trying to generate .yml for before running this script \n")

    has_run_query = input("Have you run the query in question? Enter y or n:  \n")

    if has_run_query == "n":
        sys.exit("You need to run your query before using this script. We will now exit. \n")

    # Prompt user to enter file name
    while contyn == False:
        modified_file_name = input("Enter the name of the file that you have modified, and want to generate .yml entries for: \n")
        yesno = input(f"You entered \n{modified_file_name}\nDo you want to use this? Enter y or n: \n")
        contyn = True if yesno == "y" else False

    # Grab the dev dataset for the analyst
    with open(dbt_profile_path, "r") as profiles:
        data = yaml.load(profiles)
        dbt_dev_dataset = data.get("default").get("outputs").get("dev").get("dataset")

    # Query output of modified dbt query
    dbt_query_output = table_pull(project_name, dbt_dev_dataset, modified_file_name)

    # Identify all .yml files in dbt/models
    all_yml_files = []
    for (dirpath, dirnames, filenames) in os.walk(dbt_yml_locations_path):
        for file in filenames:
            if file.startswith(yaml_file_prefix) and file.endswith(".yml") and ".ipynb_checkpoints" not in dirpath:
                all_yml_files.append(os.path.join(dirpath, file))

    # Check if the query already has a entry in a model description .yml file
    query_yml_file_location = None

    for index, ymlfile in enumerate(all_yml_files):
        with open(ymlfile, "r") as models:
            read_data = yaml.load(models)

            # We only want to search across keys inside of the "models" key entry in the description yaml files
            # We need to search 1 level down (models -> "name" keys of queries)
            # dbt model_schema.yml is a dictionary(version/models) then a list (each query is a list) then a dictionary (name/description/columns) then a list for columns, then a dictionary for each column entry

            # Access version/models in .yml -- dictionary
            for models_key, models_values in read_data.items():
                # Only want to look at "models" dictionary
                if models_key == "models":
                    # Need to access list of queries in .yml
                    for queries_index, queries_values in enumerate(models_values):
                        # Need to access top-level entries in each query (name/description/columns)
                        for query_key, query_values in queries_values.items():
                            if query_key == "name":
                                if query_values == modified_file_name:
                                    query_yml_file_location = ymlfile

    # Exit script if query not located in model_schema.yml files
    if query_yml_file_location is None:
        sys.exit(f"We couldn't find an entry for {modified_file_name} , the script will now exit :( \n")

    # Create list of columns
    dev_column_list = dbt_query_output["column_name"].tolist()
    yaml_index = None

    # This code only works if you have created a yaml entry, with a dummy column.
    # Need to identify if existing column+values needs to be removed
    # Need to identify if column needs to be added
    # Write new columns, and assume PII flagging based off of column names
    with open(query_yml_file_location, "r") as descriptions:
        yaml_entries = yaml.load(descriptions)

        # Need to identify index position of query yaml entry we care about
        for models_key, models_values in yaml_entries.items():
            if models_key == "models":
                # Need to access list of queries in .yml
                for queries_index, queries_values in enumerate(models_values):
                    # Need to find index of list so that we can access other keys inside list
                    for query_key, query_values in queries_values.items():
                        if query_key == "name" and query_values == modified_file_name:
                            yaml_index = queries_index

        ####### THIS ONLY WORKS FOR EXPLICIT DBT MODEL SCHEMA YAML ENTRIES!!!!

        # Pull all columns from current yaml file
        prod_yaml_columns = []
        for index, value in enumerate(yaml_entries["models"][yaml_index]["columns"]):
            prod_yaml_columns.append(yaml_entries["models"][yaml_index]["columns"][index]["name"])

        # Create list of all unique values from dev output + current yaml
        all_columns = []
        for i in dev_column_list + prod_yaml_columns:
            if i not in all_columns:
                all_columns.append(i)

        # Identify which columns need to be added or removed
        add_keep_remove_entries = {}
        for all_index, all_value in enumerate(all_columns):
            dev_appears = 0
            prod_appears = 0

            # Check if prod or dev appears in all entries, which lets us know whether to keep, add, or remove
            # We create a dictionary that will allow us to update the actual yaml entries
            # Performance here is probably pretty garbage, but alas, that is a problem for another day
            for dev_index, dev_value in enumerate(dev_column_list):
                if all_value == dev_value:
                    dev_appears += 1
            for prod_index, prod_value in enumerate(prod_yaml_columns):
                if all_value == prod_value:
                    prod_appears += 1

            # Add to dictionary, and decide if we keep, add, or remove an entry
            if dev_appears == 1 and prod_appears == 1:
                add_keep_remove_entries[all_value] = "keep"
            elif dev_appears == 1 and prod_appears == 0:
                add_keep_remove_entries[all_value] = "add"
            else:
                add_keep_remove_entries[all_value] = "remove"

        # Now we compare our keep/add/remove dictionary vs the dictionary of columns, and update the yaml

        prod_yaml_entries = yaml_entries["models"][yaml_index]["columns"]
        new_query_yaml = []

        for key, value in add_keep_remove_entries.items():
            # Going to check if the prod yaml entries appear in the total list. Then, we can decide if we remove the item in prod_yaml_entries, or add the new thing
            # Find index of yaml entry
            location = None
            for index, v in enumerate(prod_yaml_entries):
                for yk, yv in v.items():
                    if yk == "name" and key == yv:
                        location = index

            mini_dict = {}
            if value == "keep":
                mini_dict = prod_yaml_entries[location]
                new_query_yaml.append(mini_dict)
            elif value == "add":
                mini_dict["name"] = key
                if len(additional_default_fields) > 0:
                    for key, value in additional_default_fields.items():
                        mini_dict[key] = value
                new_query_yaml.append(mini_dict)

        # Insert column info back into all yaml dictionary
        yaml_entries["models"][yaml_index]["columns"] = new_query_yaml

    # Write output to yaml file
    with open(query_yml_file_location, "w") as descriptions:
        yaml.dump(yaml_entries, descriptions)

    print(f"Yay, we updated your yaml!! You can find the yaml file at \n {query_yml_file_location}\n")

if __name__ == "__main__":
    main()

2. Importing Staging Files Using Python

Background

Creating staging files can be a pain in the neck, when you have a lot to import, especially when you’re creating a new project, connecting to a new dataset, or exposing a bunch of new tables from an existing dataset. I know that there are some Jinja-based utils to streamline this, but it still means you have to manually create the files. We also run into the situation where you have a number of staging files already imported, but you need to add some new ones due to a new product feature release, but you don’t want to overwrite the existing staging files.

How it works

Note that this script is specific to our idiosyncratic DBT setup–we have our staging files set up in a particular way to deal with some data concerns regarding PII/PHI, due to our work with healthcare-related data.
Our specific design choices:

  1. Each underlying table is initially declared as a staging file
  2. Each field is explicitly declared
  3. We only perform minor name and datatype changes in staging files
  4. Each staging file follows a particular pattern: stg_datasourcefolder__table_name.sql ex. stg_salesforce__user.sql or stg_internal_api_replica__user.sql. You can feel free to modify the code to match your naming strategy.

Python Code

My design patterns/baked-in assumptions for this code:

  1. The user should be able to decide if they’re overwriting existing staging files, or only importing net-new ones
  2. It should ensure that BigQuery reserved keywords are modified.
  3. Any id fields are PKs, and should be renamed to table_name_id
  4. We only want to allow specific datatypes
  5. We have a recognizable pattern to how your staging folders are structured
  6. Source YAML files will have to be manually edited–dynamically editing these files was out of scope for this project, but future iterations will include concepts from the first part of this document for that.

This code assumes that your folder structure follows something similar to this:

  • models
    • staging
      • fivetran
        • fivetransource_1
          • stg_fivetransource_1__table1.sql
            stg_fivetransource_1__table2.sql
        • fivetransource_2
      • internal_api
        • internal_apisource_1
          • stg_internal_apisource_1__table1.sql
            stg_internal_apisource_1__table2.sql
        • internal_apisource_2

I know that this is perhaps not 100% the favored DBT structure, but this is what we have ended up with, due to some internal design and resource constraints.

I also had this code with user prompts, but that got a little complicated, so I have trimmed it down for this exercise–you have to manually update the variables.

Any variables specific to your use case are flagged as ### UPDATE ME!!! in the code

import os
import pandas
import pathlib
import sys

from google.cloud import bigquery
from google.api_core.exceptions import BadRequest
from google.api_core.exceptions import NotFound

########## USER DEFINED CUSTOM CONFIGS FOR USE IN SCRIPT ##########
### NOTES ###
# This script requires that dbt be installed in the top-level folder of your OS ex. ~/dbt, and that your profiles.yml lives in ~/.dbt/profiles.yml
# This script requires that your staging folders follow consistent structure, as well as your queries having consistent naming structure
# This script works for BigQuery connections--not sure how it works with other databases
### END NOTES ###

# Define your dataset name--current query works with Google BigQuery
client = bigquery.Client(location="US")
project_name = "your-project-name"  ### UPDATE ME!!!

dataset = "dataset_name" ### UPDATE ME!!!-- Name of actual BigQuery Dataset. Required.
table_prepend = None ### UPDATE ME!!! -- Naming pattern of tables in your dataset--do they all start with a common text pattern? If you don't have it, replace with None.
table_wildcard = None  ### UPDATE ME!!!-- Do the tables you want to import have a wildcard common text pattern, ex. you want all tables that match '%invoice_%'? If you don't have it, replace with None.
datasource = "internal_api"  ### UPDATE ME!!! -- Name of parent folder that you want staging files to appear in in DBT folder structure. Corresponds with say, Fivetran, your internal company data, etc. Required.
folder = "internal_apisource_1" ### UPDATE ME!!! -- Name of folder you want staging files to appear in in DBT folder structure. Corresponds with 'dataset' in BigQuery. Required.
# Flag if we only want to add non-existing staging files, or add new + overwrite existing staging files
import_type = "new"  ### UPDATE ME!!! -- Otherwise = "overwrite" Required.

########## END USER DEFINED CONFIGS ##########

# Some columns have BQ reserved keywords. We need to handle these when writing the queries
bq_reserved_keywords = ["all", "and", "any", "array", "as", "asc", "assert_rows_modified", "at", "between", "by",
                        "case", "cast", "collate", "contains", "create", "cross", "cube", "current", "default",
                        "define", "desc", "distinct", "else", "end", "enum", "escape", "except", "exclude", "exists",
                        "extract", "false", "fetch", "following", "for", "from", "full", "group", "grouping", "groups",
                        "hash", "having", "if", "ignore", "in", "inner", "intersect", "interval", "into", "is", "join",
                        "lateral", "left", "like", "limit", "lookup", "merge", "natural", "new", "no", "not", "null",
                        "nulls", "of", "on", "or", "order", "outer", "over", "partition", "preceding", "proto",
                        "qualify", "range", "recursive", "respect", "right", "rollup", "rows", "select", "set", "some",
                        "struct", "tablesample", "then", "to", "treat", "true", "unbounded", "union", "unnest", "using",
                        "when", "where", "window", "with", "within"]

# Standardize datatypes to only certain kinds, for consistency across the project
allowed_datatypes = ["INT64", "STRING", "TIMESTAMP", "DATE", "FLOAT"]

# Define paths for use across script
current_path = pathlib.Path().resolve()
home_path = pathlib.Path.home()
actual_path = os.path.relpath(home_path, current_path)

dbt_path = f"{home_path}/dbt/"
dbt_stg_locations_path = f"{home_path}/dbt/models/"

# Continue constant for user input loops
contyn = False

# Function that returns list of tables that we will create staging files for
def query_data_pull(project_name, dataset, table_prepend, table_wildcard):
    # table type
    if dataset is None:
        print(f"You did not enter a dataset, exiting script...\n")
        sys.exit()

    if table_prepend is None:
        table_cleanup = ""
    else:
        table_cleanup = f"{table_prepend}_"

    if table_prepend is None:
        table_prepend_filter = "AND 1 = 1"
    else:
        table_prepend_filter = f"AND table_name LIKE '{table_prepend}%'"

    if table_wildcard is None:
        table_wildcard_filter = "AND 1 = 1"
    else:
        table_wildcard_filter = f"AND table_name LIKE '%{table_wildcard}%'"

    query = f"""
        SELECT
        table_name
        , REPLACE(table_name, '{table_cleanup}', '') AS condensed_table_name
        , column_name
        , data_type

        FROM `{project_name}.{dataset}.INFORMATION_SCHEMA.COLUMNS`


        WHERE column_name NOT IN ('_airbyte_ab_id', '_airbyte_normalized_at')
        AND column_name NOT LIKE '_airbyte%_hashid'
        AND table_name NOT LIKE '%__dbt_tmp'
        AND table_name NOT LIKE '_airbyte_%'
        {table_prepend_filter}
        {table_wildcard_filter}

        ORDER BY 1, LENGTH(column_name)
    """
    query_job = client.query(query)

    # Run the query, exit if we have a non-functional query
    try:
        query_job.result()
    except BadRequest:
        for e in query_job.errors:
            print(f"QUERY ERROR: {e['message']}\n")
        print(f"The query used was:\n\n{query}\n")
        sys.exit(f"Beep boop query failed, exiting script...\n")

    print(f"Query complete")

    # Load query into dataframe
    tables = query_job.to_dataframe()
    print(f"The query used was:\n\n{query}\n")
    print(f"Tables loaded into dataframe\n")

    if tables.empty:
        sys.exit("Your query returned no result, exiting script\n")

    return tables

def main():

    # Get list of existing queries in folder
    existing_staging_queries_dict = {}
    existing_staging_tables_list = []
    entrycount = 0
    for (dirpath, dirnames, filenames) in os.walk(dbt_stg_locations_path):
        for file in filenames:
            if file.startswith(f"stg_{folder}") and file.endswith(".sql") and ".ipynb_checkpoints" not in dirpath:
                if entrycount == 0:
                    # Directory path to get us to staging folder
                    existing_staging_folder_path = dirpath
                # Dictionary of staging file names and their path
                existing_staging_queries_dict[file.replace(".sql", "")] = os.path.join(dirpath, file)
                # List of current tables used in staging queries
                existing_staging_tables_list.append(file.replace(".sql", "").replace(f"stg_{folder}__", ""))

    # Create source yml output
    source_yml_path = f"{existing_staging_folder_path}/src__{folder}_temp_source_yaml.yml"
    actual_path_source_yml = os.path.relpath(source_yml_path, current_path)

    # Get query output
    table_column_df = query_data_pull(project_name, dataset, table_prepend, table_wildcard)

    # Need list of tables that we might want to create tables for
    unique_tables = table_column_df["condensed_table_name"].sort_values().unique().tolist()

    # Create list of tables to iterate over--decide if we're going to do only net-new tables, or existing+net-new
    if import_type == "new":
        create_table_list = list(set(unique_tables).difference(existing_staging_tables_list))
    else:
        create_table_list = unique_tables

    table_count = len(create_table_list)

    if len(create_table_list) == 0:
        sys.exit(f"No tables to create, exiting script\n")

    create_table_list = sorted(create_table_list)

    # Write source yaml to a temporary file--we don't feel like navigating yaml at the moment
    # To be replaced with yaml writing code from model schema yaml code for dynamic updates
    with open(actual_path_source_yml, "w+") as source_f:
        output = f"version: 2\n\nsources:\n  - name: {dataset}\n    tables:"
        source_f.write(output)
        for index, table in enumerate(create_table_list):
            if table_prepend is None:
                output = f"\n      - name: {table}"
            else:
                output = f"\n      - name: {table_prepend}_{table}"
            source_f.write(output)

    print(f"Created yaml source declaration file--go to this file and copy-paste the source declarations to your official source.yml file.\n You can find this at: \n {actual_path_source_yml}\n")

    # Decide if we're over-writing existing staging files, or only adding in new non-existent ones
    # For table names in tables, create a SQL query and print to output file that matches table name
    # Depending on the prior logic, we will either replace existing+add new queries, or just add new
    for index, table in enumerate(create_table_list):

        # Pull in column name rows for the table we are creating a file of
        table_index = index + 1
        truncated_dataframe = table_column_df.loc[table_column_df["condensed_table_name"] == table]
        truncated_dataframe.reset_index(drop=True, inplace=True)

        staging_file_path = f"{dbt_stg_locations_path}/staging/{datasource}/{folder}/stg_{folder}__{table}.sql"
        actual_path_stg_file = os.path.relpath(staging_file_path, current_path)

        with open(actual_path_stg_file, "w+") as f:
            # Start SQL SELECT command
            output = "SELECT\n"
            f.write(output)

            # Generating field calls in main part of SELECT statement
            for row in truncated_dataframe.itertuples(index=True):

                # Identify if we need to have a comma in front of the field call
                commas = ", "
                if row[0] == 0:
                    commas = ""

                # Pulling out column values so things are a little more human-readable
                full_table_name = row[1]
                condensed_table_name = row[2]
                column_name = row[3]
                data_type = row[4]

                # We get fields coming back that match BQ reserved keywords or is just called 'id'--need to handle these
                if column_name in bq_reserved_keywords or column_name == "id":
                    if data_type == "DATETIME":
                        output = f"{commas}TIMESTAMP({column_name}, \"UTC\") AS {condensed_table_name}_{column_name}"
                    else:
                        output = f"{commas}{column_name} AS {condensed_table_name}_{column_name}"
                elif data_type == "DATETIME":
                    output = f"{commas}TIMESTAMP({column_name}, \"UTC\") AS {column_name}"
                else:
                    output = f"{commas}{column_name}"

                f.write(f"{output}\n")

            # Once we write all field calls, now we need to write the table call. Use a bunch of curly braces cause things are silly
            output = f"\nFROM {{{{ source('{dataset}', '{full_table_name}') }}}}"

            # Write the query
            f.write(output)

        percent_done = table_index / table_count * 100
        print(f"Processed {percent_done:.1f}% of tables")

    # Exit script after finishing
    sys.exit(f"\nFinished updating staging files! You can find the new source YAML entries at\n {actual_path_source_yml}")

if __name__ == "__main__":
    main()