v2 - Hari-dta/hari GitHub Wiki

import argparse, json, os, re, sys, subprocess, logging
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.8"
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3.8"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--uri", required=True)
parser.add_argument("--db", required=True)
parser.add_argument("--collection", required=True)
parser.add_argument("--stgname", required=True)
parser.add_argument("--targetname", required=True)
parser.add_argument("--audittable", required=True)
parser.add_argument("--loadstatustable", required=True)
parser.add_argument("--writetype", required=True)
parser.add_argument("--jobid", required=True)
args = parser.parse_args()
spark = SparkSession.builder \
   .appName("HM01_MongoDB_to_RDV_pyspark") \
   .master("local[*]") \
   .enableHiveSupport() \
   .config("spark.mongodb.read.connection.uri", args.uri) \
   .config("hive.exec.dynamic.partition", "true") \
   .config("hive.exec.dynamic.partition.mode", "nonstrict") \
   .getOrCreate()
database_name, collection_name = args.db, args.collection
staging_table, target_table = args.stgname, args.targetname
audit_table, load_status_table, write_mode = args.audittable, args.loadstatustable, args.writetype.lower()
start_time = datetime.now()
spark_app_id = spark.sparkContext.applicationId
job_id = args.jobid.lower()
def sanitize_name(name):
   name = re.sub(r'\W+', '_', name).lower()
   if name[0].isdigit(): name = 'col_' + name
   return name
def sanitize_column_names_simple(schema):
   seen = set()
   new_fields = []
   for field in schema.fields:
       base = sanitize_name(field.name)
       original = base
       i = 1
       while base in seen:
           base = f"{original}_{i}"
           i += 1
       seen.add(base)
       if isinstance(field.dataType, StructType):
           new_fields.append(StructField(base, sanitize_column_names_simple(field.dataType), field.nullable))
       elif isinstance(field.dataType, ArrayType) and isinstance(field.dataType.elementType, StructType):
           new_fields.append(StructField(base, ArrayType(sanitize_column_names_simple(field.dataType.elementType)), field.nullable))
       else:
           new_fields.append(StructField(base, field.dataType, field.nullable))
   return StructType(new_fields)
def build_explicit_schema(schema):
   def cast_type(t):
       if isinstance(t, StringType): return StringType()
       elif isinstance(t, IntegerType): return IntegerType()
       elif isinstance(t, LongType): return LongType()
       elif isinstance(t, DoubleType): return DoubleType()
       elif isinstance(t, DecimalType): return DecimalType(30, 10)
       elif isinstance(t, BooleanType): return BooleanType()
       elif isinstance(t, TimestampType): return TimestampType()
       elif isinstance(t, DateType): return DateType()
       elif isinstance(t, StructType): return StructType([StructField(f.name, cast_type(f.dataType), f.nullable) for f in t.fields])
       elif isinstance(t, ArrayType): return ArrayType(cast_type(t.elementType))
       return StringType()
   return StructType([StructField(f.name, cast_type(f.dataType), f.nullable) for f in schema.fields])
def flatten_dataframe(df):
   def get_array_lengths(df):
       return [(f.name, df.select(size(col(f.name)).alias("len")).agg(max("len")).collect()[0][0] or 0)
               for f in df.schema.fields if isinstance(f.dataType, ArrayType)]
   def flatten_once(df, array_info):
       fields = []
       for f in df.schema.fields:
           if isinstance(f.dataType, StructType):
               for sub in f.dataType.fields:
                   fields.append(col(f"{f.name}.{sub.name}").alias(f"{f.name}_{sub.name}"))
           elif isinstance(f.dataType, ArrayType):
               max_len = dict(array_info).get(f.name, 0)
               for i in range(max_len):
                   fields.append(col(f.name).getItem(i).alias(f"{f.name}_{i}"))
           else:
               fields.append(col(f.name))
       return df.select(fields)
   prev_cols = []
   while True:
       array_info = get_array_lengths(df)
       df = flatten_once(df, array_info)
       if df.columns == prev_cols: break
       prev_cols = df.columns
   return df
def extract_and_write_child_tables(df, base_table):
   children = []
   for field in df.schema.fields:
       if isinstance(field.dataType, StructType):
           flat = df.select(col("_id"), *[col(f"{field.name}.{f.name}").alias(f"{field.name}_{f.name}") for f in field.dataType.fields])
           name = f"{base_table}_{field.name}"
           flat.write.mode(write_mode).format("parquet").saveAsTable(name)
           children.append(name)
       elif isinstance(field.dataType, ArrayType) and isinstance(field.dataType.elementType, StructType):
           flat = df.select(col("_id"), col(field.name))
           flat = flatten_dataframe(flat)
           name = f"{base_table}_{field.name}"
           flat.write.mode(write_mode).format("parquet").saveAsTable(name)
           children.append(name)
   return children
