#!/usr/bin/env python3
"""
Final deep dive: reverse causality checks and forward-looking tests
for Call Wall Ratio and other top signals.
"""

import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
import json
import warnings
warnings.filterwarnings('ignore')

DATA_DIR = Path('/Users/lutherbot/.openclaw/workspace/data')
TRACE_DIR = DATA_DIR / 'trace_normalized'

with open(DATA_DIR / 'fomc_dates.json') as f:
    FOMC_DATES = set(json.load(f)['dates'])

def get_gex_tier(val):
    if val < -100e6: return 'DEEP_NEG'
    if val < 0: return 'NEG'
    if val < 100e6: return 'LOW_POS'
    if val < 250e6: return 'MID_POS'
    if val < 500e6: return 'HIGH_POS'
    return 'EXTREME_POS'

# Load SPX
spx = pd.read_csv(DATA_DIR / 'spx_5min_polygon.csv', parse_dates=['datetime'])
spx['datetime'] = pd.to_datetime(spx['datetime'], utc=True).dt.tz_convert('America/New_York')
spx = spx.sort_values('datetime').reset_index(drop=True)
spx['date'] = spx['datetime'].dt.date

print("Loading TRACE and computing call wall ratio + tilt...")
files = sorted(TRACE_DIR.glob('intradayStrikeGEX_*.parquet'))

all_daily = []
for i, f in enumerate(files):
    date_str = f.stem.split('_')[-1]
    if date_str in FOMC_DATES:
        continue
    try:
        df = pd.read_parquet(f)
    except:
        continue
    
    row = {'date_str': date_str, 'date': pd.Timestamp(date_str)}
    
    for target_h, target_m, label in [
        (9, 30, '0930'), (9, 40, '0940'), (10, 0, '1000'), (10, 30, '1030'),
        (11, 0, '1100'), (12, 0, '1200'), (13, 0, '1300'), (14, 0, '1400'),
    ]:
        snap = df[(df['timestamp'].dt.hour == target_h) & (df['timestamp'].dt.minute == target_m)]
        if len(snap) == 0:
            continue
        
        # Net gamma
        row[f'mm_{label}'] = snap['mm_gamma'].sum()
        row[f'cust_{label}'] = snap['cust_gamma'].sum()
        
        # Tilt (above/total absolute gamma)
        center = snap['strike_price'].median()
        above = snap[snap['strike_price'] >= center]['mm_gamma'].sum()
        below = snap[snap['strike_price'] < center]['mm_gamma'].sum()
        total = abs(above) + abs(below)
        if total > 0:
            row[f'tilt_{label}'] = above / total
        
        # Call Wall Ratio (positive gamma above / total positive gamma)
        pos_above = snap[snap['strike_price'] >= center]['mm_gamma'].clip(lower=0).sum()
        pos_below = snap[snap['strike_price'] < center]['mm_gamma'].clip(lower=0).sum()
        total_pos = pos_above + pos_below
        if total_pos > 0:
            row[f'cwr_{label}'] = pos_above / total_pos
        
        # Neg gamma ratio (negative gamma below / total negative gamma) - put wall strength
        neg_above = snap[snap['strike_price'] >= center]['mm_gamma'].clip(upper=0).abs().sum()
        neg_below = snap[snap['strike_price'] < center]['mm_gamma'].clip(upper=0).abs().sum()
        total_neg = neg_above + neg_below
        if total_neg > 0:
            row[f'put_wall_{label}'] = neg_below / total_neg
        
        # HHI
        mm_abs = snap['mm_gamma'].abs()
        total_abs = mm_abs.sum()
        if total_abs > 0:
            row[f'hhi_{label}'] = ((mm_abs / total_abs) ** 2).sum()
            weighted_mean = np.average(snap['strike_price'], weights=mm_abs)
            row[f'spread_{label}'] = np.sqrt(np.average((snap['strike_price'] - weighted_mean)**2, weights=mm_abs))
    
    # GEX velocity
    if 'mm_0930' in row and 'mm_1030' in row:
        row['velocity_1h'] = row['mm_1030'] - row['mm_0930']
    
    # Tilt change
    if 'tilt_0930' in row and 'tilt_1030' in row:
        row['tilt_delta'] = row['tilt_1030'] - row['tilt_0930']
    
    # CWR change  
    if 'cwr_0930' in row and 'cwr_1030' in row:
        row['cwr_delta'] = row['cwr_1030'] - row['cwr_0930']
    
    open_gex = row.get('mm_0930', row.get('mm_0940', np.nan))
    try:
        if isinstance(open_gex, (int, float, np.floating, np.integer)) and not np.isnan(float(open_gex)):
            row['gex_tier'] = get_gex_tier(float(open_gex))
    except:
        pass
    
    all_daily.append(row)
    if (i+1) % 100 == 0:
        print(f"  {i+1}...")

