"""
Firm Negative Gamma Deep Dive — Full Signal Analysis
=====================================================
Firm neg gamma tilt = |firm negative gamma above spot| / |total firm negative gamma|
Higher tilt → more short gamma above spot → bullish
"""

import json
import os
import warnings
import numpy as np
import pandas as pd
from datetime import datetime, time as dtime, timedelta
from scipy import stats
from collections import defaultdict

warnings.filterwarnings('ignore')

WORKSPACE = '/Users/lutherbot/.openclaw/workspace'
TRACE_DIR = os.path.join(WORKSPACE, 'data/trace_uncorrupted')
ES_FILE = os.path.join(WORKSPACE, 'data/es_1min_delta_bars.csv')
FOMC_FILE = os.path.join(WORKSPACE, 'data/fomc_dates.json')
OUTPUT_FILE = os.path.join(WORKSPACE, 'data/firm_neg_gamma_deep_dive.json')

CORRUPT_START = pd.Timestamp('2025-10-27')
CORRUPT_END = pd.Timestamp('2026-02-17')
PARTICIPANTS = ['cust', 'procust', 'bd', 'firm', 'mm']

# Load FOMC dates
with open(FOMC_FILE) as f:
    fomc_data = json.load(f)
FOMC_DATES = set(fomc_data['dates'])

# Load ES bars
print("Loading ES bars...")
es = pd.read_csv(ES_FILE, parse_dates=['timestamp'])
es['date'] = es['timestamp'].dt.date
es_daily_vol = es.groupby('date')['volume'].sum()

def compute_neg_gamma_tilt(snap, spot_price, participant=None):
    """Compute negative gamma tilt for a participant or total."""
    if participant is None:
        gamma = sum(snap[f'{p}_gamma'] + snap[f'{p}_gamma_0'] for p in PARTICIPANTS)
    else:
        gamma = snap[f'{participant}_gamma'] + snap[f'{participant}_gamma_0']
    neg_gamma = gamma.clip(upper=0).abs()
    above = neg_gamma[snap['strike_price'] > spot_price].sum()
    total = neg_gamma.sum()
    if total > 0:
        return above / total
    return np.nan

def compute_total_gex(snap):
    """Compute total GEX (sum of all gamma)."""
    gamma = sum(snap[f'{p}_gamma'] + snap[f'{p}_gamma_0'] for p in PARTICIPANTS)
    return gamma.sum()

def get_spot_price(snap, es_data, date_str):
    """Get spot price from ES data at snapshot time."""
    ts = snap['timestamp'].iloc[0]
    date_es = es_data[es_data['date'] == pd.Timestamp(date_str).date()]
    if len(date_es) == 0:
        return None
    # Find closest bar
    target = pd.Timestamp(ts).tz_localize(None) if pd.Timestamp(ts).tz is None else pd.Timestamp(ts).tz_convert('UTC').tz_localize(None)
    diffs = (date_es['timestamp'] - target).abs()
    idx = diffs.idxmin()
    return date_es.loc[idx, 'close']

def get_forward_returns(es_data, date_str, signal_time_utc, horizons_min=[30, 60, 120, 180]):
    """Get forward returns at various horizons from signal time."""
    date_es = es_data[es_data['date'] == pd.Timestamp(date_str).date()].copy()
    if len(date_es) == 0:
        return {}
    
    # Find entry price at signal time
    target = pd.Timestamp(signal_time_utc)
    if target.tz is not None:
        target = target.tz_convert('UTC').tz_localize(None)
    
    diffs = (date_es['timestamp'] - target).abs()
    entry_idx = diffs.idxmin()
    entry_price = date_es.loc[entry_idx, 'close']
    
    results = {}
    for h in horizons_min:
        target_exit = target + pd.Timedelta(minutes=h)
        # Find closest bar to exit time
        exit_diffs = (date_es['timestamp'] - target_exit).abs()
        min_diff = exit_diffs.min()
        if min_diff <= pd.Timedelta(minutes=5):
            exit_idx = exit_diffs.idxmin()
            exit_price = date_es.loc[exit_idx, 'close']
            results[f'{h}min'] = (exit_price - entry_price) / entry_price * 100
    
    # EOD return (use last bar of RTH: 20:00 UTC = 16:00 ET)
    rth_end = date_es[date_es['timestamp'].dt.hour >= 19]
    if len(rth_end) > 0:
        eod_price = rth_end.iloc[-1]['close']
        results['eod'] = (eod_price - entry_price) / entry_price * 100
    
    return results

print("Processing TRACE snapshots...")
files = sorted([f for f in os.listdir(TRACE_DIR) if f.endswith('.parquet')])

