#!/usr/bin/env python3
"""
Tutorial: Causal Machine Learning and the Resource Curse (EconML)
=================================================================
Pedagogical replication of the main findings from
Hodler, Lechner & Raschky (2023) using simulated data and
the EconML Causal Forest (CausalForestDML) estimator.

Uses the Double Machine Learning (DML) framework to estimate
heterogeneous treatment effects of mining and mineral prices
on economic development and conflict.

Runtime: ~3-8 minutes

Usage:
    python tutorial-econml-resource-curse.py
"""

import os
import sys
import time

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# ============================================================================
# Ground-truth parameters (from the data-generating process)
# ============================================================================
TRUE_PARAMS = {
    'ntl_mining_base': 0.25,          # Mining effect at mean institutions
    'ntl_mining_inst_mod': 0.15,      # Institutional moderation of mining
    'ntl_price_med': 0.05,            # Medium price premium (small)
    'ntl_price_high': 0.30,           # High price premium (large)
    'ntl_noise_sd': 0.25,             # Outcome noise
    'conflict_mining_base': 0.70,     # Mining increases conflict
    'conflict_mining_inst_mod': -0.50, # Institutions dampen mining-conflict
    'conflict_price_med': 0.15,
    'conflict_price_high': 0.50,
    'conflict_base_rate': 0.12,
}


def expected_ates():
    """Derive ground-truth ATEs from the DGP parameters."""
    p = TRUE_PARAMS
    return {
        'NTL': {
            '1-0': p['ntl_mining_base'],
            '2-0': p['ntl_mining_base'] + p['ntl_price_med'],
            '3-0': p['ntl_mining_base'] + p['ntl_price_high'],
            '2-1': p['ntl_price_med'],
            '3-1': p['ntl_price_high'],
            '3-2': p['ntl_price_high'] - p['ntl_price_med'],
        }
    }


# ============================================================================
# Site color palette and dark theme
# ============================================================================
STEEL_BLUE = "#6a9bcc"
WARM_ORANGE = "#d97757"
NEAR_BLACK = "#141413"
TEAL = "#00d4c8"

DARK_NAVY = "#0f1729"
GRID_LINE = "#1f2b5e"
LIGHT_TEXT = "#c8d0e0"
WHITE_TEXT = "#e8ecf2"

plt.rcParams.update({
    "figure.facecolor": DARK_NAVY,
    "axes.facecolor": DARK_NAVY,
    "axes.edgecolor": DARK_NAVY,
    "axes.linewidth": 0,
    "axes.labelcolor": LIGHT_TEXT,
    "axes.titlecolor": WHITE_TEXT,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.spines.left": False,
    "axes.spines.bottom": False,
    "axes.grid": True,
    "grid.color": GRID_LINE,
    "grid.linewidth": 0.6,
    "grid.alpha": 0.8,
    "xtick.color": LIGHT_TEXT,
    "ytick.color": LIGHT_TEXT,
    "xtick.major.size": 0,
    "ytick.major.size": 0,
    "text.color": WHITE_TEXT,
    "font.size": 12,
    "legend.frameon": False,
    "legend.fontsize": 11,
    "legend.labelcolor": LIGHT_TEXT,
    "figure.edgecolor": DARK_NAVY,
    "savefig.facecolor": DARK_NAVY,
    "savefig.edgecolor": DARK_NAVY,
})

# ============================================================================
# Configuration
# ============================================================================
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
RESULTS_DIR = os.path.join(_SCRIPT_DIR, 'tutorial_results')
os.makedirs(RESULTS_DIR, exist_ok=True)

DATA_URL = ("https://github.com/cmg777/starter-academic-v501"
            "/raw/master/content/post/python_EconML/sim_resource_curse.csv")
# Local fallback
_LOCAL_CSV = os.path.join(_SCRIPT_DIR, 'sim_resource_curse.csv')

# Heterogeneity features (X): variables the causal forest can split on
# to discover treatment effect heterogeneity
X_COLS = [
    'exec_constraints', 'quality_of_govt', 'gdp_pc',
    'elevation', 'temperature', 'ruggedness',
    'distance_capital', 'agri_suitability', 'population', 'ethnic_frac',
]

# Additional controls (W): used only in the first-stage nuisance models
# (residualization) but NOT in the second-stage causal forest
W_COLS = ['country_id', 'year']