def invalidate_impala_table(table_name):
   cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl \
       --ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem \
       -q 'INVALIDATE METADATA {table_name};'"""
   subprocess.run(cmd, shell=True, check=True)
spark.sql(f"""INSERT INTO {load_status_table}
   (spark_app_id, database_name, collection_name, status, starttime, endtime, errormessage)
   VALUES ('{spark_app_id}', '{database_name}', '{collection_name}', 'STARTED', '{start_time}', NULL, NULL)
""")
load_status = "SUCCESS"
error_message = None
try:
   raw_df = spark.read.format("mongodb").option("database", database_name).option("collection", collection_name).load()
   if raw_df.rdd.isEmpty(): raise ValueError("MongoDB collection is empty.")
   source_count = raw_df.count()
   sanitized_schema = sanitize_column_names_simple(raw_df.schema)
   explicit_schema = build_explicit_schema(sanitized_schema)
   sanitized_df = spark.createDataFrame(raw_df.rdd, schema=explicit_schema)
   sanitized_df.write.mode("overwrite").format("parquet").saveAsTable(staging_table)
   staging_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table}").first()[0]
   if source_count != staging_count: raise ValueError("Row count mismatch: MongoDB vs staging")
   flat_columns = [f.name for f in sanitized_df.schema.fields if not isinstance(f.dataType, (StructType, ArrayType))]
   flat_df = sanitized_df.select(*flat_columns)
   try:
    spark.catalog.refreshTable(target_table)
    target_df = spark.table(target_table)
    for field in target_df.schema:
       if field.name not in flat_df.columns:
           flat_df = flat_df.withColumn(field.name, lit(None).cast(field.dataType))
    flat_df.select(*target_df.columns).write.mode(write_mode).format("parquet").saveAsTable(target_table)
   except AnalysisException:
    flat_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
   target_count = spark.table(target_table).count()

   try:
       spark.catalog.refreshTable(target_table)
       target_df = spark.table(target_table)
       for field in target_df.schema:
           if field.name not in sanitized_df.columns:
               sanitized_df = sanitized_df.withColumn(field.name, lit(None).cast(field.dataType))
       sanitized_df.select(*target_df.columns).write.mode(write_mode).format("parquet").saveAsTable(target_table)
   except AnalysisException:
       sanitized_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
   target_count = spark.table(target_table).count()
   # Child tables
   child_tables = extract_and_write_child_tables(sanitized_df, target_table)
except Exception as e:
   load_status = "FAILURE"
   error_message = str(e)
   logging.error(error_message)
end_time = datetime.now()
load_duration = str(end_time - start_time)
spark.sql(f"""INSERT INTO {load_status_table}
   (spark_app_id, database_name, collection_name, status, starttime, endtime, errormessage)
   VALUES ('{spark_app_id}', '{database_name}', '{collection_name}', '{load_status}', '{start_time}', '{end_time}', {f"'{error_message}'" if error_message else "NULL"})
""")


audit_schema = StructType([
   StructField("job_id", StringType(), True), 
   StructField("spark_app_id", StringType(), True),
   StructField("database_name", StringType(), True),
   StructField("collection_name", StringType(), True),
   StructField("start_time", StringType(), True),
   StructField("end_time", StringType(), True),
   StructField("load_status", StringType(), True),
   StructField("load_duration", StringType(), True),
   StructField("error_message", StringType(), True),
   StructField("source_count", IntegerType(), True),
   StructField("staging_table_count", IntegerType(), True),
   StructField("target_table_count", IntegerType(), True)
])
audit_df = spark.createDataFrame([(
   job_id, 
   spark_app_id,
   database_name,
   collection_name,
   start_time.strftime('%Y-%m-%d %H:%M:%S'),
   end_time.strftime('%Y-%m-%d %H:%M:%S'),
   load_status,
   load_duration,
   error_message,
   source_count,
   staging_count,
   target_count
)], schema=audit_schema)
'''
audit_df = spark.createDataFrame([(
   job_id, spark_app_id, database_name, collection_name,
   start_time.strftime('%Y-%m-%d %H:%M:%S'),
   end_time.strftime('%Y-%m-%d %H:%M:%S'),
   load_status, duration, error_message,
   source_count, staging_count, target_count
)], schema="""
   job_id, spark_app_id STRING, database_name STRING, collection_name STRING,
   start_time STRING, end_time STRING, load_status STRING, load_duration STRING,
   error_message STRING, source_count INT, staging_table_count INT, target_table_count INT
""")'''
audit_df.write.mode("append").format("parquet").saveAsTable(audit_table)
for tbl in [staging_table, target_table, audit_table, load_status_table] + child_tables:
   invalidate_impala_table(tbl)
spark.stop()
⚠️ **GitHub.com Fallback** ⚠️