all_rows = []
for fname in files:
    date_str = fname.replace('intradayStrikeGEX_', '').replace('.parquet', '')
    date_ts = pd.Timestamp(date_str)
    
    # Exclusions
    if CORRUPT_START <= date_ts <= CORRUPT_END:
        continue
    if date_str in FOMC_DATES:
        continue
    
    # Check volume (holiday filter)
    date_obj = date_ts.date()
    if date_obj in es_daily_vol.index and es_daily_vol[date_obj] < 100000:
        continue
    
    try:
        df = pd.read_parquet(os.path.join(TRACE_DIR, fname))
    except:
        continue
    
    # Get morning snapshot closest to 10:00 ET (14:00 UTC)
    timestamps = df['timestamp'].unique()
    morning_ts = []
    for ts in timestamps:
        ts_pd = pd.Timestamp(ts)
        if ts_pd.tz is not None:
            hour_et = ts_pd.tz_convert('America/New_York').hour
            minute_et = ts_pd.tz_convert('America/New_York').minute
        else:
            # Assume UTC
            hour_et = (ts_pd.hour - 4) % 24
            minute_et = ts_pd.minute
        
        if 9 <= hour_et <= 11:
            morning_ts.append((ts, abs((hour_et * 60 + minute_et) - 600)))  # dist from 10:00
    
    if not morning_ts:
        continue
    
    # Pick closest to 10:00 ET
    morning_ts.sort(key=lambda x: x[1])
    best_ts = morning_ts[0][0]
    snap = df[df['timestamp'] == best_ts].copy()
    
    if len(snap) == 0:
        continue
    
    # Get spot price
    date_es = es[es['date'] == date_obj]
    if len(date_es) == 0:
        continue
    
    # Find ES price at snapshot time
    snap_time = pd.Timestamp(best_ts)
    if snap_time.tz is not None:
        snap_time_utc = snap_time.tz_convert('UTC').tz_localize(None)
    else:
        snap_time_utc = snap_time
    
    time_diffs = (date_es['timestamp'] - snap_time_utc).abs()
    closest_idx = time_diffs.idxmin()
    spot = date_es.loc[closest_idx, 'close']
    
    if pd.isna(spot):
        continue
    
    # Compute tilts for all participants
    row = {'date': date_str, 'spot': spot, 'signal_time_utc': str(snap_time_utc)}
    row['firm_neg_tilt'] = compute_neg_gamma_tilt(snap, spot, 'firm')
    row['mm_neg_tilt'] = compute_neg_gamma_tilt(snap, spot, 'mm')
    row['cust_neg_tilt'] = compute_neg_gamma_tilt(snap, spot, 'cust')
    row['total_neg_tilt'] = compute_neg_gamma_tilt(snap, spot, None)
    row['total_gex'] = compute_total_gex(snap)
    
    # Also compute positive gamma tilt for v5 comparison
    # v5 uses positive gamma tilt (we'll compute it here too)
    for p in [None] + PARTICIPANTS:
        pname = p if p else 'total'
        if p is None:
            gamma = sum(snap[f'{pp}_gamma'] + snap[f'{pp}_gamma_0'] for pp in PARTICIPANTS)
        else:
            gamma = snap[f'{p}_gamma'] + snap[f'{p}_gamma_0']
        pos_gamma = gamma.clip(lower=0)
        above = pos_gamma[snap['strike_price'] > spot].sum()
        total_pos = pos_gamma.sum()
        if total_pos > 0:
            row[f'{pname}_pos_tilt'] = above / total_pos
        else:
            row[f'{pname}_pos_tilt'] = np.nan
    
    # Forward returns
    fwd = get_forward_returns(es, date_str, snap_time_utc, [30, 60, 120, 180])
    for k, v in fwd.items():
        row[f'fwd_{k}'] = v
    
    # RVOL: volume first 30 min / average first 30 min volume
    # Compute 30-min volume from 9:30-10:00 ET (13:30-14:00 UTC)
    morn_bars = date_es[(date_es['timestamp'].dt.hour == 13) & (date_es['timestamp'].dt.minute >= 30) |
                        (date_es['timestamp'].dt.hour == 14) & (date_es['timestamp'].dt.minute < 0)]
    # Simpler: just use total daily volume as proxy
    row['daily_vol'] = date_es['volume'].sum()
    
    # Buy pct at 10:30 ET (14:30 UTC)
    bars_to_1030 = date_es[(date_es['timestamp'].dt.hour == 14) & (date_es['timestamp'].dt.minute <= 30)]
    if len(bars_to_1030) > 0:
        total_buy = bars_to_1030['buy_volume'].sum()
        total_sell = bars_to_1030['sell_volume'].sum()
        row['buy_pct_1030'] = total_buy / (total_buy + total_sell) if (total_buy + total_sell) > 0 else np.nan
    
    # Gap (open vs prev close)
    # Use 9:30 ET open
    rth_open = date_es[date_es['timestamp'].dt.hour == 13]
    if len(rth_open) > 0:
        row['open_price'] = rth_open.iloc[0]['open']
    
    # Day of week
    row['dow'] = date_ts.dayofweek
    
    # Get all morning snapshots for intraday evolution
    all_morning_snaps = {}
    for ts, dist in morning_ts[:6]:  # Keep up to 6 snapshots
        snap_t = df[df['timestamp'] == ts].copy()
        if len(snap_t) > 0:
            tilt = compute_neg_gamma_tilt(snap_t, spot, 'firm')
            ts_pd = pd.Timestamp(ts)
            if ts_pd.tz is not None:
                et_time = ts_pd.tz_convert('America/New_York')
                key = f"{et_time.hour}:{et_time.minute:02d}"
            else:
                key = str(ts)
            all_morning_snaps[key] = float(tilt) if tilt is not None and not pd.isna(tilt) else None
    row['intraday_tilts'] = json.dumps(all_morning_snaps)
    
    if not pd.isna(row.get('firm_neg_tilt', np.nan)):
        all_rows.append(row)

print(f"Processed {len(all_rows)} valid days")

# Create DataFrame
data = pd.DataFrame(all_rows)
data['date'] = pd.to_datetime(data['date'])
data = data.sort_values('date').reset_index(drop=True)

# Compute RVOL as volume relative to 20-day rolling average
data['vol_ma20'] = data['daily_vol'].rolling(20, min_periods=5).mean()
data['rvol'] = data['daily_vol'] / data['vol_ma20']

# Compute gap
prev_close = []
for i, row in data.iterrows():
    if i == 0:
        prev_close.append(np.nan)
        continue
    # Get previous day's close from ES
    prev_date = data.loc[i-1, 'date'].date()
    prev_es = es[es['date'] == prev_date]
    if len(prev_es) > 0:
        pc = prev_es.iloc[-1]['close']
        prev_close.append(pc)
    else:
        prev_close.append(np.nan)
data['prev_close'] = prev_close
data['gap_pct'] = np.where(data['prev_close'].notna() & data['open_price'].notna(),
                           (data['open_price'] - data['prev_close']) / data['prev_close'] * 100, np.nan)

# IS/OOS split
n_total = len(data)
is_cutoff = int(n_total * 0.6)
data['period'] = 'IS'
data.loc[is_cutoff:, 'period'] = 'OOS'

# 3-way split
third = n_total // 3
data['period3'] = 'P1'
data.loc[third:2*third, 'period3'] = 'P2'
data.loc[2*third:, 'period3'] = 'P3'

