Adding Unit Tests to our Project - pathfinder-analytics-uk/dab_project GitHub Wiki

Project Code

tests/test_citibike_utils.py

# test_citibike_utils.py
import sys
import os

# Run the tests from the root directory
sys.path.append(os.getcwd())

import datetime
from src.citibike.citibike_utils import get_trip_duration_mins
from pyspark.sql import SparkSession

# Adjust the sys.path if needed (usually in conftest.py or at the top of your test files)

def test_get_trip_duration_mins():
    # Create a SparkSession
    spark = SparkSession.builder.getOrCreate()   
     
    # Create a test DataFrame with known start and end timestamps using datetime objects
    data = [
        (datetime.datetime(2025, 4, 10, 10, 0, 0), datetime.datetime(2025, 4, 10, 10, 10, 0)),  # 10 minutes
        (datetime.datetime(2025, 4, 10, 10, 0, 0), datetime.datetime(2025, 4, 10, 10, 30, 0))   # 30 minutes
    ]
    schema = "start_timestamp timestamp, end_timestamp timestamp"
    df = spark.createDataFrame(data, schema=schema)
    
    # Apply the function to calculate trip duration in minutes
    result_df = get_trip_duration_mins(spark, df, "start_timestamp", "end_timestamp", "trip_duration_mins")
    
    # Collect the results for assertions
    results = result_df.select("trip_duration_mins").collect()
    
    # Assert that the differences are as expected
    assert results[0]["trip_duration_mins"] == 10
    assert results[1]["trip_duration_mins"] == 30

tests/test_datetime_utils.py

# test_datetime_utils.py
import sys
import os

# Run the tests from the root directory
sys.path.append(os.getcwd())

import datetime
from src.utils.datetime_utils import timestamp_to_date_col
from pyspark.sql import SparkSession

def test_timestamp_to_date_col():
    # Create a SparkSession
    spark = SparkSession.builder.getOrCreate()
            
    # Create a DataFrame with a known timestamp column using a datetime object
    data = [(datetime.datetime(2025, 4, 10, 10, 30, 0),)]
    schema = "ride_timestamp timestamp"
    df = spark.createDataFrame(data, schema=schema)
    
    # Use the utility to add a date column
    result_df = timestamp_to_date_col(spark, df, "ride_timestamp", "ride_date")
    
    # Assert that the extracted date matches the expected value
    row = result_df.select("ride_date").first()

    expected_date = datetime.date(2025, 4, 10)  # Expected: 2025-04-10

    assert row["ride_date"] == expected_date