# Institutional variables for GATE analysis
Z_VARS = ['exec_constraints', 'quality_of_govt']

# CATE Interpreter (single-tree summary of estimated CATEs).
# Depth 2 -> 4 leaves: communicable for the post.
# Depth 3-4 is more diagnostic but harder to read in narrative form.
CATE_TREE_DEPTH = 2
CATE_TREE_MIN_LEAF = 100


def _savefig(fig, name):
    """Save figure with dark theme settings to the post directory.

    PNGs land alongside index.md so the post's image references resolve
    without an extra copy step.
    """
    fig.patch.set_linewidth(0)
    fig.savefig(
        os.path.join(_SCRIPT_DIR, name),
        dpi=300, bbox_inches='tight',
        facecolor=DARK_NAVY, edgecolor=DARK_NAVY, pad_inches=0,
    )
    plt.close(fig)
    print(f"  Saved: {name}")


# ============================================================================
# Step 1: Load simulated data
# ============================================================================
def step_load_data():
    """Load the simulated panel dataset from GitHub."""
    print("\n" + "=" * 70)
    print("STEP 1: LOADING SIMULATED DATA")
    print("=" * 70)

    try:
        df = pd.read_csv(DATA_URL)
    except Exception:
        print("  (GitHub URL unavailable, loading local copy)")
        df = pd.read_csv(_LOCAL_CSV)

    print(f"  Dataset: {len(df):,} observations")
    print(f"  Districts: {df['district_id'].nunique()}, "
          f"Countries: {df['country_id'].nunique()}, "
          f"Years: {df['year'].min()}-{df['year'].max()}")
    print(f"  Mining districts: "
          f"{df.loc[df['mining']==1, 'district_id'].nunique()} "
          f"({df['mining'].mean():.0%} of total)")

    print(f"\n  Treatment distribution:")
    labels = {0: 'No mining', 1: 'Low prices',
              2: 'Med prices', 3: 'High prices'}
    for t, n in df['treatment'].value_counts().sort_index().items():
        print(f"    {t} ({labels[t]}): {n:,} ({n/len(df):.1%})")

    print(f"\n  Outcomes:")
    print(f"    NTL (log):  mean={df['ntl_log'].mean():.3f}, "
          f"std={df['ntl_log'].std():.3f}")
    print(f"    Conflict:   {df['conflict'].mean():.1%} of district-years")

    return df


# ============================================================================
# Step 2: Descriptive statistics
# ============================================================================
def step_descriptive_stats(df):
    """Print summary statistics and treatment distribution chart."""
    print("\n" + "=" * 70)
    print("STEP 2: DESCRIPTIVE STATISTICS")
    print("=" * 70)

    desc_vars = {
        'NTL (log)': 'ntl_log',
        'Conflict': 'conflict',
        'Exec. Constraints': 'exec_constraints',
        'Quality of Govt.': 'quality_of_govt',
        'GDP per capita': 'gdp_pc',
        'Elevation': 'elevation',
        'Temperature': 'temperature',
        'Ruggedness': 'ruggedness',
        'Dist. to Capital': 'distance_capital',
        'Agri. Suitability': 'agri_suitability',
        'Population': 'population',
        'Ethnic Frac.': 'ethnic_frac',
    }

    rows = []
    for label, var in desc_vars.items():
        s = df[var]
        rows.append({'Variable': label, 'Mean': s.mean(), 'Std': s.std(),
                     'Min': s.min(), 'Max': s.max()})

    stats = pd.DataFrame(rows)
    print(stats.to_string(index=False, float_format='{:.3f}'.format))

    # Outcomes by treatment group
    labels = {0: 'No mining', 1: 'Low prices',
              2: 'Med prices', 3: 'High prices'}
    print(f"\n  Outcomes by treatment group:")
    print(f"  {'Treatment':<20s} {'Mean NTL':>10s} {'Conflict Rate':>15s}")
    print(f"  {'-'*47}")
    for t in sorted(df['treatment'].unique()):
        mask = df['treatment'] == t
        m_ntl = df.loc[mask, 'ntl_log'].mean()
        m_conf = df.loc[mask, 'conflict'].mean()
        print(f"  {t} ({labels[t]:<13s})  {m_ntl:>10.3f} {m_conf:>14.1%}")

    # Treatment distribution chart
    fig, ax = plt.subplots(figsize=(8, 3))
    colors = [LIGHT_TEXT, STEEL_BLUE, TEAL, WARM_ORANGE]
    counts = df['treatment'].value_counts().sort_index()
    bars = ax.barh(
        [labels[t] for t in counts.index], counts.values,
        color=colors, edgecolor=DARK_NAVY, linewidth=0.8,
    )
    for bar, count in zip(bars, counts.values):
        ax.text(bar.get_width() + 20, bar.get_y() + bar.get_height()/2,
                f'{count:,} ({count/len(df):.1%})',
                va='center', fontsize=9, color=LIGHT_TEXT)
    ax.set_xlabel('Number of observations')
    ax.set_title('Treatment Distribution (M=4)')
    _savefig(fig, 'python_econml_treatment_dist.png')

    stats.to_csv(os.path.join(RESULTS_DIR, 'descriptive-stats.csv'),
                 index=False)


# ============================================================================
# Step 2b: Naive comparison (motivating causal ML)
# ============================================================================
def step_naive_comparison(df):
    """Compare naive difference-in-means to ground truth."""
    print("\n" + "=" * 70)
    print("STEP 2b: NAIVE COMPARISON (why we need causal ML)")
    print("=" * 70)

    gt = expected_ates()['NTL']

    print(f"\n  Simple difference-in-means (no confounder adjustment):")
    print(f"  {'Comparison':<15s} {'Naive':>8s} {'Ground Truth':>14s} {'Bias':>8s}")
    print(f"  {'-'*47}")

    for comp in ['1-0', '2-1', '3-1']:
        a, b = int(comp[0]), int(comp[2])
        mean_a = df.loc[df['treatment'] == a, 'ntl_log'].mean()
        mean_b = df.loc[df['treatment'] == b, 'ntl_log'].mean()
        naive = mean_a - mean_b
        truth = gt[comp]
        bias = naive - truth
        print(f"  {comp:<15s} {naive:>8.3f} {truth:>14.3f} {bias:>+8.3f}")

    print(f"\n  The naive 1-0 estimate is biased because mining districts")
    print(f"  differ systematically from non-mining districts (geography,")
    print(f"  institutions). The Causal Forest controls for these confounders")
    print(f"  via the Double Machine Learning residualization step.")


# ============================================================================
# Step 3: EconML Causal Forest Estimation
# ============================================================================
def step_econml_estimation(df, outcome, label):
    """Run CausalForestDML for one outcome variable."""
    from econml.dml import CausalForestDML
    from sklearn.ensemble import (GradientBoostingRegressor,
                                  GradientBoostingClassifier)

    print(f"\n{'=' * 70}")
    print(f"STEP 3: CAUSAL FOREST DML — {label}")
    print(f"  Outcome: {outcome}  |  Treatment: treatment (M=4)")
    print(f"  Observations: {len(df):,}  |  Trees: 500")
    print(f"{'=' * 70}")

    Y = df[outcome].values
    T = df['treatment'].values
    X = df[X_COLS].values
    W = df[W_COLS].values

    print("  Configuring CausalForestDML...")
    print("    - discrete_treatment=True (4 treatment levels)")
    print("    - DML: GradientBoosting for both nuisance models")
    print("    - 500 honest causal trees (separate split/estimation samples)")
    print("    - BLB inference (bootstrap of little bags)")
    print("    - GroupKFold CV by district_id (prevents within-district")
    print("      data leakage in cross-fitting; does NOT cluster SEs)")

    est = CausalForestDML(
        # Nuisance models (first stage: residualize Y and T)
        model_y=GradientBoostingRegressor(
            n_estimators=200, max_depth=4, random_state=42),
        model_t=GradientBoostingClassifier(
            n_estimators=200, max_depth=4, random_state=42),

        # Treatment
        discrete_treatment=True,
        categories=[0, 1, 2, 3],

        # Causal forest (second stage)
        n_estimators=500,
        min_samples_leaf=10,
        max_depth=None,         # Fully grown trees
        honest=True,            # Honesty for valid inference

        # Inference
        inference=True,         # BLB (bootstrap of little bags)

        # Cross-validation
        cv=5,

        # Computation
        n_jobs=1,               # Single-threaded for reproducibility
        random_state=42,
    )

    t0 = time.time()
    print("  Fitting (first-stage residualization + causal forest)...")
    # `groups=district_id` triggers GroupKFold for cross-fitting only — it
    # keeps observations from the same district inside the same fold so the
    # first-stage models do not leak within-district information across
    # folds. It does NOT translate into clustered second-stage SEs.
    est.fit(Y, T, X=X, W=W, groups=df['district_id'].values)

    elapsed = time.time() - t0
    print(f"  Done in {elapsed/60:.1f} minutes")

    return est


# ============================================================================
# Step 4: Extract and display ATEs
# ============================================================================
def step_ate_table(est_ntl, df):
    """Print ATEs (NTL outcome) with BLB inference and ground-truth comparison.

    The post format mirrors what gets pasted into index.md, line by line.
    Numbers should round-trip cleanly through .4f / .3f formatting.
    """
    print("\n" + "=" * 70)
    print("STEP 4: AVERAGE TREATMENT EFFECTS")
    print("=" * 70)

    X = df[X_COLS].values
    gt = expected_ates()['NTL']

    # All pairwise comparisons — ate_inference gives proper BLB CIs
    all_comparisons = [
        ('1-0', 0, 1), ('2-0', 0, 2), ('3-0', 0, 3),
        ('2-1', 1, 2), ('3-1', 1, 3), ('3-2', 2, 3),
    ]

    rows = []
    print(f"\n  Per-comparison inference (BLB; alpha=0.10):")
    for comp_label, t0, t1 in all_comparisons:
        ntl_res = est_ntl.ate_inference(X, T0=t0, T1=t1)
        ntl_lo, ntl_hi = ntl_res.conf_int_mean(alpha=0.1)  # 90% CI

        rows.append({
            'Comparison': comp_label,
            'NTL Effect': ntl_res.mean_point,
            'NTL SE': ntl_res.stderr_mean,
            'NTL 90% CI Lo': ntl_lo,
            'NTL 90% CI Hi': ntl_hi,
            'NTL Ground Truth': gt.get(comp_label, np.nan),
        })
        # This print line is what index.md mirrors verbatim.
        print(f"  {comp_label}: ATE={ntl_res.mean_point:.4f} "
              f"SE={ntl_res.stderr_mean:.4f} "
              f"90%CI=[{ntl_lo:.3f}, {ntl_hi:.3f}]")

    table = pd.DataFrame(rows)

    # Print compact summary table for the log
    print(f"\n  Summary table (compare against ground truth):")
    print(f"  {'Comp':<8s} {'NTL Effect':>11s} {'NTL SE':>8s} "
          f"{'Ground Truth':>13s} {'Sig':>5s}")
    print(f"  {'-'*52}")
    for _, r in table.iterrows():
        sig = ''
        if r['NTL SE'] > 0:
            z = abs(r['NTL Effect'] / r['NTL SE'])
            if z > 2.576: sig = '***'
            elif z > 1.96: sig = '**'
            elif z > 1.645: sig = '*'
        print(f"  {r['Comparison']:<8s} {r['NTL Effect']:>11.4f} "
              f"{r['NTL SE']:>8.4f} {r['NTL Ground Truth']:>13.3f} "
              f"{sig:>5s}")
    print(f"\n  * p<0.10, ** p<0.05, *** p<0.01 (two-sided, BLB SEs)")

    # Interpretation hooks for the post.
    eff_10 = table.loc[table['Comparison'] == '1-0', 'NTL Effect'].iloc[0]
    eff_21 = table.loc[table['Comparison'] == '2-1', 'NTL Effect'].iloc[0]
    eff_31 = table.loc[table['Comparison'] == '3-1', 'NTL Effect'].iloc[0]
    print("\n  Key findings:")
    print(f"    Finding 1: Mining increases NTL (1-0 effect = {eff_10:.3f})")
    print(f"    Finding 2: Non-linear prices — "
          f"2-1 = {eff_21:.3f} (small) vs 3-1 = {eff_31:.3f} (large)")

    table.to_csv(os.path.join(RESULTS_DIR, 'ate-table.csv'), index=False)
    return table