print(f"IS: {is_cutoff} days, OOS: {n_total - is_cutoff} days")
print(f"Date range: {data['date'].min().date()} to {data['date'].max().date()}")
print(f"Firm neg tilt range: {data['firm_neg_tilt'].min():.3f} to {data['firm_neg_tilt'].max():.3f}")
print(f"Mean: {data['firm_neg_tilt'].mean():.3f}, Median: {data['firm_neg_tilt'].median():.3f}")

results = {}

# ============================================================
# 1. THRESHOLD LADDER (BOTH DIRECTIONS)
# ============================================================
print("\n" + "="*70)
print("1. THRESHOLD LADDER")
print("="*70)

threshold_results = {}

# BULLISH (high tilt)
print("\n--- BULLISH: Firm neg gamma tilt > threshold → long ---")
print(f"{'Threshold':>10} | {'Full WR':>8} {'N':>4} {'Avg Ret':>8} {'t-stat':>7} | {'IS WR':>8} {'N':>4} | {'OOS WR':>8} {'N':>4}")
print("-" * 85)

for thresh in [0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90]:
    mask = data['firm_neg_tilt'] > thresh
    
    for horizon in ['fwd_60min', 'fwd_180min', 'fwd_eod']:
        sub = data[mask & data[horizon].notna()]
        is_sub = sub[sub['period'] == 'IS']
        oos_sub = sub[sub['period'] == 'OOS']
        
        full_wr = (sub[horizon] > 0).mean() if len(sub) > 0 else np.nan
        full_n = len(sub)
        avg_ret = sub[horizon].mean() if len(sub) > 0 else np.nan
        t = stats.ttest_1samp(sub[horizon], 0).statistic if len(sub) > 2 else np.nan
        
        is_wr = (is_sub[horizon] > 0).mean() if len(is_sub) > 0 else np.nan
        oos_wr = (oos_sub[horizon] > 0).mean() if len(oos_sub) > 0 else np.nan
        
        h_label = horizon.replace('fwd_', '')
        key = f'bull_{thresh}_{h_label}'
        threshold_results[key] = {
            'direction': 'bull', 'threshold': thresh, 'horizon': h_label,
            'full_wr': round(full_wr, 4) if not pd.isna(full_wr) else None,
            'full_n': full_n,
            'avg_ret': round(avg_ret, 5) if not pd.isna(avg_ret) else None,
            't_stat': round(t, 3) if not pd.isna(t) else None,
            'is_wr': round(is_wr, 4) if not pd.isna(is_wr) else None,
            'is_n': len(is_sub),
            'oos_wr': round(oos_wr, 4) if not pd.isna(oos_wr) else None,
            'oos_n': len(oos_sub),
        }
        
        if horizon == 'fwd_60min':
            print(f"  >{thresh:.0%}     | {full_wr:>7.1%} {full_n:>4d} {avg_ret:>7.4f}% {t:>7.2f} | {is_wr:>7.1%} {len(is_sub):>4d} | {oos_wr:>7.1%} {len(oos_sub):>4d}" if not pd.isna(full_wr) else f"  >{thresh:.0%}     | N/A")

# BEARISH (low tilt)
print("\n--- BEARISH: Firm neg gamma tilt < threshold → short ---")
print(f"{'Threshold':>10} | {'Full WR':>8} {'N':>4} {'Avg Ret':>8} {'t-stat':>7} | {'IS WR':>8} {'N':>4} | {'OOS WR':>8} {'N':>4}")
print("-" * 85)

for thresh in [0.45, 0.40, 0.35, 0.30, 0.25, 0.20, 0.15, 0.10]:
    mask = data['firm_neg_tilt'] < thresh
    
    for horizon in ['fwd_60min', 'fwd_180min', 'fwd_eod']:
        sub = data[mask & data[horizon].notna()]
        is_sub = sub[sub['period'] == 'IS']
        oos_sub = sub[sub['period'] == 'OOS']
        
        # Bearish: WR = % of days with negative return
        full_wr = (sub[horizon] < 0).mean() if len(sub) > 0 else np.nan
        full_n = len(sub)
        avg_ret = sub[horizon].mean() if len(sub) > 0 else np.nan
        t = stats.ttest_1samp(sub[horizon], 0).statistic if len(sub) > 2 else np.nan
        
        is_wr = (is_sub[horizon] < 0).mean() if len(is_sub) > 0 else np.nan
        oos_wr = (oos_sub[horizon] < 0).mean() if len(oos_sub) > 0 else np.nan
        
        h_label = horizon.replace('fwd_', '')
        key = f'bear_{thresh}_{h_label}'
        threshold_results[key] = {
            'direction': 'bear', 'threshold': thresh, 'horizon': h_label,
            'full_wr': round(full_wr, 4) if not pd.isna(full_wr) else None,
            'full_n': full_n,
            'avg_ret': round(avg_ret, 5) if not pd.isna(avg_ret) else None,
            't_stat': round(t, 3) if not pd.isna(t) else None,
            'is_wr': round(is_wr, 4) if not pd.isna(is_wr) else None,
            'is_n': len(is_sub),
            'oos_wr': round(oos_wr, 4) if not pd.isna(oos_wr) else None,
            'oos_n': len(oos_sub),
        }
        
        if horizon == 'fwd_60min':
            print(f"  <{thresh:.0%}     | {full_wr:>7.1%} {full_n:>4d} {avg_ret:>7.4f}% {t:>7.2f} | {is_wr:>7.1%} {len(is_sub):>4d} | {oos_wr:>7.1%} {len(oos_sub):>4d}" if not pd.isna(full_wr) else f"  <{thresh:.0%}     | N/A")

results['threshold_ladder'] = threshold_results

# ============================================================
# 2. DECILE ANALYSIS WITH MONOTONICITY
# ============================================================
print("\n" + "="*70)
print("2. DECILE ANALYSIS")
print("="*70)

data['decile'] = pd.qcut(data['firm_neg_tilt'], 10, labels=False, duplicates='drop') + 1
decile_results = {}

print(f"\n{'Decile':>7} | {'Tilt Range':>15} | {'N':>4} | {'1H WR':>7} {'1H Ret':>8} | {'3H WR':>7} {'3H Ret':>8} | {'EOD WR':>7} {'EOD Ret':>8}")
print("-" * 100)

