{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "a6b265df", "metadata": {}, "source": [ "# Type checking\n", "## Runtime\n", "\n", "Every time a `DataSet` is initialized, it checks whether the schema matches the data." ] }, { "cell_type": "code", "execution_count": 1, "id": "24a21def", "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "\n", "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n", "spark.sparkContext.setLogLevel(\"ERROR\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "fd37a1d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---+----+---+\n", "| id|name|age|\n", "+---+----+---+\n", "| 1|John| 20|\n", "| 2|Jane| 30|\n", "| 3|Jack| 40|\n", "+---+----+---+\n", "\n" ] } ], "source": [ "import pandas as pd\n", "from typedspark import Column, DataSet, Schema\n", "from pyspark.sql.types import LongType, StringType\n", "\n", "\n", "class Person(Schema):\n", " id: Column[LongType]\n", " name: Column[StringType]\n", " age: Column[LongType]\n", "\n", "\n", "df = spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " age=[20, 30, 40],\n", " )\n", " )\n", ")\n", "# no errors raised\n", "df = DataSet[Person](df)\n", "df.show()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c27a99fe", "metadata": {}, "source": [ "As a convention, we ignore any columns that start with `__` during this check." ] }, { "cell_type": "code", "execution_count": 3, "id": "b99e82f8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---+----+---+--------------+\n", "| id|name|age|__extra_column|\n", "+---+----+---+--------------+\n", "| 1|John| 20| 1|\n", "| 2|Jane| 30| 2|\n", "| 3|Jack| 40| 3|\n", "+---+----+---+--------------+\n", "\n" ] } ], "source": [ "df = spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " age=[20, 30, 40],\n", " __extra_column=[1, 2, 3],\n", " )\n", " )\n", ")\n", "# no errors raised because __extra_column is ignored during the check\n", "df = DataSet[Person](df)\n", "df.show()" ] }, { "cell_type": "code", "execution_count": 4, "id": "baa96cfa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Schema Person contains the following columns not present in data: {'age'}\n" ] } ], "source": [ "df = spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " )\n", " )\n", ")\n", "try:\n", " DataSet[Person](df)\n", "except TypeError as e:\n", " print(e)" ] }, { "cell_type": "code", "execution_count": 5, "id": "71bfc3b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data contains the following columns not present in schema Person: {'gender'}.\n", "\n", "If you believe these columns should be part of the schema, consider adding the following lines to it.\n", "\n", "class Person(Schema):\n", " gender: Column[StringType]\n", "\n" ] } ], "source": [ "df = spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " age=[20, 30, 40],\n", " gender=[\"male\", \"female\", \"male\"],\n", " )\n", " )\n", ")\n", "try:\n", " DataSet[Person](df)\n", "except TypeError as e:\n", " print(e)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "357c3f96", "metadata": {}, "source": [ "Assuming your code is run regularly (e.g. through unit tests, scheduled pipelines, etc.), this means you can safely assume a `DataSet[Person]` object that you come across on the master branch indeed follows the indicated schema.\n", "\n", "## Linting\n", "Additionally, during coding, we can use linting (e.g. mypy, pyright) to check the schemas. For instance:" ] }, { "cell_type": "code", "execution_count": 6, "id": "7e34671a", "metadata": {}, "outputs": [], "source": [ "class Person(Schema):\n", " id: Column[LongType]\n", " name: Column[StringType]\n", " age: Column[LongType]\n", "\n", "\n", "class Address(Schema):\n", " street: Column[StringType]\n", " number: Column[LongType]\n", "\n", "\n", "def birthday(df: DataSet[Person]) -> DataSet[Person]:\n", " return DataSet[Person](\n", " df.withColumn(Person.age.str, Person.age + 1),\n", " )\n", "\n", "\n", "df_1 = DataSet[Person](\n", " spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " age=[20, 30, 40],\n", " )\n", " )\n", " )\n", ")\n", "# no linting error\n", "birthday(df_1)\n", "\n", "df_2 = DataSet[Address](\n", " spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " street=[\"Lynton Walk\", \"Canada Square\", \"Chapelside Avenue\"],\n", " number=[1, 2, 3],\n", " )\n", " )\n", " )\n", ")\n", "try:\n", " # linting error: expected DataSet[Person], observed DataSet[Address]\n", " birthday(df_2)\n", "except:\n", " pass\n", "\n", "df_3 = spark.createDataFrame(\n", " pd.DataFrame(\n", " dict(\n", " id=[1, 2, 3],\n", " name=[\"John\", \"Jane\", \"Jack\"],\n", " age=[20, 30, 40],\n", " )\n", " )\n", ")\n", "try:\n", " # linting error: expected DataSet[Person], observed DataFrame\n", " birthday(df_3)\n", "except:\n", " pass" ] }, { "attachments": {}, "cell_type": "markdown", "id": "30738286", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "typedspark", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }