mgrdvdyn - Hari-dta/hari GitHub Wiki

`import pyspark import json import argparse from pyspark.sql import * from pyspark.sql.functions import * from pyspark.sql.types import * from pyspark.sql.utils import * import os from datetime import datetime import logging import sys import subprocess import re from bson import * os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.8" os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3.8" logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') parser = argparse.ArgumentParser() parser.add_argument("--uri", help="MongoDB connection URI", required=True) parser.add_argument("--db", help="Name of the MongoDB database", required=True) parser.add_argument("--collection", help="Name of the MongoDB collection", required=True) parser.add_argument("--stgname", help="Name of the staging table", required=True) parser.add_argument("--targetname", help="Name of the target table", required=True) parser.add_argument("--audittable", help="Audit table name", required=True) parser.add_argument("--loadstatustable", help="Load status table name", required=True) parser.add_argument("--writetype", help="append or overwrite?", required=True)

args = parser.parse_args() spark = SparkSession.builder.master("yarn")
.appName("HM01_MongoDB_to_RDV_pyspark")
.enableHiveSupport()
.master("local[*]")
.config("spark.mongodb.read.connection.uri", args.uri)
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.enableHiveSupport()
.getOrCreate() database_name = args.db collection_name = args.collection start_time = datetime.now() starting_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') spark_app_id = spark.sparkContext.applicationId logging.info(f"Starting Spark Application ID: {spark_app_id}") target_table = args.targetname spark.sql(f""" INSERT INTO {args.loadstatustable} (spark_app_id, database_name , collection_name, status, starttime, endtime, errormessage) VALUES ('{spark_app_id}', '{database_name}', '{collection_name}', 'STARTED', '{starting_time}', NULL, NULL) """)

def sanitize_names_with_uniqueness(schema): unknown_counter = [0] top_level_names = {} sanitized_top_names = set()

sanitized_fields = [] def make_unique(name, existing_names): base = name counter = 1 while name in existing_names: name = f"{base}{counter}" counter += 1 existing_names.add(name) return name for field in schema.fields: raw_name = field.name base_name = re.sub(r'\W+', '', raw_name).strip('').lower() if not base_name: base_name = f"col_unknown{unknown_counter[0]}" unknown_counter[0] += 1 elif base_name[0].isdigit(): base_name = f"col_{base_name}" unique_name = make_unique(base_name, sanitized_top_names) top_level_names[raw_name] = unique_name
sanitized_fields.append( StructField(unique_name, field.dataType, field.nullable) )

def recursive_sanitize(field, parent_path=None, nested_tracker=None): if nested_tracker is None: nested_tracker = {} raw_name = field.name base_name = re.sub(r'\W+', '', raw_name).strip('').lower() if not base_name: base_name = f"col_unknown_{unknown_counter[0]}" unknown_counter[0] += 1 elif base_name[0].isdigit(): base_name = f"col_{base_name}" prefix = ''.join(parent_path) if parent_path else '' full_name = f"{prefix}{base_name}" if prefix else base_name

   full_name = make_unique(full_name, sanitized_top_names)
   dtype = field.dataType
   if isinstance(dtype, StructType):
       new_fields = [
           recursive_sanitize(f, (parent_path or []) + [base_name], nested_tracker)
           for f in dtype.fields
       ]
       return StructField(full_name, StructType(new_fields), field.nullable)
   elif isinstance(dtype, ArrayType):
       element_field = recursive_sanitize(
           StructField("element", dtype.elementType),
           (parent_path or []) + [base_name],
           nested_tracker
       )
       return StructField(full_name, ArrayType(element_field.dataType), field.nullable)
   elif isinstance(dtype, MapType):
       value_field = recursive_sanitize(
           StructField("value", dtype.valueType),
           (parent_path or []) + [base_name],
           nested_tracker
       )
       return StructField(full_name, MapType(StringType(), value_field.dataType), field.nullable)
   else:
       return StructField(full_name, dtype, field.nullable)

final_fields = [] for field in sanitized_fields: dtype = field.dataType if isinstance(dtype, StructType): new_field = recursive_sanitize(field) final_fields.append(new_field) elif isinstance(dtype, ArrayType) and isinstance(dtype.elementType, StructType): new_field = recursive_sanitize(field) final_fields.append(new_field) else: final_fields.append(field) return StructType(final_fields)

def build_explicit_schema(spark_schema): def sanitize_and_copy(field): name = field.name dtype = field.dataType nullable = field.nullable if isinstance(dtype, StringType): return StructField(name, StringType(), nullable) elif isinstance(dtype, IntegerType): return StructField(name, IntegerType(), nullable) elif isinstance(dtype, BooleanType): return StructField(name, BooleanType(), nullable) elif isinstance(dtype, DoubleType): return StructField(name, DoubleType(), nullable) elif isinstance(dtype, DecimalType): return StructField(name, DecimalType(30, 10), nullable) elif isinstance(dtype, TimestampType): return StructField(name, TimestampType(), nullable) elif isinstance(dtype, DateType): return StructField(name, DateType(), nullable) elif isinstance(dtype, LongType): return StructField(name, LongType(), nullable) elif isinstance(dtype, ArrayType): return StructField(name, ArrayType(build_explicit_schema(StructType([StructField("element", dtype.elementType)]))[0].dataType), nullable) elif isinstance(dtype, MapType): return StructField(name, MapType(StringType(), StringType()), nullable) elif isinstance(dtype, StructType): return StructField(name, build_explicit_schema(dtype), nullable) else: return StructField(name, StringType(), nullable) return StructType([sanitize_and_copy(field) for field in spark_schema.fields]) def flatten_dataframe(df): def get_array_fields_and_lengths(df): array_fields = [] for field in df.schema.fields: if isinstance(field.dataType, ArrayType): max_len = df.select(size(col(field.name)).alias("len")).agg(max("len")).collect()[0][0] array_fields.append((field.name, max_len or 0)) return array_fields def flatten_once(input_df, array_info): fields = [] for field in input_df.schema.fields: dtype = field.dataType name = field.name if isinstance(dtype, StructType): for subfield in dtype.fields: fields.append(col(f"{name}.{subfield.name}").alias(f"{name}{subfield.name}")) elif isinstance(dtype, ArrayType): max_len = dict(array_info).get(name, 0) for i in range(max_len): fields.append(col(name).getItem(i).alias(f"{name}{i}")) else: fields.append(col(name)) return input_df.select(fields) prev_cols = [] while True: array_info = get_array_fields_and_lengths(df) df = flatten_once(df, array_info) struct_cols_exist = any(isinstance(df.schema[field].dataType, StructType) for field in df.columns) if df.columns == prev_cols and not struct_cols_exist: break prev_cols = df.columns return df

staging_table_name = args.stgname audit_table_name = args.audittable write_mode = args.writetype.lower() load_status = "SUCCESS" error_message = None source_count = 0 target_count = 0 try: raw_df = spark.read.format("mongodb")
.option("database", database_name)
.option("collection", collection_name)
.load()

if raw_df.rdd.isEmpty():
    raise ValueError("MongoDB collection is empty. Nothing to process.")

source_count = raw_df.count()
logging.info(f"MongoDB document count: {source_count}")
sanitized_schema = sanitize_names_with_uniqueness(raw_df.schema)

if len(sanitized_schema.fields) == 0:
    raise ValueError("Sanitized schema has no valid fields. Possibly all column names were invalid.")

explicit_schema = build_explicit_schema(sanitized_schema)
sanitized_df = spark.createDataFrame(raw_df.rdd, schema=explicit_schema)
sanitized_df.write.mode("overwrite").format("parquet").saveAsTable(staging_table_name)

staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]
logging.info(f"Hive staging count: {staging_table_count}") 

if source_count != staging_table_count:
    raise ValueError(f"Row count mismatch: MongoDB={source_count}, Hive={staging_table_count}")

nested_df = spark.sql(f"SELECT * FROM {staging_table_name}")

flattened_df = flatten_dataframe(nested_df)

try:
    spark.catalog.refreshTable(target_table)
    target_df = spark.table(target_table)
    target_schema = target_df.schema
    target_columns = target_df.columns
    
    for field in target_schema:
        if field.name not in flattened_df.columns:
            flattened_df = flattened_df.withColumn(field.name, lit(None).cast(field.dataType))
    
    write_df = flattened_df.select(*target_columns)
    write_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
    logging.info("Overwrote target table with matching columns only.")
    target_table_count = write_df.count()
    #staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]
except AnalysisException:
    flattened_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
    logging.info("Target table did not exist. Created it with all columns.")
    target_table_count = flattened_df.count()
    #staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]

except Exception as e: load_status = "FAILURE" error_message = str(e) logging.error("Error: %s", error_message) ending_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

spark.sql(f""" INSERT INTO {args.loadstatustable} (spark_app_id, database_name, collection_name, status, starttime, endtime, errormessage) VALUES ( '{spark_app_id}', '{database_name}', '{collection_name}', '{load_status}', '{starting_time}', '{ending_time}', {f"'{error_message}'" if error_message else "NULL"}) """)

end_time = datetime.now() staging_table_count = spark.sql(f"SELECT COUNT() FROM {staging_table_name}").first()[0] load_duration = str(end_time - start_time) target_table_count = spark.sql(f"SELECT COUNT() FROM {target_table}").first()[0]

audit_schema = StructType([ StructField("spark_app_id", StringType(), True), StructField("database_name", StringType(), True), StructField("collection_name", StringType(), True), StructField("start_time", StringType(), True), StructField("end_time", StringType(), True), StructField("load_status", StringType(), True), StructField("load_duration", StringType(), True), StructField("error_message", StringType(), True), StructField("source_count", IntegerType(), True), StructField("staging_table_count", IntegerType(), True), StructField("target_table_count", IntegerType(), True) ]) audit_df = spark.createDataFrame([( spark_app_id, database_name, collection_name, start_time.strftime('%Y-%m-%d %H:%M:%S'), end_time.strftime('%Y-%m-%d %H:%M:%S'), load_status, load_duration, error_message, source_count, staging_table_count, target_table_count )], schema=audit_schema) audit_df.write.mode("append").format("parquet").saveAsTable(audit_table_name) load_status_table = args.loadstatustable impala_stg_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {staging_table_name};'""" impala_target_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {target_table};'""" impala_audit_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {audit_table_name};'""" impala_loadstatus_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {load_status_table};'"""

try: subprocess.run(impala_stg_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {staging_table_name}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (staging): {e.stderr}") try: subprocess.run(impala_target_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {target_table}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (target): {e.stderr}") try: subprocess.run(impala_loadstatus_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {load_status_table}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (audit): {e.stderr}") try: subprocess.run(impala_audit_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {audit_table_name}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (audit): {e.stderr}")

spark.stop()`