decile_means = {}
for d_val in sorted(data['decile'].unique()):
    sub = data[data['decile'] == d_val]
    tilt_range = f"{sub['firm_neg_tilt'].min():.3f}-{sub['firm_neg_tilt'].max():.3f}"
    
    row_data = {
        'n': len(sub),
        'tilt_min': round(sub['firm_neg_tilt'].min(), 4),
        'tilt_max': round(sub['firm_neg_tilt'].max(), 4),
        'tilt_mean': round(sub['firm_neg_tilt'].mean(), 4),
    }
    
    for h, label in [('fwd_60min', '1H'), ('fwd_180min', '3H'), ('fwd_eod', 'EOD')]:
        valid = sub[sub[h].notna()]
        if len(valid) > 0:
            row_data[f'{label}_wr'] = round((valid[h] > 0).mean(), 4)
            row_data[f'{label}_ret'] = round(valid[h].mean(), 5)
        else:
            row_data[f'{label}_wr'] = None
            row_data[f'{label}_ret'] = None
    
    decile_results[f'D{d_val}'] = row_data
    decile_means[d_val] = row_data.get('1H_ret', 0) or 0
    
    wr_1h = row_data['1H_wr'] or 0
    ret_1h = row_data['1H_ret'] or 0
    wr_3h = row_data['3H_wr'] or 0
    ret_3h = row_data['3H_ret'] or 0
    wr_eod = row_data['EOD_wr'] or 0
    ret_eod = row_data['EOD_ret'] or 0
    
    print(f"  D{d_val:>4}  | {tilt_range:>15} | {len(sub):>4} | {wr_1h:>6.1%} {ret_1h:>7.4f}% | {wr_3h:>6.1%} {ret_3h:>7.4f}% | {wr_eod:>6.1%} {ret_eod:>7.4f}%")

# Monotonicity score
sorted_means = [decile_means[k] for k in sorted(decile_means.keys())]
mono_score = 0
for i in range(1, len(sorted_means)):
    if sorted_means[i] > sorted_means[i-1]:
        mono_score += 1
mono_pct = mono_score / (len(sorted_means) - 1) if len(sorted_means) > 1 else 0

# Q10-Q1 spread
q10_ret = decile_means.get(max(decile_means.keys()), 0)
q1_ret = decile_means.get(min(decile_means.keys()), 0)
spread = q10_ret - q1_ret

print(f"\nMonotonicity: {mono_score}/{len(sorted_means)-1} = {mono_pct:.0%}")
print(f"Q10-Q1 spread (1H): {spread:.4f}%")
print(f"Linear vs threshold: {'Linear' if mono_pct > 0.7 else 'Threshold-based' if mono_pct < 0.5 else 'Mixed'}")

decile_results['monotonicity'] = round(mono_pct, 4)
decile_results['q10_q1_spread'] = round(spread, 5)
results['decile_analysis'] = decile_results

# ============================================================
# 3. TIME HORIZONS (>70% threshold)
# ============================================================
print("\n" + "="*70)
print("3. TIME HORIZONS — Firm neg tilt > 70%")
print("="*70)

mask_70 = data['firm_neg_tilt'] > 0.70
sub_70 = data[mask_70].copy()

time_horizon_results = {}
print(f"\n{'Horizon':>10} | {'WR':>7} | {'Avg Ret':>8} | {'t-stat':>7} | {'N':>4}")
print("-" * 50)

for h, label in [('fwd_30min', '30min'), ('fwd_60min', '1H'), ('fwd_120min', '2H'), ('fwd_180min', '3H'), ('fwd_eod', 'EOD')]:
    valid = sub_70[sub_70[h].notna()]
    if len(valid) > 0:
        wr = (valid[h] > 0).mean()
        avg = valid[h].mean()
        t = stats.ttest_1samp(valid[h], 0).statistic if len(valid) > 2 else np.nan
        print(f"  {label:>8} | {wr:>6.1%} | {avg:>7.4f}% | {t:>7.2f} | {len(valid):>4}")
        time_horizon_results[label] = {
            'wr': round(wr, 4), 'avg_ret': round(avg, 5), 
            't_stat': round(t, 3) if not pd.isna(t) else None, 'n': len(valid)
        }
    else:
        print(f"  {label:>8} | N/A")

results['time_horizons_gt70'] = time_horizon_results

# ============================================================
# 4. RVOL INTERACTION
# ============================================================
print("\n" + "="*70)
print("4. RVOL INTERACTION")
print("="*70)

valid_rvol = data[data['rvol'].notna()].copy()
valid_rvol['rvol_q'] = pd.qcut(valid_rvol['rvol'], 4, labels=['Q1_Low', 'Q2', 'Q3', 'Q4_High'], duplicates='drop')
valid_rvol['tilt_q'] = pd.qcut(valid_rvol['firm_neg_tilt'], 4, labels=['Q1_Low', 'Q2', 'Q3', 'Q4_High'], duplicates='drop')

rvol_results = {}

print("\n--- RVOL Quartile × Firm Tilt Quartile → 1H Return ---")
print(f"{'':>12} | {'Tilt Q1':>10} {'Tilt Q2':>10} {'Tilt Q3':>10} {'Tilt Q4':>10}")
print("-" * 60)

for rq in ['Q1_Low', 'Q2', 'Q3', 'Q4_High']:
    row_str = f"  RVOL {rq:>5} |"
    for tq in ['Q1_Low', 'Q2', 'Q3', 'Q4_High']:
        cell = valid_rvol[(valid_rvol['rvol_q'] == rq) & (valid_rvol['tilt_q'] == tq)]
        valid_cell = cell[cell['fwd_60min'].notna()]
        if len(valid_cell) >= 3:
            ret = valid_cell['fwd_60min'].mean()
            wr = (valid_cell['fwd_60min'] > 0).mean()
            row_str += f" {ret:>+6.3f}%({wr:.0%})"
            rvol_results[f'{rq}_{tq}_1H'] = {'ret': round(ret, 5), 'wr': round(wr, 4), 'n': len(valid_cell)}
        else:
            row_str += f" {'N/A':>10}"
    print(row_str)