# ============================================================================
# Step 5: GATE plots by institutional quality
# ============================================================================
def compute_gate(est, df, z_var, t0, t1):
    """Compute Group Average Treatment Effects (GATEs) by a grouping variable.

    Uses effect_inference() for per-observation BLB standard errors and
    propagates them to group means as follows.

    Let tau_hat_i be the per-observation CATE for unit i in group g of size
    n_g, with BLB standard error se_i. Treating the n_g CATE estimates as
    approximately uncorrelated within g (a working assumption — EconML's
    BLB does not return their full covariance matrix), the GATE is

        GATE_g = (1/n_g) * sum_{i in g} tau_hat_i

    and its variance is

        Var(GATE_g) ~= (1/n_g^2) * sum_{i in g} se_i^2
                     = (1/n_g) * mean_{i in g}(se_i^2),

    so the SE used here is sqrt(mean(se_i^2) / n_g). This captures
    estimation uncertainty in each individual CATE; it does not capture
    sampling variability of who lands in group g, nor any within-group
    correlation across panel observations of the same district.
    """
    X = df[X_COLS].values
    inf = est.effect_inference(X, T0=t0, T1=t1)
    ite = inf.point_estimate
    ite_se = inf.stderr

    z_vals = sorted(df[z_var].unique())
    gate_data = []
    for z in z_vals:
        mask = df[z_var].values == z
        n = mask.sum()
        gate_mean = ite[mask].mean()
        # SE of group mean: sqrt(mean(se_i^2) / n) — accounts for
        # estimation uncertainty in each individual CATE
        gate_se = np.sqrt(np.mean(ite_se[mask] ** 2) / n)
        gate_data.append({
            'z_value': z,
            'gate': gate_mean,
            'se': gate_se,
            'lower': gate_mean - 1.645 * gate_se,  # 90% CI
            'upper': gate_mean + 1.645 * gate_se,
            'n': n,
        })
    return pd.DataFrame(gate_data), ite


def _plot_single_gate(est, df, z_var, z_label, t0, t1, title, color, fname):
    """Plot a single GATE panel and save to file."""
    gate_df, ite = compute_gate(est, df, z_var, t0, t1)

    fig, ax = plt.subplots(figsize=(8, 6))
    fig.patch.set_linewidth(0)

    ax.fill_between(gate_df['z_value'], gate_df['lower'],
                    gate_df['upper'], alpha=0.25, color=color)
    ax.plot(gate_df['z_value'], gate_df['gate'], 'o-',
            color=WHITE_TEXT, markersize=7, linewidth=1.5,
            markeredgecolor=DARK_NAVY, markeredgewidth=0.8, zorder=3)

    ate_val = ite.mean()
    ax.axhline(ate_val, color=WARM_ORANGE, linewidth=1.5, linestyle='--',
               alpha=0.8, label=f'ATE = {ate_val:.3f}')

    ax.set_xlabel(z_label, fontsize=13)
    ax.set_ylabel('GATE', fontsize=13)
    ax.set_title(title, fontsize=14, fontweight='bold', pad=12)
    ax.legend(fontsize=11)

    _savefig(fig, fname)
    return gate_df


def step_gate_plots(est_ntl, df):
    """Create individual NTL GATE plots by institutional quality."""
    print("\n" + "=" * 70)
    print("STEP 5: GATEs BY INSTITUTIONAL QUALITY")
    print("=" * 70)

    z_specs = [
        ('exec_constraints', 'Constraints on the Executive', 'exec'),
        ('quality_of_govt', 'Quality of Government', 'qog'),
    ]

    # Individual GATE panels embedded in the blog post.
    # The post focuses on the NTL outcome only.
    individual_specs = [
        (est_ntl, 0, 1, 'NTL: Mining vs No Mining (1-0)', STEEL_BLUE, 'ntl_1v0'),
        (est_ntl, 1, 3, 'NTL: High vs Low Prices (3-1)', STEEL_BLUE, 'ntl_3v1'),
    ]

    for z_var, z_label, z_short in z_specs:
        for est, t0, t1, title, color, comp_short in individual_specs:
            fname = f'python_econml_gate_{comp_short}_{z_short}.png'
            _plot_single_gate(est, df, z_var, z_label, t0, t1,
                              title, color, fname)

    # GATE values walkthrough (mirrored verbatim into index.md)
    print(f"\n  GATE values for NTL by Executive Constraints:")
    print(f"  {'='*65}")
    for (t0, t1), comp_label in [((0, 1), 'Mining vs No Mining (1-0)'),
                                  ((1, 3), 'High vs Low Prices (3-1)')]:
        gate_df, _ = compute_gate(est_ntl, df, 'exec_constraints', t0, t1)
        print(f"\n    {t1}-{t0} ({comp_label}):")
        print(f"    {'Exec. Constr.':<15s} {'GATE':>8s} {'90% CI':>20s}"
              f" {'N':>6s}")
        print(f"    {'-'*52}")
        for _, row in gate_df.iterrows():
            print(f"    {row['z_value']:>13.0f}   {row['gate']:>8.3f}   "
                  f"[{row['lower']:.3f}, {row['upper']:.3f}]"
                  f" {row['n']:>6.0f}")
        rng = gate_df['gate'].max() - gate_df['gate'].min()
        print(f"    Range: {rng:.3f}")

    print(f"\n    Finding 3a: 1-0 effects vary with institutions")
    print(f"    Finding 3b: 3-1 effects are FLAT across institutions")