daily = pd.DataFrame(all_daily)

# Merge SPX
spx_daily = spx.groupby('date').agg(
    spx_open=('open', 'first'), spx_high=('high', 'max'),
    spx_low=('low', 'min'), spx_close=('close', 'last')
).reset_index()
spx_daily['date'] = pd.to_datetime(spx_daily['date'])
daily['date'] = pd.to_datetime(daily['date'])
daily = daily.merge(spx_daily, on='date', how='inner')
daily['ret_oc'] = daily['spx_close'] / daily['spx_open'] - 1
daily['range_pct'] = (daily['spx_high'] - daily['spx_low']) / daily['spx_open']

# Build forward returns from each time
print("Building forward returns...")
for idx, row in daily.iterrows():
    date_val = row['date'].date()
    day_spx = spx[spx['date'] == date_val].sort_values('datetime')
    if len(day_spx) < 5:
        continue
    close_price = day_spx.iloc[-1]['close']
    
    for th, tm, label in [
        (9, 30, '0930'), (10, 0, '1000'), (10, 30, '1030'),
        (11, 0, '1100'), (12, 0, '1200'), (13, 0, '1300'), (14, 0, '1400'),
    ]:
        snap = day_spx[(day_spx['datetime'].dt.hour == th) & (day_spx['datetime'].dt.minute == tm)]
        if len(snap) > 0:
            daily.loc[idx, f'ret_{label}_close'] = close_price / snap.iloc[0]['close'] - 1
            daily.loc[idx, f'price_{label}'] = snap.iloc[0]['close']
    
    # Return from open to various times (for reverse causality)
    open_price = day_spx.iloc[0]['open']
    for th, tm, label in [
        (10, 0, '1000'), (10, 30, '1030'), (12, 0, '1200'), (14, 0, '1400'),
    ]:
        snap = day_spx[(day_spx['datetime'].dt.hour == th) & (day_spx['datetime'].dt.minute == tm)]
        if len(snap) > 0:
            daily.loc[idx, f'ret_open_{label}'] = snap.iloc[0]['close'] / open_price - 1
    
    # 1H and 2H forward from 10:30
    for th, tm, label in [(11, 30, '1h'), (12, 30, '2h')]:
        snap = day_spx[(day_spx['datetime'].dt.hour == th) & (day_spx['datetime'].dt.minute == tm)]
        t1030 = day_spx[(day_spx['datetime'].dt.hour == 10) & (day_spx['datetime'].dt.minute == 30)]
        if len(snap) > 0 and len(t1030) > 0:
            daily.loc[idx, f'ret_1030_{label}'] = snap.iloc[0]['close'] / t1030.iloc[0]['close'] - 1

# IS/OOS split
daily = daily.sort_values('date').reset_index(drop=True)
all_dates = sorted(daily['date_str'].unique())
n_is = int(len(all_dates) * 0.6)
is_dates = set(all_dates[:n_is])
oos_dates = set(all_dates[n_is:])
daily['is_oos'] = daily['date_str'].isin(oos_dates)
daily['is_is'] = daily['date_str'].isin(is_dates)
oos = daily[daily['is_oos']].copy()
iis = daily[daily['is_is']].copy()

print(f"IS: {len(iis)}, OOS: {len(oos)}")

# ============================================================
# CALL WALL RATIO - DEEP ANALYSIS
# ============================================================
print("\n" + "="*70)
print("CALL WALL RATIO (CWR) - COMPREHENSIVE ANALYSIS")
print("="*70)