print("\n--- RVOL Quartile × Firm Tilt Quartile → 3H Return ---")
print(f"{'':>12} | {'Tilt Q1':>10} {'Tilt Q2':>10} {'Tilt Q3':>10} {'Tilt Q4':>10}")
print("-" * 60)

for rq in ['Q1_Low', 'Q2', 'Q3', 'Q4_High']:
    row_str = f"  RVOL {rq:>5} |"
    for tq in ['Q1_Low', 'Q2', 'Q3', 'Q4_High']:
        cell = valid_rvol[(valid_rvol['rvol_q'] == rq) & (valid_rvol['tilt_q'] == tq)]
        valid_cell = cell[cell['fwd_180min'].notna()]
        if len(valid_cell) >= 3:
            ret = valid_cell['fwd_180min'].mean()
            wr = (valid_cell['fwd_180min'] > 0).mean()
            row_str += f" {ret:>+6.3f}%({wr:.0%})"
            rvol_results[f'{rq}_{tq}_3H'] = {'ret': round(ret, 5), 'wr': round(wr, 4), 'n': len(valid_cell)}
        else:
            row_str += f" {'N/A':>10}"
    print(row_str)

# Does firm tilt >70% still work at each RVOL quartile?
print("\n--- Firm tilt >70% by RVOL quartile (1H) ---")
for rq in ['Q1_Low', 'Q2', 'Q3', 'Q4_High']:
    sub = valid_rvol[(valid_rvol['rvol_q'] == rq) & (valid_rvol['firm_neg_tilt'] > 0.7) & valid_rvol['fwd_60min'].notna()]
    if len(sub) >= 2:
        wr = (sub['fwd_60min'] > 0).mean()
        print(f"  RVOL {rq}: WR={wr:.1%} N={len(sub)} Avg={sub['fwd_60min'].mean():.4f}%")
    else:
        print(f"  RVOL {rq}: N={len(sub)} (insufficient)")

results['rvol_interaction'] = rvol_results

# ============================================================
# 5. GEX REGIME INTERACTION
# ============================================================
print("\n" + "="*70)
print("5. GEX REGIME INTERACTION")
print("="*70)

data['gex_regime'] = np.where(data['total_gex'] > 0, 'Positive', 'Negative')
gex_results = {}

for regime in ['Positive', 'Negative']:
    sub = data[data['gex_regime'] == regime]
    print(f"\n--- {regime} GEX regime (N={len(sub)}) ---")
    
    for thresh in [0.65, 0.70, 0.75]:
        masked = sub[sub['firm_neg_tilt'] > thresh]
        valid = masked[masked['fwd_60min'].notna()]
        if len(valid) >= 3:
            wr = (valid['fwd_60min'] > 0).mean()
            avg = valid['fwd_60min'].mean()
            print(f"  Firm tilt >{thresh:.0%}: WR={wr:.1%} N={len(valid)} Avg={avg:.4f}%")
            gex_results[f'{regime}_gt{thresh}_1H'] = {'wr': round(wr, 4), 'n': len(valid), 'avg_ret': round(avg, 5)}
        else:
            print(f"  Firm tilt >{thresh:.0%}: N={len(valid)} (insufficient)")

    # Bearish side
    for thresh in [0.30, 0.25]:
        masked = sub[sub['firm_neg_tilt'] < thresh]
        valid = masked[masked['fwd_60min'].notna()]
        if len(valid) >= 3:
            wr = (valid['fwd_60min'] < 0).mean()
            avg = valid['fwd_60min'].mean()
            print(f"  Firm tilt <{thresh:.0%}: Short WR={wr:.1%} N={len(valid)} Avg={avg:.4f}%")
            gex_results[f'{regime}_lt{thresh}_1H'] = {'wr': round(wr, 4), 'n': len(valid), 'avg_ret': round(avg, 5)}

results['gex_interaction'] = gex_results

# ============================================================
# 6. FREQUENCY ANALYSIS
# ============================================================
print("\n" + "="*70)
print("6. FREQUENCY ANALYSIS")
print("="*70)

freq_results = {}
print(f"\n{'Threshold':>10} | {'Count':>6} | {'% of Days':>9} | {'Days/Month':>10}")
print("-" * 50)

for thresh in [0.65, 0.70, 0.75, 0.80, 0.85]:
    count = (data['firm_neg_tilt'] > thresh).sum()
    pct = count / len(data)
    months = (data['date'].max() - data['date'].min()).days / 30
    per_month = count / months if months > 0 else 0
    print(f"  >{thresh:.0%}     | {count:>6} | {pct:>8.1%} | {per_month:>9.1f}")
    freq_results[f'gt_{thresh}'] = {'count': int(count), 'pct': round(pct, 4), 'days_per_month': round(per_month, 2)}

# Persistence: if >75% today, what % chance >75% tomorrow?
print("\n--- Persistence analysis ---")
data_sorted = data.sort_values('date')
for thresh in [0.65, 0.70, 0.75]:
    signal_today = data_sorted['firm_neg_tilt'] > thresh
    signal_tomorrow = signal_today.shift(-1)
    both = signal_today & signal_tomorrow
    persist_pct = both.sum() / signal_today.sum() if signal_today.sum() > 0 else np.nan
    print(f"  >{thresh:.0%} today → >{thresh:.0%} tomorrow: {persist_pct:.1%} (N={signal_today.sum()})")
    freq_results[f'persist_{thresh}'] = round(persist_pct, 4) if not pd.isna(persist_pct) else None

# Clustering: streak analysis
streaks = []
current_streak = 0
for v in (data_sorted['firm_neg_tilt'] > 0.70).values:
    if v:
        current_streak += 1
    else:
        if current_streak > 0:
            streaks.append(current_streak)
        current_streak = 0
if current_streak > 0:
    streaks.append(current_streak)

