Project Code
tests/test_citibike_utils.py
# test_citibike_utils.py
import sys
import os
# Run the tests from the root directory
sys.path.append(os.getcwd())
import datetime
from src.citibike.citibike_utils import get_trip_duration_mins
from pyspark.sql import SparkSession
# Adjust the sys.path if needed (usually in conftest.py or at the top of your test files)
def test_get_trip_duration_mins():
# Create a SparkSession
spark = SparkSession.builder.getOrCreate()
# Create a test DataFrame with known start and end timestamps using datetime objects
data = [
(datetime.datetime(2025, 4, 10, 10, 0, 0), datetime.datetime(2025, 4, 10, 10, 10, 0)), # 10 minutes
(datetime.datetime(2025, 4, 10, 10, 0, 0), datetime.datetime(2025, 4, 10, 10, 30, 0)) # 30 minutes
]
schema = "start_timestamp timestamp, end_timestamp timestamp"
df = spark.createDataFrame(data, schema=schema)
# Apply the function to calculate trip duration in minutes
result_df = get_trip_duration_mins(spark, df, "start_timestamp", "end_timestamp", "trip_duration_mins")
# Collect the results for assertions
results = result_df.select("trip_duration_mins").collect()
# Assert that the differences are as expected
assert results[0]["trip_duration_mins"] == 10
assert results[1]["trip_duration_mins"] == 30
tests/test_datetime_utils.py
# test_datetime_utils.py
import sys
import os
# Run the tests from the root directory
sys.path.append(os.getcwd())
import datetime
from src.utils.datetime_utils import timestamp_to_date_col
from pyspark.sql import SparkSession
def test_timestamp_to_date_col():
# Create a SparkSession
spark = SparkSession.builder.getOrCreate()
# Create a DataFrame with a known timestamp column using a datetime object
data = [(datetime.datetime(2025, 4, 10, 10, 30, 0),)]
schema = "ride_timestamp timestamp"
df = spark.createDataFrame(data, schema=schema)
# Use the utility to add a date column
result_df = timestamp_to_date_col(spark, df, "ride_timestamp", "ride_date")
# Assert that the extracted date matches the expected value
row = result_df.select("ride_date").first()
expected_date = datetime.date(2025, 4, 10) # Expected: 2025-04-10
assert row["ride_date"] == expected_date