Type checking

Runtime

Every time a DataSet is initialized, it checks whether the schema matches the data.

[1]:
from pyspark.sql import SparkSession

spark = SparkSession.Builder().config("spark.ui.showConsoleProgress", "false").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
[2]:
import pandas as pd
from typedspark import Column, DataSet, Schema
from pyspark.sql.types import LongType, StringType


class Person(Schema):
    id: Column[LongType]
    name: Column[StringType]
    age: Column[LongType]


df = spark.createDataFrame(
    pd.DataFrame(
        dict(
            id=[1, 2, 3],
            name=["John", "Jane", "Jack"],
            age=[20, 30, 40],
        )
    )
)
# no errors raised
df = DataSet[Person](df)
df.show()
+---+----+---+
| id|name|age|
+---+----+---+
|  1|John| 20|
|  2|Jane| 30|
|  3|Jack| 40|
+---+----+---+

As a convention, we ignore any columns that start with __ during this check.

[3]:
df = spark.createDataFrame(
    pd.DataFrame(
        dict(
            id=[1, 2, 3],
            name=["John", "Jane", "Jack"],
            age=[20, 30, 40],
            __extra_column=[1, 2, 3],
        )
    )
)
# no errors raised because __extra_column is ignored during the check
df = DataSet[Person](df)
df.show()
+---+----+---+--------------+
| id|name|age|__extra_column|
+---+----+---+--------------+
|  1|John| 20|             1|
|  2|Jane| 30|             2|
|  3|Jack| 40|             3|
+---+----+---+--------------+

[4]:
df = spark.createDataFrame(
    pd.DataFrame(
        dict(
            id=[1, 2, 3],
            name=["John", "Jane", "Jack"],
        )
    )
)
try:
    DataSet[Person](df)
except TypeError as e:
    print(e)
Schema Person contains the following columns not present in data: {'age'}
[5]:
df = spark.createDataFrame(
    pd.DataFrame(
        dict(
            id=[1, 2, 3],
            name=["John", "Jane", "Jack"],
            age=[20, 30, 40],
            gender=["male", "female", "male"],
        )
    )
)
try:
    DataSet[Person](df)
except TypeError as e:
    print(e)
Data contains the following columns not present in schema Person: {'gender'}.

If you believe these columns should be part of the schema, consider adding the following lines to it.

class Person(Schema):
    gender: Column[StringType]

Assuming your code is run regularly (e.g. through unit tests, scheduled pipelines, etc.), this means you can safely assume a DataSet[Person] object that you come across on the master branch indeed follows the indicated schema.

Linting

Additionally, during coding, we can use linting (e.g. mypy, pyright) to check the schemas. For instance:

[6]:
class Person(Schema):
    id: Column[LongType]
    name: Column[StringType]
    age: Column[LongType]


class Address(Schema):
    street: Column[StringType]
    number: Column[LongType]


def birthday(df: DataSet[Person]) -> DataSet[Person]:
    return DataSet[Person](
        df.withColumn(Person.age.str, Person.age + 1),
    )


df_1 = DataSet[Person](
    spark.createDataFrame(
        pd.DataFrame(
            dict(
                id=[1, 2, 3],
                name=["John", "Jane", "Jack"],
                age=[20, 30, 40],
            )
        )
    )
)
# no linting error
birthday(df_1)

df_2 = DataSet[Address](
    spark.createDataFrame(
        pd.DataFrame(
            dict(
                street=["Lynton Walk", "Canada Square", "Chapelside Avenue"],
                number=[1, 2, 3],
            )
        )
    )
)
try:
    # linting error: expected DataSet[Person], observed DataSet[Address]
    birthday(df_2)
except:
    pass

df_3 = spark.createDataFrame(
    pd.DataFrame(
        dict(
            id=[1, 2, 3],
            name=["John", "Jane", "Jack"],
            age=[20, 30, 40],
        )
    )
)
try:
    # linting error: expected DataSet[Person], observed DataFrame
    birthday(df_3)
except:
    pass