if streaks:
    print(f"\n  >70% streak stats: avg={np.mean(streaks):.1f}, max={max(streaks)}, median={np.median(streaks):.0f}")
    freq_results['streak_stats_70'] = {
        'avg': round(np.mean(streaks), 2), 'max': int(max(streaks)), 
        'median': float(np.median(streaks)), 'n_streaks': len(streaks)
    }

results['frequency'] = freq_results

# ============================================================
# 7. INTRADAY EVOLUTION
# ============================================================
print("\n" + "="*70)
print("7. INTRADAY EVOLUTION")
print("="*70)

# Parse intraday tilts and check if multiple snapshots exist
intraday_results = {}
n_multi = 0
tilt_changes = []

for _, row in data.iterrows():
    try:
        tilts = json.loads(row['intraday_tilts'])
        if len(tilts) > 1:
            n_multi += 1
            times = sorted(tilts.keys())
            if len(times) >= 2:
                first_val = tilts[times[0]]
                last_val = tilts[times[-1]]
                if first_val is not None and last_val is not None:
                    tilt_changes.append({
                        'date': str(row['date']),
                        'first_time': times[0],
                        'first_tilt': first_val,
                        'last_time': times[-1],
                        'last_tilt': last_val,
                        'change': last_val - first_val,
                        'fwd_60min': row.get('fwd_60min', np.nan)
                    })
    except:
        pass

print(f"Days with multiple morning snapshots: {n_multi}/{len(data)}")

if tilt_changes:
    changes_df = pd.DataFrame(tilt_changes)
    changes_df = changes_df[changes_df['fwd_60min'].notna()]
    
    # Does CHANGE predict better than level?
    if len(changes_df) > 10:
        increasing = changes_df[changes_df['change'] > 0.05]
        decreasing = changes_df[changes_df['change'] < -0.05]
        
        print(f"\nFirm tilt INCREASING (change > 5%): N={len(increasing)}")
        if len(increasing) > 2:
            wr = (increasing['fwd_60min'] > 0).mean()
            print(f"  1H WR: {wr:.1%}, Avg ret: {increasing['fwd_60min'].mean():.4f}%")
            intraday_results['increasing_wr'] = round(wr, 4)
            intraday_results['increasing_n'] = len(increasing)
        
        print(f"Firm tilt DECREASING (change < -5%): N={len(decreasing)}")
        if len(decreasing) > 2:
            wr = (decreasing['fwd_60min'] < 0).mean()
            print(f"  1H Short WR: {wr:.1%}, Avg ret: {decreasing['fwd_60min'].mean():.4f}%")
            intraday_results['decreasing_short_wr'] = round(wr, 4)
            intraday_results['decreasing_n'] = len(decreasing)
        
        # IC of change vs level
        ic_level = changes_df['last_tilt'].corr(changes_df['fwd_60min'])
        ic_change = changes_df['change'].corr(changes_df['fwd_60min'])
        print(f"\nIC (level): {ic_level:.4f}")
        print(f"IC (change): {ic_change:.4f}")
        print(f"{'Change' if abs(ic_change) > abs(ic_level) else 'Level'} is more predictive")
        intraday_results['ic_level'] = round(ic_level, 4)
        intraday_results['ic_change'] = round(ic_change, 4)

results['intraday_evolution'] = intraday_results

# ============================================================
# 8. BEARISH SIGNAL DEEP DIVE
# ============================================================
print("\n" + "="*70)
print("8. BEARISH SIGNAL DEEP DIVE")
print("="*70)

bear_results = {}

# Firm tilt <25% standalone
for h, label in [('fwd_60min', '1H'), ('fwd_180min', '3H'), ('fwd_eod', 'EOD')]:
    sub = data[(data['firm_neg_tilt'] < 0.25) & data[h].notna()]
    if len(sub) >= 3:
        wr = (sub[h] < 0).mean()
        avg = sub[h].mean()
        print(f"Firm tilt <25% standalone {label}: Short WR={wr:.1%} N={len(sub)} Avg={avg:.4f}%")
        bear_results[f'standalone_{label}'] = {'wr': round(wr, 4), 'n': len(sub), 'avg_ret': round(avg, 5)}

# Combo: firm <25% AND mm <40%
combo = data[(data['firm_neg_tilt'] < 0.25) & (data['mm_neg_tilt'] < 0.40)]
valid = combo[combo['fwd_60min'].notna()]
if len(valid) >= 2:
    wr = (valid['fwd_60min'] < 0).mean()
    print(f"\nFirm <25% + MM <40%: Short WR={wr:.1%} N={len(valid)} Avg={valid['fwd_60min'].mean():.4f}%")
    bear_results['firm25_mm40'] = {'wr': round(wr, 4), 'n': len(valid), 'avg_ret': round(valid['fwd_60min'].mean(), 5)}

# Firm <25% + negative GEX
combo = data[(data['firm_neg_tilt'] < 0.25) & (data['total_gex'] < 0)]
valid = combo[combo['fwd_60min'].notna()]
if len(valid) >= 2:
    wr = (valid['fwd_60min'] < 0).mean()
    print(f"Firm <25% + neg GEX: Short WR={wr:.1%} N={len(valid)} Avg={valid['fwd_60min'].mean():.4f}%")
    bear_results['firm25_negGEX'] = {'wr': round(wr, 4), 'n': len(valid), 'avg_ret': round(valid['fwd_60min'].mean(), 5)}

# Day-over-day shift: was >60% yesterday, dropped to <35% today
data_sorted = data.sort_values('date').reset_index(drop=True)
shift_mask = (data_sorted['firm_neg_tilt'].shift(1) > 0.60) & (data_sorted['firm_neg_tilt'] < 0.35)
shift_days = data_sorted[shift_mask & data_sorted['fwd_60min'].notna()]
if len(shift_days) >= 2:
    wr = (shift_days['fwd_60min'] < 0).mean()
    print(f"\nBearish SHIFT (>60% yesterday → <35% today): Short WR={wr:.1%} N={len(shift_days)} Avg={shift_days['fwd_60min'].mean():.4f}%")
    bear_results['bearish_shift'] = {'wr': round(wr, 4), 'n': len(shift_days), 'avg_ret': round(shift_days['fwd_60min'].mean(), 5)}
