#!/usr/bin/env python3
"""
Gamma Shift Tilt × Buy_pct Interaction Backtest
================================================
Tests whether buy_pct from ES delta bars confirms/improves gamma tilt signal.
"""
import json
import warnings
import glob
import os
from datetime import datetime, timedelta, time as dtime
import numpy as np
import pandas as pd
from scipy import stats

warnings.filterwarnings('ignore')

WORKSPACE = '/Users/lutherbot/.openclaw/workspace'
TRACE_DIR = os.path.join(WORKSPACE, 'data/trace_uncorrupted')
ES_FILE = os.path.join(WORKSPACE, 'data/es_1min_delta_bars.csv')
FOMC_FILE = os.path.join(WORKSPACE, 'data/fomc_dates.json')
OUTPUT_FILE = os.path.join(WORKSPACE, 'data/gamma_tilt_buypct_results.json')

# Corrupted TRACE dates to exclude
CORRUPT_START = pd.Timestamp('2025-10-27')
CORRUPT_END = pd.Timestamp('2026-02-17')

# RTH times in ET
RTH_OPEN = dtime(9, 30)
RTH_CLOSE = dtime(16, 0)

# Checkpoints: minutes after 9:30 ET
CHECKPOINTS = {
    '30min': 30,
    '60min': 60,
    '90min': 90,
}

# Tilt thresholds
TILT_THRESHOLDS = [0.60, 0.65, 0.70, 0.75, 0.80, 0.85]
BUYPCT_BULL = 0.509  # top quartile
BUYPCT_BEAR = 0.491  # bottom quartile


def load_fomc_dates():
    with open(FOMC_FILE) as f:
        data = json.load(f)
    return set(pd.to_datetime(data['dates']).date)


def load_es_bars():
    """Load ES 1-min bars, parse timestamps as ET."""
    df = pd.read_csv(ES_FILE)
    # Timestamps are in UTC (14:30 = 9:30 ET)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    # Convert to ET
    df['timestamp_et'] = df['timestamp'].dt.tz_localize('UTC').dt.tz_convert('America/New_York')
    df['date'] = df['timestamp_et'].dt.date
    df['time_et'] = df['timestamp_et'].dt.time
    return df