# 1. Forward-looking test: Does CWR at time T predict T→close?
print("\n--- CWR at time T → T→close (FORWARD RETURN) ---")
for time_label in ['0930', '1000', '1030', '1200', '1400']:
    cwr_col = f'cwr_{time_label}'
    fwd_col = f'ret_{time_label}_close'
    if cwr_col in oos.columns and fwd_col in oos.columns:
        sub = oos.dropna(subset=[cwr_col, fwd_col])
        if len(sub) > 20:
            ic, p = stats.spearmanr(sub[cwr_col], sub[fwd_col])
            print(f"  CWR@{time_label} → {time_label}→Close: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")

# 2. Reverse causality: Does return from open→T predict CWR at T?
print("\n--- REVERSE CAUSALITY: Open→T return vs CWR@T ---")
for time_label in ['1000', '1030', '1200', '1400']:
    cwr_col = f'cwr_{time_label}'
    lag_col = f'ret_open_{time_label}'
    if cwr_col in oos.columns and lag_col in oos.columns:
        sub = oos.dropna(subset=[cwr_col, lag_col])
        if len(sub) > 20:
            ic, p = stats.spearmanr(sub[lag_col], sub[cwr_col])
            print(f"  Open→{time_label} return → CWR@{time_label}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")

# 3. CWR at 9:30 (before any intraday price action) → O→C
print("\n--- CWR@9:30 (pre-market gamma) → O→C ---")
for split, mask in [('IS', daily['is_is']), ('OOS', daily['is_oos'])]:
    sub = daily[mask].dropna(subset=['cwr_0930', 'ret_oc'])
    if len(sub) > 20:
        ic, p = stats.spearmanr(sub['cwr_0930'], sub['ret_oc'])
        # Quintiles
        sub = sub.copy()
        sub['q'] = pd.qcut(sub['cwr_0930'], 5, labels=['Q1', 'Q2', 'Q3', 'Q4', 'Q5'], duplicates='drop')
        print(f"\n  {split} CWR@9:30 → O→C: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
        for q in ['Q1', 'Q2', 'Q3', 'Q4', 'Q5']:
            qd = sub[sub['q'] == q]
            if len(qd) > 0:
                wr = (qd['ret_oc'] > 0).mean()
                avg = qd['ret_oc'].mean() * 10000
                print(f"    {q}: WR={wr:.1%}, Avg={avg:+.1f}bps, N={len(qd)}, Range=[{qd['cwr_0930'].min():.3f}, {qd['cwr_0930'].max():.3f}]")

# 4. CWR at 10:30 → 10:30→close (the true forward test)
print("\n--- CWR@10:30 FORWARD: → 10:30→Close ---")
for split, mask in [('IS', daily['is_is']), ('OOS', daily['is_oos'])]:
    sub = daily[mask].dropna(subset=['cwr_1030', 'ret_1030_close'])
    if len(sub) > 20:
        ic, p = stats.spearmanr(sub['cwr_1030'], sub['ret_1030_close'])
        sub = sub.copy()
        sub['q'] = pd.qcut(sub['cwr_1030'], 5, labels=['Q1', 'Q2', 'Q3', 'Q4', 'Q5'], duplicates='drop')
        print(f"\n  {split} CWR@10:30 → 10:30→Close: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
        for q in ['Q1', 'Q2', 'Q3', 'Q4', 'Q5']:
            qd = sub[sub['q'] == q]
            if len(qd) > 0:
                wr = (qd['ret_1030_close'] > 0).mean()
                avg = qd['ret_1030_close'].mean() * 10000
                print(f"    {q}: WR={wr:.1%}, Avg={avg:+.1f}bps, N={len(qd)}")

# 5. CWR → 1H and 2H forward from 10:30
print("\n--- CWR@10:30 → 1H and 2H forward ---")
for ret_col, label in [('ret_1030_1h', '1H fwd'), ('ret_1030_2h', '2H fwd')]:
    if ret_col in oos.columns:
        sub = oos.dropna(subset=['cwr_1030', ret_col])
        if len(sub) > 20:
            ic, p = stats.spearmanr(sub['cwr_1030'], sub[ret_col])
            print(f"  CWR@10:30 → {label}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")

# ============================================================
# TILT DELTA (change in tilt 9:30→10:30) - FORWARD TEST
# ============================================================
print("\n" + "="*70)
print("TILT DELTA (9:30→10:30) - FORWARD TEST")
print("="*70)

# Does tilt change predict 10:30→close?
print("\n--- Tilt Delta → 10:30→Close ---")
for split, mask in [('IS', daily['is_is']), ('OOS', daily['is_oos'])]:
    sub = daily[mask].dropna(subset=['tilt_delta', 'ret_1030_close'])
    if len(sub) > 20:
        ic, p = stats.spearmanr(sub['tilt_delta'], sub['ret_1030_close'])
        sub = sub.copy()
        sub['q'] = pd.qcut(sub['tilt_delta'], 5, labels=['Q1', 'Q2', 'Q3', 'Q4', 'Q5'], duplicates='drop')
        print(f"\n  {split} Tilt Delta → 10:30→Close: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
        for q in ['Q1', 'Q2', 'Q3', 'Q4', 'Q5']:
            qd = sub[sub['q'] == q]
            if len(qd) > 0:
                wr = (qd['ret_1030_close'] > 0).mean()
                avg = qd['ret_1030_close'].mean() * 10000
                print(f"    {q}: WR={wr:.1%}, Avg={avg:+.1f}bps, N={len(qd)}")

# Reverse causality: Does first hour return predict tilt delta?
print("\n--- REVERSE CAUSALITY: Open→10:30 return vs Tilt Delta ---")
sub = oos.dropna(subset=['tilt_delta', 'ret_open_1030'])
if len(sub) > 20:
    ic, p = stats.spearmanr(sub['ret_open_1030'], sub['tilt_delta'])
    print(f"  Open→10:30 return → Tilt Delta: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")

# ============================================================
# CWR DELTA - FORWARD TEST
# ============================================================
print("\n" + "="*70)
print("CWR DELTA (9:30→10:30) - FORWARD TEST")
print("="*70)

if 'cwr_delta' in daily.columns:
    print("\n--- CWR Delta → O→C ---")
    for split, mask in [('IS', daily['is_is']), ('OOS', daily['is_oos'])]:
        sub = daily[mask].dropna(subset=['cwr_delta', 'ret_oc'])
        if len(sub) > 20:
            ic, p = stats.spearmanr(sub['cwr_delta'], sub['ret_oc'])
            print(f"  {split}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
    
    print("\n--- CWR Delta → 10:30→Close ---")
    for split, mask in [('IS', daily['is_is']), ('OOS', daily['is_oos'])]:
        sub = daily[mask].dropna(subset=['cwr_delta', 'ret_1030_close'])
        if len(sub) > 20:
            ic, p = stats.spearmanr(sub['cwr_delta'], sub['ret_1030_close'])
            print(f"  {split}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
    
    # Reverse causality
    sub = oos.dropna(subset=['cwr_delta', 'ret_open_1030'])
    if len(sub) > 20:
        ic, p = stats.spearmanr(sub['ret_open_1030'], sub['cwr_delta'])
        print(f"\n  REVERSE: Open→10:30 return → CWR Delta: IC={ic:+.4f}, p={p:.4f}")

# ============================================================
# RESIDUALIZED TEST: CWR/Tilt controlling for morning return
# ============================================================
print("\n" + "="*70)
print("RESIDUALIZED TESTS (controlling for morning return)")
print("="*70)

# Regress CWR on morning return, then test residual vs afternoon return
from numpy.polynomial import polynomial as P

for sig_col, time_label in [('cwr_1030', '1030'), ('tilt_1030', '1030'), ('cwr_1200', '1200')]:
    lag_col = f'ret_open_{time_label}'
    fwd_col = f'ret_{time_label}_close'
    
    sub = oos.dropna(subset=[sig_col, lag_col, fwd_col]).copy()
    if len(sub) < 30:
        continue
    
    # Residualize: regress signal on morning return
    slope, intercept = np.polyfit(sub[lag_col], sub[sig_col], 1)
    sub['signal_resid'] = sub[sig_col] - (slope * sub[lag_col] + intercept)
    
    ic_raw, p_raw = stats.spearmanr(sub[sig_col], sub[fwd_col])
    ic_resid, p_resid = stats.spearmanr(sub['signal_resid'], sub[fwd_col])
    
    print(f"\n  {sig_col} → {fwd_col}:")
    print(f"    Raw IC:         {ic_raw:+.4f}, p={p_raw:.4f}")
    print(f"    Residualized IC: {ic_resid:+.4f}, p={p_resid:.4f}")
    print(f"    (after removing morning return component)")

# ============================================================
# CONDITIONAL CWR: by GEX tier
# ============================================================
print("\n" + "="*70)
print("CWR@10:30 BY GEX TIER (OOS)")
print("="*70)

for tier in ['LOW_POS', 'MID_POS', 'HIGH_POS', 'EXTREME_POS']:
    sub = oos[(oos['gex_tier'] == tier)].dropna(subset=['cwr_1030', 'ret_oc'])
    if len(sub) < 10:
        continue
    ic, p = stats.spearmanr(sub['cwr_1030'], sub['ret_oc'])
    
    high = sub[sub['cwr_1030'] > 0.65]
    low = sub[sub['cwr_1030'] < 0.5]
    wr_high = (high['ret_oc'] > 0).mean() if len(high) > 0 else np.nan
    wr_low = (low['ret_oc'] > 0).mean() if len(low) > 0 else np.nan
    
    print(f"  {tier:12s}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
    print(f"    High CWR (>0.65): WR={wr_high:.1%}, N={len(high)}")
    print(f"    Low CWR (<0.50):  WR={wr_low:.1%}, N={len(low)}")

# ============================================================
# PRACTICAL TRADING SIGNALS - THRESHOLD ANALYSIS
# ============================================================
print("\n" + "="*70)
print("PRACTICAL TRADING SIGNALS - THRESHOLD ANALYSIS (OOS)")
print("="*70)

# For CWR@10:30 and Tilt@10:30, test actionable thresholds
for sig_name, sig_col in [('CWR@10:30', 'cwr_1030'), ('Tilt@10:30', 'tilt_1030'), 
                           ('CWR@9:30', 'cwr_0930'), ('Tilt Delta', 'tilt_delta')]:
    print(f"\n--- {sig_name} ---")
    sub = oos.dropna(subset=[sig_col, 'ret_oc']).copy()
    if len(sub) < 30:
        continue
    
    # Test various thresholds
    if 'delta' in sig_col.lower():
        thresholds_long = [0.05, 0.1, 0.15, 0.2]
        thresholds_short = [-0.05, -0.1, -0.15, -0.2]
    else:
        thresholds_long = [0.55, 0.6, 0.65, 0.7, 0.75]
        thresholds_short = [0.45, 0.4, 0.35, 0.3, 0.25]
    
    print(f"  LONG when {sig_name} > threshold:")
    for t in thresholds_long:
        triggered = sub[sub[sig_col] > t]
        if len(triggered) >= 5:
            wr = (triggered['ret_oc'] > 0).mean()
            avg = triggered['ret_oc'].mean() * 10000
            freq = len(triggered) / len(sub) * 100
            print(f"    >{t:.2f}: WR={wr:.1%}, Avg={avg:+.1f}bps, N={len(triggered)} ({freq:.0f}% of days)")
    
    print(f"  SHORT when {sig_name} < threshold:")
    for t in thresholds_short:
        triggered = sub[sub[sig_col] < t]
        if len(triggered) >= 5:
            wr = (triggered['ret_oc'] < 0).mean()  # short WR
            avg = -triggered['ret_oc'].mean() * 10000  # short PnL
            freq = len(triggered) / len(sub) * 100
            print(f"    <{t:.2f}: Short WR={wr:.1%}, Avg={avg:+.1f}bps, N={len(triggered)} ({freq:.0f}% of days)")

# ============================================================
# HHI and SPREAD - VOL SIGNALS
# ============================================================
print("\n" + "="*70)
print("VOL PREDICTION SIGNALS (OOS)")  
print("="*70)

for sig_name, sig_col in [('HHI@9:30', 'hhi_0930'), ('Spread@9:30', 'spread_0930')]:
    sub = oos.dropna(subset=[sig_col, 'range_pct']).copy()
    if len(sub) < 30:
        continue
    
    ic, p = stats.spearmanr(sub[sig_col], sub['range_pct'])
    sub['q'] = pd.qcut(sub[sig_col], 5, labels=['Q1', 'Q2', 'Q3', 'Q4', 'Q5'], duplicates='drop')
    
    print(f"\n  {sig_name} → Daily Range: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")
    for q in ['Q1', 'Q2', 'Q3', 'Q4', 'Q5']:
        qd = sub[sub['q'] == q]
        if len(qd) > 0:
            avg_range = qd['range_pct'].mean() * 10000
            print(f"    {q}: Avg Range={avg_range:.0f}bps, N={len(qd)}")

# ============================================================
# VELOCITY → O→C: IS THE VELOCITY JUST FIRST HOUR RETURN?
# ============================================================
print("\n" + "="*70)
print("GEX VELOCITY: Is it just measuring price momentum?")
print("="*70)

# Correlation between velocity and first hour return
sub = oos.dropna(subset=['velocity_1h', 'ret_open_1030']).copy()
if len(sub) > 20:
    ic, p = stats.spearmanr(sub['velocity_1h'], sub['ret_open_1030'])
    print(f"  Velocity vs First Hour Return: IC={ic:+.4f}, p={p:.4f}")
    
    # Residualized test
    slope, intercept = np.polyfit(sub['ret_open_1030'], sub['velocity_1h'], 1)
    sub['vel_resid'] = sub['velocity_1h'] - (slope * sub['ret_open_1030'] + intercept)
    
    # Does residualized velocity predict afternoon?
    ic_resid, p_resid = stats.spearmanr(sub['vel_resid'], sub['ret_1030_close'].dropna())
    sub2 = sub.dropna(subset=['vel_resid', 'ret_1030_close'])
    ic_resid, p_resid = stats.spearmanr(sub2['vel_resid'], sub2['ret_1030_close'])
    print(f"  Residualized Velocity → 10:30→Close: IC={ic_resid:+.4f}, p={p_resid:.4f}")

# ============================================================
# FINAL FREQ/WR SUMMARY FOR TRADEABLE SIGNALS  
# ============================================================
print("\n" + "="*70)
print("FINAL TRADEABLE SIGNAL SUMMARY (OOS)")
print("="*70)

print("""
Key: Only signals with FORWARD-LOOKING power are truly tradeable.
Signals that just reflect past price action are "momentum confirmation" at best.
""")

# Test each signal for forward power
tests = [
    ('CWR@9:30→O→C', 'cwr_0930', 'ret_oc'),
    ('CWR@10:30→O→C', 'cwr_1030', 'ret_oc'),
    ('CWR@10:30→10:30→Close', 'cwr_1030', 'ret_1030_close'),
    ('Tilt@10:30→O→C', 'tilt_1030', 'ret_oc'),
    ('Tilt@10:30→10:30→Close', 'tilt_1030', 'ret_1030_close'),
    ('Tilt Delta→O→C', 'tilt_delta', 'ret_oc'),
    ('Tilt Delta→10:30→Close', 'tilt_delta', 'ret_1030_close'),
    ('Velocity→O→C', 'velocity_1h', 'ret_oc'),
    ('Velocity→10:30→Close', 'velocity_1h', 'ret_1030_close'),
    ('HHI@9:30→Range', 'hhi_0930', 'range_pct'),
    ('Spread@9:30→Range', 'spread_0930', 'range_pct'),
]

for name, sig, tgt in tests:
    sub = oos.dropna(subset=[sig, tgt])
    if len(sub) < 20:
        continue
    ic, p = stats.spearmanr(sub[sig], sub[tgt])
    marker = '✅' if abs(ic) >= 0.08 and p < 0.05 else '❌'
    print(f"  {marker} {name:35s}: IC={ic:+.4f}, p={p:.4f}, N={len(sub)}")

print("\nDone!")