else:
    print(f"\nBearish SHIFT: N={len(shift_days)} (insufficient)")

# Try different combo thresholds
for firm_t, mm_t in [(0.30, 0.35), (0.30, 0.40), (0.35, 0.40), (0.25, 0.35)]:
    combo = data[(data['firm_neg_tilt'] < firm_t) & (data['mm_neg_tilt'] < mm_t)]
    valid = combo[combo['fwd_60min'].notna()]
    if len(valid) >= 3:
        wr = (valid['fwd_60min'] < 0).mean()
        print(f"Firm <{firm_t:.0%} + MM <{mm_t:.0%}: Short WR={wr:.1%} N={len(valid)}")
        bear_results[f'firm{int(firm_t*100)}_mm{int(mm_t*100)}'] = {'wr': round(wr, 4), 'n': len(valid)}

results['bearish_deep_dive'] = bear_results

# ============================================================
# 9. COMBINATION WITH OTHER SIGNALS
# ============================================================
print("\n" + "="*70)
print("9. COMBINATION WITH OTHER SIGNALS")
print("="*70)

combo_results = {}

# Buy_pct persistence
if 'buy_pct_1030' in data.columns:
    # Persistent buy = buy_pct > 0.52 at 10:30
    data['buy_persistent'] = data['buy_pct_1030'] > 0.52
    
    # Firm >70% AND persistent buy
    combo1 = data[(data['firm_neg_tilt'] > 0.70) & data['buy_persistent'] & data['fwd_60min'].notna()]
    if len(combo1) >= 2:
        wr = (combo1['fwd_60min'] > 0).mean()
        print(f"Firm >70% + persistent buy_pct: WR={wr:.1%} N={len(combo1)} Avg={combo1['fwd_60min'].mean():.4f}%")
        combo_results['firm70_buypct_persistent'] = {'wr': round(wr, 4), 'n': len(combo1)}
    
    # Firm >70% AND faded buy (buy_pct < 0.48)
    combo2 = data[(data['firm_neg_tilt'] > 0.70) & (data['buy_pct_1030'] < 0.48) & data['fwd_60min'].notna()]
    if len(combo2) >= 2:
        wr = (combo2['fwd_60min'] > 0).mean()
        print(f"Firm >70% + faded buy_pct: WR={wr:.1%} N={len(combo2)} (contradicting signals)")
        combo_results['firm70_buypct_faded'] = {'wr': round(wr, 4), 'n': len(combo2)}

# Gap direction
if 'gap_pct' in data.columns:
    data['gap_up'] = data['gap_pct'] > 0.1
    data['gap_down'] = data['gap_pct'] < -0.1
    
    for gap_dir, gap_label in [('gap_up', 'Gap Up'), ('gap_down', 'Gap Down')]:
        combo = data[(data['firm_neg_tilt'] > 0.70) & data[gap_dir] & data['fwd_60min'].notna()]
        if len(combo) >= 2:
            wr = (combo['fwd_60min'] > 0).mean()
            print(f"Firm >70% + {gap_label}: WR={wr:.1%} N={len(combo)}")
            combo_results[f'firm70_{gap_dir}'] = {'wr': round(wr, 4), 'n': len(combo)}

# Day of week
print("\n--- By Day of Week (Firm >70%) ---")
for dow, dow_name in [(0, 'Mon'), (1, 'Tue'), (2, 'Wed'), (3, 'Thu'), (4, 'Fri')]:
    sub = data[(data['firm_neg_tilt'] > 0.70) & (data['dow'] == dow) & data['fwd_60min'].notna()]
    if len(sub) >= 2:
        wr = (sub['fwd_60min'] > 0).mean()
        print(f"  {dow_name}: WR={wr:.1%} N={len(sub)}")
        combo_results[f'firm70_dow_{dow_name}'] = {'wr': round(wr, 4), 'n': len(sub)}

results['combinations'] = combo_results

# ============================================================
# 10. COMPARISON: firm neg tilt vs total tilt vs v5
# ============================================================
print("\n" + "="*70)
print("10. SIGNAL COMPARISON (same day, 1H)")
print("="*70)

comparison_results = {}

# Define signals and their thresholds
signals = [
    ('total_pos_tilt', 0.55, 'v5 Total Pos Tilt >55%'),
    ('total_neg_tilt', 0.65, 'Total Neg Tilt >65%'),
    ('firm_neg_tilt', 0.65, 'Firm Neg Tilt >65%'),
    ('firm_neg_tilt', 0.70, 'Firm Neg Tilt >70%'),
]

print(f"\n{'Signal':>30} | {'Full WR':>8} {'N':>4} | {'IS WR':>8} {'N':>4} | {'OOS WR':>8} {'N':>4}")
print("-" * 85)

for col, thresh, label in signals:
    if col in data.columns:
        mask = data[col] > thresh
        valid = data[mask & data['fwd_60min'].notna()]
        is_v = valid[valid['period'] == 'IS']
        oos_v = valid[valid['period'] == 'OOS']
        
        full_wr = (valid['fwd_60min'] > 0).mean() if len(valid) > 0 else np.nan
        is_wr = (is_v['fwd_60min'] > 0).mean() if len(is_v) > 0 else np.nan
        oos_wr = (oos_v['fwd_60min'] > 0).mean() if len(oos_v) > 0 else np.nan
        
        print(f"  {label:>28} | {full_wr:>7.1%} {len(valid):>4} | {is_wr:>7.1%} {len(is_v):>4} | {oos_wr:>7.1%} {len(oos_v):>4}" if not pd.isna(full_wr) else f"  {label:>28} | N/A")
        
        comparison_results[label] = {
            'full_wr': round(full_wr, 4) if not pd.isna(full_wr) else None,
            'full_n': len(valid),
            'is_wr': round(is_wr, 4) if not pd.isna(is_wr) else None,
            'is_n': len(is_v),
            'oos_wr': round(oos_wr, 4) if not pd.isna(oos_wr) else None,
            'oos_n': len(oos_v)
        }

