End to end production grade etl pyspark pipelines.
🚀 Building Production-Ready PySpark ETL Pipelines: From Zero to Hero with Real-World Retail…
End to end production grade etl pyspark pipelines.

a11y-light·June 28, 2025 (Updated: June 28, 2025)·Free: No
Hey there, data engineers! 👋 Ready to dive deep into building a bulletproof PySpark ETL pipeline that actually works in production? Today we’re going to tackle a real business problem that every e-commerce company faces — understanding customer behavior and sales performance through data.
🎯 The Business Problem We’re Solving
Imagine you’re working for “RetailMax” — a growing online retail company drowning in data but starving for insights. They have:
- 📊 Millions of transactions daily
- 🛒 Complex customer journeys across multiple touchpoints
- 📈 Need for real-time business metrics
- 🔄 Data scattered across different systems
The CEO walks into your office and says: “We need to understand our customer lifetime value, identify our top-perform ing products, and track key business metrics in real-time. Can you build something that scales?”
That’s where we come in! 💪
🏗️ Project Architecture Overview
Before we start coding, let’s understand what we’re building:
Copy📁 retail_etl_pipeline/
├── 📁 src/
│ ├── 📁 config/
│ ├── 📁 data_ingestion/
│ ├── 📁 data_processing/
│ ├── 📁 data_quality/
│ ├── 📁 business_metrics/
│ └── 📁 utils/
├── 📁 tests/
├── 📁 logs/
├── 📁 data/
│ ├── 📁 raw/
│ ├── 📁 processed/
│ └── 📁 curated/
└── 📁 docs/
🎲 Creat ing Our Sample Dataset
First, let’s generate realistic sample data that mimics what you’d see in a real retail environment:
Copy# data_generator.py
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import random
from faker import Faker
import uuid
fake = Faker()
def generate_customers(num_customers=10000):
"""Generate realistic customer data 👥"""
customers = []
for i in range(num_customers):
customer = {
'customer_id': str(uuid.uuid4()),
'first_name': fake.first_name(),
'last_name': fake.last_name(),
'email': fake.email(),
'phone': fake.phone_number(),
'registration_date': fake.date_between(start_date='-2y', end_date='today'),
'country': fake.country(),
'city': fake.city(),
'age_group': random.choice(['18-25', '26-35', '36-45', '46-55', '55+']),
'customer_segment': random.choice(['Premium', 'Standard', 'Basic'])
}
customers.append(customer)
return pd.DataFrame(customers)
def generate_products(num_products=1000):
"""Generate product catalog 🛍️"""
categories = ['Electronics', 'Clothing', 'Home & Garden', 'Sports', 'Books', 'Beauty']
products = []
for i in range(num_products):
category = random.choice(categories)
product = {
'product_id': f"PROD_{i:06d}",
'product_name': fake.catch_phrase(),
'category': category,
'subcategory': f"{category}_Sub_{random.randint(1,5)}",
'brand': fake.company(),
'price': round(random.uniform(10, 1000), 2),
'cost': round(random.uniform(5, 500), 2),
'launch_date': fake.date_between(start_date='-1y', end_date='today')
}
products.append(product)
return pd.DataFrame(products)
def generate_transactions(customers_df, products_df, num_transactions=100000):
"""Generate realistic transaction data 💳"""
transactions = []
for i in range(num_transactions):
customer_id = random.choice(customers_df['customer_id'].tolist())
product_id = random.choice(products_df['product_id'].tolist())
product_price = products_df[products_df['product_id'] == product_id]['price'].iloc[0]
transaction = {
'transaction_id': str(uuid.uuid4()),
'customer_id': customer_id,
'product_id': product_id,
'quantity': random.randint(1, 5),
'unit_price': product_price,
'discount_amount': round(random.uniform(0, product_price * 0.3), 2),
'transaction_date': fake.date_time_between(start_date='-1y', end_date='now'),
'payment_method': random.choice(['Credit Card', 'Debit Card', 'PayPal', 'Cash']),
'channel': random.choice(['Online', 'Mobile App', 'In-Store']),
'status': random.choice(['Completed', 'Pending', 'Cancelled'])
}
# Calculate total amount
transaction['total_amount'] = (transaction['quantity'] * transaction['unit_price']) - transaction['discount_amount']
transactions.append(transaction)
return pd.DataFrame(transactions)
# Generate and save sample data
if __name__ == "__main__":
print("🎲 Generat ing sample retail data...")
customers = generate_customers(10000)
products = generate_products(1000)
transactions = generate_transactions(customers, products, 100000)
# Save to CSV
customers.to_csv('data/raw/customers.csv', index=False)
products.to_csv('data/raw/products.csv', index=False)
transactions.to_csv('data/raw/transactions.csv', index=False)
print("✅ Sample data generated successfully!")
⚙️ Configuration Management
Let’s create a robust configuration system:
Copy# src/config/config.py
import os
from dataclasses import dataclass
from typing import Dict, Any
import yaml
@dataclass
class SparkConfig:
"""Spark configuration settings ⚡"""
app_name: str = "RetailETLPipeline"
master: str = "local[*]"
executor_memory: str = "4g"
driver_memory: str = "2g"
max_result_size: str = "1g"
serializer: str = "org.apache.spark.serializer.KryoSerializer"
@dataclass
class S3Config:
"""AWS S3 configuration 🪣"""
bucket_name: str = "retail-analytics-bucket"
raw_data_prefix: str = "raw/"
processed_data_prefix: str = "processed/"
curated_data_prefix: str = "curated/"
access_key: str = os.getenv("AWS_ACCESS_KEY_ID", "")
secret_key: str = os.getenv("AWS_SECRET_ACCESS_KEY", "")
region: str = "us-east-1"
@dataclass
class DataQualityConfig:
"""Data quality thresholds 🎯"""
null_threshold: float = 0.05 # 5% null values allowed
duplicate_threshold: float = 0.01 # 1% duplicates allowed
outlier_threshold: float = 3.0 # 3 standard deviations
class ConfigManager:
"""Central configuration manager 🎛️"""
def __init__(self, config_path: str = "src/config/pipeline_config.yaml"):
self.config_path = config_path
self.spark_config = SparkConfig()
self.s3_config = S3Config()
self.data_quality_config = DataQualityConfig()
self._load_config()
def _load_config(self):
"""Load configuration from YAML file"""
if os.path.exists(self.config_path):
with open(self.config_path, 'r') as file:
config_data = yaml.safe_load(file)
self._update_configs(config_data)
def _update_configs(self, config_data: Dict[str, Any]):
"""Update configuration objects with YAML data"""
if 'spark' in config_data:
for key, value in config_data['spark'].items():
if hasattr(self.spark_config, key):
setattr(self.spark_config, key, value)
if 's3' in config_data:
for key, value in config_data['s3'].items():
if hasattr(self.s3_config, key):
setattr(self.s3_config, key, value)
🔧 Utility Functions & Logging
Copy# src/utils/logger.py
import logging
import os
from datetime import datetime
from typing import Optional
class ETLLogger:
"""Production-grade logging system 📝"""
def __init__(self, name: str, log_level: str = "INFO"):
self.logger = logging.getLogger(name)
self.logger.setLevel(getattr(logging, log_level.upper()))
if not self.logger.handlers:
self._setup_handlers()
def _setup_handlers(self):
"""Setup file and console handlers"""
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# File handler
log_filename = f"logs/etl_pipeline_{datetime.now().strftime('%Y%m%d')}.log"
file_handler = logging.FileHandler(log_filename)
file_handler.setLevel(logging.DEBUG)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def info(self, message: str, **kwargs):
self.logger.info(message, extra=kwargs)
def error(self, message: str, **kwargs):
self.logger.error(message, extra=kwargs)
def warning(self, message: str, **kwargs):
self.logger.warning(message, extra=kwargs)
def debug(self, message: str, **kwargs):
self.logger.debug(message, extra=kwargs)
# src/utils/spark_utils.py
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from src.config.config import ConfigManager
from src.utils.logger import ETLLogger
class SparkSessionManager:
"""Spark session management with best practices ⚡"""
def __init__(self, config_manager: ConfigManager):
self.config = config_manager.spark_config
self.logger = ETLLogger(__name__)
self._session = None
def get_session(self) -> SparkSession:
"""Get or create Spark session"""
if self._session is None:
self.logger.info("🚀 Creating new Spark session...")
self._session = SparkSession.builder \
.appName(self.config.app_name) \
.master(self.config.master) \
.config("spark.executor.memory", self.config.executor_memory) \
.config("spark.driver.memory", self.config.driver_memory) \
.config("spark.driver.maxResultSize", self.config.max_result_size) \
.config("spark.serializer", self.config.serializer) \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.getOrCreate()
# Set log level to reduce noise
self._session.sparkContext.setLogLevel("WARN")
self.logger.info("✅ Spark session created successfully")
return self._session
def stop_session(self):
"""Stop Spark session"""
if self._session:
self._session.stop()
self._session = None
self.logger.info("🛑 Spark session stopped")
def get_retail_schema():
"""Define schemas for our retail data 📋"""
customer_schema = StructType([
StructField("customer_id", StringType(), False),
StructField("first_name", StringType(), True),
StructField("last_name", StringType(), True),
StructField("email", StringType(), True),
StructField("phone", StringType(), True),
StructField("registration_date", DateType(), True),
StructField("country", StringType(), True),
StructField("city", StringType(), True),
StructField("age_group", StringType(), True),
StructField("customer_segment", StringType(), True)
])
product_schema = StructType([
StructField("product_id", StringType(), False),
StructField("product_name", StringType(), True),
StructField("category", StringType(), True),
StructField("subcategory", StringType(), True),
StructField("brand", StringType(), True),
StructField("price", DoubleType(), True),
StructField("cost", DoubleType(), True),
StructField("launch_date", DateType(), True)
])
transaction_schema = StructType([
StructField("transaction_id", StringType(), False),
StructField("customer_id", StringType(), False),
StructField("product_id", StringType(), False),
StructField("quantity", IntegerType(), True),
StructField("unit_price", DoubleType(), True),
StructField("discount_amount", DoubleType(), True),
StructField("total_amount", DoubleType(), True),
StructField("transaction_date", TimestampType(), True),
StructField("payment_method", StringType(), True),
StructField("channel", StringType(), True),
StructField("status", StringType(), True)
])
return {
"customers": customer_schema,
"products": product_schema,
"transactions": transaction_schema
}
📥 Data Ingestion Layer
Copy# src/data_ingestion/data_reader.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from src.utils.spark_utils import SparkSessionManager, get_retail_schema
from src.utils.logger import ETLLogger
from typing import Dict, Optional
import boto3
from botocore.exceptions import ClientError
class DataReader:
"""Production-grade data ingestion with error handling 📥"""
def __init__(self, spark_manager: SparkSessionManager):
self.spark = spark_manager.get_session()
self.logger = ETLLogger(__name__)
self.schemas = get_retail_schema()
def read_csv_with_schema(self, file_path: str, schema_name: str) -> Optional[DataFrame]:
"""Read CSV with predefined schema and error handling"""
try:
self.logger.info(f"📖 Reading CSV file: {file_path}")
if schema_name not in self.schemas:
raise ValueError(f"Schema '{schema_name}' not found")
df = self.spark.read \
.option("header", "true") \
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss") \
.option("dateFormat", "yyyy-MM-dd") \
.schema(self.schemas[schema_name]) \
.csv(file_path)
# Add ingestion metadata
df = df.withColumn("ingestion_timestamp", current_timestamp()) \
.withColumn("source_file", lit(file_path))
record_count = df.count()
self.logger.info(f"✅ Successfully read {record_count} records from {file_path}")
return df
except Exception as e:
self.logger.error(f"❌ Failed to read {file_path}: {str(e)}")
return None
def read_from_s3(self, s3_path: str, file_format: str = "parquet") -> Optional[DataFrame]:
"""Read data from S3 with retry logic"""
try:
self.logger.info(f"☁️ Reading from S3: {s3_path}")
if file_format.lower() == "parquet":
df = self.spark.read.parquet(s3_path)
elif file_format.lower() == "csv":
df = self.spark.read.option("header", "true").csv(s3_path)
elif file_format.lower() == "json":
df = self.spark.read.json(s3_path)
else:
raise ValueError(f"Unsupported file format: {file_format}")
self.logger.info(f"✅ Successfully read data from S3")
return df
except Exception as e:
self.logger.error(f"❌ Failed to read from S3: {str(e)}")
return None
def validate_data_freshness(self, df: DataFrame, date_column: str, max_age_hours: int = 24) -> bool:
"""Check if data is fresh enough for processing"""
try:
latest_date = df.agg(max(col(date_column)).alias("latest_date")).collect()[0]["latest_date"]
if latest_date:
hours_old = (datetime.now() - latest_date).total_seconds() / 3600
is_fresh = hours_old <= max_age_hours
self.logger.info(f"📅 Data freshness check: {hours_old:.1f} hours old (threshold: {max_age_hours}h)")
return is_fresh
return False
except Exception as e:
self.logger.error(f"❌ Data freshness validation failed: {str(e)}")
return False
🧹 Data Quality & Validation
Copy# src/data_quality/quality_checker.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import *
from src.utils.logger import ETLLogger
from src.config.config import DataQualityConfig
from typing import Dict, List, Tuple
import json
class DataQualityChecker:
"""Comprehensive data quality validation 🎯"""
def __init__(self, config: DataQualityConfig):
self.config = config
self.logger = ETLLogger(__name__)
self.quality_report = {}
def run_quality_checks(self, df: DataFrame, dataset_name: str) -> Tuple[bool, Dict]:
"""Run comprehensive data quality checks"""
self.logger.info(f"🔍 Starting data quality checks for {dataset_name}")
checks = [
self._check_null_values,
self._check_duplicates,
self._check_data_types,
self._check_outliers,
self._check_referential_integrity
]
all_passed = True
report = {"dataset": dataset_name, "checks": {}}
for check in checks:
try:
check_name = check.__name__.replace("_check_", "")
passed, details = check(df)
report["checks"][check_name] = {
"passed": passed,
"details": details
}
if not passed:
all_passed = False
self.logger.warning(f"⚠️ Quality check failed: {check_name}")
else:
self.logger.info(f"✅ Quality check passed: {check_name}")
except Exception as e:
self.logger.error(f"❌ Quality check error in {check.__name__}: {str(e)}")
all_passed = False
self.quality_report[dataset_name] = report
return all_passed, report
def _check_null_values(self, df: DataFrame) -> Tuple[bool, Dict]:
"""Check for excessive null values"""
total_rows = df.count()
null_counts = {}
for column in df.columns:
null_count = df.filter(col(column).isNull()).count()
null_percentage = (null_count / total_rows) * 100
null_counts[column] = {
"null_count": null_count,
"null_percentage": round(null_percentage, 2)
}
# Check if any column exceeds threshold
failed_columns = [
col_name for col_name, stats in null_counts.items()
if stats["null_percentage"] > (self.config.null_threshold * 100)
]
return len(failed_columns) == 0, {
"null_counts": null_counts,
"failed_columns": failed_columns,
"threshold_percentage": self.config.null_threshold * 100
}
def _check_duplicates(self, df: DataFrame) -> Tuple[bool, Dict]:
"""Check for duplicate records"""
total_rows = df.count()
unique_rows = df.distinct().count()
duplicate_count = total_rows - unique_rows
duplicate_percentage = (duplicate_count / total_rows) * 100
passed = duplicate_percentage <= (self.config.duplicate_threshold * 100)
return passed, {
"total_rows": total_rows,
"unique_rows": unique_rows,
"duplicate_count": duplicate_count,
"duplicate_percentage": round(duplicate_percentage, 2),
"threshold_percentage": self.config.duplicate_threshold * 100
}
def _check_data_types(self, df: DataFrame) -> Tuple[bool, Dict]:
"""Validate data types and formats"""
type_issues = []
for column, data_type in df.dtypes:
if data_type in ["string", "varchar"]:
# Check for potential numeric columns stored as strings
numeric_pattern = r'^\d+\.?\d*$'
non_numeric_count = df.filter(
col(column).isNotNull() &
~col(column).rlike(numeric_pattern)
).count()
if non_numeric_count == 0 and df.filter(col(column).isNotNull()).count() > 0:
type_issues.append(f"{column} might be numeric but stored as string")
return len(type_issues) == 0, {"issues": type_issues}
def _check_outliers(self, df: DataFrame) -> Tuple[bool, Dict]:
"""Detect outliers in numeric columns"""
numeric_columns = [col_name for col_name, dtype in df.dtypes if dtype in ["int", "bigint", "float", "double"]]
outlier_report = {}
for column in numeric_columns:
stats = df.select(
mean(col(column)).alias("mean"),
stddev(col(column)).alias("stddev")
).collect()[0]
if stats["stddev"] and stats["stddev"] > 0:
lower_bound = stats["mean"] - (self.config.outlier_threshold * stats["stddev"])
upper_bound = stats["mean"] + (self.config.outlier_threshold * stats["stddev"])
outlier_count = df.filter(
(col(column) < lower_bound) | (col(column) > upper_bound)
).count()
total_count = df.filter(col(column).isNotNull()).count()
outlier_percentage = (outlier_count / total_count) * 100 if total_count > 0 else 0
outlier_report[column] = {
"outlier_count": outlier_count,
"outlier_percentage": round(outlier_percentage, 2),
"lower_bound": round(lower_bound, 2),
"upper_bound": round(upper_bound, 2)
}
return True, outlier_report # Outliers don't fail the check, just report them
def _check_referential_integrity(self, df: DataFrame) -> Tuple[bool, Dict]:
"""Basic referential integrity checks"""
# This would be expanded based on specific business rules
integrity_issues = []
# Example: Check if customer_id exists in transactions
if "customer_id" in df.columns and "transaction_id" in df.columns:
null_customer_ids = df.filter(col("customer_id").isNull()).count()
if null_customer_ids > 0:
integrity_issues.append(f"Found {null_customer_ids} transactions with null customer_id")
return len(integrity_issues) == 0, {"issues": integrity_issues}
def save_quality_report(self, output_path: str):
"""Save quality report to file"""
try:
with open(output_path, 'w') as f:
json.dump(self.quality_report, f, indent=2, default=str)
self.logger.info(f"📊 Quality report saved to {output_path}")
except Exception as e:
self.logger.error(f"❌ Failed to save quality report: {str(e)}")
This is just the beginning! 🚀 We’ve covered the foundation with data generation, configuration management, logging, data ingestion, and quality checks.
🔄 Data Processing & Transformation Engine
Now comes the heart of our ETL pipeline — the transformation layer! This is where raw data becomes business intelligence. Think of it as a sophisticated kitchen where we take raw ingredients (data) and create gourmet meals (insights). 👨🍳
Why Data Transformation Matters 🤔
Before diving into code, let’s understand why this layer is crucial:
- Business Logic Implementation: Convert ing raw transactions into meaningful metrics
- Data Enrichment: Adding calculated fields that provide business context
- Performance Optimization: Structuring data for fast querying
- Standardization: Ensuring consistent data formats across the organization
Copy# src/data_processing/transformations.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.sql.types import *
from src.utils.logger import ETLLogger
from typing import Dict, List, Optional
import math
class RetailDataTransformer:
""" 🏭 The transformation factory - where raw data becomes business gold!
This class handles all the complex business logic transformations that convert raw retail data into actionable insights.
Think of it as the brain of our ETL pipeline.
"""
def __init__(self, spark_session):
self.spark = spark_session
self.logger = ETLLogger(__name__)
# 📊 These window specifications are reusable patterns for analytics
# They define how we partition and order data for calculations
self.customer_window = Window.partitionBy("customer_id").orderBy("transaction_date")
self.product_window = Window.partitionBy("product_id").orderBy("transaction_date")
self.monthly_window = Window.partitionBy("year", "month").orderBy("transaction_date")
def create_enriched_transactions(self, transactions_df: DataFrame, customers_df: DataFrame, products_df: DataFrame) -> DataFrame:
""" 🎯 The Master Join Operation
This is where the magic happens! We're creating a 360-degree view of each transaction by combining data from multiple sources.
It's like assembling a puzzle where each piece adds crucial context to understand customer behavior.
Why this matters:
- Enables cross-dimensional analysis (customer + product + time)
- Creates the foundation for all downstream analytics
- Reduces the need for multiple joins in reporting queries
"""
self.logger.info("🔗 Creating enriched transaction dataset...")
# Step 1: Start with transactions as the base (fact table)
enriched_df = transactions_df.filter(col("status") == "Completed")
# Step 2: Add customer context - WHO is buying?
# Left join ensures we keep all transactions even if customer data is missing
enriched_df = enriched_df.join(
customers_df.select(
"customer_id",
"customer_segment",
"country",
"age_group",
"registration_date"
),
"customer_id",
"left"
)
# Step 3: Add product context - WHAT are they buying?
enriched_df = enriched_df.join(
products_df.select(
"product_id",
"category",
"subcategory",
"brand",
"cost"
),
"product_id",
"left"
)
# Step 4: Add time-based dimensions for temporal analysis
enriched_df = enriched_df.withColumn("year", year(col("transaction_date"))) \
.withColumn("month", month(col("transaction_date"))) \
.withColumn("quarter", quarter(col("transaction_date"))) \
.withColumn("day_of_week", dayofweek(col("transaction_date"))) \
.withColumn("hour", hour(col("transaction_date")))
# Step 5: Calculate business metrics at transaction level
enriched_df = enriched_df.withColumn(
"profit_amount",
col("total_amount") - (col("quantity") * col("cost"))
).withColumn(
"profit_margin",
when(col("total_amount") > 0,
(col("profit_amount") / col("total_amount")) * 100
).otherwise(0)
)
# Step 6: Add customer tenure (how long they've been with us)
enriched_df = enriched_df.withColumn(
"customer_tenure_days",
datediff(col("transaction_date"), col("registration_date"))
)
record_count = enriched_df.count()
self.logger.info(f"✅ Created enriched dataset with {record_count:,} transactions")
return enriched_df
def calculate_customer_lifetime_metrics(self, enriched_df: DataFrame) -> DataFrame:
""" 💎 Customer Lifetime Value (CLV) Calculation
This is one of the most important metrics in retail! CLV helps us understand:
- Which customers are most valuable
- How much we can spend on customer acquisition
- Which customer segments to focus on
We're using a sophisticated approach that considers:
- Recency: How recently did they purchase?
- Frequency: How often do they purchase?
- Monetary: How much do they spend?
"""
self.logger.info("💎 Calculating customer lifetime metrics...")
# Calculate RFM metrics (Recency, Frequency, Monetary)
current_date = self.spark.sql("SELECT current_date() as today").collect()[0]["today"]
customer_metrics = enriched_df.groupBy("customer_id").agg(
# Monetary metrics
sum("total_amount").alias("total_spent"),
avg("total_amount").alias("avg_order_value"),
sum("profit_amount").alias("total_profit_generated"),
# Frequency metrics
count("transaction_id").alias("total_orders"),
countDistinct("product_id").alias("unique_products_purchased"),
countDistinct("category").alias("categories_explored"),
# Recency metrics
max("transaction_date").alias("last_purchase_date"),
min("transaction_date").alias("first_purchase_date"),
# Customer context
first("customer_segment").alias("customer_segment"),
first("country").alias("country"),
first("age_group").alias("age_group")
)
# Calculate derived metrics
customer_metrics = customer_metrics.withColumn(
"days_since_last_purchase",
datediff(lit(current_date), col("last_purchase_date"))
).withColumn(
"customer_lifespan_days",
datediff(col("last_purchase_date"), col("first_purchase_date"))
).withColumn(
"purchase_frequency",
when(col("customer_lifespan_days") > 0,
col("total_orders") / (col("customer_lifespan_days") / 30.0)
).otherwise(col("total_orders"))
)
# Calculate CLV using a simplified formula
# CLV = Average Order Value × Purchase Frequency × Customer Lifespan
customer_metrics = customer_metrics.withColumn(
"estimated_clv",
col("avg_order_value") * col("purchase_frequency") * (col("customer_lifespan_days") / 365.0)
)
# Create customer segments based on CLV
clv_percentiles = customer_metrics.select(
expr("percentile_approx(estimated_clv, 0.8)").alias("p80"),
expr("percentile_approx(estimated_clv, 0.6)").alias("p60"),
expr("percentile_approx(estimated_clv, 0.4)").alias("p40")
).collect()[0]
customer_metrics = customer_metrics.withColumn(
"clv_segment",
when(col("estimated_clv") >= clv_percentiles["p80"], "Champions")
.when(col("estimated_clv") >= clv_percentiles["p60"], "Loyal Customers")
.when(col("estimated_clv") >= clv_percentiles["p40"], "Potential Loyalists")
.otherwise("New Customers")
)
self.logger.info("✅ Customer lifetime metrics calculated successfully")
return customer_metrics
def create_product_performance_metrics(self, enriched_df: DataFrame) -> DataFrame:
""" 📈 Product Performance Analytics
Understanding which products drive your business is crucial for:
- Inventory management
- Marketing focus
- Pricing strategies
- Product development decisions
This function creates comprehensive product analytics that answer:
- Which products are bestsellers?
- What's the profit margin by product?
- How do products perform across different customer segments?
"""
self.logger.info("📈 Calculating product performance metrics...")
# Basic product metrics
product_metrics = enriched_df.groupBy(
"product_id", "category", "subcategory", "brand"
).agg(
# Sales metrics
sum("quantity").alias("total_units_sold"),
sum("total_amount").alias("total_revenue"),
sum("profit_amount").alias("total_profit"),
count("transaction_id").alias("total_transactions"),
# Customer metrics
countDistinct("customer_id").alias("unique_customers"),
# Performance metrics
avg("total_amount").alias("avg_transaction_value"),
avg("profit_margin").alias("avg_profit_margin")
)
# Calculate advanced metrics
product_metrics = product_metrics.withColumn(
"revenue_per_customer",
col("total_revenue") / col("unique_customers")
).withColumn(
"units_per_transaction",
col("total_units_sold") / col("total_transactions")
)
# Add ranking within categories
category_window = Window.partitionBy("category").orderBy(col("total_revenue").desc())
product_metrics = product_metrics.withColumn(
"category_revenue_rank",
row_number().over(category_window)
).withColumn(
"is_top_performer",
when(col("category_revenue_rank") <= 5, True).otherwise(False)
)
self.logger.info("✅ Product performance metrics calculated")
return product_metrics
def create_time_series_analytics(self, enriched_df: DataFrame) -> DataFrame:
""" 📅 Time Series Analytics for Trend Analysis
Time-based analysis is crucial for:
- Identifying seasonal patterns
- Forecasting future sales
- Understanding business cycles
- Planning inventory and marketing campaigns
This creates daily, weekly, and monthly aggregations with trend indicators.
"""
self.logger.info("📅 Creating time series analytics...")
# Daily aggregations
daily_metrics = enriched_df.groupBy("year", "month", "day_of_week").agg(
sum("total_amount").alias("daily_revenue"),
sum("profit_amount").alias("daily_profit"),
count("transaction_id").alias("daily_transactions"),
countDistinct("customer_id").alias("daily_unique_customers"),
avg("total_amount").alias("daily_avg_order_value")
)
# Add day-over-day growth calculations
daily_window = Window.partitionBy("day_of_week").orderBy("year", "month")
daily_metrics = daily_metrics.withColumn(
"revenue_growth_rate",
((col("daily_revenue") - lag("daily_revenue", 1).over(daily_window)) / lag("daily_revenue", 1).over(daily_window)) * 100
)
# Monthly aggregations for executive reporting
monthly_metrics = enriched_df.groupBy("year", "month").agg(
sum("total_amount").alias("monthly_revenue"),
sum("profit_amount").alias("monthly_profit"),
count("transaction_id").alias("monthly_transactions"),
countDistinct("customer_id").alias("monthly_active_customers"),
countDistinct("product_id").alias("monthly_products_sold")
)
# Calculate month-over-month growth
monthly_window = Window.orderBy("year", "month")
monthly_metrics = monthly_metrics.withColumn(
"mom_revenue_growth",
((col("monthly_revenue") - lag("monthly_revenue", 1).over(monthly_window)) / lag("monthly_revenue", 1).over(monthly_window)) * 100
).withColumn(
"mom_customer_growth",
((col("monthly_active_customers") - lag("monthly_active_customers", 1).over(monthly_window)) / lag("monthly_active_customers", 1).over(monthly_window)) * 100
)
self.logger.info("✅ Time series analytics completed")
return daily_metrics, monthly_metrics
📊 Advanced Business Metrics Engine
This is where we get really sophisticated! 🧠 These metrics are what separate good analytics from great analytics. They answer the strategic questions that keep executives awake at night.
Copy# src/business_metrics/advanced_analytics.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from src.utils.logger import ETLLogger
from typing import Tuple, Dict
import numpy as np
class AdvancedBusinessMetrics:
""" 🎯 Advanced Analytics Engine
This class implements sophisticated business metrics that provide deep insights:
- Customer segmentation using RFM analysis
- Cohort analysis for retention tracking
- Market basket analysis for cross-selling
- Churn prediction indicators
These metrics help answer questions like:
- Which customers are at risk of churning?
- What products are frequently bought together?
- How do customer cohorts perform over time?
"""
def __init__(self, spark_session):
self.spark = spark_session
self.logger = ETLLogger(__name__)
def perform_rfm_analysis(self, enriched_df: DataFrame) -> DataFrame:
""" 🎯 RFM Analysis - The Gold Standard of Customer Segmentation
RFM stands for Recency, Frequency, Monetary - three key dimensions that predict customer behavior better than demographics alone!
Why RFM works:
- Recency: Recent customers are more likely to respond to offers
- Frequency: Frequent customers show loyalty and engagement
- Monetary: High-value customers drive profitability
This creates actionable customer segments for targeted marketing.
"""
self.logger.info("🎯 Performing RFM analysis...")
# Calculate the reference date (latest transaction date)
max_date = enriched_df.agg(max("transaction_date")).collect()[0][0]
# Calculate RFM metrics for each customer
rfm_data = enriched_df.groupBy("customer_id").agg(
# Recency: Days since last purchase
datediff(lit(max_date), max("transaction_date")).alias("recency"),
# Frequency: Number of transactions
count("transaction_id").alias("frequency"),
# Monetary: Total amount spent
sum("total_amount").alias("monetary")
)
# Calculate quintiles for each RFM dimension
# Quintiles divide customers into 5 equal groups (1=worst, 5=best)
rfm_quintiles = rfm_data.select(
expr("percentile_approx(recency, array(0.2, 0.4, 0.6, 0.8))").alias("recency_quintiles"),
expr("percentile_approx(frequency, array(0.2, 0.4, 0.6, 0.8))").alias("frequency_quintiles"),
expr("percentile_approx(monetary, array(0.2, 0.4, 0.6, 0.8))").alias("monetary_quintiles")
).collect()[0]
# Assign RFM scores (1-5 scale)
# Note: For recency, lower values are better (more recent), so we reverse the scoring
rfm_scored = rfm_data.withColumn(
"r_score",
when(col("recency") <= rfm_quintiles["recency_quintiles"][0], 5)
.when(col("recency") <= rfm_quintiles["recency_quintiles"][1], 4)
.when(col("recency") <= rfm_quintiles["recency_quintiles"][2], 3)
.when(col("recency") <= rfm_quintiles["recency_quintiles"][3], 2)
.otherwise(1)
).withColumn(
"f_score",
when(col("frequency") >= rfm_quintiles["frequency_quintiles"][3], 5)
.when(col("frequency") >= rfm_quintiles["frequency_quintiles"][2], 4)
.when(col("frequency") >= rfm_quintiles["frequency_quintiles"][1], 3)
.when(col("frequency") >= rfm_quintiles["frequency_quintiles"][0], 2)
.otherwise(1)
).withColumn(
"m_score",
when(col("monetary") >= rfm_quintiles["monetary_quintiles"][3], 5)
.when(col("monetary") >= rfm_quintiles["monetary_quintiles"][2], 4)
.when(col("monetary") >= rfm_quintiles["monetary_quintiles"][1], 3)
.when(col("monetary") >= rfm_quintiles["monetary_quintiles"][0], 2)
.otherwise(1)
)
# Create RFM segments based on scores
rfm_segments = rfm_scored.withColumn(
"rfm_segment",
when((col("r_score") >= 4) & (col("f_score") >= 4) & (col("m_score") >= 4), "Champions")
.when((col("r_score") >= 3) & (col("f_score") >= 3) & (col("m_score") >= 3), "Loyal Customers")
.when((col("r_score") >= 4) & (col("f_score") <= 2), "New Customers")
.when((col("r_score") >= 3) & (col("f_score") <= 2) & (col("m_score") >= 3), "Potential Loyalists")
.when((col("r_score") <= 2) & (col("f_score") >= 3) & (col("m_score") >= 3), "At Risk")
.when((col("r_score") <= 2) & (col("f_score") <= 2) & (col("m_score") >= 3), "Cannot Lose Them")
.when((col("r_score") <= 2) & (col("f_score") <= 2) & (col("m_score") <= 2), "Hibernating")
.otherwise("Others")
).withColumn(
"rfm_score",
concat(col("r_score"), col("f_score"), col("m_score"))
)
# Add business recommendations for each segment
rfm_segments = rfm_segments.withColumn(
"recommended_action",
when(col("rfm_segment") == "Champions", "Reward them. They can become advocates.")
.when(col("rfm_segment") == "Loyal Customers", "Upsell higher value products.")
.when(col("rfm_segment") == "New Customers", "Provide onboarding support.")
.when(col("rfm_segment") == "At Risk", "Send personalized reactivation campaigns.")
.when(col("rfm_segment") == "Cannot Lose Them", "Win them back via renewals or newer products.")
.otherwise("Re-engage with special offers.")
)
self.logger.info("✅ RFM analysis completed")
return rfm_segments
def calculate_cohort_analysis(self, enriched_df: DataFrame) -> DataFrame:
""" 📈 Cohort Analysis - Understanding Customer Retention Over Time
Cohort analysis tracks groups of customers over time to understand:
- How customer behavior changes after acquisition
- Which acquisition channels produce the best long-term customers
- What the natural customer lifecycle looks like
This is essential for:
- Calculating accurate customer lifetime value
- Optimizing marketing spend
- Identifying retention issues early
"""
self.logger.info("📈 Performing cohort analysis...")
# Step 1: Identify each customer's first purchase (cohort assignment)
customer_cohorts = enriched_df.groupBy("customer_id").agg(
min("transaction_date").alias("first_purchase_date")
).withColumn(
"cohort_month",
date_format(col("first_purchase_date"), "yyyy-MM")
)
# Step 2: Join back to get cohort info for all transactions
cohort_data = enriched_df.join(customer_cohorts, "customer_id")
# Step 3: Calculate period number (months since first purchase)
cohort_data = cohort_data.withColumn(
"transaction_month",
date_format(col("transaction_date"), "yyyy-MM")
).withColumn(
"period_number",
months_between(col("transaction_date"), col("first_purchase_date"))
)
# Step 4: Calculate cohort metrics
cohort_metrics = cohort_data.groupBy("cohort_month", "period_number").agg(
countDistinct("customer_id").alias("customers_in_period"),
sum("total_amount").alias("revenue_in_period"),
avg("total_amount").alias("avg_order_value_in_period")
)
# Step 5: Calculate cohort sizes (customers in month 0)
cohort_sizes = cohort_metrics.filter(col("period_number") == 0) \
.select("cohort_month",
col("customers_in_period").alias("cohort_size"))
# Step 6: Calculate retention rates
cohort_retention = cohort_metrics.join(cohort_sizes, "cohort_month") \
.withColumn(
"retention_rate",
(col("customers_in_period") / col("cohort_size")) * 100
)
self.logger.info("✅ Cohort analysis completed")
return cohort_retention
def market_basket_analysis(self, enriched_df: DataFrame, min_support: float = 0.01) -> DataFrame:
""" 🛒 Market Basket Analysis - What Products Go Together?
This analysis discovers which products are frequently bought together, enabling:
- Cross-selling recommendations ("Customers who bought X also bought Y")
- Store layout optimization
- Bundle pricing strategies
- Inventory planning
We use the Apriori algorithm concept to find frequent itemsets.
"""
self.logger.info("🛒 Performing market basket analysis...")
# Step 1: Create transaction baskets (products bought together)
# Group by customer and date to identify shopping sessions
transaction_baskets = enriched_df.groupBy("customer_id", "transaction_date").agg(
collect_list("product_id").alias("products_in_basket"),
sum("total_amount").alias("basket_value")
).withColumn(
"basket_size",
size(col("products_in_basket"))
)
# Step 2: Find frequent product pairs
# This is a simplified version - in production, you'd use MLlib's FPGrowth
product_pairs = transaction_baskets.filter(col("basket_size") >= 2) \
.select("customer_id", "transaction_date",
explode(col("products_in_basket")).alias("product_1")) \
.join(
transaction_baskets.filter(col("basket_size") >= 2) \
.select("customer_id", "transaction_date",
explode(col("products_in_basket")).alias("product_2")),
["customer_id", "transaction_date"]
).filter(col("product_1") < col("product_2")) # Avoid duplicates
# Step 3: Calculate support and confidence
total_transactions = transaction_baskets.count()
pair_counts = product_pairs.groupBy("product_1", "product_2").agg(
count("*").alias("pair_count")
).withColumn(
"support",
col("pair_count") / total_transactions
).filter(col("support") >= min_support)
# Step 4: Add product information for interpretability
from src.data_processing.transformations import RetailDataTransformer
# Get product details
product_info = enriched_df.select("product_id", "category", "brand").distinct()
market_basket_results = pair_counts.join(
product_info.withColumnRenamed("product_id", "product_1")
.withColumnRenamed("category", "category_1")
.withColumnRenamed("brand", "brand_1"),
"product_1"
).join(
product_info.withColumnRenamed("product_id", "product_2")
.withColumnRenamed("category", "category_2")
.withColumnRenamed("brand", "brand_2"),
"product_2"
).orderBy(col("support").desc())
self.logger.info("✅ Market basket analysis completed")
return market_basket_results
def calculate_churn_indicators(self, customer_metrics_df: DataFrame) -> DataFrame:
""" ⚠️ Churn Risk Prediction Indicators
Identifying customers at risk of churning is crucial for retention efforts. This function creates early warning indicators based on:
- Purchase recency patterns
- Frequency changes
- Spending behavior shifts
These indicators help prioritize retention campaigns and interventions.
"""
self.logger.info("⚠️ Calculating churn risk indicators...")
# Define churn risk factors
churn_indicators = customer_metrics_df.withColumn(
"days_since_last_purchase_risk",
when(col("days_since_last_purchase") > 90, "High")
.when(col("days_since_last_purchase") > 60, "Medium")
.when(col("days_since_last_purchase") > 30, "Low")
.otherwise("Very Low")
).withColumn(
"frequency_risk",
when(col("total_orders") == 1, "High")
.when(col("total_orders") <= 3, "Medium")
.otherwise("Low")
).withColumn(
"engagement_risk",
when(col("categories_explored") == 1, "High")
.when(col("categories_explored") <= 2, "Medium")
.otherwise("Low")
)
# Calculate overall churn risk score
churn_indicators = churn_indicators.withColumn(
"churn_risk_score",
when(col("days_since_last_purchase_risk") == "High", 3).otherwise(0) +
when(col("frequency_risk") == "High", 2).otherwise(0) +
when(col("engagement_risk") == "High", 1).otherwise(0)
).withColumn(
"churn_risk_level",
when(col("churn_risk_score") >= 4, "Critical")
.when(col("churn_risk_score") >= 2, "High")
.when(col("churn_risk_score") >= 1, "Medium")
.otherwise("Low")
)
# Add recommended interventions
churn_indicators = churn_indicators.withColumn(
"recommended_intervention",
when(col("churn_risk_level") == "Critical", "Immediate personal outreach + special offer")
.when(col("churn_risk_level") == "High", "Targeted email campaign + discount")
.when(col("churn_risk_level") == "Medium", "Newsletter + product recommendations")
.otherwise("Standard marketing communications")
)
self.logger.info("✅ Churn risk indicators calculated")
return churn_indicators
☁️ S3 Integration & Data Lake Management
Now let’s build a robust S3 integration that handles the complexities of cloud storage! 🌩️ This isn’t just about uploading files — it’s about creat ing a scalable, reliable data lake architecture.
Copy# src/data_storage/s3_manager.py
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
from pyspark.sql import DataFrame
from src.utils.logger import ETLLogger
from src.config.config import S3Config
from typing import Optional, List, Dict
import json
from datetime import datetime
import time
class S3DataManager:
""" ☁️ Production-Grade S3 Data Lake Manager
This class handles all S3 operations with enterprise-grade features:
- Automatic retry logic for transient failures
- Partitioning strategies for optimal query performance
- Data lifecycle management
- Cost optimization through intelligent storage classes
- Comprehensive error handling and logging
Why S3 for Data Lakes?
- Virtually unlimited scalability
- Cost-effective storage with multiple tiers
- Integration with analytics tools (Athena, Redshift, etc.)
- Built-in durability and availability
"""
def __init__(self, config: S3Config):
self.config = config
self.logger = ETLLogger(__name__)
self.s3_client = None
self._initialize_s3_client()
def _initialize_s3_client(self):
""" 🔐 Initialize S3 client with proper error handling
This handles various authentication scenarios:
- IAM roles (recommended for production)
- Access keys (for development/testing)
- Cross-account access
"""
try:
if self.config.access_key and self.config.secret_key:
# Explicit credentials (development/testing)
self.s3_client = boto3.client(
's3',
aws_access_key_id=self.config.access_key,
aws_secret_access_key=self.config.secret_key,
region_name=self.config.region
)
self.logger.info("🔑 S3 client initialized with explicit credentials")
else:
# Use IAM role or default credential chain (production)
self.s3_client = boto3.client('s3', region_name=self.config.region)
self.logger.info("🔑 S3 client initialized with default credential chain")
# Test connection
self.s3_client.head_bucket(Bucket=self.config.bucket_name)
self.logger.info(f"✅ Successfully connected to S3 bucket: {self.config.bucket_name}")
except NoCredentialsError:
self.logger.error("❌ AWS credentials not found")
raise
except ClientError as e:
error_code = e.response['Error']['Code']
if error_code == '404':
self.logger.error(f"❌ S3 bucket not found: {self.config.bucket_name}")
else:
self.logger.error(f"❌ S3 connection failed: {str(e)}")
raise
def write_dataframe_to_s3(self, df: DataFrame, s3_path: str,
file_format: str = "parquet", partition_cols: Optional[List[str]] = None,
mode: str = "overwrite") -> bool:
""" 📤 Write DataFrame to S3 with optimizations
This method implements several production best practices:
- Partitioning for query performance
- Compression for storage efficiency
- Atomic writes to prevent partial failures
- Metadata tracking for data lineage
Partitioning Strategy:
- Time-based partitioning (year/month/day) for time-series data
- Category-based partitioning for dimensional data
- Balanced partition sizes (not too many small files)
"""
try:
self.logger.info(f"📤 Writing DataFrame to S3: {s3_path}")
# Add metadata columns for data lineage
df_with_metadata = df.withColumn("etl_timestamp", current_timestamp()) \
.withColumn("etl_job_id", lit(f"job_{int(time.time())}"))
# Configure write operation based on format
writer = df_with_metadata.write.mode(mode)
if file_format.lower() == "parquet":
# Parquet is optimal for analytics workloads
writer = writer.option("compression", "snappy") # Good balance of speed/compression
if partition_cols:
writer = writer.partitionBy(*partition_cols)
writer.parquet(s3_path)
elif file_format.lower() == "delta":
# Delta Lake for ACID transactions (if available)
if partition_cols:
writer = writer.partitionBy(*partition_cols)
writer.format("delta").save(s3_path)
elif file_format.lower() == "csv":
# CSV for compatibility (not recommended for large datasets)
writer.option("header", "true").csv(s3_path)
else:
raise ValueError(f"Unsupported file format: {file_format}")
# Verify write success
if self._verify_s3_write(s3_path):
self.logger.info(f"✅ Successfully wrote data to {s3_path}")
self._log_dataset_metadata(s3_path, df_with_metadata, file_format, partition_cols)
return True
else:
self.logger.error(f"❌ Write verification failed for {s3_path}")
return False
except Exception as e:
self.logger.error(f"❌ Failed to write DataFrame to S3: {str(e)}")
return False
def _verify_s3_write(self, s3_path: str) -> bool:
""" ✅ Verify that data was successfully written to S3
This is crucial for data integrity - we need to ensure that our write operations actually succeeded before marking the job as complete.
"""
try:
# Extract bucket and key from S3 path
path_parts = s3_path.replace("s3://", "").replace("s3a://", "").split("/", 1)
bucket = path_parts[0]
prefix = path_parts[1] if len(path_parts) > 1 else ""
# List objects to verify data exists
response = self.s3_client.list_objects_v2(
Bucket=bucket,
Prefix=prefix,
MaxKeys=1
)
return response.get('KeyCount', 0) > 0
except Exception as e:
self.logger.error(f"❌ S3 write verification failed: {str(e)}")
return False
def _log_dataset_metadata(self, s3_path: str, df: DataFrame,
file_format: str, partition_cols: Optional[List[str]]):
""" 📋 Log dataset metadata for data catalog and lineage tracking
This creates a metadata record that helps with:
- Data discovery and cataloging
- Impact analysis for schema changes
- Compliance and audit requirements
- Performance optimization
"""
try:
metadata = {
"dataset_path": s3_path,
"record_count": df.count(),
"column_count": len(df.columns),
"schema": df.schema.json(),
"file_format": file_format,
"partition_columns": partition_cols or [],
"created_timestamp": datetime.now().isoformat(),
"size_estimate_mb": self._estimate_dataset_size(df)
}
# Save metadata to S3 for data catalog
metadata_path = f"{s3_path}/_metadata/dataset_info.json"
self._upload_json_to_s3(metadata, metadata_path)
self.logger.info(f"📋 Dataset metadata logged: {metadata['record_count']:,} records")
except Exception as e:
self.logger.warning(f"⚠️ Failed to log dataset metadata: {str(e)}")
def _estimate_dataset_size(self, df: DataFrame) -> float:
"""Estimate dataset size in MB"""
try:
# Simple estimation based on row count and column types
row_count = df.count()
avg_row_size = 0
for col_name, col_type in df.dtypes:
if col_type in ["string", "varchar"]:
avg_row_size += 50 # Estimate 50 bytes per string column
elif col_type in ["int", "integer"]:
avg_row_size += 4
elif col_type in ["bigint", "long", "double"]:
avg_row_size += 8
elif col_type in ["timestamp", "date"]:
avg_row_size += 8
else:
avg_row_size += 10 # Default estimate
total_size_bytes = row_count * avg_row_size
return round(total_size_bytes / (1024 * 1024), 2) # Convert to MB
except:
return 0.0
def _upload_json_to_s3(self, data: Dict, s3_path: str):
"""Upload JSON data to S3"""
try:
path_parts = s3_path.replace("s3://", "").replace("s3a://", "").split("/", 1)
bucket = path_parts[0]
key = path_parts[1]
self.s3_client.put_object(
Bucket=bucket,
Key=key,
Body=json.dumps(data, indent=2),
ContentType='application/json'
)
except Exception as e:
self.logger.error(f"❌ Failed to upload JSON to S3: {str(e)}")
def create_data_lake_structure(self):
""" 🏗️ Create organized data lake folder structure
A well-organized data lake is crucial for:
- Data governance and discoverability
- Performance optimization
- Cost management
- Compliance and security
Standard structure:
- raw/: Unprocessed data as received
- processed/: Cleaned and validated data
- curated/: Business-ready analytics datasets
- archive/: Historical data for compliance
"""
folders = [
f"{self.config.raw_data_prefix}customers/",
f"{self.config.raw_data_prefix}products/",
f"{self.config.raw_data_prefix}transactions/",
f"{self.config.processed_data_prefix}enriched_transactions/",
f"{self.config.processed_data_prefix}customer_metrics/",
f"{self.config.processed_data_prefix}product_metrics/",
f"{self.config.curated_data_prefix}business_metrics/",
f"{self.config.curated_data_prefix}rfm_analysis/",
f"{self.config.curated_data_prefix}cohort_analysis/",
"archive/",
"metadata/",
"logs/"
]
for folder in folders:
try:
# Create empty object to establish folder structure
self.s3_client.put_object(
Bucket=self.config.bucket_name,
Key=f"{folder}.keep",
Body=""
)
self.logger.info(f"📁 Created folder: {folder}")
except Exception as e:
self.logger.warning(f"⚠️ Failed to create folder {folder}: {str(e)}")
self.logger.info("🏗️ Data lake structure created successfully")
This is gett ing excit ing! 🚀 We’ve now built a comprehensive foundation with sophisticated data processing, advanced analytics, and robust cloud storage.
🧪 Comprehensive Testing Framework
Testing in data engineering isn’t just about unit tests — it’s about ensuring data quality, pipeline reliability, and business logic correctness. Think of it as your safety net that catches issues before they impact business decisions! 🛡️
Copy# tests/test_framework.py
import pytest
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *
from src.data_processing.transformations import RetailDataTransformer
from src.business_metrics.advanced_analytics import AdvancedBusinessMetrics
from src.data_quality.quality_checker import DataQualityChecker
from src.utils.logger import ETLLogger
from datetime import datetime, timedelta
import pandas as pd
class DataPipelineTestFramework:
""" 🧪 Comprehensive Testing Framework for Data Pipelines
This framework implements multiple testing layers:
1. Unit Tests: Individual function validation
2. Integration Tests: Component interaction testing
3. Data Quality Tests: Business rule validation
4. Performance Tests: Scalability and speed validation
5. End-to-End Tests: Full pipeline validation
Why comprehensive testing matters in data engineering:
- Data bugs are often silent and hard to detect
- Business decisions depend on data accuracy
- Pipeline failures can cascade across systems
- Regulatory compliance requires audit trails
"""
@classmethod
def setup_test_spark_session(cls):
""" ⚡ Create optimized Spark session for testing
Test Spark sessions need special configuration:
- Smaller memory footprint for CI/CD environments
- Deterministic behavior for reproducible tests
- Fast startup and shutdown
"""
return SparkSession.builder \
.appName("DataPipelineTests") \
.master("local[2]") \
.config("spark.sql.shuffle.partitions", "2") \
.config("spark.default.parallelism", "2") \
.config("spark.sql.adaptive.enabled", "false") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.getOrCreate()
@classmethod
def create_test_data(cls, spark):
""" 🎲 Generate deterministic test data
Test data should be:
- Small enough for fast execution
- Comprehensive enough to cover edge cases
- Deterministic for reproducible results
- Representative of real-world scenarios
"""
# Create test customers
customers_data = [
("CUST_001", "John", "Doe", "john@email.com", "2023-01-15", "Premium", "USA"),
("CUST_002", "Jane", "Smith", "jane@email.com", "2023-02-20", "Standard", "Canada"),
("CUST_003", "Bob", "Johnson", "bob@email.com", "2023-03-10", "Basic", "USA"),
("CUST_004", "Alice", "Brown", "alice@email.com", "2023-01-05", "Premium", "UK")
]
customers_schema = StructType([
StructField("customer_id", StringType(), False),
StructField("first_name", StringType(), True),
StructField("last_name", StringType(), True),
StructField("email", StringType(), True),
StructField("registration_date", StringType(), True),
StructField("customer_segment", StringType(), True),
StructField("country", StringType(), True)
])
customers_df = spark.createDataFrame(customers_data, customers_schema) \
.withColumn("registration_date", to_date(col("registration_date")))
# Create test products
products_data = [
("PROD_001", "Laptop Pro", "Electronics", "Computers", "TechBrand", 1200.00, 800.00),
("PROD_002", "Running Shoes", "Sports", "Footwear", "SportsBrand", 150.00, 75.00),
("PROD_003", "Coffee Maker", "Home", "Kitchen", "HomeBrand", 200.00, 120.00),
("PROD_004", "Smartphone", "Electronics", "Mobile", "TechBrand", 800.00, 500.00)
]
products_schema = StructType([
StructField("product_id", StringType(), False),
StructField("product_name", StringType(), True),
StructField("category", StringType(), True),
StructField("subcategory", StringType(), True),
StructField("brand", StringType(), True),
StructField("price", DoubleType(), True),
StructField("cost", DoubleType(), True)
])
products_df = spark.createDataFrame(products_data, products_schema)
# Create test transactions
transactions_data = [
("TXN_001", "CUST_001", "PROD_001", 1, 1200.00, 0.00, 1200.00, "2023-06-15 10:30:00", "Completed"),
("TXN_002", "CUST_001", "PROD_002", 2, 150.00, 20.00, 280.00, "2023-06-20 14:15:00", "Completed"),
("TXN_003", "CUST_002", "PROD_003", 1, 200.00, 10.00, 190.00, "2023-06-18 09:45:00", "Completed"),
("TXN_004", "CUST_003", "PROD_004", 1, 800.00, 50.00, 750.00, "2023-06-22 16:20:00", "Completed"),
("TXN_005", "CUST_001", "PROD_003", 1, 200.00, 0.00, 200.00, "2023-07-01 11:00:00", "Completed")
]
transactions_schema = StructType([
StructField("transaction_id", StringType(), False),
StructField("customer_id", StringType(), False),
StructField("product_id", StringType(), False),
StructField("quantity", IntegerType(), True),
StructField("unit_price", DoubleType(), True),
StructField("discount_amount", DoubleType(), True),
StructField("total_amount", DoubleType(), True),
StructField("transaction_date", StringType(), True),
StructField("status", StringType(), True)
])
transactions_df = spark.createDataFrame(transactions_data, transactions_schema) \
.withColumn("transaction_date", to_timestamp(col("transaction_date")))
return customers_df, products_df, transactions_df
class TestDataTransformations(unittest.TestCase):
""" 🔧 Unit Tests for Data Transformations
These tests validate individual transformation functions:
- Input/output data types
- Business logic correctness
- Edge case handling
- Performance characteristics
"""
@classmethod
def setUpClass(cls):
"""Set up test environment once for all tests"""
cls.spark = DataPipelineTestFramework.setup_test_spark_session()
cls.transformer = RetailDataTransformer(cls.spark)
cls.logger = ETLLogger(__name__)
# Create test data
cls.customers_df, cls.products_df, cls.transactions_df = \
DataPipelineTestFramework.create_test_data(cls.spark)
@classmethod
def tearDownClass(cls):
"""Clean up test environment"""
cls.spark.stop()
def test_enriched_transactions_creation(self):
""" ✅ Test enriched transaction creation
This test validates:
- All expected columns are present
- Join operations work correctly
- Calculated fields are accurate
- No data loss during transformation
"""
self.logger.info("🧪 Testing enriched transactions creation...")
# Execute transformation
enriched_df = self.transformer.create_enriched_transactions(
self.transactions_df, self.customers_df, self.products_df
)
# Validate schema
expected_columns = [
"transaction_id", "customer_id", "product_id", "total_amount",
"customer_segment", "category", "profit_amount", "profit_margin"
]
for col in expected_columns:
self.assertIn(col, enriched_df.columns, f"Missing column: {col}")
# Validate data integrity
original_count = self.transactions_df.count()
enriched_count = enriched_df.count()
self.assertEqual(original_count, enriched_count, "Row count mismatch after enrichment")
# Validate business logic
profit_check = enriched_df.select("profit_amount", "total_amount", "quantity", "cost").collect()
for row in profit_check:
if row["cost"] and row["quantity"]:
expected_profit = row["total_amount"] - (row["quantity"] * row["cost"])
self.assertAlmostEqual(row["profit_amount"], expected_profit, places=2, msg="Profit calculation incorrect")
self.logger.info("✅ Enriched transactions test passed")
def test_customer_lifetime_metrics(self):
""" 💎 Test customer lifetime value calculations
This validates the complex CLV algorithm:
- Mathematical accuracy of calculations
- Handling of edge cases (single purchase customers)
- Segment assignment logic
"""
self.logger.info("🧪 Testing customer lifetime metrics...")
# Create enriched data first
enriched_df = self.transformer.create_enriched_transactions(
self.transactions_df, self.customers_df, self.products_df
)
# Calculate CLV metrics
clv_df = self.transformer.calculate_customer_lifetime_metrics(enriched_df)
# Validate that all customers are included
unique_customers = self.transactions_df.select("customer_id").distinct().count()
clv_customers = clv_df.count()
self.assertEqual(unique_customers, clv_customers, "Customer count mismatch in CLV calculation")
# Validate CLV calculation for known customer
cust_001_metrics = clv_df.filter(col("customer_id") == "CUST_001").collect()[0]
# CUST_001 has 3 transactions: $1200, $280, $200 = $1680 total
self.assertAlmostEqual(cust_001_metrics["total_spent"], 1680.0, places=2)
self.assertEqual(cust_001_metrics["total_orders"], 3)
# Validate segment assignment
self.assertIn(cust_001_metrics["clv_segment"],
["Champions", "Loyal Customers", "Potential Loyalists", "New Customers"])
self.logger.info("✅ Customer lifetime metrics test passed")
def test_data_quality_edge_cases(self):
""" 🎯 Test edge cases and data quality scenarios
Edge cases that can break production systems:
- Null values in critical fields
- Negative amounts
- Future dates
- Duplicate records
"""
self.logger.info("🧪 Testing edge cases...")
# Create edge case data
edge_case_data = [
("TXN_EDGE_001", "CUST_001", "PROD_001", 1, 100.0, 0.0, -50.0, "2025-01-01 10:00:00", "Completed"), # Negative amount
("TXN_EDGE_002", None, "PROD_002", 1, 200.0, 0.0, 200.0, "2023-06-15 10:00:00", "Completed"), # Null customer
("TXN_EDGE_003", "CUST_002", "PROD_003", 0, 150.0, 0.0, 0.0, "2023-06-15 10:00:00", "Completed") # Zero quantity
]
edge_schema = StructType([
StructField("transaction_id", StringType(), False),
StructField("customer_id", StringType(), True), # Allow nulls for testing
StructField("product_id", StringType(), False),
StructField("quantity", IntegerType(), True),
StructField("unit_price", DoubleType(), True),
StructField("discount_amount", DoubleType(), True),
StructField("total_amount", DoubleType(), True),
StructField("transaction_date", StringType(), True),
StructField("status", StringType(), True)
])
edge_df = self.spark.createDataFrame(edge_case_data, edge_schema) \
.withColumn("transaction_date", to_timestamp(col("transaction_date")))
# Test data quality checker
from src.config.config import DataQualityConfig
quality_checker = DataQualityChecker(DataQualityConfig())
passed, report = quality_checker.run_quality_checks(edge_df, "edge_case_test")
# Should fail due to null customer_id
self.assertFalse(passed, "Quality check should fail for edge case data")
self.assertIn("null_values", report["checks"])
self.logger.info("✅ Edge cases test passed")
class TestBusinessMetrics(unittest.TestCase):
""" 📊 Integration Tests for Business Metrics
These tests validate complex business logic:
- RFM analysis accuracy
- Cohort analysis calculations
- Market basket analysis results
"""
@classmethod
def setUpClass(cls):
cls.spark = DataPipelineTestFramework.setup_test_spark_session()
cls.analytics = AdvancedBusinessMetrics(cls.spark)
cls.logger = ETLLogger(__name__)
# Create enriched test data
transformer = RetailDataTransformer(cls.spark)
customers_df, products_df, transactions_df = \
DataPipelineTestFramework.create_test_data(cls.spark)
cls.enriched_df = transformer.create_enriched_transactions(
transactions_df, customers_df, products_df
)
@classmethod
def tearDownClass(cls):
cls.spark.stop()
def test_rfm_analysis(self):
""" 🎯 Test RFM analysis implementation
RFM is critical for customer segmentation, so we need to ensure:
- Score calculations are mathematically correct
- Segment assignments follow business rules
- Edge cases are handled properly
"""
self.logger.info("🧪 Testing RFM analysis...")
rfm_df = self.analytics.perform_rfm_analysis(self.enriched_df)
# Validate RFM scores are in valid range (1-5)
rfm_scores = rfm_df.select("r_score", "f_score", "m_score").collect()
for row in rfm_scores:
self.assertGreaterEqual(row["r_score"], 1)
self.assertLessEqual(row["r_score"], 5)
self.assertGreaterEqual(row["f_score"], 1)
self.assertLessEqual(row["f_score"], 5)
self.assertGreaterEqual(row["m_score"], 1)
self.assertLessEqual(row["m_score"], 5)
# Validate segment assignment
segments = rfm_df.select("rfm_segment").distinct().collect()
valid_segments = [
"Champions", "Loyal Customers", "New Customers",
"Potential Loyalists", "At Risk", "Cannot Lose Them",
"Hibernating", "Others"
]
for segment in segments:
self.assertIn(segment["rfm_segment"], valid_segments)
# Validate that CUST_001 (highest value customer) gets appropriate segment
cust_001_rfm = rfm_df.filter(col("customer_id") == "CUST_001").collect()[0]
self.assertIn(cust_001_rfm["rfm_segment"],
["Champions", "Loyal Customers", "Potential Loyalists"])
self.logger.info("✅ RFM analysis test passed")
def test_cohort_analysis(self):
""" 📈 Test cohort analysis calculations
Cohort analysis is complex because it involves:
- Time-based grouping
- Retention rate calculations
- Period-over-period comparisons
"""
self.logger.info("🧪 Testing cohort analysis...")
cohort_df = self.analytics.calculate_cohort_analysis(self.enriched_df)
# Validate cohort structure
self.assertIn("cohort_month", cohort_df.columns)
self.assertIn("period_number", cohort_df.columns)
self.assertIn("retention_rate", cohort_df.columns)
# Validate retention rates are percentages (0-100)
retention_rates = cohort_df.select("retention_rate").collect()
for row in retention_rates:
if row["retention_rate"] is not None:
self.assertGreaterEqual(row["retention_rate"], 0)
self.assertLessEqual(row["retention_rate"], 100)
# Validate that period 0 has 100% retention (by definition)
period_0_retention = cohort_df.filter(col("period_number") == 0) \
.select("retention_rate").collect()
for row in period_0_retention:
self.assertAlmostEqual(row["retention_rate"], 100.0, places=1)
self.logger.info("✅ Cohort analysis test passed")
class TestPerformance(unittest.TestCase):
""" ⚡ Performance Tests
These tests ensure the pipeline can handle production workloads:
- Processing time benchmarks
- Memory usage validation
- Scalability testing
"""
@classmethod
def setUpClass(cls):
cls.spark = DataPipelineTestFramework.setup_test_spark_session()
cls.logger = ETLLogger(__name__)
@classmethod
def tearDownClass(cls):
cls.spark.stop()
def test_large_dataset_processing(self):
""" 📊 Test processing performance with larger datasets
This simulates production-scale data to ensure:
- Processing completes within acceptable time limits
- Memory usage stays within bounds
- No performance degradation with data growth
"""
self.logger.info("🧪 Testing large dataset processing...")
import time
# Generate larger test dataset (10,000 transactions)
large_transactions = []
for i in range(10000):
large_transactions.append((
f"TXN_{i:06d}",
f"CUST_{i % 1000:03d}", # 1000 unique customers
f"PROD_{i % 100:03d}", # 100 unique products
1,
100.0 + (i % 500), # Varying prices
i % 50, # Varying discounts
100.0 + (i % 500) - (i % 50),
f"2023-{(i % 12) + 1:02d}-{(i % 28) + 1:02d} 10:00:00",
"Completed"
))
schema = StructType([
StructField("transaction_id", StringType(), False),
StructField("customer_id", StringType(), False),
StructField("product_id", StringType(), False),
StructField("quantity", IntegerType(), True),
StructField("unit_price", DoubleType(), True),
StructField("discount_amount", DoubleType(), True),
StructField("total_amount", DoubleType(), True),
StructField("transaction_date", StringType(), True),
StructField("status", StringType(), True)
])
large_df = self.spark.createDataFrame(large_transactions, schema) \
.withColumn("transaction_date", to_timestamp(col("transaction_date")))
# Test processing time
start_time = time.time()
# Perform aggregation (common operation)
result = large_df.groupBy("customer_id").agg(
sum("total_amount").alias("total_spent"),
count("transaction_id").alias("transaction_count")
).collect()
processing_time = time.time() - start_time
# Validate results
self.assertEqual(len(result), 1000) # Should have 1000 unique customers
self.assertLess(processing_time, 30) # Should complete within 30 seconds
self.logger.info(f"✅ Large dataset test passed - processed 10K records in {processing_time:.2f}s")
🚨 Advanced Error Handling & Recovery
Error handling in data pipelines isn’t just about try-catch blocks — it’s about building resilient systems that can recover gracefully from failures! 🛡️
Copy# src/error_handling/pipeline_resilience.py
import time
import functools
from typing import Callable, Any, Optional, Dict, List
from enum import Enum
from dataclasses import dataclass
from src.utils.logger import ETLLogger
import traceback
import json
class ErrorSeverity(Enum):
""" 🚨 Error severity levels for intelligent handling
Different errors require different responses:
- CRITICAL: Stop everything, alert on-call team
- HIGH: Retry with backoff, escalate if persistent
- MEDIUM: Retry limited times, log for investigation
- LOW: Log and continue, handle gracefully
"""
CRITICAL = "CRITICAL"
HIGH = "HIGH"
MEDIUM = "MEDIUM"
LOW = "LOW"
@dataclass
class ErrorContext:
"""📋 Comprehensive error context for debugging and recovery"""
error_type: str
error_message: str
severity: ErrorSeverity
component: str
timestamp: str
stack_trace: str
data_context: Dict[str, Any]
retry_count: int = 0
recovery_suggestions: List[str] = None
class CircuitBreaker:
""" ⚡ Circuit Breaker Pattern Implementation
Prevents cascade failures by temporarily stopping calls to failing services.
States:
- CLOSED: Normal operation, calls pass through
- OPEN: Failure threshold reached, calls fail fast
- HALF_OPEN: Testing if service has recovered
This is crucial for:
- Preventing resource exhaustion
- Allowing failing services time to recover
- Maintaining system stability under stress
"""
def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.failure_count = 0
self.last_failure_time = None
self.state = "CLOSED"
self.logger = ETLLogger(__name__)
def call(self, func: Callable, *args, **kwargs) -> Any:
"""Execute function with circuit breaker protection"""
if self.state == "OPEN":
if self._should_attempt_reset():
self.state = "HALF_OPEN"
self.logger.info("🔄 Circuit breaker moving to HALF_OPEN state")
else:
raise Exception("Circuit breaker is OPEN - failing fast")
try:
result = func(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure()
raise e
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt recovery"""
if self.last_failure_time is None:
return True
return time.time() - self.last_failure_time >= self.recovery_timeout
def _on_success(self):
"""Handle successful call"""
if self.state == "HALF_OPEN":
self.state = "CLOSED"
self.failure_count = 0
self.logger.info("✅ Circuit breaker reset to CLOSED state")
def _on_failure(self):
"""Handle failed call"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = "OPEN"
self.logger.error(f"🚨 Circuit breaker OPEN - {self.failure_count} failures")
def retry_with_exponential_backoff(max_retries: int = 3, base_delay: float = 1.0,
max_delay: float = 60.0, backoff_factor: float = 2.0):
""" 🔄 Exponential Backoff Retry Decorator
Implements intelligent retry logic with exponential backoff:
- First retry: wait 1 second
- Second retry: wait 2 seconds
- Third retry: wait 4 seconds
- etc.
This prevents overwhelming failing services while giving them time to recover.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
logger = ETLLogger(func.__name__)
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_retries:
logger.error(f"❌ Function {func.__name__} failed after {max_retries} retries: {str(e)}")
raise e
# Calculate delay with exponential backoff
delay = min(base_delay * (backoff_factor ** attempt), max_delay)
logger.warning(f"⚠️ Attempt {attempt + 1} failed for {func.__name__}, retrying in {delay:.1f}s: {str(e)}")
time.sleep(delay)
return wrapper
return decorator
class PipelineErrorHandler:
""" 🛡️ Centralized Error Handling System
This class provides comprehensive error handling capabilities:
- Error classification and routing
- Recovery strategy execution
- Error reporting and alerting
- Graceful degradation options
"""
def __init__(self):
self.logger = ETLLogger(__name__)
self.error_history: List[ErrorContext] = []
self.circuit_breakers: Dict[str, CircuitBreaker] = {}
self.recovery_strategies = {
"data_source_unavailable": self._handle_data_source_error,
"memory_error": self._handle_memory_error,
"network_timeout": self._handle_network_error,
"data_quality_failure": self._handle_data_quality_error,
"s3_access_error": self._handle_s3_error
}
def handle_error(self, error: Exception, component: str,
data_context: Dict[str, Any] = None) -> ErrorContext:
""" 🎯 Main error handling entry point
This method:
1. Classifies the error type and severity
2. Creates comprehensive error context
3. Determines appropriate recovery strategy
4. Executes recovery if possible
5. Logs and reports the incident
"""
error_context = self._create_error_context(error, component, data_context)
self.error_history.append(error_context)
self.logger.error(f"🚨 Error in {component}: {error_context.error_message}")
# Attempt recovery based on error type
recovery_attempted = self._attempt_recovery(error_context)
# Alert if critical or recovery failed
if error_context.severity == ErrorSeverity.CRITICAL or not recovery_attempted:
self._send_alert(error_context)
return error_context
def _create_error_context(self, error: Exception, component: str,
data_context: Dict[str, Any] = None) -> ErrorContext:
"""Create comprehensive error context for analysis"""
error_type = type(error).__name__
severity = self._classify_error_severity(error, component)
return ErrorContext(
error_type=error_type,
error_message=str(error),
severity=severity,
component=component,
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
stack_trace=traceback.format_exc(),
data_context=data_context or {},
recovery_suggestions=self._get_recovery_suggestions(error_type)
)
def _classify_error_severity(self, error: Exception, component: str) -> ErrorSeverity:
""" 🎯 Intelligent error classification
Classification rules:
- OutOfMemoryError: CRITICAL (can crash entire pipeline)
- S3 access errors: HIGH (affects data persistence)
- Data quality issues: MEDIUM (affects accuracy but not availability)
- Network timeouts: LOW (usually transient)
"""
error_type = type(error).__name__
error_message = str(error).lower()
# Critical errors that require immediate attention
if any(keyword in error_message for keyword in ["out of memory", "java heap space", "metaspace"]):
return ErrorSeverity.CRITICAL
# High priority errors
if any(keyword in error_message for keyword in ["s3", "access denied", "connection refused"]):
return ErrorSeverity.HIGH
# Medium priority errors
if any(keyword in error_message for keyword in ["data quality", "schema mismatch", "null pointer"]):
return ErrorSeverity.MEDIUM
# Default to low priority
return ErrorSeverity.LOW
def _attempt_recovery(self, error_context: ErrorContext) -> bool:
""" 🔧 Attempt automatic recovery based on error type
Recovery strategies:
- Memory errors: Reduce partition size, increase parallelism
- Network errors: Retry with backoff
- Data quality errors: Skip bad records, continue processing
- S3 errors: Switch to backup region, retry with different credentials
"""
recovery_strategy = None
# Map error types to recovery strategies
if "memory" in error_context.error_message.lower():
recovery_strategy = self.recovery_strategies.get("memory_error")
elif "s3" in error_context.error_message.lower():
recovery_strategy = self.recovery_strategies.get("s3_access_error")
elif "timeout" in error_context.error_message.lower():
recovery_strategy = self.recovery_strategies.get("network_timeout")
elif "data quality" in error_context.error_message.lower():
recovery_strategy = self.recovery_strategies.get("data_quality_failure")
if recovery_strategy:
try:
self.logger.info(f"🔧 Attempting recovery for {error_context.component}")
recovery_strategy(error_context)
return True
except Exception as recovery_error:
self.logger.error(f"❌ Recovery failed: {str(recovery_error)}")
return False
return False
def _handle_memory_error(self, error_context: ErrorContext):
"""Recovery strategy for memory-related errors"""
self.logger.info("🧠 Implementing memory error recovery...")
suggestions = [
"Increase executor memory configuration",
"Reduce partition size to process smaller chunks",
"Enable dynamic allocation",
"Use more efficient data formats (Parquet vs CSV)",
"Implement data sampling for large datasets"
]
error_context.recovery_suggestions = suggestions
def _handle_s3_error(self, error_context: ErrorContext):
"""Recovery strategy for S3 access errors"""
self.logger.info("☁️ Implementing S3 error recovery...")
suggestions = [
"Check AWS credentials and permissions",
"Verify S3 bucket exists and is accessible",
"Try alternative S3 endpoint or region",
"Implement retry logic with exponential backoff",
"Use S3 Transfer Acceleration if available"
]
error_context.recovery_suggestions = suggestions
def _handle_data_quality_error(self, error_context: ErrorContext):
"""Recovery strategy for data quality issues"""
self.logger.info("🎯 Implementing data quality error recovery...")
suggestions = [
"Skip records that fail validation",
"Apply data cleansing rules",
"Use default values for missing data",
"Quarantine bad data for manual review",
"Adjust quality thresholds temporarily"
]
error_context.recovery_suggestions = suggestions
def _handle_network_error(self, error_context: ErrorContext):
"""Recovery strategy for network-related errors"""
self.logger.info("🌐 Implementing network error recovery...")
# Implement retry with circuit breaker
component = error_context.component
if component not in self.circuit_breakers:
self.circuit_breakers[component] = CircuitBreaker()
def _get_recovery_suggestions(self, error_type: str) -> List[str]:
"""Get contextual recovery suggestions based on error type"""
suggestions_map = {
"OutOfMemoryError": [
"Increase Spark executor memory",
"Reduce data partition size",
"Use more efficient serialization"
],
"ConnectionError": [
"Check network connectivity",
"Verify service endpoints",
"Implement retry logic"
],
"FileNotFoundError": [
"Verify file paths and permissions",
"Check if data source is available",
"Implement fallback data sources"
]
}
return suggestions_map.get(error_type, ["Contact system administrator"])
def _send_alert(self, error_context: ErrorContext):
""" 📢 Send alerts for critical errors
In production, this would integrate with:
- Slack/Teams for immediate notifications
- PagerDuty for on-call escalation
- Email for detailed error reports
- Monitoring dashboards for visibility
"""
alert_message = {
"severity": error_context.severity.value,
"component": error_context.component,
"error": error_context.error_message,
"timestamp": error_context.timestamp,
"suggestions": error_context.recovery_suggestions
}
self.logger.error(f"🚨 ALERT: {json.dumps(alert_message, indent=2)}")
def get_error_summary(self) -> Dict[str, Any]:
"""Generate error summary for monitoring and reporting"""
if not self.error_history:
return {"total_errors": 0, "summary": "No errors recorded"}
error_counts = {}
severity_counts = {}
for error in self.error_history:
error_counts[error.error_type] = error_counts.get(error.error_type, 0) + 1
severity_counts[error.severity.value] = severity_counts.get(error.severity.value, 0) + 1
return {
"total_errors": len(self.error_history),
"error_types": error_counts,
"severity_distribution": severity_counts,
"most_recent_error": self.error_history[-1].error_message,
"error_rate_trend": self._calculate_error_rate_trend()
}
def _calculate_error_rate_trend(self) -> str:
"""Calculate if error rate is increasing, decreasing, or stable"""
if len(self.error_history) < 10:
return "Insufficient data"
recent_errors = len([e for e in self.error_history[-5:]])
previous_errors = len([e for e in self.error_history[-10:-5]])
if recent_errors > previous_errors:
return "Increasing"
elif recent_errors < previous_errors:
return "Decreasing"
else:
return "Stable"
This is getting really exciting! 🚀 We’ve now built a comprehensive testing framework and bulletproof error handling system.
🎭 Main Pipeline Orchestration — The Conductor
The orchestrator is the brain of our ETL pipeline — it coordinates all components, manages dependencies, handles failures gracefully, and ensures data flows smoothly from source to dest ination! 🧠
Copy# src/pipeline/main_orchestrator.py
from pyspark.sql import SparkSession
from src.config.config import ConfigManager, SparkConfig, S3Config, DataQualityConfig
from src.utils.spark_utils import SparkSessionManager
from src.data_ingestion.data_reader import DataReader
from src.data_processing.transformations import RetailDataTransformer
from src.business_metrics.advanced_analytics import AdvancedBusinessMetrics
from src.data_quality.quality_checker import DataQualityChecker
from src.data_storage.s3_manager import S3DataManager
from src.error_handling.pipeline_resilience import PipelineErrorHandler, retry_with_exponential_backoff
from src.utils.logger import ETLLogger
from typing import Dict, List, Optional, Tuple
import time
from datetime import datetime
import json
import os
class RetailETLOrchestrator:
""" 🎭 The Master Orchestrator - Conducting the Data Symphony
This class is the heart of our ETL pipeline, responsible for:
- Coordinating all pipeline components
- Managing execution flow and dependencies
- Handling errors and recovery scenarios
- Monitoring performance and data quality
- Ensuring data lineage and audit trails
Think of it as the conductor of an orchestra - each component is an instrument, and the orchestrator ensures they all play in harmony to create beautiful music (insights)!
"""
def __init__(self, config_path: str = "src/config/pipeline_config.yaml"):
self.logger = ETLLogger(__name__)
self.config_manager = ConfigManager(config_path)
self.error_handler = PipelineErrorHandler()
# Initialize core components
self.spark_manager = SparkSessionManager(self.config_manager)
self.spark = self.spark_manager.get_session()
# Initialize pipeline components
self.data_reader = DataReader(self.spark_manager)
self.transformer = RetailDataTransformer(self.spark)
self.analytics = AdvancedBusinessMetrics(self.spark)
self.quality_checker = DataQualityChecker(self.config_manager.data_quality_config)
self.s3_manager = S3DataManager(self.config_manager.s3_config)
# Pipeline state tracking
self.pipeline_state = {
"start_time": None,
"end_time": None,
"status": "INITIALIZED",
"processed_records": {},
"quality_reports": {},
"errors": [],
"performance_metrics": {}
}
self.logger.info("🎭 ETL Orchestrator initialized successfully")
def execute_full_pipeline(self, execution_date: Optional[str] = None) -> Dict:
""" 🚀 Execute the complete ETL pipeline
This is the main entry point that orchestrates the entire data processing workflow:
1. 📥 Data Ingestion: Read from multiple sources
2. 🧹 Data Quality: Validate and clean data
3. 🔄 Transformation: Apply business logic
4. 📊 Analytics: Generate business metrics
5. ☁️ Storage: Persist to data lake
6. 📋 Reporting: Generate execution summary
The pipeline is designed to be:
- Idempotent: Can be run multiple times safely
- Resumable: Can restart from failure points
- Auditable: Full lineage and quality tracking
"""
self.pipeline_state["start_time"] = datetime.now()
self.pipeline_state["status"] = "RUNNING"
try:
self.logger.info("🚀 Starting full ETL pipeline execution...")
# Step 1: Data Ingestion
raw_datasets = self._execute_data_ingestion()
# Step 2: Data Quality Validation
validated_datasets = self._execute_data_quality_checks(raw_datasets)
# Step 3: Data Transformation
transformed_datasets = self._execute_transformations(validated_datasets)
# Step 4: Business Analytics
analytics_datasets = self._execute_business_analytics(transformed_datasets)
# Step 5: Data Persistence
self._execute_data_persistence(analytics_datasets)
# Step 6: Pipeline Completion
self._finalize_pipeline_execution()
self.logger.info("✅ ETL pipeline completed successfully!")
return self._generate_execution_summary()
except Exception as e:
return self._handle_pipeline_failure(e)
@retry_with_exponential_backoff(max_retries=3, base_delay=2.0)
def _execute_data_ingestion(self) -> Dict[str, any]:
""" 📥 Data Ingestion Phase
This phase reads data from various sources with robust error handling:
- File-based sources (CSV, Parquet, JSON)
- Database sources (JDBC connections)
- API sources (REST endpoints)
- Streaming sources (Kafka, Kinesis)
Key features:
- Schema validation on ingestion
- Data freshness checks
- Source availability monitoring
- Automatic retry for transient failures
"""
self.logger.info("📥 Executing data ingestion phase...")
try:
datasets = {}
# Read customers data
self.logger.info("👥 Reading customer data...")
customers_df = self.data_reader.read_csv_with_schema(
"data/raw/customers.csv", "customers"
)
if customers_df is None:
raise Exception("Failed to read customers data")
datasets["customers"] = customers_df
self.pipeline_state["processed_records"]["customers"] = customers_df.count()
# Read products data
self.logger.info("🛍️ Reading product data...")
products_df = self.data_reader.read_csv_with_schema(
"data/raw/products.csv", "products"
)
if products_df is None:
raise Exception("Failed to read products data")
datasets["products"] = products_df
self.pipeline_state["processed_records"]["products"] = products_df.count()
# Read transactions data
self.logger.info("💳 Reading transaction data...")
transactions_df = self.data_reader.read_csv_with_schema(
"data/raw/transactions.csv", "transactions"
)
if transactions_df is None:
raise Exception("Failed to read transactions data")
datasets["transactions"] = transactions_df
self.pipeline_state["processed_records"]["transactions"] = transactions_df.count()
# Validate data freshness
if not self.data_reader.validate_data_freshness(transactions_df, "transaction_date", 48):
self.logger.warning("⚠️ Transaction data may be stale")
self.logger.info("✅ Data ingestion completed successfully")
return datasets
except Exception as e:
error_context = self.error_handler.handle_error(
e, "data_ingestion", {"phase": "ingestion"}
)
self.pipeline_state["errors"].append(error_context)
raise e
def _execute_data_quality_checks(self, datasets: Dict[str, any]) -> Dict[str, any]:
""" 🎯 Data Quality Validation Phase
This is one of the most critical phases! Data quality issues can:
- Lead to incorrect business decisions
- Cause downstream system failures
- Result in compliance violations
- Damage stakeholder trust
Our comprehensive quality checks include:
- Schema validation
- Null value analysis
- Duplicate detection
- Outlier identification
- Business rule validation
- Cross-dataset consistency checks
"""
self.logger.info("🎯 Executing data quality checks...")
try:
validated_datasets = {}
quality_reports = {}
for dataset_name, df in datasets.items():
self.logger.info(f"🔍 Validating {dataset_name} dataset...")
# Run comprehensive quality checks
passed, report = self.quality_checker.run_quality_checks(
df, dataset_name
)
quality_reports[dataset_name] = report
if passed:
self.logger.info(f"✅ {dataset_name} passed quality checks")
validated_datasets[dataset_name] = df
else:
self.logger.warning(f"⚠️ {dataset_name} has quality issues")
# Apply data cleansing based on quality issues
cleaned_df = self._apply_data_cleansing(df, report)
validated_datasets[dataset_name] = cleaned_df
self.logger.info(f"🧹 Applied data cleansing to {dataset_name}")
# Store quality reports for audit
self.pipeline_state["quality_reports"] = quality_reports
self.quality_checker.save_quality_report("logs/quality_report.json")
# Cross-dataset validation
self._validate_cross_dataset_consistency(validated_datasets)
self.logger.info("✅ Data quality validation completed")
return validated_datasets
except Exception as e:
error_context = self.error_handler.handle_error(
e, "data_quality", {"phase": "validation"}
)
self.pipeline_state["errors"].append(error_context)
raise e
def _apply_data_cleansing(self, df, quality_report: Dict) -> any:
""" 🧹 Apply intelligent data cleansing
Based on quality check results, we apply appropriate cleansing:
- Fill missing values with business-appropriate defaults
- Remove or flag duplicate records
- Standardize data formats
- Apply business rules for data correction
"""
from pyspark.sql.functions import when, col, isnan, isnull
cleaned_df = df
# Handle null values based on business rules
if "null_analysis" in quality_report.get("checks", {}):
null_analysis = quality_report["checks"]["null_analysis"]["details"]
for column, stats in null_analysis["null_counts"].items():
if stats["null_percentage"] > 0:
if column in ["customer_segment"]: # Fill missing customer segments with 'Unknown'
cleaned_df = cleaned_df.fillna({"customer_segment": "Unknown"})
elif column in ["country"]: # Fill missing countries with 'Not Specified'
cleaned_df = cleaned_df.fillna({"country": "Not Specified"})
# Remove duplicates if found
if "duplicate_analysis" in quality_report.get("checks", {}):
duplicate_analysis = quality_report["checks"]["duplicate_analysis"]["details"]
if duplicate_analysis["duplicate_count"] > 0:
cleaned_df = cleaned_df.dropDuplicates()
self.logger.info(f"🗑️ Removed {duplicate_analysis['duplicate_count']} duplicate records")
return cleaned_df
def _validate_cross_dataset_consistency(self, datasets: Dict[str, any]):
""" 🔗 Cross-dataset consistency validation
Ensures referential integrity across datasets:
- All customer_ids in transactions exist in customers table
- All product_ids in transactions exist in products table
- Date ranges are consistent across datasets
"""
self.logger.info("🔗 Validating cross-dataset consistency...")
transactions_df = datasets["transactions"]
customers_df = datasets["customers"]
products_df = datasets["products"]
# Check customer referential integrity
transaction_customers = transactions_df.select("customer_id").distinct()
valid_customers = customers_df.select("customer_id").distinct()
orphaned_customers = transaction_customers.subtract(valid_customers).count()
if orphaned_customers > 0:
self.logger.warning(f"⚠️ Found {orphaned_customers} transactions with invalid customer_ids")
# Check product referential integrity
transaction_products = transactions_df.select("product_id").distinct()
valid_products = products_df.select("product_id").distinct()
orphaned_products = transaction_products.subtract(valid_products).count()
if orphaned_products > 0:
self.logger.warning(f"⚠️ Found {orphaned_products} transactions with invalid product_ids")
self.logger.info("✅ Cross-dataset consistency validation completed")
def _execute_transformations(self, datasets: Dict[str, any]) -> Dict[str, any]:
""" 🔄 Data Transformation Phase
This phase applies business logic to transform raw data into analytics-ready datasets:
- Data enrichment through joins
- Calculated field generation
- Data aggregation and summarization
- Time-based partitioning
- Performance optimization
The goal is to create datasets that directly support business questions.
"""
self.logger.info("🔄 Executing data transformations...")
try:
start_time = time.time()
transformed_datasets = {}
# Create enriched transactions (the foundation dataset)
self.logger.info("🔗 Creating enriched transactions...")
enriched_transactions = self.transformer.create_enriched_transactions(
datasets["transactions"],
datasets["customers"],
datasets["products"]
)
# Cache this dataset as it's used by multiple downstream processes
enriched_transactions.cache()
transformed_datasets["enriched_transactions"] = enriched_transactions
# Calculate customer lifetime metrics
self.logger.info("💎 Calculating customer lifetime metrics...")
customer_metrics = self.transformer.calculate_customer_lifetime_metrics(
enriched_transactions
)
transformed_datasets["customer_metrics"] = customer_metrics
# Calculate product performance metrics
self.logger.info("📈 Calculating product performance metrics...")
product_metrics = self.transformer.create_product_performance_metrics(
enriched_transactions
)
transformed_datasets["product_metrics"] = product_metrics
# Create time series analytics
self.logger.info("📅 Creating time series analytics...")
daily_metrics, monthly_metrics = self.transformer.create_time_series_analytics(
enriched_transactions
)
transformed_datasets["daily_metrics"] = daily_metrics
transformed_datasets["monthly_metrics"] = monthly_metrics
# Record performance metrics
transformation_time = time.time() - start_time
self.pipeline_state["performance_metrics"]["transformation_time"] = transformation_time
self.logger.info(f"✅ Data transformations completed in {transformation_time:.2f} seconds")
return transformed_datasets
except Exception as e:
error_context = self.error_handler.handle_error(
e, "data_transformation", {"phase": "transformation"}
)
self.pipeline_state["errors"].append(error_context)
raise e
def _execute_business_analytics(self, datasets: Dict[str, any]) -> Dict[str, any]:
""" 📊 Business Analytics Phase
This phase generates sophisticated business insights:
- RFM customer segmentation
- Cohort analysis for retention tracking
- Market basket analysis for cross-selling
- Churn prediction indicators
- Advanced KPIs and metrics
These analytics directly support strategic business decisions.
"""
self.logger.info("📊 Executing business analytics...")
try:
start_time = time.time()
analytics_datasets = datasets.copy() # Keep existing datasets
enriched_transactions = datasets["enriched_transactions"]
customer_metrics = datasets["customer_metrics"]
# RFM Analysis for customer segmentation
self.logger.info("🎯 Performing RFM analysis...")
rfm_analysis = self.analytics.perform_rfm_analysis(enriched_transactions)
analytics_datasets["rfm_analysis"] = rfm_analysis
# Cohort analysis for retention insights
self.logger.info("📈 Performing cohort analysis...")
cohort_analysis = self.analytics.calculate_cohort_analysis(enriched_transactions)
analytics_datasets["cohort_analysis"] = cohort_analysis
# Market basket analysis for cross-selling opportunities
self.logger.info("🛒 Performing market basket analysis...")
market_basket = self.analytics.market_basket_analysis(enriched_transactions)
analytics_datasets["market_basket_analysis"] = market_basket
# Churn risk indicators
self.logger.info("⚠️ Calculating churn risk indicators...")
churn_indicators = self.analytics.calculate_churn_indicators(customer_metrics)
analytics_datasets["churn_indicators"] = churn_indicators
# Generate executive summary metrics
executive_summary = self._generate_executive_summary(analytics_datasets)
analytics_datasets["executive_summary"] = executive_summary
# Record performance metrics
analytics_time = time.time() - start_time
self.pipeline_state["performance_metrics"]["analytics_time"] = analytics_time
self.logger.info(f"✅ Business analytics completed in {analytics_time:.2f} seconds")
return analytics_datasets
except Exception as e:
error_context = self.error_handler.handle_error(
e, "business_analytics", {"phase": "analytics"}
)
self.pipeline_state["errors"].append(error_context)
raise e
def _generate_executive_summary(self, datasets: Dict[str, any]) -> any:
""" 📋 Generate Executive Summary Dashboard Data
Creates high-level KPIs that executives care about:
- Total revenue and growth rates
- Customer acquisition and retention metrics
- Product performance indicators
- Operational efficiency metrics
"""
from pyspark.sql.functions import sum as spark_sum, count, avg, max as spark_max
enriched_transactions = datasets["enriched_transactions"]
customer_metrics = datasets["customer_metrics"]
rfm_analysis = datasets["rfm_analysis"]
# Calculate key business metrics
business_metrics = enriched_transactions.agg(
spark_sum("total_amount").alias("total_revenue"),
spark_sum("profit_amount").alias("total_profit"),
count("transaction_id").alias("total_transactions"),
countDistinct("customer_id").alias("active_customers"),
countDistinct("product_id").alias("products_sold"),
avg("total_amount").alias("avg_order_value")
).collect()[0]
# Customer segment distribution
segment_distribution = rfm_analysis.groupBy("rfm_segment").count().collect()
# Top performing products
top_products = enriched_transactions.groupBy("product_id", "category") \
.agg(spark_sum("total_amount").alias("revenue")) \
.orderBy(col("revenue").desc()) \
.limit(10).collect()
# Create summary dataset
summary_data = [
{
"metric_name": "Total Revenue",
"metric_value": float(business_metrics["total_revenue"]),
"metric_type": "currency"
},
{
"metric_name": "Total Profit",
"metric_value": float(business_metrics["total_profit"]),
"metric_type": "currency"
},
{
"metric_name": "Active Customers",
"metric_value": int(business_metrics["active_customers"]),
"metric_type": "count"
},
{
"metric_name": "Average Order Value",
"metric_value": float(business_metrics["avg_order_value"]),
"metric_type": "currency"
}
]
summary_schema = StructType([
StructField("metric_name", StringType(), False),
StructField("metric_value", DoubleType(), False),
StructField("metric_type", StringType(), False)
])
return self.spark.createDataFrame(summary_data, summary_schema)
@retry_with_exponential_backoff(max_retries=2, base_delay=5.0)
def _execute_data_persistence(self, datasets: Dict[str, any]):
""" ☁️ Data Persistence Phase
Persists all processed datasets to the data lake with:
- Optimal partitioning strategies
- Compression for storage efficiency
- Metadata tracking for data catalog
- Backup and versioning
"""
self.logger.info("☁️ Executing data persistence...")
try:
start_time = time.time()
# Ensure S3 structure exists
self.s3_manager.create_data_lake_structure()
# Persist datasets with appropriate partitioning
persistence_config = {
"enriched_transactions": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/processed/enriched_transactions/",
"format": "parquet",
"partition_cols": ["year", "month"]
},
"customer_metrics": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/curated/customer_metrics/",
"format": "parquet",
"partition_cols": ["customer_segment"]
},
"product_metrics": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/curated/product_metrics/",
"format": "parquet",
"partition_cols": ["category"]
},
"rfm_analysis": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/curated/rfm_analysis/",
"format": "parquet",
"partition_cols": ["rfm_segment"]
},
"cohort_analysis": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/curated/cohort_analysis/",
"format": "parquet",
"partition_cols": ["cohort_month"]
},
"executive_summary": {
"path": f"s3a://{self.config_manager.s3_config.bucket_name}/curated/executive_summary/",
"format": "parquet",
"partition_cols": None
}
}
for dataset_name, config in persistence_config.items():
if dataset_name in datasets:
self.logger.info(f"💾 Persisting {dataset_name}...")
success = self.s3_manager.write_dataframe_to_s3(
datasets[dataset_name],
config["path"],
config["format"],
config["partition_cols"]
)
if not success:
raise Exception(f"Failed to persist {dataset_name}")
self.logger.info(f"✅ {dataset_name} persisted successfully")
# Record performance metrics
persistence_time = time.time() - start_time
self.pipeline_state["performance_metrics"]["persistence_time"] = persistence_time
self.logger.info(f"✅ Data persistence completed in {persistence_time:.2f} seconds")
except Exception as e:
error_context = self.error_handler.handle_error(
e, "data_persistence", {"phase": "persistence"}
)
self.pipeline_state["errors"].append(error_context)
raise e
def _finalize_pipeline_execution(self):
""" 🏁 Pipeline Finalization
Final steps to complete the pipeline:
- Update pipeline state
- Generate execution summary
- Clean up resources
- Send completion notifications
"""
self.pipeline_state["end_time"] = datetime.now()
self.pipeline_state["status"] = "COMPLETED"
# Calculate total execution time
total_time = (self.pipeline_state["end_time"] - self.pipeline_state["start_time"]).total_seconds()
self.pipeline_state["performance_metrics"]["total_execution_time"] = total_time
# Clean up Spark resources
for dataset_name in ["enriched_transactions"]:
# Cached datasets
try:
# In a real implementation, you'd unpersist cached DataFrames
pass
except:
pass
self.logger.info(f"🏁 Pipeline execution finalized - Total time: {total_time:.2f} seconds")
def _handle_pipeline_failure(self, error: Exception) -> Dict:
""" 💥 Handle complete pipeline failure
When the pipeline fails catastrophically:
- Log comprehensive error information
- Attempt graceful cleanup
- Generate failure report
- Send critical alerts
"""
self.pipeline_state["end_time"] = datetime.now()
self.pipeline_state["status"] = "FAILED"
error_context = self.error_handler.handle_error(
error, "pipeline_orchestrator", {"phase": "complete_pipeline"}
)
self.pipeline_state["errors"].append(error_context)
self.logger.error(f"💥 Pipeline execution failed: {str(error)}")
# Attempt cleanup
try:
self.spark_manager.stop_session()
except:
pass
return self._generate_execution_summary()
def _generate_execution_summary(self) -> Dict:
""" 📊 Generate comprehensive execution summary
Creates a detailed report of pipeline execution including:
- Performance metrics
- Data quality results
- Error summary
- Business metrics overview
"""
summary = {
"pipeline_execution": {
"status": self.pipeline_state["status"],
"start_time": self.pipeline_state["start_time"].isoformat() if self.pipeline_state["start_time"] else None,
"end_time": self.pipeline_state["end_time"].isoformat() if self.pipeline_state["end_time"] else None,
"total_execution_time": self.pipeline_state["performance_metrics"].get("total_execution_time", 0)
},
"data_processing": {
"records_processed": self.pipeline_state["processed_records"],
"performance_metrics": self.pipeline_state["performance_metrics"]
},
"data_quality": {
"quality_reports": self.pipeline_state["quality_reports"]
},
"errors": {
"error_count": len(self.pipeline_state["errors"]),
"error_summary": self.error_handler.get_error_summary()
}
}
# Save execution summary
summary_path = f"logs/execution_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2, default=str)
self.logger.info(f"📊 Execution summary saved to {summary_path}")
return summary
# Main execution script
def main():
""" 🚀 Main execution entry point
This is how the pipeline would be executed in production:
- Can be called from command line
- Integrated with schedulers (Airflow, Cron)
- Triggered by events (file arrival, API calls)
"""
logger = ETLLogger("main")
try:
logger.info("🚀 Starting Retail ETL Pipeline...")
# Initialize orchestrator
orchestrator = RetailETLOrchestrator()
# Execute pipeline
execution_summary = orchestrator.execute_full_pipeline()
# Print summary
logger.info("📊 Pipeline Execution Summary:")
logger.info(f"Status: {execution_summary['pipeline_execution']['status']}")
logger.info(f"Total Time: {execution_summary['pipeline_execution']['total_execution_time']:.2f} seconds")
logger.info(f"Records Processed: {execution_summary['data_processing']['records_processed']}")
if execution_summary['errors']['error_count'] > 0:
logger.warning(f"⚠️ Pipeline completed with {execution_summary['errors']['error_count']} errors")
else:
logger.info("✅ Pipeline completed successfully with no errors!")
return execution_summary
except Exception as e:
logger.error(f"💥 Pipeline execution failed: {str(e)}")
raise e
if __name__ == "__main__":
main()
📚 Complete Documentation System
Documentation is crucial for production systems! Let’s create comprehensive docs that make our pipeline maintainable and auditable! 📖
Copy# docs/documentation_generator.py
import os
import json
from typing import Dict, List
from datetime import datetime
from src.utils.logger import ETLLogger
class DocumentationGenerator:
""" 📚 Automated Documentation Generator
Generates comprehensive documentation for our ETL pipeline:
- API documentation
- User guides
- Troubleshooting guides
- Data dictionary
- Architecture diagrams
- Audit trails
Why automated documentation matters:
- Always up-to-date with code changes
- Consistent formatting and structure
- Reduces maintenance overhead
- Improves team onboarding
- Supports compliance requirements
"""
def __init__(self):
self.logger = ETLLogger(__name__)
self.docs_dir = "docs/generated"
os.makedirs(self.docs_dir, exist_ok=True)
def generate_complete_documentation(self):
"""🎯 Generate all documentation artifacts"""
self.logger.info("📚 Generating complete documentation suite...")
# Generate different types of documentation
self._generate_api_documentation()
self._generate_user_guide()
self._generate_troubleshooting_guide()
self._generate_data_dictionary()
self._generate_architecture_overview()
self._generate_deployment_guide()
self._generate_monitoring_guide()
self.logger.info("✅ Documentation generation completed")
def _generate_api_documentation(self):
"""📋 Generate API documentation"""
api_doc = """# 🚀 Retail ETL Pipeline API Documentation
## Overview
The Retail ETL Pipeline provides a comprehensive data processing solution for e-commerce analytics.
## Main Classes and Methods
### RetailETLOrchestrator
The main orchestrator class that coordinates the entire pipeline.
#### Methods:
**`execute_full_pipeline(execution_date: Optional[str] = None) -> Dict`**
- Executes the complete ETL pipeline
- Parameters:
- `execution_date`: Optional date for processing (defaults to current date)
- Returns: Execution summary dictionary
- Raises: Various exceptions based on failure points
**Example Usage:**
```python
from src.pipeline.main_orchestrator import RetailETLOrchestrator
# Initialize orchestrator
orchestrator = RetailETLOrchestrator()
# Execute pipeline
summary = orchestrator.execute_full_pipeline()
print(f"Pipeline status: {summary['pipeline_execution']['status']}")
"""
DataReader
Handles data ingestion from various sources.
Methods:
read_csv_with_schema(file_path: str, schema_name: str) -> DataFrame
- Reads CSV files with predefined schemas
- Parameters:
- file_path: Path to CSV file
- schema_name: Name of predefined schema
- Returns: Spark DataFrame or None if failed
RetailDataTransformer
Applies business logic transformations.
Methods:
create_enriched_transactions(transactions_df, customers_df, products_df) -> DataFrame
- Creates enriched transaction dataset by joining multiple sources
- Returns: DataFrame with enriched transaction data
calculate_customer_lifetime_metrics(enriched_df) -> DataFrame
- Calculates customer lifetime value and related metrics
- Returns: DataFrame with customer-level metrics
AdvancedBusinessMetrics
Generates sophisticated business analytics.
Methods:
perform_rfm_analysis(enriched_df) -> DataFrame
- Performs RFM (Recency, Frequency, Monetary) customer segmentation
- Returns: DataFrame with RFM scores and segments
calculate_cohort_analysis(enriched_df) -> DataFrame
- Calculates customer cohort retention analysis
- Returns: DataFrame with cohort metrics
Error Handling
The pipeline implements comprehensive error handling:
- Automatic retry with exponential backoff
- Circuit breaker pattern for failing services
- Graceful degradation for non-critical failures
- Comprehensive error logging and alert ing
Configuration
Pipeline behavior is controlled through YAML configuration files:
Copyspark:
app_name: "RetailETLPipeline"
executor_memory: "4g"
driver_memory: "2g"
s3:
bucket_name: "retail-analytics-bucket"
region: "us-east-1"
data_quality:
null_threshold: 0.05
duplicate_threshold: 0.01
Performance Considerations
- Use appropriate partitioning for large datasets
- Cache frequently accessed DataFrames
- Monitor memory usage and adjust Spark configuration
- Use columnar formats (Parquet) for better performance
Security
- Use IAM roles for S3 access in production
- Encrypt data at rest and in transit
- Implement proper access controls
- Audit all data access and modifications “””
with open(f”{self.docs_dir}/api_documentation.md”, “w”) as f: f.write(api_doc)
self.logger.info(“📋 API documentation generated”)
def _generate_user_guide(self): “””👥 Generate user guide”””
user_guide = “””
👥 Retail ETL Pipeline User Guide
Gett ing Started
Prerequisites
- Python 3.8+
- Apache Spark 3.2+
- AWS CLI configured (for S3 access)
- Required Python packages (see requirements.txt)
Installation
- Clone the repository:
Copygit clone <repository-url>
cd retail-etl-pipeline
2. Install dependencies:
Copypip install -r requirements.txt
3. Configure AWS credentials:
Copyaws configure
4. Update configuration: Edit
Copysrc/config/pipeline_config.yaml
with your settings.
Running the Pipeline
Basic Execution
Copypython -m src.pipeline.main_orchestrator
With Custom Configuration
Copypython -m src.pipeline.main_orchestrator --config custom_config.yaml
Scheduled Execution
The pipeline can be scheduled using various tools:
Cron (Linux/Mac):
Copy# Run daily at 2 AM
0 2 * * * /path/to/python /path/to/pipeline/main_orchestrator.py
Apache Airflow:
Copyfrom airflow import DAG
from airflow.operators.python_operator import PythonOperator
def run_etl_pipeline():
from src.pipeline.main_orchestrator import main
return main()
dag = DAG('retail_etl', schedule_interval='@daily')
etl_task = PythonOperator(
task_id='run_etl',
python_callable=run_etl_pipeline,
dag=dag
)
Understanding the Output
Execution Summary
After each run, the pipeline generates an execution summary:
Copy{
"pipeline_execution": {
"status": "COMPLETED",
"start_time": "2023-12-01 T10:00:00",
"end_time": "2023-12-01 T10:15:30",
"total_execution_time": 930.5
},
"data_processing": {
"records_processed": {
"customers": 10000,
"products": 1000,
"transactions": 100000
}
}
}
Generated Datasets
The pipeline creates several analytics-ready datasets:
- Enriched Transactions (
s3://bucket/processed/enriched_transactions/
)
- Complete transaction data with customer and product context
- Partitioned by year/month for optimal querying
2. Customer Metrics (s3://bucket/curated/customer_metrics/
)
- Customer lifetime value, purchase frequency, recency
- Segmented by customer tier
3. RFM Analysis (s3://bucket/curated/rfm_analysis/
)
- Customer segmentation based on RFM methodology
- Actionable customer segments with recommendations
4. Product Performance (s3://bucket/curated/product_metrics/
)
- Product sales, profitability, and performance metrics
- Category-level analysis
Data Quality Reports
Quality reports are generated for each dataset:
- Null value analysis
- Duplicate detection
- Outlier identification
- Cross-dataset consistency checks
Reports are saved to logs/quality_report.json
Monitoring and Alerts
Log Files
- Application logs:
logs/etl_pipeline_YYYYMMDD.log
- Quality reports:
logs/quality_report.json
- Execution summaries:
logs/execution_summary_YYYYMMDD_HHMMSS.json
Key Metrics to Monitor
- Pipeline execution time
- Data quality scores
- Error rates
- Resource utilization
Sett ing Up Alerts
Configure alerts for:
- Pipeline failures
- Data quality degradation
- Performance issues
- Resource constraints
Troubleshoot ing
Common Issues
Pipeline fails with “Out of Memory” error:
- Increase Spark executor memory
- Reduce partition size
- Enable dynamic allocation
S3 access denied errors:
- Check AWS credentials
- Verify IAM permissions
- Ensure bucket exists and is accessible
Data quality failures:
- Review quality reports
- Check source data integrity
- Adjust quality thresholds if appropriate
Getting Help
- Check logs in
logs/
- directory
- Review error messages in execution summary
- Consult troubleshoot ing guide
- Contact data engineering team
Best Practices
Data Management
- Always backup critical data before processing
- Use version control for configuration changes
- Monitor data lineage and dependencies
- Implement proper data retention policies
Performance Optimization
- Partition large datasets appropriately
- Use columnar formats (Parquet) for analytics
- Cache frequently accessed data
- Monitor and tune Spark configuration
Security
- Use IAM roles instead of access keys
- Encrypt sensitive data
- Implement proper access controls
- Audit data access regularly “””
with open(f”{self.docs_dir}/user_guide.md”, “w”) as f: f.write(user_guide)
self.logger.info(“👥 User guide generated”)
def _generate_troubleshoot ing_guide(self): “””🔧 Generate troubleshoot ing guide”””
troubleshoot ing_guide = “””
🔧 Troubleshooting Guide
Common Error Scenarios and Solutions
1. Memory-Related Errors
Symptoms:
OutOfMemoryError: Java heap space
OutOfMemoryError: Metaspace
- Pipeline hangs during processing
Root Causes:
- Insufficient Spark executor memory
- Large datasets without proper partitioning
- Memory leaks in transformations
- Too many small files
Solutions:
Increase Memory Allocation:
Copyspark:
executor_memory: "8g" # Increase from 4g
driver_memory: "4g" # Increase from 2g
executor_cores: 4
Optimize Partitioning:
Copy# Repartition large datasets
df = df.repartition(200) # Adjust based on data size
# Use appropriate partition columns
df.write.partitionBy("year", "month").parquet(path)
Enable Dynamic Allocation:
Copyspark:
dynamicAllocation.enabled: true
dynamicAllocation.m inExecutors: 1
dynamicAllocation.maxExecutors: 10
2. S3 Access Issues
Symptoms:
Access Denied
error.NoSuchBucket
exceptions.- Slow S3 operations.
Root Causes:
- Incorrect AWS credentials
- Missing IAM permissions
- Network connectivity issues
- S3 bucket configuration
Solutions:
Check Credentials:
Copyaws sts get-caller-identity
aws s3 ls s3://your-bucket-name/
Required IAM Permissions:
Copy{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
Optimize S3 Performance:
Copy# Use S3A instead of S3
spark.conf.set("spark.hadoop.fs.s3a.fast.upload", "true")
spark.conf.set("spark.hadoop.fs.s3a.block.size", "134217728") # 128 MB
3. Data Quality Failures
Symptoms:
- High null percentages in quality reports
- Unexpected duplicate records
- Schema validation failures
Root Causes:
- Source data quality issues
- Schema evolution
- Data pipeline bugs
- External system changes
Solutions:
Investigate Source Data:
Copy# Check source data directly
df.describe().show()
df.printSchema()
df.filter(col("important_column").isNull()).count()
Implement Data Cleansing:
Copy# Handle nulls appropriately
df = df.fillna({
"customer_segment": "Unknown",
"country": "Not Specified"
})
# Remove duplicates
df = df.dropDuplicates(["customer_id", "transaction_date"])
Adjust Quality Thresholds:
Copydata_quality:
null_threshold: 0.10 # Allow 10% nulls temporarily
duplicate_threshold: 0.05 # Allow 5% duplicates
4. Performance Issues
Symptoms:
- Pipeline takes too long to complete
- High CPU/memory usage
- Frequent garbage collection
Root Causes:
- Inefficient transformations
- Poor partitioning strategy
- Unnecessary data shuffling
- Suboptimal Spark configuration
Solutions:
Optimize Transformations:
Copy# Use broadcast joins for small datasets
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")
# Cache frequently used DataFrames
df.cache()
df.count() # Trigger caching
Reduce Data Shuffling:
Copy# Partition before joins
df1 = df1.repartition("join_key")
df2 = df2.repartition("join_key")
result = df1.join(df2, "join_key")
Tune Spark Configuration:
Copyspark:
sql.adaptive.enabled: true
sql.adaptive.coalescePartitions.enabled: true
sql.adaptive.skewJoin.enabled: true
serializer: "org.apache.spark.serializer.KryoSerializer"
5. Network and Connectivity Issues
Symptoms:
- Connection timeouts
- Intermittent failures
- Slow data transfers
Solutions:
Implement Retry Logic:
Copy@retry_with_exponential_backoff(max_retries=3, base_delay=2.0)
def read_data_with_retry():
return spark.read.parquet(s3_path)
Use Circuit Breakers:
Copycircuit_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=60)
result = circuit_breaker.call(risky_operation)
Diagnostic Commands
Check Spark Application Status
Copy# View Spark UI (if running locally)
open http://localhost:4040
# Check Spark application logs
yarn logs -applicationId application_id
Monitor System Resources
Copy# Check memory usage
free -h
top -p $(pgrep -f spark)
# Check disk space
df -h
du -sh /tmp/spark-*
Validate Data
Copy# Quick data validation
df.count()
df.printSchema()
df.describe().show()
df.filter(col("key_column").isNull()).count()
Emergency Procedures
Pipeline Stuck or Hanging
- Check Spark UI for stuck tasks
- Kill hanging Spark application
- Check for resource constraints
- Restart with reduced parallelism
Data Corruption Detected
- Stop pipeline immediately
- Identify corruption scope
- Restore from backup if available
- Investigate root cause
- Implement additional validation
Critical System Failure
- Alert on-call team
- Document failure details
- Implement immediate workaround
- Schedule post-mortem review
Gett ing Additional Help
Log Analysis
- Check application logs in
logs/
directory - Review Spark driver and executor logs
- Analyze error patterns and frequencies
Performance Profiling
- Use Spark UI for detailed analysis
- Monitor resource utilization
- Profile memory usage patterns
Escalation Path
- Check documentation and troubleshoot ing guide
- Review recent changes and deployments
- Consult with data engineering team
- Escalate to infrastructure team if needed “””
with open(f”{self.docs_dir}/troubleshoot ing_guide.md”, “w”) as f: f.write(troubleshoot ing_guide)
self.logger.info(“🔧 Troubleshoot ing guide generated”)
def _generate_data_dictionary(self): “””📊 Generate data dictionary”””
data_dictionary = “””
📊 Data Dictionary
Overview
This document describes all datasets, tables, and fields used in the Retail ETL Pipeline.
Source Datasets
customers.csv
Customer master data containing demographic and account information.

products.csv
Product catalog with pricing and category information.

transactions.csv
Transaction records capturing all customer purchases.

Processed Datasets
enriched_transactions
Enhanced transaction data with customer and product context.
Location: s3://bucket/processed/enriched_transactions/
Partitioning: year, month Format: Parquet

customer_metrics
Customer-level analytics and lifetime value calculations.
Location: s3://bucket/curated/customer_metrics/
Partitioning: customer_segment Format: Parquet

rfm_analysis
RFM customer segmentation analysis.
Location:s3://bucket/curated/rfm_analysis/

RFM Segments:
- Champions (555, 554, 544, 545, 454, 455, 445): Best customers
- Loyal Customers (543, 444, 435, 355, 354, 345, 344, 335): Regular buyers
- Potential Loyalists (512, 511, 422, 421, 412, 411, 311): Recent customers
- New Customers (512, 511, 422, 421, 412, 411, 311): First-time buyers
- At Risk (155, 154, 144, 214, 215, 115, 114): Declining engagement
- Cannot Lose Them (155, 154, 144, 214, 215, 115): High-value at risk
- Hibernat ing (155, 154, 144, 214, 215, 115, 114): Inactive customers
Data Quality Rules
Validation Rules
- Primary Keys: Must be unique and not null
- Foreign Keys: Must exist in referenced tables
- Dates: Cannot be future dates (except for scheduled events)
- Amounts: Must be positive for prices and totals
- Percentages: Must be between 0 and 100
- Email: Must follow valid email format
- Phone: Must follow valid phone number format
Quality Thresholds
- Null Values: < 5% for critical fields
- Duplicates: < 1% for transaction data
- Outliers: Flag values > 3 standard deviations
- Referential Integrity: 100% for foreign keys
Data Lineage
All processed datasets include metadata columns:
Copyetl_timestamp : When record was processed
etl_job_id: Unique identifier for processing job
source_file: Original source file name
Business Metrics Definitions
Customer Lifetime Value (CLV)
Formula: Average Order Value × Purchase Frequency × Customer Lifespan (years) Purpose: Predict total value of customer relationship Usage: Customer acquisition cost optimization, retention prioritization
RFM Scores
- Recency (R): How recently customer made a purchase (1=old, 5=recent)
- Frequency (F): How often customer makes purchases (1=rare, 5=frequent)
- Monetary (M): How much customer spends (1=low, 5=high)
Cohort Analysis
Definition: Groups customers by acquisition month and tracks behavior over time Metrics: Retention rate, revenue per cohort, customer lifetime patterns Usage: Understanding customer lifecycle, measuring retention initiatives
Market Basket Analysis
Definition: Identifies products frequently purchased together Metrics: Support (frequency), confidence (likelihood), lift (correlation) Usage: Cross-selling recommendations, product bundling, store layout “””
with open(f”{self.docs_dir}/data_dictionary.md”, “w”) as f: f.write(data_dictionary)
self.logger.info(“📊 Data dictionary generated”)
if name == “main”: doc_generator = DocumentationGenerator() doc_generator.generate_complete_documentation()
Copy# deployment/deployment_manager.py
import os
import yaml
import subprocess
from typing import Dict, List, Tuple, Callable
from src.utils.logger import ETLLogger
class DeploymentManager:
"""
Production Deployment Manager
Handles deployment of the ETL pipeline to various environments:
- Development: Local testing and development
- Staging: Pre-production validation
- Production: Live production environment
Features:
- Environment-specific configurations
- Automated testing before deployment
- Rollback capabilities
- Health checks and monitoring
"""
def __init__(self, environment: str = "development"):
self.environment = environment
self.logger = ETLLogger(__name__)
self.deployment_config = self._load_deployment_config()
self.prerequisite_checks = {
"Python 3.8+": self._check_python_version,
"Spark Installation": self._check_spark_installation,
"AWS CLI": self._check_aws_cli,
"Required Packages": self._check_python_packages,
"S3 Access": self._check_s3_access,
"Network Connectivity": self._check_network_connectivity
}
def _load_deployment_config(self) -> Dict:
"""Load environment-specific deployment configuration"""
config_file = f"deployment/config/{self.environment}.yaml"
if os.path.exists(config_file):
with open(config_file, 'r') as f:
return yaml.safe_load(f)
# Default configuration
return {
"spark": {
"master": "local[*]",
"executor_memory": "8g" if self.environment != "development" else "2g",
"driver_memory": "4g" if self.environment != "development" else "1g"
},
"monitoring": {
"enabled": self.environment == "production",
"metrics_endpoint": f"http://monitoring-{self.environment}.company.com"
}
}
def deploy_pipeline(self) -> bool:
"""
Deploy pipeline to target environment
Deployment steps:
1. Validate environment prerequisites
2. Run automated tests
3. Package application
4. Deploy to target environment
5. Run health checks
6. Enable monitoring
"""
self.logger.info(f"Starting deployment to {self.environment} environment...")
try:
deployment_steps = [
(self._validate_prerequisites, "Prerequisites validation failed"),
(self._run_deployment_tests, "Deployment tests failed"),
(self._package_application, "Application packaging failed"),
(self._deploy_to_environment, "Environment deployment failed"),
(self._run_health_checks, "Health checks failed"),
(self._setup_monitoring, "Monitoring setup failed")
]
for step, error_msg in deployment_steps:
if not step():
raise Exception(error_msg)
self.logger.info(f"Deployment to {self.environment} completed successfully!")
return True
except Exception as e:
self.logger.error(f"Deployment failed: {str(e)}")
self._rollback_deployment()
return False
def _validate_prerequisites(self) -> bool:
"""Validate environment prerequisites"""
self.logger.info("Validating deployment prerequisites...")
return self._run_checks(self.prerequisite_checks)
def _run_checks(self, checks: Dict[str, Callable[[], bool]]) -> bool:
"""Run a series of checks and return overall status"""
all_passed = True
for name, check_func in checks.items():
try:
if check_func():
self.logger.info(f"{name}: OK")
else:
self.logger.error(f"{name}: FAILED")
all_passed = False
except Exception as e:
self.logger.error(f"{name}: ERROR - {str(e)}")
all_passed = False
return all_passed
# Prerequisite check methods
def _check_python_version(self) -> bool:
"""Check Python version"""
import sys
return sys.version_info >= (3, 8)
def _check_spark_installation(self) -> bool:
"""Check Spark installation"""
try:
result = subprocess.run(
['spark-submit', '--version'],
capture_output=True,
text=True,
timeout=10
)
return result.returncode == 0
except Exception:
return False
def _check_aws_cli(self) -> bool:
"""Check AWS CLI installation and configuration"""
try:
result = subprocess.run(
['aws', 'sts', 'get-caller-identity'],
capture_output=True,
text=True,
timeout=10
)
return result.returncode == 0
except Exception:
return False
def _check_python_packages(self) -> bool:
"""Check required Python packages"""
required_packages = [
'pyspark', 'boto3', 'pandas', 'numpy', 'pyyaml', 'pytest'
]
try:
import importlib
for package in required_packages:
importlib.import_module(package)
return True
except ImportError:
return False
def _check_s3_access(self) -> bool:
"""Check S3 bucket access"""
try:
import boto3
s3_client = boto3.client('s3')
bucket_name = self.deployment_config.get('s3', {}).get('bucket_name', 'test-bucket')
s3_client.head_bucket(Bucket=bucket_name)
return True
except Exception:
return False
def _check_network_connectivity(self) -> bool:
"""Check network connectivity to required services"""
import socket
endpoints = [
('s3.amazonaws.com', 443),
('spark.apache.org', 80)
]
for host, port in endpoints:
try:
socket.create_connection((host, port), timeout=5)
except Exception:
return False
return True
# Placeholder methods for deployment steps
def _run_deployment_tests(self) -> bool:
"""Run deployment tests"""
self.logger.info("Running deployment tests...")
# Implementation would go here
return True
def _package_application(self) -> bool:
"""Package application for deployment"""
self.logger.info("Packaging application...")
# Implementation would go here
return True
def _deploy_to_environment(self) -> bool:
"""Deploy to target environment"""
self.logger.info(f"Deploying to {self.environment} environment...")
# Implementation would go here
return True
def _run_health_checks(self) -> bool:
"""Run health checks after deployment"""
self.logger.info("Running health checks...")
# Implementation would go here
return True
def _setup_monitoring(self) -> bool:
"""Set up monitoring for the deployment"""
self.logger.info("Setting up monitoring...")
# Implementation would go here
return True
def _rollback_deployment(self):
"""Rollback deployment in case of failure"""
self.logger.info("Initiating rollback procedure...")
# Implementation would go here
def create_dockerfile():
"""Create Dockerfile for containerized deployment"""
dockerfile_content = """FROM python:3.9-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \\
openjdk-11-jdk \\
wget \\
curl \\
&& rm -rf /var/lib/apt/lists/*
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
# Install Spark
ENV SPARK_VERSION=3.4.0
ENV HADOOP_VERSION=3
RUN wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz \\
&& tar -xzf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz \\
&& mv spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION} /opt/spark \\
&& rm spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz
ENV SPARK_HOME=/opt/spark
ENV PATH=${SPARK_HOME}/bin:${PATH}
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY src/ ./src/
COPY data/ ./data/
COPY logs/ ./logs/
COPY docs/ ./docs/
# Create non-root user for security
RUN useradd -m -u 1000 sparkuser && chown -R sparkuser:sparkuser /app
USER sparkuser
# Set environment variables
ENV PYTHONPATH=/app
ENV PYSPARK_PYTHON=python3
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \\
CMD python -c "from src.pipeline.main_orchestrator import RetailETLOrchestrator; print('Health check passed')"
CMD ["python", "-m", "src.pipeline.main_orchestrator"]
"""
with open("Dockerfile", "w") as f:
f.write(dockerfile_content)
def create_kubernetes_config():
"""Create Kubernetes deployment configuration"""
k8s_config = """apiVersion: apps/v1
kind: Deployment
metadata:
name: retail-etl-pipeline
labels:
app: retail-etl
spec:
replicas: 1
selector:
matchLabels:
app: retail-etl
template:
metadata:
labels:
app: retail-etl
spec:
containers:
- name: retail-etl
image: retail-etl-pipeline:latest
resources:
requests:
memory: "4Gi"
cpu: "2"
limits:
memory: "8Gi"
cpu: "4"
env:
- name: ENVIRONMENT
value: "production"
- name: AWS_DEFAULT_REGION
value: "us-east-1"
- name: S3_BUCKET_NAME
valueFrom:
secretKeyRef:
name: etl-secrets
key: s3-bucket-name
volumeMounts:
- name: config-volume
mountPath: /app/config
- name: logs-volume
mountPath: /app/logs
volumes:
- name: config-volume
configMap:
name: etl-config
- name: logs-volume
emptyDir: {}
serviceAccountName: retail-etl-service-account
---
apiVersion: v1
kind: ConfigMap
metadata:
name: etl-config
data:
pipeline_config.yaml: |
spark:
app_name: "RetailETLPipeline-Production"
executor_memory: "6g"
driver_memory: "2g"
s3:
region: "us-east-1"
data_quality:
null_threshold: 0.05
duplicate_threshold: 0.01
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: retail-etl-service-account
annotations:
eks.amazonaws.com/role-arn: arn:aws:iam::ACCOUNT:role/RetailETLRole
---
apiVersion: batch/v1
kind: CronJob
metadata:
name: retail-etl-daily
spec:
schedule: "0 2 * * *"
jobTemplate:
spec:
template:
spec:
containers:
- name: retail-etl
image: retail-etl-pipeline:latest
command: ["python", "-m", "src.pipeline.main_orchestrator"]
restartPolicy: OnFailure
"""
with open("k8s-deployment.yaml", "w") as f:
f.write(k8s_config)
def create_monitoring_config():
"""Create monitoring and alerting configuration"""
prometheus_config = """global:
scrape_interval: 15s
evaluation_interval: 15s
rule_files:
- "etl_alerts.yml"
scrape_configs:
- job_name: 'retail-etl-pipeline'
static_configs:
- targets: ['localhost:8080']
metrics_path: /metrics
scrape_interval: 30s
alerting:
alertmanagers:
- static_configs:
- targets:
- alertmanager:9093
"""
alert_rules = """groups:
- name: retail_etl_alerts
rules:
- alert: ETLPipelineFailure
expr: etl_pipeline_status != 1
for: 5m
labels:
severity: critical
annotations:
summary: "ETL Pipeline has failed"
description: "The retail ETL pipeline has been failing for more than 5 minutes"
- alert: DataQualityDegraded
expr: etl_data_quality_score < 0.95
for: 10m
labels:
severity: warning
annotations:
summary: "Data quality has degraded"
description: "Data quality score is {{ $value }}, below threshold of 0.95"
- alert: ETLExecutionTimeHigh
expr: etl_execution_time_seconds > 3600
for: 0m
labels:
severity: warning
annotations:
summary: "ETL execution time is high"
description: "ETL pipeline took {{ $value }} seconds to complete"
- alert: HighErrorRate
expr: rate(etl_errors_total[5m]) > 0.1
for: 5m
labels:
severity: warning
annotations:
summary: "High error rate detected"
description: "Error rate is {{ $value }} errors per second"
"""
with open("prometheus.yml", "w") as f:
f.write(prometheus_config)
with open("etl_alerts.yml", "w") as f:
f.write(alert_rules)
def create_cicd_config():
"""Create CI/CD pipeline configuration"""
github_actions = """name: Retail ETL Pipeline CI/CD
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
env:
PYTHON_VERSION: 3.9
SPARK_VERSION: 3.4.0
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install Java
uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: '11'
- name: Cache dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: runner.os-${{ hashFiles('**/requirements.txt') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov
- name: Install Spark
run: |
wget https://archive.apache.org/dist/spark/spark-${{ env.SPARK_VERSION }}-bin-hadoop3.tgz
tar -xzf spark-${{ env.SPARK_VERSION }}-bin-hadoop3.tgz
sudo mv spark-${{ env.SPARK_VERSION }}-bin-hadoop3 /opt/spark
echo "SPARK_HOME=/opt/spark" >> $GITHUB_ENV
echo "/opt/spark/bin" >> $GITHUB_PATH
- name: Run unit tests
run: |
pytest tests/ -v --cov=src --cov-report=xml
- name: Run data quality tests
run: |
python -m pytest tests/test_data_quality.py -v
- name: Run integration tests
run: |
python -m pytest tests/test_integration.py -v
- name: Upload coverage reports
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
build:
needs: test
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v3
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-1
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v1
- name: Build and push Docker image
env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: retail-etl-pipeline
IMAGE_TAG: ${{ github.sha }}
run: |
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest
docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest
deploy:
needs: build
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v3
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-1
- name: Deploy to EKS
run: |
aws eks update-kubeconfig --region us-east-1 --name retail-etl-cluster
kubectl apply -f k8s-deployment.yaml
kubectl rollout status deployment/retail-etl-pipeline
- name: Run smoke tests
run: |
kubectl exec deployment/retail-etl-pipeline -- python -c "
from src.pipeline.main_orchestrator import RetailETLOrchestrator;
orchestrator = RetailETLOrchestrator();
print('Smoke test passed')
"
- name: Notify deployment
uses: 8398a7/action-slack@v3
with:
status: ${{ job.status }}
channel: '#data-engineering'
webhook_url: ${{ secrets.SLACK_WEBHOOK }}
if: always()
"""
os.makedirs(".github/workflows", exist_ok=True)
with open(".github/workflows/ci-cd.yml", "w") as f:
f.write(github_actions)
if __name__ == "__main__":
print("Creating deployment configurations...")
create_dockerfile()
create_kubernetes_config()
create_monitoring_config()
create_cicd_config()
print("All deployment configurations created successfully!")
print("\nNext steps:")
print("1. Review and customize configurations for your environment")
print("2. Set up AWS credentials and permissions")
print("3. Configure monitoring and alerting endpoints")
print("4. Test deployment in staging environment")
print("5. Deploy to production with proper change management")
🎉 Conclusion & Final Thoughts
Congratulations! 🎊 We’ve just built a production-grade PySpark ETL pipeline that’s enterprise-ready and follows industry best practices! Let me summarize what we’ve accomplished:
🏆 What We’ve Built
- 📊 Complete Data Processing Pipeline
- Robust data ingestion with schema validation
- Sophisticated business transformations
- Advanced analytics (RFM, cohort analysis, market basket)
- Intelligent data quality monitoring
- 🛡️ Production-Grade Features
- Comprehensive error handling with circuit breakers
- Exponential backoff retry logic
- Graceful degradation and recovery
- Complete audit trails and data lineage
- 🧪 Test ing Framework
- Unit tests for individual components
- Integration tests for end-to-end workflows
- Data quality validation tests
- Performance benchmarking
- ☁️ Cloud-Native Architecture
- S3 data lake with intelligent partitioning
- Scalable Spark configuration
- Container-ready deployment
- Kubernetes orchestration
- 📚 Complete Documentation
- API documentation with examples
- User guides and troubleshoot ing
- Data dictionary and business metrics
- Deployment and monitoring guides
🎯 Key Business Value
This pipeline solves real business problems:
- Customer Segmentation: RFM analysis identifies high-value customers
- Retention Analysis: Cohort tracking reveals customer lifecycle patterns
- Cross-Selling: Market basket analysis drives revenue optimization
- Churn Prevention: Early warning indicators enable proactive retention
- Executive Insights: Automated KPI generation supports decision-making
🚀 Production Readiness Checklist
✅ Scalability: Handles millions of transactions with Spark optimization ✅ Reliability: Circuit breakers and retry logic ensure resilience ✅ Monitoring: Comprehensive logging and alert ing for operational visibility ✅ Security: IAM roles, encryption, and access controls ✅ Maintainability: Modular design with comprehensive documentation ✅ Testability: Full test coverage with automated CI/CD ✅ Compliance: Audit trails and data lineage for regulatory requirements
💡 Next Steps for Implementation
- Environment Setup
- Configure AWS credentials and S3 buckets
- Set up Spark cluster (EMR, Databricks, or Kubernetes)
- Deploy monitoring infrastructure
2. Data Integration
- Connect to your actual data sources
- Customize schemas for your business
- Implement source-specific data readers
3. Business Customization
- Adapt metrics to your KPIs
- Customize customer segments
- Add industry-specific analytics
4. Operational Excellence
- Set up alert ing and monitoring
- Implement backup and disaster recovery
- Establish change management processes
🌟 Why This Approach Works
This isn’t just code — it’s a complete data engineering solution that:
- Scales with your business from startup to enterprise
- Adapts to changing requirements with modular architecture
- Reduces operational overhead with automation and monitoring
- Ensures data quality with comprehensive validation
- Supports compliance with audit trails and documentation
📈 Impact on Your Organization
By implement ing this pipeline, you’ll:
- Reduce time-to-insight from weeks to hours
- Improve data quality with automated validation
- Enable data-driven decisions with reliable metrics
- Scale analytics capabilities without proportional cost increases
- Build data engineering expertise within your team
This pipeline represents months of engineering effort condensed into a comprehensive, production-ready solution. It incorporates lessons learned from real-world implementations and follows industry best practices from companies like Netflix, Uber, and Airbnb.
Remember: Great data engineering isn’t just about moving data — it’s about creat ing reliable, scalable systems that enable business success! 🎯
Happy data engineering! 🚀✨