def compute_buy_pct_at_checkpoints(es_df):
    """
    For each trading day, compute cumulative buy_pct at 30/60/90 min checkpoints.
    buy_pct = cumulative_buy_volume / cumulative_total_volume since 9:30.
    """
    # Filter RTH only
    rth = es_df[(es_df['time_et'] >= RTH_OPEN) & (es_df['time_et'] < RTH_CLOSE)].copy()
    
    results = []
    for date, day_df in rth.groupby('date'):
        day_df = day_df.sort_values('timestamp_et')
        day_vol = day_df['volume'].sum()
        if day_vol < 100_000:
            continue  # holiday filter
        
        # Cumulative buy and total volume from open
        cum_buy = day_df['buy_volume'].cumsum()
        cum_vol = day_df['volume'].cumsum()
        
        # Minutes since 9:30
        open_ts = day_df['timestamp_et'].iloc[0]
        open_930 = open_ts.replace(hour=9, minute=30, second=0, microsecond=0)
        day_df = day_df.copy()
        day_df['mins_since_open'] = (day_df['timestamp_et'] - open_930).dt.total_seconds() / 60.0
        day_df['cum_buy'] = cum_buy
        day_df['cum_vol'] = cum_vol
        day_df['cum_buy_pct'] = cum_buy / cum_vol
        
        row = {'date': date, 'rth_volume': day_vol}
        
        for cp_name, cp_mins in CHECKPOINTS.items():
            # Get the bar closest to checkpoint (within 1 min)
            mask = (day_df['mins_since_open'] >= cp_mins - 1) & (day_df['mins_since_open'] <= cp_mins + 1)
            if mask.any():
                cp_bar = day_df[mask].iloc[0]
                row[f'buy_pct_{cp_name}'] = cp_bar['cum_buy_pct']
                row[f'close_{cp_name}'] = cp_bar['close']
                row[f'timestamp_{cp_name}'] = cp_bar['timestamp_et']
            else:
                row[f'buy_pct_{cp_name}'] = np.nan
                row[f'close_{cp_name}'] = np.nan
                row[f'timestamp_{cp_name}'] = pd.NaT
        
        # Forward returns: 1H and 3H from each checkpoint
        for cp_name, cp_mins in CHECKPOINTS.items():
            cp_ts = row.get(f'timestamp_{cp_name}')
            if pd.isna(cp_ts):
                row[f'fwd_1h_{cp_name}'] = np.nan
                row[f'fwd_3h_{cp_name}'] = np.nan
                continue
            
            cp_close = row[f'close_{cp_name}']
            
            # 1H forward
            target_1h = cp_ts + pd.Timedelta(hours=1)
            mask_1h = (day_df['timestamp_et'] >= target_1h - pd.Timedelta(minutes=1)) & \
                      (day_df['timestamp_et'] <= target_1h + pd.Timedelta(minutes=1))
            if mask_1h.any():
                fwd_close_1h = day_df[mask_1h].iloc[0]['close']
                row[f'fwd_1h_{cp_name}'] = (fwd_close_1h - cp_close) / cp_close * 10000  # bps
            else:
                row[f'fwd_1h_{cp_name}'] = np.nan
            
            # 3H forward
            target_3h = cp_ts + pd.Timedelta(hours=3)
            mask_3h = (day_df['timestamp_et'] >= target_3h - pd.Timedelta(minutes=1)) & \
                      (day_df['timestamp_et'] <= target_3h + pd.Timedelta(minutes=1))
            if mask_3h.any():
                fwd_close_3h = day_df[mask_3h].iloc[0]['close']
                row[f'fwd_3h_{cp_name}'] = (fwd_close_3h - cp_close) / cp_close * 10000  # bps
            else:
                row[f'fwd_3h_{cp_name}'] = np.nan
        
        results.append(row)
    
    return pd.DataFrame(results)


def compute_gamma_tilt_from_trace(date_str, spot_price, checkpoint_times_et):
    """
    Compute gamma tilt from TRACE parquet for a given date.
    Tilt = sum of |neg_gamma| above spot / total |neg_gamma|
    
    Returns dict of {checkpoint: tilt_value}
    """
    fname = os.path.join(TRACE_DIR, f'intradayStrikeGEX_{date_str}.parquet')
    if not os.path.exists(fname):
        return {}
    
    df = pd.read_parquet(fname)
    
    # Total gamma across all participants
    gamma_cols = ['bd_gamma', 'cust_gamma', 'firm_gamma', 'mm_gamma', 'procust_gamma']
    gamma0_cols = ['bd_gamma_0', 'cust_gamma_0', 'firm_gamma_0', 'mm_gamma_0', 'procust_gamma_0']
    
    # Use net gamma (gamma + gamma_0 combined) - total across participants
    df['total_gamma'] = sum(df[c] for c in gamma_cols) + sum(df[c] for c in gamma0_cols)
    # Focus on negative gamma (dealer short gamma = magnets)
    df['neg_gamma'] = df['total_gamma'].clip(upper=0).abs()
    
    results = {}
    for cp_name, cp_target_et in checkpoint_times_et.items():
        if cp_name not in spot_price:
            continue
        spot = spot_price[cp_name]
        if pd.isna(spot):
            continue
        
        # Find nearest timestamp in TRACE data to checkpoint
        # TRACE timestamps are already in ET
        timestamps = df['timestamp'].unique()
        ts_series = pd.Series(timestamps)
        
        # Convert checkpoint to comparable format
        # checkpoint_times_et are pd.Timestamp with ET tz
        diffs = (ts_series - cp_target_et).abs()
        nearest_idx = diffs.idxmin()
        nearest_ts = timestamps[nearest_idx]
        
        # Only use if within 15 minutes
        if abs((nearest_ts - cp_target_et).total_seconds()) > 900:
            continue
        
        snap = df[df['timestamp'] == nearest_ts].copy()
        if len(snap) == 0:
            continue
        
        # Compute tilt: what fraction of negative gamma is above spot
        above = snap[snap['strike_price'] > spot]['neg_gamma'].sum()
        total = snap['neg_gamma'].sum()
        
        if total > 0:
            tilt = above / total
        else:
            tilt = 0.5  # neutral
        
        results[cp_name] = tilt
    
    return results