# ============================================================================
# Step 5b: Variable importance
# ============================================================================
def step_variable_importance(est_ntl):
    """Display and plot feature importances from the causal forest."""
    print("\n" + "=" * 70)
    print("STEP 5b: VARIABLE IMPORTANCE")
    print("=" * 70)

    importances = est_ntl.feature_importances_
    vim_data = sorted(zip(X_COLS, importances),
                      key=lambda x: x[1], reverse=True)

    print(f"\n  Feature importances (heterogeneity drivers):")
    for var, imp in vim_data:
        bar = '#' * int(imp * 100)
        print(f"    {var:<25s} {imp:>6.3f}  {bar}")

    print(f"\n  Note: These measure how much each feature contributes to")
    print(f"  treatment effect HETEROGENEITY, not to outcome prediction.")

    # Bar chart
    fig, ax = plt.subplots(figsize=(8, 5))
    fig.patch.set_linewidth(0)
    vars_, imps = zip(*reversed(vim_data))
    ax.barh(vars_, imps, color=STEEL_BLUE, edgecolor=DARK_NAVY,
            linewidth=0.8, alpha=0.9)
    ax.set_xlabel('Feature Importance (heterogeneity contribution)',
                  fontsize=13)
    ax.set_title('Treatment Effect Heterogeneity Drivers (NTL)',
                 fontsize=14, fontweight='bold', pad=12)
    _savefig(fig, 'python_econml_var_importance.png')


# ============================================================================
# Step 5c: CATE Interpreter (EconML-specific feature)
# ============================================================================
def step_cate_interpreter(est_ntl, df):
    """Use SingleTreeCateInterpreter to find interpretable subgroups."""
    from econml.cate_interpreter import SingleTreeCateInterpreter

    print("\n" + "=" * 70)
    print("STEP 5c: CATE INTERPRETER (EconML-specific feature)")
    print("=" * 70)

    X = df[X_COLS].values

    # Interpret the 1-0 contrast (mining vs no mining)
    intrp = SingleTreeCateInterpreter(
        max_depth=CATE_TREE_DEPTH,
        min_samples_leaf=CATE_TREE_MIN_LEAF,
    )
    intrp.interpret(est_ntl, X)

    print(f"\n  The SingleTreeCateInterpreter fits a shallow decision tree")
    print(f"  to the estimated CATEs to find interpretable subgroups.")

    # Save the tree plot — dark theme with post-processed node colors
    fig, ax = plt.subplots(figsize=(12, 6))
    fig.patch.set_linewidth(0)

    intrp.plot(feature_names=X_COLS, ax=ax)
    ax.set_title('Interpretable CATE Subgroups: Mining vs No Mining (1-0)',
                 fontsize=14, fontweight='bold', color=WHITE_TEXT, pad=12)

    # Post-process tree for dark-theme readability.
    # sklearn renders nodes as Annotation objects with bbox patches whose
    # fill goes from green (high CATE) to white (low CATE).  Remap light
    # fills to darker tones so all nodes are visible on dark navy.
    for child in ax.get_children():
        if not isinstance(child, plt.Annotation):
            continue
        child.set_color(WHITE_TEXT)
        if child.arrowprops:
            child.arrowprops['edgecolor'] = LIGHT_TEXT
            child.arrowprops['facecolor'] = LIGHT_TEXT
        bbox = child.get_bbox_patch()
        if bbox is None:
            continue
        fc = bbox.get_facecolor()
        r, g, b = fc[0], fc[1], fc[2]
        luminance = 0.299 * r + 0.587 * g + 0.114 * b
        if luminance > 0.85:
            bbox.set_facecolor((0.10, 0.23, 0.36, 1.0))   # dark steel-blue
        elif luminance > 0.7:
            bbox.set_facecolor((0.13, 0.30, 0.33, 1.0))   # dark teal
        elif luminance > 0.55:
            bbox.set_facecolor((0.20, 0.40, 0.30, 1.0))   # muted green
        # saturated green nodes (lum <= 0.55) keep their original color
        bbox.set_edgecolor(LIGHT_TEXT)
        bbox.set_linewidth(1.5)

    # Also fix the title color (findobj catches it and any stray Text objects)
    for text_obj in fig.findobj(plt.Text):
        text_obj.set_color(WHITE_TEXT)

    _savefig(fig, 'python_econml_cate_tree.png')


