mgrdvdyn - Hari-dta/hari GitHub Wiki
`import pyspark import json import argparse from pyspark.sql import * from pyspark.sql.functions import * from pyspark.sql.types import * from pyspark.sql.utils import * import os from datetime import datetime import logging import sys import subprocess import re from bson import * os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.8" os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3.8" logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') parser = argparse.ArgumentParser() parser.add_argument("--uri", help="MongoDB connection URI", required=True) parser.add_argument("--db", help="Name of the MongoDB database", required=True) parser.add_argument("--collection", help="Name of the MongoDB collection", required=True) parser.add_argument("--stgname", help="Name of the staging table", required=True) parser.add_argument("--targetname", help="Name of the target table", required=True) parser.add_argument("--audittable", help="Audit table name", required=True) parser.add_argument("--loadstatustable", help="Load status table name", required=True) parser.add_argument("--writetype", help="append or overwrite?", required=True)
args = parser.parse_args()
spark = SparkSession.builder.master("yarn")
.appName("HM01_MongoDB_to_RDV_pyspark")
.enableHiveSupport()
.master("local[*]")
.config("spark.mongodb.read.connection.uri", args.uri)
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.enableHiveSupport()
.getOrCreate()
database_name = args.db
collection_name = args.collection
start_time = datetime.now()
starting_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
spark_app_id = spark.sparkContext.applicationId
logging.info(f"Starting Spark Application ID: {spark_app_id}")
target_table = args.targetname
spark.sql(f"""
INSERT INTO {args.loadstatustable}
(spark_app_id, database_name , collection_name, status, starttime, endtime, errormessage)
VALUES ('{spark_app_id}', '{database_name}', '{collection_name}', 'STARTED', '{starting_time}', NULL, NULL)
""")
def sanitize_names_with_uniqueness(schema): unknown_counter = [0] top_level_names = {} sanitized_top_names = set()
sanitized_fields = []
def make_unique(name, existing_names):
base = name
counter = 1
while name in existing_names:
name = f"{base}{counter}"
counter += 1
existing_names.add(name)
return name
for field in schema.fields:
raw_name = field.name
base_name = re.sub(r'\W+', '', raw_name).strip('').lower()
if not base_name:
base_name = f"col_unknown{unknown_counter[0]}"
unknown_counter[0] += 1
elif base_name[0].isdigit():
base_name = f"col_{base_name}"
unique_name = make_unique(base_name, sanitized_top_names)
top_level_names[raw_name] = unique_name
sanitized_fields.append(
StructField(unique_name, field.dataType, field.nullable)
)
def recursive_sanitize(field, parent_path=None, nested_tracker=None): if nested_tracker is None: nested_tracker = {} raw_name = field.name base_name = re.sub(r'\W+', '', raw_name).strip('').lower() if not base_name: base_name = f"col_unknown_{unknown_counter[0]}" unknown_counter[0] += 1 elif base_name[0].isdigit(): base_name = f"col_{base_name}" prefix = ''.join(parent_path) if parent_path else '' full_name = f"{prefix}{base_name}" if prefix else base_name
full_name = make_unique(full_name, sanitized_top_names)
dtype = field.dataType
if isinstance(dtype, StructType):
new_fields = [
recursive_sanitize(f, (parent_path or []) + [base_name], nested_tracker)
for f in dtype.fields
]
return StructField(full_name, StructType(new_fields), field.nullable)
elif isinstance(dtype, ArrayType):
element_field = recursive_sanitize(
StructField("element", dtype.elementType),
(parent_path or []) + [base_name],
nested_tracker
)
return StructField(full_name, ArrayType(element_field.dataType), field.nullable)
elif isinstance(dtype, MapType):
value_field = recursive_sanitize(
StructField("value", dtype.valueType),
(parent_path or []) + [base_name],
nested_tracker
)
return StructField(full_name, MapType(StringType(), value_field.dataType), field.nullable)
else:
return StructField(full_name, dtype, field.nullable)
final_fields = [] for field in sanitized_fields: dtype = field.dataType if isinstance(dtype, StructType): new_field = recursive_sanitize(field) final_fields.append(new_field) elif isinstance(dtype, ArrayType) and isinstance(dtype.elementType, StructType): new_field = recursive_sanitize(field) final_fields.append(new_field) else: final_fields.append(field) return StructType(final_fields)
def build_explicit_schema(spark_schema): def sanitize_and_copy(field): name = field.name dtype = field.dataType nullable = field.nullable if isinstance(dtype, StringType): return StructField(name, StringType(), nullable) elif isinstance(dtype, IntegerType): return StructField(name, IntegerType(), nullable) elif isinstance(dtype, BooleanType): return StructField(name, BooleanType(), nullable) elif isinstance(dtype, DoubleType): return StructField(name, DoubleType(), nullable) elif isinstance(dtype, DecimalType): return StructField(name, DecimalType(30, 10), nullable) elif isinstance(dtype, TimestampType): return StructField(name, TimestampType(), nullable) elif isinstance(dtype, DateType): return StructField(name, DateType(), nullable) elif isinstance(dtype, LongType): return StructField(name, LongType(), nullable) elif isinstance(dtype, ArrayType): return StructField(name, ArrayType(build_explicit_schema(StructType([StructField("element", dtype.elementType)]))[0].dataType), nullable) elif isinstance(dtype, MapType): return StructField(name, MapType(StringType(), StringType()), nullable) elif isinstance(dtype, StructType): return StructField(name, build_explicit_schema(dtype), nullable) else: return StructField(name, StringType(), nullable) return StructType([sanitize_and_copy(field) for field in spark_schema.fields]) def flatten_dataframe(df): def get_array_fields_and_lengths(df): array_fields = [] for field in df.schema.fields: if isinstance(field.dataType, ArrayType): max_len = df.select(size(col(field.name)).alias("len")).agg(max("len")).collect()[0][0] array_fields.append((field.name, max_len or 0)) return array_fields def flatten_once(input_df, array_info): fields = [] for field in input_df.schema.fields: dtype = field.dataType name = field.name if isinstance(dtype, StructType): for subfield in dtype.fields: fields.append(col(f"{name}.{subfield.name}").alias(f"{name}{subfield.name}")) elif isinstance(dtype, ArrayType): max_len = dict(array_info).get(name, 0) for i in range(max_len): fields.append(col(name).getItem(i).alias(f"{name}{i}")) else: fields.append(col(name)) return input_df.select(fields) prev_cols = [] while True: array_info = get_array_fields_and_lengths(df) df = flatten_once(df, array_info) struct_cols_exist = any(isinstance(df.schema[field].dataType, StructType) for field in df.columns) if df.columns == prev_cols and not struct_cols_exist: break prev_cols = df.columns return df
staging_table_name = args.stgname
audit_table_name = args.audittable
write_mode = args.writetype.lower()
load_status = "SUCCESS"
error_message = None
source_count = 0
target_count = 0
try:
raw_df = spark.read.format("mongodb")
.option("database", database_name)
.option("collection", collection_name)
.load()
if raw_df.rdd.isEmpty():
raise ValueError("MongoDB collection is empty. Nothing to process.")
source_count = raw_df.count()
logging.info(f"MongoDB document count: {source_count}")
sanitized_schema = sanitize_names_with_uniqueness(raw_df.schema)
if len(sanitized_schema.fields) == 0:
raise ValueError("Sanitized schema has no valid fields. Possibly all column names were invalid.")
explicit_schema = build_explicit_schema(sanitized_schema)
sanitized_df = spark.createDataFrame(raw_df.rdd, schema=explicit_schema)
sanitized_df.write.mode("overwrite").format("parquet").saveAsTable(staging_table_name)
staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]
logging.info(f"Hive staging count: {staging_table_count}")
if source_count != staging_table_count:
raise ValueError(f"Row count mismatch: MongoDB={source_count}, Hive={staging_table_count}")
nested_df = spark.sql(f"SELECT * FROM {staging_table_name}")
flattened_df = flatten_dataframe(nested_df)
try:
spark.catalog.refreshTable(target_table)
target_df = spark.table(target_table)
target_schema = target_df.schema
target_columns = target_df.columns
for field in target_schema:
if field.name not in flattened_df.columns:
flattened_df = flattened_df.withColumn(field.name, lit(None).cast(field.dataType))
write_df = flattened_df.select(*target_columns)
write_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
logging.info("Overwrote target table with matching columns only.")
target_table_count = write_df.count()
#staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]
except AnalysisException:
flattened_df.write.mode(write_mode).format("parquet").saveAsTable(target_table)
logging.info("Target table did not exist. Created it with all columns.")
target_table_count = flattened_df.count()
#staging_table_count = spark.sql(f"SELECT COUNT(*) FROM {staging_table_name}").first()[0]
except Exception as e: load_status = "FAILURE" error_message = str(e) logging.error("Error: %s", error_message) ending_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
spark.sql(f""" INSERT INTO {args.loadstatustable} (spark_app_id, database_name, collection_name, status, starttime, endtime, errormessage) VALUES ( '{spark_app_id}', '{database_name}', '{collection_name}', '{load_status}', '{starting_time}', '{ending_time}', {f"'{error_message}'" if error_message else "NULL"}) """)
end_time = datetime.now() staging_table_count = spark.sql(f"SELECT COUNT() FROM {staging_table_name}").first()[0] load_duration = str(end_time - start_time) target_table_count = spark.sql(f"SELECT COUNT() FROM {target_table}").first()[0]
audit_schema = StructType([
StructField("spark_app_id", StringType(), True),
StructField("database_name", StringType(), True),
StructField("collection_name", StringType(), True),
StructField("start_time", StringType(), True),
StructField("end_time", StringType(), True),
StructField("load_status", StringType(), True),
StructField("load_duration", StringType(), True),
StructField("error_message", StringType(), True),
StructField("source_count", IntegerType(), True),
StructField("staging_table_count", IntegerType(), True),
StructField("target_table_count", IntegerType(), True)
])
audit_df = spark.createDataFrame([(
spark_app_id,
database_name,
collection_name,
start_time.strftime('%Y-%m-%d %H:%M:%S'),
end_time.strftime('%Y-%m-%d %H:%M:%S'),
load_status,
load_duration,
error_message,
source_count,
staging_table_count,
target_table_count
)], schema=audit_schema)
audit_df.write.mode("append").format("parquet").saveAsTable(audit_table_name)
load_status_table = args.loadstatustable
impala_stg_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {staging_table_name};'"""
impala_target_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {target_table};'"""
impala_audit_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {audit_table_name};'"""
impala_loadstatus_cmd = f"""impala-shell -i ab00-dimpalalb.arabbank.plc:21000 -d default -k --ssl
--ca_cert=/var/lib/cloudera-scm-agent/agent-cert/cm-auto-global_cacerts.pem -q 'INVALIDATE METADATA {load_status_table};'"""
try: subprocess.run(impala_stg_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {staging_table_name}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (staging): {e.stderr}") try: subprocess.run(impala_target_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {target_table}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (target): {e.stderr}") try: subprocess.run(impala_loadstatus_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {load_status_table}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (audit): {e.stderr}") try: subprocess.run(impala_audit_cmd, shell=True, check=True, capture_output=True, text=True) logging.info(f"Impala metadata invalidated for {audit_table_name}") except subprocess.CalledProcessError as e: logging.warning(f"Impala error (audit): {e.stderr}")
spark.stop()`