# Weighted combo: 50% firm + 30% mm + 20% cust
data['weighted_neg_tilt'] = 0.5 * data['firm_neg_tilt'] + 0.3 * data['mm_neg_tilt'] + 0.2 * data['cust_neg_tilt']

for thresh in [0.55, 0.60, 0.65]:
    mask = data['weighted_neg_tilt'] > thresh
    valid = data[mask & data['fwd_60min'].notna()]
    is_v = valid[valid['period'] == 'IS']
    oos_v = valid[valid['period'] == 'OOS']
    
    if len(valid) > 0:
        full_wr = (valid['fwd_60min'] > 0).mean()
        is_wr = (is_v['fwd_60min'] > 0).mean() if len(is_v) > 0 else np.nan
        oos_wr = (oos_v['fwd_60min'] > 0).mean() if len(oos_v) > 0 else np.nan
        label = f'Weighted Combo >{thresh:.0%}'
        print(f"  {label:>28} | {full_wr:>7.1%} {len(valid):>4} | {is_wr:>7.1%} {len(is_v):>4} | {oos_wr:>7.1%} {len(oos_v):>4}" if not pd.isna(full_wr) else f"  {label:>28} | N/A")
        comparison_results[label] = {
            'full_wr': round(full_wr, 4),
            'full_n': len(valid),
            'is_wr': round(is_wr, 4) if not pd.isna(is_wr) else None,
            'is_n': len(is_v),
            'oos_wr': round(oos_wr, 4) if not pd.isna(oos_wr) else None,
            'oos_n': len(oos_v)
        }

results['signal_comparison'] = comparison_results

# ============================================================
# 11. WALK-FORWARD STABILITY (3 periods)
# ============================================================
print("\n" + "="*70)
print("11. WALK-FORWARD STABILITY")
print("="*70)

walk_results = {}
print(f"\n{'Period':>10} | {'Date Range':>25} | {'N':>4} | {'Firm>70% WR':>12} {'N':>4} | {'Firm>65% WR':>12} {'N':>4}")
print("-" * 90)

for period in ['P1', 'P2', 'P3']:
    sub = data[data['period3'] == period]
    date_range = f"{sub['date'].min().date()} to {sub['date'].max().date()}"
    
    period_data = {}
    period_data['date_range'] = date_range
    period_data['n_total'] = len(sub)
    
    row_str = f"  {period:>8} | {date_range:>25} | {len(sub):>4} |"
    
    for thresh in [0.70, 0.65]:
        masked = sub[(sub['firm_neg_tilt'] > thresh) & sub['fwd_60min'].notna()]
        if len(masked) >= 2:
            wr = (masked['fwd_60min'] > 0).mean()
            row_str += f" {wr:>11.1%} {len(masked):>4} |"
            period_data[f'gt{int(thresh*100)}_wr'] = round(wr, 4)
            period_data[f'gt{int(thresh*100)}_n'] = len(masked)
        else:
            row_str += f" {'N/A':>11} {len(masked):>4} |"
    
    print(row_str)
    walk_results[period] = period_data

results['walk_forward'] = walk_results

# ============================================================
# SUMMARY & RECOMMENDATIONS
# ============================================================
print("\n" + "="*70)
print("SUMMARY & RECOMMENDATIONS")
print("="*70)

summary = {}

# Best threshold
print("\n--- Recommended Threshold ---")
best_full_wr = 0
best_thresh = None
for thresh in [0.60, 0.65, 0.70, 0.75]:
    key = f'bull_{thresh}_60min'
    if key in threshold_results:
        r = threshold_results[key]
        if r['full_n'] >= 15 and r['full_wr'] and r['full_wr'] > best_full_wr:
            best_full_wr = r['full_wr']
            best_thresh = thresh

if best_thresh:
    best_key = f'bull_{best_thresh}_60min'
    r = threshold_results[best_key]
    print(f"  Best threshold (N≥15): >{best_thresh:.0%}")
    print(f"  Full WR: {r['full_wr']:.1%} (N={r['full_n']})")
    print(f"  IS WR: {r['is_wr']}" if r['is_wr'] else "  IS WR: N/A")
    print(f"  OOS WR: {r['oos_wr']}" if r['oos_wr'] else "  OOS WR: N/A")
    summary['recommended_threshold'] = best_thresh
    summary['recommended_wr'] = r['full_wr']
    summary['recommended_n'] = r['full_n']

# Expected frequency
freq_key = f'gt_{best_thresh}' if best_thresh else 'gt_0.7'
if freq_key in freq_results:
    f = freq_results[freq_key]
    print(f"\n  Expected frequency: {f['days_per_month']:.1f} days/month ({f['pct']:.1%} of trading days)")
    summary['expected_frequency_per_month'] = f['days_per_month']

# Best combo
print("\n--- Best Combinations ---")
if combo_results:
    sorted_combos = sorted(combo_results.items(), key=lambda x: (x[1].get('wr', 0) if x[1].get('n', 0) >= 5 else 0), reverse=True)
    for name, vals in sorted_combos[:5]:
        print(f"  {name}: WR={vals['wr']:.1%} N={vals['n']}")
    summary['best_combo'] = sorted_combos[0][0] if sorted_combos else None

summary['centerpiece_verdict'] = "YES" if (best_full_wr >= 0.65 and threshold_results.get(f'bull_{best_thresh}_60min', {}).get('full_n', 0) >= 20) else "CONDITIONAL"

results['summary'] = summary
results['metadata'] = {
    'n_days': len(data),
    'date_range': [str(data['date'].min().date()), str(data['date'].max().date())],
    'is_days': len(data[data['period'] == 'IS']),
    'oos_days': len(data[data['period'] == 'OOS']),
    'firm_tilt_mean': round(data['firm_neg_tilt'].mean(), 4),
    'firm_tilt_std': round(data['firm_neg_tilt'].std(), 4),
    'firm_tilt_median': round(data['firm_neg_tilt'].median(), 4),
}

# Save
with open(OUTPUT_FILE, 'w') as f:
    json.dump(results, f, indent=2, default=str)

print(f"\n✅ Results saved to {OUTPUT_FILE}")
print(f"Total days analyzed: {len(data)}")
