Transformations for all schemas with a given column using DataSetImplements

Let’s illustrate this with an example! First, we’ll define some data.

[1]:
from pyspark.sql import SparkSession

spark = SparkSession.Builder().config("spark.ui.showConsoleProgress", "false").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
[2]:
from pyspark.sql.types import LongType, StringType
from typedspark import (
    Schema,
    Column,
    create_empty_dataset,
)


class Person(Schema):
    name: Column[StringType]
    age: Column[LongType]
    job: Column[StringType]


class Pet(Schema):
    name: Column[StringType]
    age: Column[LongType]
    type: Column[StringType]


class Fruit(Schema):
    type: Column[StringType]


person = create_empty_dataset(spark, Person)
pet = create_empty_dataset(spark, Pet)
fruit = create_empty_dataset(spark, Fruit)

Now, suppose we want to define a function birthday() that works on all schemas that contain the column age. With DataSet, we’d have to specifically indicate which schemas contain the age column. We could do this with for example:

[3]:
from typing import TypeVar, Union

from typedspark import DataSet, transform_to_schema

T = TypeVar("T", bound=Union[Person, Pet])


def birthday(df: DataSet[T]) -> DataSet[T]:
    return transform_to_schema(
        df,
        df.typedspark_schema,
        {Person.age: Person.age + 1},
    )

This can get tedious if the list of schemas with the column age changes, for example because new schemas are added, or because the age column is removed from a schema! It’s also not great that we’re using Person.age here to define the age column…

Fortunately, we can do better! Consider the following example:

[4]:
from typing import Protocol

from typedspark import DataSetImplements


class Age(Schema, Protocol):
    age: Column[LongType]


T = TypeVar("T", bound=Schema)


def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]:
    return transform_to_schema(
        df,
        df.typedspark_schema,
        {Age.age: Age.age + 1},
    )

Here, we define Age to be both a Schema and a Protocol (PEP-0544).

We then define birthday() to:

  1. Take as an input DataSetImplements[Age, T]: a DataSet that implements the protocol Age as T.

  2. Return a DataSet[T]: a DataSet of the same type as the one that was provided.

Let’s see this in action!

[5]:
# returns a DataSet[Person]
happy_person = birthday(person)

# returns a DataSet[Pet]
happy_pet = birthday(pet)

try:
    # Raises a linting error:
    # Argument of type "DataSet[Fruit]" cannot be assigned to
    # parameter "df" of type "DataSetImplements[Age, T@birthday]"
    birthday(fruit)
except Exception as e:
    pass