# ============================================================================
# Step 6: Summary and comparison
# ============================================================================
def step_summary():
    """Print final summary with ground-truth comparison."""
    print("\n" + "=" * 70)
    print("SUMMARY: EconML CAUSAL FOREST vs GROUND TRUTH")
    print("=" * 70)

    gt = expected_ates()['NTL']
    print("\n  Expected ATEs (from DGP):")
    for comp, val in gt.items():
        print(f"    {comp}: {val:.3f}")

    print("\n  Three key findings to verify:")
    print("    1. Mining -> positive NTL and conflict (all x-0 > 0)")
    print("    2. Non-linear prices: 2-1 ~ 0, 3-1 >> 0, 3-2 >> 0")
    print("    3. GATEs by institutions:")
    print("       - 1-0: slope (institutions moderate mining effect)")
    print("       - 3-1: flat (institutions don't affect price effects)")

    print(f"\n  Results saved to: {RESULTS_DIR}/")
    print("  Figures: python_econml_*.png")
    print("  Tables:  descriptive-stats.csv, ate-table.csv")


# ============================================================================
# Main
# ============================================================================
def _print_environment():
    """Log Python and dependency versions to make later audits possible.

    Numbers in index.md are pinned to a specific stack (econml, sklearn,
    numpy). If those drift, the post numbers drift with them. The log
    written here is the receipt for a given run.
    """
    import sklearn
    import econml
    print("Environment:")
    print(f"  Python:       {sys.version.split()[0]}")
    print(f"  econml:       {econml.__version__}")
    print(f"  scikit-learn: {sklearn.__version__}")
    print(f"  numpy:        {np.__version__}")
    print(f"  pandas:       {pd.__version__}")
    print(f"  matplotlib:   {matplotlib.__version__}")
    print()


def main():
    print("=" * 70)
    print("TUTORIAL: CAUSAL MACHINE LEARNING AND THE RESOURCE CURSE")
    print("EconML CausalForestDML — Replicating Hodler, Lechner & Raschky (2023)")
    print("=" * 70)

    _print_environment()

    t_start = time.time()

    # Step 1: Load data
    df = step_load_data()

    # Step 2: Descriptive statistics
    step_descriptive_stats(df)

    # Step 2b: Naive comparison
    step_naive_comparison(df)

    # Step 3: Causal Forest estimation (NTL outcome only).
    # The post focuses on NTL; estimating conflict here would just slow the
    # script down and produce numbers that no section consumes.
    est_ntl = step_econml_estimation(df, 'ntl_log', 'NTL')

    # Step 4: ATEs (all pairwise comparisons with BLB inference)
    step_ate_table(est_ntl, df)

    # Step 5: GATE plots by institutional quality
    step_gate_plots(est_ntl, df)

    # Step 5b: Variable importance
    step_variable_importance(est_ntl)

    # Step 5c: CATE interpreter (EconML-specific)
    step_cate_interpreter(est_ntl, df)

    # Step 6: Summary
    step_summary()

    total = time.time() - t_start
    print(f"\n  Total runtime: {total/60:.1f} minutes")
    print("\n" + "=" * 70)
    print("Tutorial complete.")
    print("=" * 70)


class _Tee:
    """Duplicate writes across multiple text streams (e.g. stdout + file)."""

    def __init__(self, *streams):
        self._streams = streams

    def write(self, data):
        for s in self._streams:
            s.write(data)

    def flush(self):
        for s in self._streams:
            s.flush()


if __name__ == '__main__':
    log_path = os.path.join(_SCRIPT_DIR, 'execution_log.txt')
    real_stdout = sys.stdout
    with open(log_path, 'w') as fh:
        sys.stdout = _Tee(real_stdout, fh)
        try:
            main()
        finally:
            sys.stdout = real_stdout
    print(f"\nFull log written to: {log_path}")