def main():
    print("=" * 80)
    print("GAMMA SHIFT TILT × BUY_PCT INTERACTION BACKTEST")
    print("=" * 80)
    
    # Load data
    print("\n[1] Loading ES 1-min delta bars...")
    es_df = load_es_bars()
    print(f"    Loaded {len(es_df):,} bars, {es_df['date'].nunique()} unique dates")
    print(f"    Date range: {es_df['date'].min()} to {es_df['date'].max()}")
    
    print("\n[2] Loading FOMC dates...")
    fomc_dates = load_fomc_dates()
    print(f"    {len(fomc_dates)} FOMC dates loaded")
    
    print("\n[3] Computing buy_pct at checkpoints...")
    bp_df = compute_buy_pct_at_checkpoints(es_df)
    print(f"    {len(bp_df)} trading days with buy_pct data")
    
    # List TRACE files available
    trace_files = sorted(glob.glob(os.path.join(TRACE_DIR, 'intradayStrikeGEX_*.parquet')))
    trace_dates = set()
    for f in trace_files:
        d = os.path.basename(f).replace('intradayStrikeGEX_', '').replace('.parquet', '')
        trace_dates.add(d)
    print(f"\n[4] Found {len(trace_dates)} TRACE parquet files")
    
    # Filter out excluded dates
    bp_df['date_str'] = bp_df['date'].apply(lambda d: d.strftime('%Y-%m-%d'))
    bp_df['date_pd'] = pd.to_datetime(bp_df['date'])
    
    # Exclude FOMC
    bp_df['is_fomc'] = bp_df['date'].apply(lambda d: d in fomc_dates)
    n_fomc = bp_df['is_fomc'].sum()
    
    # Exclude corrupted TRACE dates
    bp_df['is_corrupt'] = (bp_df['date_pd'] >= CORRUPT_START) & (bp_df['date_pd'] <= CORRUPT_END)
    n_corrupt = bp_df['is_corrupt'].sum()
    
    # Exclude holidays
    bp_df['is_holiday'] = bp_df['rth_volume'] < 100_000
    n_holiday = bp_df['is_holiday'].sum()
    
    bp_df['has_trace'] = bp_df['date_str'].apply(lambda d: d in trace_dates)
    
    valid = bp_df[~bp_df['is_fomc'] & ~bp_df['is_corrupt'] & ~bp_df['is_holiday'] & bp_df['has_trace']].copy()
    print(f"\n[5] After filters:")
    print(f"    Excluded: {n_fomc} FOMC, {n_corrupt} corrupted, {n_holiday} holidays")
    print(f"    No TRACE data: {(~bp_df['has_trace']).sum()}")
    print(f"    Valid days: {len(valid)}")
    
    # Compute gamma tilt for each valid day
    print("\n[6] Computing gamma tilt from TRACE data...")
    tilt_data = []
    for idx, row in valid.iterrows():
        date_str = row['date_str']
        # Build checkpoint times in ET
        cp_times = {}
        spot_prices = {}
        for cp_name in CHECKPOINTS:
            ts = row.get(f'timestamp_{cp_name}')
            if pd.notna(ts):
                cp_times[cp_name] = ts
                spot_prices[cp_name] = row[f'close_{cp_name}']
        
        tilts = compute_gamma_tilt_from_trace(date_str, spot_prices, cp_times)
        
        tilt_row = {'date': row['date']}
        for cp_name in CHECKPOINTS:
            tilt_row[f'tilt_{cp_name}'] = tilts.get(cp_name, np.nan)
        tilt_data.append(tilt_row)
    
    tilt_df = pd.DataFrame(tilt_data)
    valid = valid.merge(tilt_df, on='date', how='left')
    
    # Check tilt coverage
    for cp in CHECKPOINTS:
        n_tilt = valid[f'tilt_{cp}'].notna().sum()
        print(f"    Tilt coverage at {cp}: {n_tilt}/{len(valid)} days")
    
    # IS/OOS split (first 60% IS, last 40% OOS)
    valid = valid.sort_values('date').reset_index(drop=True)
    split_idx = int(len(valid) * 0.6)
    valid['split'] = ['IS' if i < split_idx else 'OOS' for i in range(len(valid))]
    is_df = valid[valid['split'] == 'IS']
    oos_df = valid[valid['split'] == 'OOS']
    print(f"\n[7] IS/OOS split: IS={len(is_df)} days ({is_df['date'].min()} to {is_df['date'].max()})")
    print(f"    OOS={len(oos_df)} days ({oos_df['date'].min()} to {oos_df['date'].max()})")
    
    # ============================================================
    # ANALYSIS 1: Buy_pct standalone signal (IC + quintiles)
    # ============================================================
    print("\n" + "=" * 80)
    print("ANALYSIS 1: BUY_PCT AS STANDALONE SIGNAL")
    print("=" * 80)
    
    results = {'buy_pct_standalone': {}, 'gamma_tilt_standalone': {}, 
               'combined': {}, 'confirmation_disagreement': {}}
    
    for cp in CHECKPOINTS:
        print(f"\n--- Checkpoint: {cp} ---")
        for horizon in ['1h', '3h']:
            col_signal = f'buy_pct_{cp}'
            col_ret = f'fwd_{horizon}_{cp}'
            
            for split_name, split_df in [('IS', is_df), ('OOS', oos_df)]:
                mask = split_df[col_signal].notna() & split_df[col_ret].notna()
                sub = split_df[mask]
                if len(sub) < 20:
                    continue
                
                ic = sub[col_signal].corr(sub[col_ret])
                n = len(sub)
                t_stat = ic * np.sqrt(n - 2) / np.sqrt(1 - ic**2) if abs(ic) < 1 else 0
                
                # Quintile analysis
                sub = sub.copy()
                sub['q'] = pd.qcut(sub[col_signal], 5, labels=False, duplicates='drop')
                q_stats = sub.groupby('q').agg(
                    n=(col_ret, 'count'),
                    mean_ret=(col_ret, 'mean'),
                    wr=(col_ret, lambda x: (x > 0).mean())
                ).to_dict('index')
                
                # Check monotonicity
                means = [q_stats[q]['mean_ret'] for q in sorted(q_stats.keys())]
                mono = 0
                if len(means) >= 3:
                    diffs = [means[i+1] - means[i] for i in range(len(means)-1)]
                    pos = sum(1 for d in diffs if d > 0)
                    neg = sum(1 for d in diffs if d < 0)
                    mono = max(pos, neg) / len(diffs) if diffs else 0
                
                key = f'{cp}_{horizon}_{split_name}'
                results['buy_pct_standalone'][key] = {
                    'checkpoint': cp, 'horizon': horizon, 'split': split_name,
                    'IC': round(ic, 4), 't_stat': round(t_stat, 2), 'N': n,
                    'monotonicity': round(mono, 2),
                    'Q1_wr': round(q_stats.get(0, {}).get('wr', 0), 3),
                    'Q5_wr': round(q_stats.get(4, {}).get('wr', 0) if 4 in q_stats else q_stats.get(max(q_stats.keys()), {}).get('wr', 0), 3),
                    'Q1_ret': round(q_stats.get(0, {}).get('mean_ret', 0), 2),
                    'Q5_ret': round(q_stats.get(4, {}).get('mean_ret', 0) if 4 in q_stats else q_stats.get(max(q_stats.keys()), {}).get('mean_ret', 0), 2),
                }
                
                print(f"  {split_name} {horizon}: IC={ic:.4f} t={t_stat:.2f} N={n} mono={mono:.2f} "
                      f"Q1_WR={q_stats.get(0,{}).get('wr',0):.1%} Q5_WR={q_stats.get(4,q_stats.get(max(q_stats.keys()),{})).get('wr',0):.1%}")
    
    # ============================================================
    # ANALYSIS 2: Gamma Tilt standalone (sanity check)
    # ============================================================
    print("\n" + "=" * 80)
    print("ANALYSIS 2: GAMMA TILT STANDALONE (SANITY CHECK)")
    print("=" * 80)
    
    for cp in CHECKPOINTS:
        print(f"\n--- Checkpoint: {cp} ---")
        col_tilt = f'tilt_{cp}'
        
        for horizon in ['1h', '3h']:
            col_ret = f'fwd_{horizon}_{cp}'
            
            for split_name, split_df in [('IS', is_df), ('OOS', oos_df)]:
                mask = split_df[col_tilt].notna() & split_df[col_ret].notna()
                sub = split_df[mask]
                if len(sub) < 20:
                    print(f"  {split_name} {horizon}: N={len(sub)} (too few)")
                    continue
                
                ic = sub[col_tilt].corr(sub[col_ret])
                n = len(sub)
                t_stat = ic * np.sqrt(n - 2) / np.sqrt(1 - ic**2) if abs(ic) < 1 else 0
                
                # Win rate at various tilt thresholds
                thresh_stats = {}
                for thresh in TILT_THRESHOLDS:
                    bullish = sub[sub[col_tilt] > thresh]
                    if len(bullish) >= 5:
                        wr = (bullish[col_ret] > 0).mean()
                        avg = bullish[col_ret].mean()
                        thresh_stats[f'>{thresh:.0%}'] = {'N': len(bullish), 'WR': round(wr, 3), 'avg_bps': round(avg, 2)}
                
                key = f'{cp}_{horizon}_{split_name}'
                results['gamma_tilt_standalone'][key] = {
                    'checkpoint': cp, 'horizon': horizon, 'split': split_name,
                    'IC': round(ic, 4), 't_stat': round(t_stat, 2), 'N': n,
                    'threshold_stats': thresh_stats
                }
                
                print(f"  {split_name} {horizon}: IC={ic:.4f} t={t_stat:.2f} N={n}")
                for thr, st in thresh_stats.items():
                    print(f"    Tilt {thr}: WR={st['WR']:.1%} avg={st['avg_bps']:.1f}bps N={st['N']}")
    
    # ============================================================
    # ANALYSIS 3: Confirmation / Disagreement
    # ============================================================
    print("\n" + "=" * 80)
    print("ANALYSIS 3: TILT × BUY_PCT CONFIRMATION / DISAGREEMENT")
    print("=" * 80)
    
    for cp in CHECKPOINTS:
        print(f"\n--- Checkpoint: {cp} ---")
        col_tilt = f'tilt_{cp}'
        col_bp = f'buy_pct_{cp}'
        
        for horizon in ['1h', '3h']:
            col_ret = f'fwd_{horizon}_{cp}'
            
            for split_name, split_df in [('IS', is_df), ('OOS', oos_df)]:
                mask = split_df[col_tilt].notna() & split_df[col_bp].notna() & split_df[col_ret].notna()
                sub = split_df[mask]
                if len(sub) < 20:
                    continue
                
                print(f"\n  {split_name} {horizon}:")
                
                for thresh in [0.65, 0.70, 0.75, 0.80]:
                    # Tilt alone
                    tilt_bull = sub[sub[col_tilt] > thresh]
                    if len(tilt_bull) < 5:
                        continue
                    wr_tilt = (tilt_bull[col_ret] > 0).mean()
                    avg_tilt = tilt_bull[col_ret].mean()
                    
                    # Tilt + buy_pct confirm
                    confirm = sub[(sub[col_tilt] > thresh) & (sub[col_bp] > BUYPCT_BULL)]
                    if len(confirm) >= 3:
                        wr_conf = (confirm[col_ret] > 0).mean()
                        avg_conf = confirm[col_ret].mean()
                    else:
                        wr_conf = np.nan
                        avg_conf = np.nan
                    
                    # Tilt + buy_pct disagree
                    disagree = sub[(sub[col_tilt] > thresh) & (sub[col_bp] < BUYPCT_BEAR)]
                    if len(disagree) >= 3:
                        wr_dis = (disagree[col_ret] > 0).mean()
                        avg_dis = disagree[col_ret].mean()
                    else:
                        wr_dis = np.nan
                        avg_dis = np.nan
                    
                    key = f'{cp}_{horizon}_{split_name}_tilt{int(thresh*100)}'
                    results['confirmation_disagreement'][key] = {
                        'checkpoint': cp, 'horizon': horizon, 'split': split_name,
                        'tilt_threshold': thresh,
                        'tilt_alone': {'N': len(tilt_bull), 'WR': round(wr_tilt, 3), 'avg_bps': round(avg_tilt, 2)},
                        'both_bullish': {'N': len(confirm), 
                                        'WR': round(wr_conf, 3) if not np.isnan(wr_conf) else None,
                                        'avg_bps': round(avg_conf, 2) if not np.isnan(avg_conf) else None},
                        'tilt_bull_bp_bear': {'N': len(disagree),
                                             'WR': round(wr_dis, 3) if not np.isnan(wr_dis) else None,
                                             'avg_bps': round(avg_dis, 2) if not np.isnan(avg_dis) else None}
                    }
                    
                    n_conf = len(confirm)
                    n_dis = len(disagree)
                    wr_conf_s = f"{wr_conf:.1%}" if not np.isnan(wr_conf) else "N/A"
                    wr_dis_s = f"{wr_dis:.1%}" if not np.isnan(wr_dis) else "N/A"
                    print(f"    Tilt>{thresh:.0%}: alone WR={wr_tilt:.1%}(N={len(tilt_bull)}) | "
                          f"+bp_bull WR={wr_conf_s}(N={n_conf}) | +bp_bear WR={wr_dis_s}(N={n_dis})")
    
    # ============================================================
    # ANALYSIS 4: Combined composite model
    # ============================================================
    print("\n" + "=" * 80)
    print("ANALYSIS 4: COMBINED COMPOSITE (Z-SCORE AVERAGE)")
    print("=" * 80)
    
    for cp in CHECKPOINTS:
        print(f"\n--- Checkpoint: {cp} ---")
        col_tilt = f'tilt_{cp}'
        col_bp = f'buy_pct_{cp}'
        
        for horizon in ['1h', '3h']:
            col_ret = f'fwd_{horizon}_{cp}'
            
            # Use IS to compute z-score params, apply to both
            mask_is = is_df[col_tilt].notna() & is_df[col_bp].notna() & is_df[col_ret].notna()
            is_sub = is_df[mask_is].copy()
            
            if len(is_sub) < 20:
                continue
            
            # Z-score standardize using IS mean/std
            tilt_mean = is_sub[col_tilt].mean()
            tilt_std = is_sub[col_tilt].std()
            bp_mean = is_sub[col_bp].mean()
            bp_std = is_sub[col_bp].std()
            
            for split_name, split_df in [('IS', is_df), ('OOS', oos_df)]:
                mask = split_df[col_tilt].notna() & split_df[col_bp].notna() & split_df[col_ret].notna()
                sub = split_df[mask].copy()
                if len(sub) < 20:
                    continue
                
                sub['tilt_z'] = (sub[col_tilt] - tilt_mean) / tilt_std
                sub['bp_z'] = (sub[col_bp] - bp_mean) / bp_std
                sub['composite'] = (sub['tilt_z'] + sub['bp_z']) / 2
                
                # IC of composite vs individual
                ic_tilt = sub[col_tilt].corr(sub[col_ret])
                ic_bp = sub[col_bp].corr(sub[col_ret])
                ic_comp = sub['composite'].corr(sub[col_ret])
                n = len(sub)
                
                t_comp = ic_comp * np.sqrt(n - 2) / np.sqrt(1 - ic_comp**2) if abs(ic_comp) < 1 else 0
                
                # Quintile WR for composite
                sub['q'] = pd.qcut(sub['composite'], 5, labels=False, duplicates='drop')
                q1 = sub[sub['q'] == 0]
                q5 = sub[sub['q'] == sub['q'].max()]
                
                key = f'{cp}_{horizon}_{split_name}'
                results['combined'][key] = {
                    'checkpoint': cp, 'horizon': horizon, 'split': split_name,
                    'IC_tilt': round(ic_tilt, 4),
                    'IC_buypct': round(ic_bp, 4),
                    'IC_composite': round(ic_comp, 4),
                    't_stat_composite': round(t_comp, 2),
                    'N': n,
                    'Q1_wr': round((q1[col_ret] > 0).mean(), 3) if len(q1) > 0 else None,
                    'Q5_wr': round((q5[col_ret] > 0).mean(), 3) if len(q5) > 0 else None,
                    'Q1_ret': round(q1[col_ret].mean(), 2) if len(q1) > 0 else None,
                    'Q5_ret': round(q5[col_ret].mean(), 2) if len(q5) > 0 else None,
                }
                
                print(f"  {split_name} {horizon}: IC_tilt={ic_tilt:.4f} IC_bp={ic_bp:.4f} IC_combo={ic_comp:.4f} "
                      f"t={t_comp:.2f} N={n}")
                if len(q1) > 0 and len(q5) > 0:
                    print(f"    Q1(bearish): WR={(q1[col_ret]>0).mean():.1%} avg={q1[col_ret].mean():.1f}bps")
                    print(f"    Q5(bullish): WR={(q5[col_ret]>0).mean():.1%} avg={q5[col_ret].mean():.1f}bps")
    
    # ============================================================
    # SUMMARY TABLE
    # ============================================================
    print("\n" + "=" * 80)
    print("SUMMARY TABLE")
    print("=" * 80)
    
    print("\n┌─────────────────────────────────────────────────────────────────────────────┐")
    print("│ Buy_pct Standalone IC                                                       │")
    print("├───────────┬──────────┬────────┬────────┬────────┬──────┬───────┬───────────┤")
    print("│ Checkpoint│ Horizon  │ Split  │   IC   │ t-stat │  N   │ Q1 WR │  Q5 WR    │")
    print("├───────────┼──────────┼────────┼────────┼────────┼──────┼───────┼───────────┤")
    for key, val in sorted(results['buy_pct_standalone'].items()):
        print(f"│ {val['checkpoint']:9s} │ {val['horizon']:8s} │ {val['split']:6s} │ {val['IC']:+.4f} │ {val['t_stat']:+6.2f} │ {val['N']:4d} │ {val['Q1_wr']:.1%}  │ {val['Q5_wr']:.1%}     │")
    print("└───────────┴──────────┴────────┴────────┴────────┴──────┴───────┴───────────┘")
    
    print("\n┌─────────────────────────────────────────────────────────────────────────────┐")
    print("│ Gamma Tilt × Buy_pct: Confirmation vs Disagreement (1H forward)             │")
    print("├───────────┬─────────┬────────┬──────────────────┬──────────────────┬─────────┤")
    print("│ Checkpoint│ Tilt Thr│ Split  │  Both Bull WR(N) │  Disagree WR(N)  │Tilt WR  │")
    print("├───────────┼─────────┼────────┼──────────────────┼──────────────────┼─────────┤")
    for key, val in sorted(results['confirmation_disagreement'].items()):
        if '_1h_' not in key:
            continue
        bb = val['both_bullish']
        dis = val['tilt_bull_bp_bear']
        ta = val['tilt_alone']
        bb_s = f"{bb['WR']:.1%}({bb['N']})" if bb['WR'] is not None else f"N/A({bb['N']})"
        dis_s = f"{dis['WR']:.1%}({dis['N']})" if dis['WR'] is not None else f"N/A({dis['N']})"
        print(f"│ {val['checkpoint']:9s} │ >{val['tilt_threshold']:.0%}   │ {val['split']:6s} │ {bb_s:16s} │ {dis_s:16s} │ {ta['WR']:.1%}   │")
    print("└───────────┴─────────┴────────┴──────────────────┴──────────────────┴─────────┘")
    
    print("\n┌─────────────────────────────────────────────────────────────────────────────┐")
    print("│ Combined Model IC Comparison                                                │")
    print("├───────────┬──────────┬────────┬──────────┬──────────┬──────────┬─────┬───────┤")
    print("│ Checkpoint│ Horizon  │ Split  │ IC_tilt  │ IC_bp    │ IC_combo │  N  │ t-stat│")
    print("├───────────┼──────────┼────────┼──────────┼──────────┼──────────┼─────┼───────┤")
    for key, val in sorted(results['combined'].items()):
        print(f"│ {val['checkpoint']:9s} │ {val['horizon']:8s} │ {val['split']:6s} │ {val['IC_tilt']:+.4f}  │ {val['IC_buypct']:+.4f}  │ {val['IC_composite']:+.4f}  │ {val['N']:3d} │ {val['t_stat_composite']:+5.2f} │")
    print("└───────────┴──────────┴────────┴──────────┴──────────┴──────────┴─────┴───────┘")
    
    # Save results
    # Convert any remaining non-serializable types
    def make_serializable(obj):
        if isinstance(obj, (np.integer,)):
            return int(obj)
        elif isinstance(obj, (np.floating,)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, pd.Timestamp):
            return str(obj)
        elif isinstance(obj, (datetime,)):
            return obj.isoformat()
        return obj
    
    class NpEncoder(json.JSONEncoder):
        def default(self, obj):
            val = make_serializable(obj)
            if val is not obj:
                return val
            return super().default(obj)
    
    with open(OUTPUT_FILE, 'w') as f:
        json.dump(results, f, indent=2, cls=NpEncoder)
    
    print(f"\n✅ Results saved to {OUTPUT_FILE}")
    
    # Final verdict
    print("\n" + "=" * 80)
    print("VERDICT")
    print("=" * 80)
    
    # Check if buy_pct has meaningful IC
    bp_ics = [(v['IC'], v['t_stat'], v['split'], v['checkpoint'], v['horizon']) 
              for v in results['buy_pct_standalone'].values()]
    oos_bp = [x for x in bp_ics if x[2] == 'OOS']
    
    print("\nBuy_pct OOS performance:")
    for ic, t, split, cp, hz in sorted(oos_bp, key=lambda x: abs(x[0]), reverse=True):
        sig = "✅" if abs(t) > 2 else "❌"
        print(f"  {sig} {cp} {hz}: IC={ic:+.4f} t={t:+.2f}")
    
    # Check confirmation effect
    print("\nDoes buy_pct improve gamma tilt?")
    for key, val in sorted(results['confirmation_disagreement'].items()):
        if '_1h_OOS' in key and val['tilt_threshold'] == 0.75:
            ta = val['tilt_alone']
            bb = val['both_bullish']
            dis = val['tilt_bull_bp_bear']
            print(f"  {val['checkpoint']} 1H OOS (tilt>75%):")
            print(f"    Tilt alone: WR={ta['WR']:.1%} N={ta['N']}")
            if bb['WR'] is not None:
                diff = bb['WR'] - ta['WR']
                print(f"    +buy_pct bull: WR={bb['WR']:.1%} N={bb['N']} (Δ={diff:+.1%})")
            if dis['WR'] is not None:
                diff = dis['WR'] - ta['WR']
                print(f"    +buy_pct bear: WR={dis['WR']:.1%} N={dis['N']} (Δ={diff:+.1%})")


if __name__ == '__main__':
    main()
