#!/usr/bin/env python3
"""
Comprehensive TRACE Pattern Mining
Mines 443 days of normalized TRACE data for predictive intraday patterns.
IS/OOS split: 60/40 by date. Only reports OOS metrics.
"""

import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
import json
import warnings
import traceback
warnings.filterwarnings('ignore')

DATA_DIR = Path('/Users/lutherbot/.openclaw/workspace/data')
TRACE_DIR = DATA_DIR / 'trace_normalized'

# Load FOMC dates
with open(DATA_DIR / 'fomc_dates.json') as f:
    FOMC_DATES = set(json.load(f)['dates'])

# GEX tier boundaries
GEX_TIERS = {
    'DEEP_NEG': (-np.inf, -100e6),
    'NEG': (-100e6, 0),
    'LOW_POS': (0, 100e6),
    'MID_POS': (100e6, 250e6),
    'HIGH_POS': (250e6, 500e6),
    'EXTREME_POS': (500e6, np.inf),
}

def get_gex_tier(val):
    for name, (lo, hi) in GEX_TIERS.items():
        if lo <= val < hi:
            return name
    return 'EXTREME_POS'

def load_spx_5min():
    """Load SPX 5-min data, convert UTC to ET"""
    df = pd.read_csv(DATA_DIR / 'spx_5min_polygon.csv', parse_dates=['datetime'])
    df['datetime'] = pd.to_datetime(df['datetime'], utc=True).dt.tz_convert('America/New_York')
    df = df.sort_values('datetime').reset_index(drop=True)
    return df

def load_es_1min():
    """Load ES 1-min data"""
    df = pd.read_csv(DATA_DIR / 'es_1min_bars.csv', parse_dates=['ts_event'])
    df['datetime'] = pd.to_datetime(df['ts_event']).dt.tz_convert('America/New_York')
    df = df.sort_values('datetime').reset_index(drop=True)
    return df

print("Loading SPX 5-min data...")
spx = load_spx_5min()
print(f"  SPX rows: {len(spx)}, range: {spx['datetime'].min()} to {spx['datetime'].max()}")

print("Loading ES 1-min data...")
es = load_es_1min()
print(f"  ES rows: {len(es)}, range: {es['datetime'].min()} to {es['datetime'].max()}")

# Build SPX daily OHLC and intraday returns
spx['date'] = spx['datetime'].dt.date
spx_daily = spx.groupby('date').agg(
    open=('open', 'first'),
    high=('high', 'max'),
    low=('low', 'min'),
    close=('close', 'last')
).reset_index()
spx_daily['date'] = pd.to_datetime(spx_daily['date'])
spx_daily['ret'] = spx_daily['close'] / spx_daily['open'] - 1
spx_daily['range_pct'] = (spx_daily['high'] - spx_daily['low']) / spx_daily['open']
spx_daily['close_ret'] = spx_daily['close'].pct_change()  # close-to-close

# Create hourly SPX returns for intraday analysis
spx['hour'] = spx['datetime'].dt.hour
spx['minute'] = spx['datetime'].dt.minute

print("\nLoading TRACE files...")

def load_all_trace():
    """Load all trace files and compute aggregate metrics per timestamp"""
    files = sorted(TRACE_DIR.glob('intradayStrikeGEX_*.parquet'))
    all_days = []
    
    for i, f in enumerate(files):
        date_str = f.stem.split('_')[-1]
        if date_str in FOMC_DATES:
            continue
        
        try:
            df = pd.read_parquet(f)
        except Exception:
            continue
        
        # Compute per-timestamp aggregates
        ts_agg = df.groupby('timestamp').agg(
            net_mm=('mm_gamma', 'sum'),
            net_cust=('cust_gamma', 'sum'),
            net_firm=('firm_gamma', 'sum'),
            net_bd=('bd_gamma', 'sum'),
            net_procust=('procust_gamma', 'sum'),
            net_mm_0=('mm_gamma_0', 'sum'),
            net_cust_0=('cust_gamma_0', 'sum'),
            net_firm_0=('firm_gamma_0', 'sum'),
            net_bd_0=('bd_gamma_0', 'sum'),
            net_procust_0=('procust_gamma_0', 'sum'),
            n_strikes=('strike_price', 'count'),
            # For strike-level analysis
            max_mm_strike=('mm_gamma', lambda x: df.loc[x.index, 'strike_price'].iloc[x.values.argmax()] if len(x) > 0 else np.nan),
        ).reset_index()
        
        # Total net gamma (all participants)
        ts_agg['net_total'] = ts_agg['net_mm'] + ts_agg['net_cust'] + ts_agg['net_firm'] + ts_agg['net_bd'] + ts_agg['net_procust']
        ts_agg['net_total_0'] = ts_agg['net_mm_0'] + ts_agg['net_cust_0'] + ts_agg['net_firm_0'] + ts_agg['net_bd_0'] + ts_agg['net_procust_0']
        
        # 0DTE ratio
        ts_agg['dte0_mm_ratio'] = ts_agg['net_mm_0'] / ts_agg['net_mm'].replace(0, np.nan)
        
        ts_agg['date'] = pd.Timestamp(date_str)
        ts_agg['date_str'] = date_str
        ts_agg['hour'] = ts_agg['timestamp'].dt.hour
        ts_agg['minute'] = ts_agg['timestamp'].dt.minute
        ts_agg['time_str'] = ts_agg['timestamp'].dt.strftime('%H:%M')
        
        # Also compute strike-level metrics for this day
        # Gamma concentration (HHI of mm_gamma across strikes)
        for ts in df['timestamp'].unique():
            ts_data = df[df['timestamp'] == ts]
            mm_abs = ts_data['mm_gamma'].abs()
            total = mm_abs.sum()
            if total > 0:
                shares = mm_abs / total
                hhi = (shares ** 2).sum()
                ts_agg.loc[ts_agg['timestamp'] == ts, 'gamma_hhi'] = hhi
                
                # Gamma spread (std of strike prices weighted by |mm_gamma|)
                weighted_std = np.average(ts_data['strike_price'], weights=mm_abs)
                gamma_spread = np.sqrt(np.average((ts_data['strike_price'] - weighted_std)**2, weights=mm_abs))
                ts_agg.loc[ts_agg['timestamp'] == ts, 'gamma_spread'] = gamma_spread
                
                # Gamma skew (weighted position relative to center)
                center = ts_data['strike_price'].median()
                above = ts_data[ts_data['strike_price'] >= center]['mm_gamma'].sum()
                below = ts_data[ts_data['strike_price'] < center]['mm_gamma'].sum()
                total_mm = abs(above) + abs(below)
                if total_mm > 0:
                    ts_agg.loc[ts_agg['timestamp'] == ts, 'gamma_tilt'] = above / total_mm
        
        all_days.append(ts_agg)
        
        if (i + 1) % 50 == 0:
            print(f"  Loaded {i+1} files...")
    
    result = pd.concat(all_days, ignore_index=True)
    print(f"  Total: {len(result)} timestamp-rows across {result['date_str'].nunique()} days")
    return result

trace = load_all_trace()

# Filter to market hours for most analysis (9:30 - 16:00 ET)
def is_market_hours(row):
    h, m = row['hour'], row['minute']
    return (h == 9 and m >= 30) or (10 <= h <= 15) or (h == 16 and m == 0)

trace['is_rth'] = trace.apply(is_market_hours, axis=1)
trace_rth = trace[trace['is_rth']].copy()

# Get unique dates, split IS/OOS
all_dates = sorted(trace['date_str'].unique())
n_is = int(len(all_dates) * 0.6)
is_dates = set(all_dates[:n_is])
oos_dates = set(all_dates[n_is:])
print(f"\nIS dates: {len(is_dates)}, OOS dates: {len(oos_dates)}")
print(f"IS range: {all_dates[0]} to {all_dates[n_is-1]}")
print(f"OOS range: {all_dates[n_is]} to {all_dates[-1]}")

# Build daily summary from market open snapshot
def get_daily_gex():
    """Get GEX at specific times each day"""
    results = []
    for date_str in all_dates:
        day = trace_rth[trace_rth['date_str'] == date_str].sort_values('timestamp')
        if len(day) == 0:
            continue
        
        row = {'date_str': date_str, 'date': pd.Timestamp(date_str)}
        
        # GEX at various times
        for target_h, target_m, label in [
            (9, 30, '0930'), (9, 40, '0940'), (10, 0, '1000'),
            (10, 30, '1030'), (11, 0, '1100'), (11, 30, '1130'),
            (12, 0, '1200'), (12, 30, '1230'), (13, 0, '1300'),
            (13, 30, '1330'), (14, 0, '1400'), (14, 30, '1430'),
            (15, 0, '1500'), (15, 30, '1530'), (16, 0, '1600'),
        ]:
            snap = day[(day['hour'] == target_h) & (day['minute'] == target_m)]
            if len(snap) > 0:
                s = snap.iloc[0]
                row[f'mm_{label}'] = s['net_mm']
                row[f'cust_{label}'] = s['net_cust']
                row[f'firm_{label}'] = s['net_firm']
                row[f'procust_{label}'] = s['net_procust']
                row[f'bd_{label}'] = s['net_bd']
                row[f'total_{label}'] = s['net_total']
                row[f'mm0_{label}'] = s['net_mm_0']
                row[f'dte0_ratio_{label}'] = s.get('dte0_mm_ratio', np.nan)
                row[f'hhi_{label}'] = s.get('gamma_hhi', np.nan)
                row[f'spread_{label}'] = s.get('gamma_spread', np.nan)
                row[f'tilt_{label}'] = s.get('gamma_tilt', np.nan)
        
        # GEX velocity (change from 9:30 to 10:30)
        if f'mm_0930' in row and f'mm_1030' in row:
            row['mm_velocity_first_hour'] = row.get('mm_1030', np.nan) - row.get('mm_0930', np.nan)
        
        # Afternoon erosion (change from 14:00 to 15:30)
        if 'mm_1400' in row and 'mm_1530' in row:
            row['mm_afternoon_erosion'] = row.get('mm_1530', np.nan) - row.get('mm_1400', np.nan)
        
        # Max/min GEX during day
        row['mm_max'] = day['net_mm'].max()
        row['mm_min'] = day['net_mm'].min()
        row['mm_range'] = row['mm_max'] - row['mm_min']
        row['mm_mean'] = day['net_mm'].mean()
        
        # GEX tier at open
        open_gex = row.get('mm_0930', row.get('mm_0940', np.nan))
        if not np.isnan(open_gex):
            row['gex_tier_open'] = get_gex_tier(open_gex)
        
        # Participant divergences at open
        if 'mm_0930' in row and 'cust_0930' in row:
            row['mm_cust_div_open'] = row['mm_0930'] - row['cust_0930']
            # Agreement score: how many participants have same sign as MM
            mm_sign = np.sign(row.get('mm_0930', 0))
            agree = sum(1 for p in ['cust_0930', 'firm_0930', 'procust_0930', 'bd_0930'] 
                       if p in row and np.sign(row[p]) == mm_sign)
            row['agreement_score_open'] = agree
        
        # Check for GEX flip during day
        if 'mm_0930' in row:
            open_sign = np.sign(row.get('mm_0930', 0))
            for label in ['1000', '1030', '1100', '1130', '1200', '1230', '1300', '1330', '1400', '1430', '1500']:
                if f'mm_{label}' in row:
                    if np.sign(row[f'mm_{label}']) != open_sign and open_sign != 0:
                        row['gex_flip_time'] = label
                        break
        
        results.append(row)
    
    return pd.DataFrame(results)

print("\nBuilding daily GEX summary...")
daily_gex = get_daily_gex()
print(f"  Daily GEX rows: {len(daily_gex)}")

# Merge with SPX daily
daily_gex['date'] = pd.to_datetime(daily_gex['date'])
spx_daily['date'] = pd.to_datetime(spx_daily['date'])
daily = daily_gex.merge(spx_daily[['date', 'open', 'high', 'low', 'close', 'ret', 'range_pct', 'close_ret']], 
                         on='date', how='inner')
print(f"  Merged with SPX: {len(daily)} days")

# Add day of week
daily['dow'] = daily['date'].dt.dayofweek  # 0=Mon

# Add previous day GEX for day-over-day
daily = daily.sort_values('date').reset_index(drop=True)
daily['prev_mm_0930'] = daily['mm_0930'].shift(1)
daily['gex_dod_change'] = daily['mm_0930'] - daily['prev_mm_0930']

# IS/OOS split
daily['is_is'] = daily['date_str'].isin(is_dates)
daily['is_oos'] = daily['date_str'].isin(oos_dates)

# ============================================================
# TESTING FRAMEWORK
# ============================================================
results = {}

def test_continuous(name, signal_col, target_col, df, min_n=50, desc=""):
    """Test continuous signal vs continuous target using IC"""
    for split_name, mask in [('IS', df['is_is']), ('OOS', df['is_oos'])]:
        sub = df[mask].dropna(subset=[signal_col, target_col])
        if len(sub) < min_n:
            results[name] = {'status': 'INSUFFICIENT_N', 'desc': desc}
            return
        ic, p = stats.spearmanr(sub[signal_col], sub[target_col])
        if split_name == 'IS':
            is_ic, is_p, is_n = ic, p, len(sub)
        else:
            oos_ic, oos_p, oos_n = ic, p, len(sub)
    
    results[name] = {
        'type': 'continuous',
        'desc': desc,
        'is_ic': round(is_ic, 4), 'is_p': round(is_p, 4), 'is_n': is_n,
        'oos_ic': round(oos_ic, 4), 'oos_p': round(oos_p, 4), 'oos_n': oos_n,
        'signal': abs(oos_ic) >= 0.08 and oos_p < 0.05,
    }

def test_quintile(name, signal_col, target_col, df, min_n=50, desc=""):
    """Test quintile spread of signal vs target WR"""
    for split_name, mask in [('IS', df['is_is']), ('OOS', df['is_oos'])]:
        sub = df[mask].dropna(subset=[signal_col, target_col])
        if len(sub) < min_n:
            results[name] = {'status': 'INSUFFICIENT_N', 'desc': desc}
            return
        sub = sub.copy()
        sub['q'] = pd.qcut(sub[signal_col], 5, labels=False, duplicates='drop')
        q_stats = sub.groupby('q')[target_col].agg(['mean', 'count', lambda x: (x > 0).mean()])
        q_stats.columns = ['mean_ret', 'n', 'wr']
        
        if split_name == 'IS':
            is_q = q_stats.to_dict()
        else:
            oos_q = q_stats.to_dict()
            oos_n = len(sub)
            # Q1 vs Q5 WR spread
            if 0 in q_stats.index and max(q_stats.index) in q_stats.index:
                q1_wr = q_stats.loc[0, 'wr']
                q5_wr = q_stats.loc[max(q_stats.index), 'wr']
                oos_spread = q5_wr - q1_wr
            else:
                oos_spread = 0
    
    ic, p = stats.spearmanr(df[df['is_oos']].dropna(subset=[signal_col, target_col])[signal_col],
                             df[df['is_oos']].dropna(subset=[signal_col, target_col])[target_col])
    
    results[name] = {
        'type': 'quintile',
        'desc': desc,
        'oos_ic': round(ic, 4), 'oos_p': round(p, 4), 'oos_n': oos_n,
        'oos_q1_wr': round(q_stats.loc[0, 'wr'], 4) if 0 in q_stats.index else None,
        'oos_q5_wr': round(q_stats.loc[max(q_stats.index), 'wr'], 4) if max(q_stats.index) in q_stats.index else None,
        'oos_wr_spread': round(oos_spread, 4),
        'signal': abs(oos_spread) >= 0.08 and p < 0.05,
    }

def test_conditional(name, condition_func, target_col, df, min_n=50, desc=""):
    """Test conditional WR (when condition is true, what's the WR?)"""
    for split_name, mask in [('IS', df['is_is']), ('OOS', df['is_oos'])]:
        sub = df[mask].dropna(subset=[target_col])
        cond = sub.apply(condition_func, axis=1)
        triggered = sub[cond]
        not_triggered = sub[~cond]
        
        if len(triggered) < min_n // 2:
            results[name] = {'status': 'INSUFFICIENT_N', 'desc': desc, 'n_triggered': len(triggered)}
            return
        
        wr_triggered = (triggered[target_col] > 0).mean()
        wr_not = (not_triggered[target_col] > 0).mean() if len(not_triggered) > 0 else 0.5
        
        if split_name == 'IS':
            is_wr, is_n = wr_triggered, len(triggered)
        else:
            oos_wr, oos_n = wr_triggered, len(triggered)
            oos_wr_base = wr_not
    
    # P-value: binomial test
    try:
        from scipy.stats import binomtest
        p_val = binomtest(int(oos_wr * oos_n), oos_n, 0.5).pvalue
    except Exception:
        try:
            from scipy.stats import binom_test
            p_val = binom_test(int(oos_wr * oos_n), oos_n, 0.5)
        except Exception:
            p_val = 1.0
    
    results[name] = {
        'type': 'conditional',
        'desc': desc,
        'is_wr': round(is_wr, 4), 'is_n': is_n,
        'oos_wr': round(oos_wr, 4), 'oos_n': oos_n,
        'oos_base_wr': round(oos_wr_base, 4),
        'oos_wr_spread': round(oos_wr - oos_wr_base, 4),
        'p_value': round(p_val, 4),
        'signal': abs(oos_wr - 0.5) >= 0.04 and p_val < 0.10,
    }

print("\n" + "="*60)
print("RUNNING TESTS")
print("="*60)

# ============================================================
# 1. INTRADAY GEX DYNAMICS
# ============================================================
print("\n--- 1. Intraday GEX Dynamics ---")

# 1a. GEX velocity (first hour) → afternoon return
test_continuous('gex_velocity_1h_to_ret', 'mm_velocity_first_hour', 'ret', daily,
                desc="GEX velocity (9:30→10:30 change) vs daily O→C return")

# 1b. GEX velocity → range
test_continuous('gex_velocity_1h_to_range', 'mm_velocity_first_hour', 'range_pct', daily,
                desc="GEX velocity (9:30→10:30 change) vs daily range")

# 1c. Afternoon erosion → close direction
daily['close_direction'] = (daily['close'] > daily.apply(
    lambda r: r.get('close', r['open']), axis=1)).astype(int)  # will fix below

# Need intraday SPX returns from specific times
# Build returns from 14:00 to close
print("  Building intraday returns from SPX...")
intraday_rets = {}
for date in daily['date']:
    date_val = date.date() if hasattr(date, 'date') else date
    day_spx = spx[spx['date'] == date_val].sort_values('datetime')
    if len(day_spx) < 10:
        continue
    
    d = {}
    d['open_price'] = day_spx.iloc[0]['open']
    d['close_price'] = day_spx.iloc[-1]['close']
    d['ret_oc'] = d['close_price'] / d['open_price'] - 1
    
    # Returns from various times to close
    for target_h, target_m, label in [
        (9, 30, '0930'), (10, 0, '1000'), (10, 30, '1030'),
        (11, 0, '1100'), (11, 30, '1130'), (12, 0, '1200'),
        (13, 0, '1300'), (14, 0, '1400'), (15, 0, '1500'),
    ]:
        t_snap = day_spx[(day_spx['datetime'].dt.hour == target_h) & 
                         (day_spx['datetime'].dt.minute == target_m)]
        if len(t_snap) > 0:
            d[f'price_{label}'] = t_snap.iloc[0]['close']
            d[f'ret_{label}_to_close'] = d['close_price'] / t_snap.iloc[0]['close'] - 1
    
    # First 15 min return
    t0 = day_spx[(day_spx['datetime'].dt.hour == 9) & (day_spx['datetime'].dt.minute == 30)]
    t1 = day_spx[(day_spx['datetime'].dt.hour == 9) & (day_spx['datetime'].dt.minute == 45)]
    if len(t0) > 0 and len(t1) > 0:
        d['ret_first_15'] = t1.iloc[0]['close'] / t0.iloc[0]['open'] - 1
    
    # First 30 min return
    t2 = day_spx[(day_spx['datetime'].dt.hour == 10) & (day_spx['datetime'].dt.minute == 0)]
    if len(t0) > 0 and len(t2) > 0:
        d['ret_first_30'] = t2.iloc[0]['close'] / t0.iloc[0]['open'] - 1
    
    # Hourly ranges for vol analysis
    for h in range(10, 16):
        hour_data = day_spx[(day_spx['datetime'].dt.hour == h)]
        if len(hour_data) > 0:
            d[f'range_h{h}'] = (hour_data['high'].max() - hour_data['low'].min()) / d['open_price']
    
    # Overnight gap (open vs prev close) - approximate from open
    d['open'] = d['open_price']
    
    intraday_rets[str(date_val)] = d

intraday_df = pd.DataFrame.from_dict(intraday_rets, orient='index')
intraday_df.index.name = 'date_str'
intraday_df = intraday_df.reset_index()
print(f"  Intraday returns built: {len(intraday_df)} days")

# Merge
daily = daily.merge(intraday_df, on='date_str', how='left', suffixes=('', '_intra'))

# Now run more tests

# 1d. 0DTE takeover ratio at 10:30 → rest of day return
daily['ret_1030_close'] = daily.get('ret_1030_to_close', pd.Series(dtype=float))
test_continuous('dte0_ratio_1030_to_ret', 'dte0_ratio_1030', 'ret_1030_to_close', daily,
                desc="0DTE/total MM gamma ratio at 10:30 vs 10:30→close return")

# 1e. Afternoon erosion → last hour return
test_continuous('afternoon_erosion_to_ret', 'mm_afternoon_erosion', 'ret_1400_to_close', daily,
                desc="MM gamma change 14:00→15:30 vs 14:00→close return")

# 1f. GEX at specific times → forward returns
for time_label, ret_label in [
    ('0930', 'ret_oc'), ('1000', 'ret_1000_to_close'), ('1030', 'ret_1030_to_close'),
    ('1100', 'ret_1100_to_close'), ('1200', 'ret_1200_to_close'), ('1400', 'ret_1400_to_close'),
]:
    test_continuous(f'mm_gex_{time_label}_to_{ret_label}', f'mm_{time_label}', ret_label, daily,
                    desc=f"MM GEX at {time_label} vs {ret_label}")

# 1g. GEX reversal (flip during day)
daily['had_flip'] = daily['gex_flip_time'].notna()
test_conditional('gex_flip_positive_ret', lambda r: r.get('had_flip', False), 'ret_oc', daily,
                 desc="When GEX flips sign during day, is return positive?")

# ============================================================
# 2. PARTICIPANT FLOW PATTERNS
# ============================================================
print("\n--- 2. Participant Flow Patterns ---")

# 2a. MM vs Customer divergence at open
test_continuous('mm_cust_div_to_ret', 'mm_cust_div_open', 'ret_oc', daily,
                desc="MM-Customer gamma divergence at open vs O→C return")

test_continuous('mm_cust_div_to_range', 'mm_cust_div_open', 'range_pct', daily,
                desc="MM-Customer gamma divergence at open vs daily range")

# 2b. Procust (hedge fund) positioning relative to MM at open
if 'procust_0930' in daily.columns and 'mm_0930' in daily.columns:
    daily['procust_mm_ratio'] = daily['procust_0930'] / daily['mm_0930'].replace(0, np.nan)
    daily['procust_mm_opposite'] = (np.sign(daily['procust_0930']) != np.sign(daily['mm_0930']))
    
    test_continuous('procust_mm_ratio_to_ret', 'procust_mm_ratio', 'ret_oc', daily,
                    desc="Procust/MM gamma ratio at open vs O→C return")
    
    test_conditional('procust_vs_mm_opposite_ret', 
                     lambda r: r.get('procust_mm_opposite', False), 'ret_oc', daily,
                     desc="When procust and MM have opposite sign at open, WR")

# 2c. Participant agreement score
test_continuous('agreement_score_to_ret', 'agreement_score_open', 'ret_oc', daily,
                desc="Participant agreement score at open vs O→C return")

test_continuous('agreement_score_to_range', 'agreement_score_open', 'range_pct', daily,
                desc="Participant agreement score at open vs daily range")

# 2d. Firm gamma magnitude at open → vol
if 'firm_0930' in daily.columns:
    daily['firm_abs_open'] = daily['firm_0930'].abs()
    test_continuous('firm_abs_to_range', 'firm_abs_open', 'range_pct', daily,
                    desc="Firm gamma magnitude at open vs daily range")
    test_continuous('firm_sign_to_ret', 'firm_0930', 'ret_oc', daily,
                    desc="Firm gamma sign/magnitude at open vs O→C return")

# ============================================================
# 3. STRIKE-LEVEL PATTERNS
# ============================================================
print("\n--- 3. Strike-Level Patterns ---")

# 3a. Gamma concentration (HHI) at open → range
test_continuous('gamma_hhi_0930_to_range', 'hhi_0930', 'range_pct', daily,
                desc="Gamma HHI (concentration) at 9:30 vs daily range")

test_continuous('gamma_hhi_0930_to_ret', 'hhi_0930', 'ret_oc', daily,
                desc="Gamma HHI (concentration) at 9:30 vs O→C return")

# 3b. Gamma spread at open → range
test_continuous('gamma_spread_0930_to_range', 'spread_0930', 'range_pct', daily,
                desc="Gamma spread (points) at 9:30 vs daily range")

# 3c. Gamma tilt at various times
for time_label in ['0930', '1030', '1200', '1400']:
    test_continuous(f'gamma_tilt_{time_label}_to_ret', f'tilt_{time_label}', 'ret_oc', daily,
                    desc=f"Gamma tilt (above/total ratio) at {time_label} vs O→C return")

# ============================================================
# 4. REGIME TRANSITIONS
# ============================================================
print("\n--- 4. Regime Transitions ---")

# 4a. Day-over-day GEX change → return
test_continuous('gex_dod_change_to_ret', 'gex_dod_change', 'ret_oc', daily,
                desc="Day-over-day GEX change vs O→C return")

test_quintile('gex_dod_change_quintile', 'gex_dod_change', 'ret_oc', daily,
              desc="Day-over-day GEX change quintiles vs O→C return")

# 4b. GEX tier transitions
if 'gex_tier_open' in daily.columns:
    daily['prev_tier'] = daily['gex_tier_open'].shift(1)
    daily['tier_changed'] = daily['gex_tier_open'] != daily['prev_tier']
    
    # Tier upgrade (moving to higher GEX)
    tier_order = {'DEEP_NEG': 0, 'NEG': 1, 'LOW_POS': 2, 'MID_POS': 3, 'HIGH_POS': 4, 'EXTREME_POS': 5}
    daily['tier_num'] = daily['gex_tier_open'].map(tier_order)
    daily['prev_tier_num'] = daily['prev_tier'].map(tier_order)
    daily['tier_delta'] = daily['tier_num'] - daily['prev_tier_num']
    
    test_continuous('tier_delta_to_ret', 'tier_delta', 'ret_oc', daily,
                    desc="GEX tier change (upgrade=+, downgrade=-) vs O→C return")

# 4c. GEX at open vs GEX range during day
daily['gex_intraday_range_ratio'] = daily['mm_range'] / daily['mm_0930'].abs().replace(0, np.nan)
test_continuous('gex_intraday_instability_to_range', 'gex_intraday_range_ratio', 'range_pct', daily,
                desc="GEX intraday range / open GEX vs daily price range")

# ============================================================
# 5. CONDITIONAL SETUPS
# ============================================================
print("\n--- 5. Conditional Setups ---")

# 5a. High GEX + first 15 min direction
if 'ret_first_15' in daily.columns:
    # In HIGH/EXTREME POS, if first 15 min goes up, does rest of day continue?
    def high_gex_first15_up(r):
        tier = r.get('gex_tier_open', '')
        return tier in ('HIGH_POS', 'EXTREME_POS') and r.get('ret_first_15', 0) > 0
    
    def high_gex_first15_down(r):
        tier = r.get('gex_tier_open', '')
        return tier in ('HIGH_POS', 'EXTREME_POS') and r.get('ret_first_15', 0) < 0
    
    test_conditional('high_gex_first15_up_continuation', high_gex_first15_up, 
                     'ret_1000_to_close', daily, min_n=30,
                     desc="HIGH/EXTREME GEX + first 15 min UP → 10:00→close return positive?")
    
    test_conditional('high_gex_first15_down_reversal', high_gex_first15_down,
                     'ret_1000_to_close', daily, min_n=30,
                     desc="HIGH/EXTREME GEX + first 15 min DOWN → 10:00→close return positive? (mean reversion)")

# 5b. GEX × day of week
for dow, dow_name in [(0, 'Mon'), (1, 'Tue'), (2, 'Wed'), (3, 'Thu'), (4, 'Fri')]:
    sub = daily[daily['dow'] == dow].copy()
    if len(sub) > 30:
        test_continuous(f'mm_gex_0930_to_ret_{dow_name}', 'mm_0930', 'ret_oc', sub,
                        desc=f"MM GEX at 9:30 vs O→C return on {dow_name}")

# 5c. GEX × first 30 min
if 'ret_first_30' in daily.columns:
    # Interaction: GEX * first 30 min return
    daily['gex_x_first30'] = daily['mm_0930'] * np.sign(daily['ret_first_30'].fillna(0))
    test_continuous('gex_x_first30_to_ret', 'gex_x_first30', 'ret_1000_to_close', daily,
                    desc="GEX × first 30 min direction interaction vs 10:00→close return")

# 5d. Large GEX decline from yesterday → today's direction
def big_gex_decline(r):
    return r.get('gex_dod_change', 0) < -100e6
    
test_conditional('big_gex_decline_ret', big_gex_decline, 'ret_oc', daily,
                 desc="When GEX drops >100M from yesterday → O→C return direction")

# ============================================================  
# 6. TIMING PATTERNS
# ============================================================
print("\n--- 6. Timing Patterns ---")

# 6a. Best time to enter by GEX tier - using intraday returns
# For each tier, what's the avg return from each time to close?
timing_results = {}
for tier in ['DEEP_NEG', 'NEG', 'LOW_POS', 'MID_POS', 'HIGH_POS', 'EXTREME_POS']:
    tier_data = daily[(daily['gex_tier_open'] == tier) & daily['is_oos']]
    if len(tier_data) < 10:
        continue
    tier_timing = {}
    for time_label in ['0930', '1000', '1030', '1100', '1200', '1300', '1400', '1500']:
        col = f'ret_{time_label}_to_close' if time_label != '0930' else 'ret_oc'
        vals = tier_data[col].dropna()
        if len(vals) > 5:
            tier_timing[time_label] = {
                'mean_ret': round(vals.mean() * 10000, 2),  # bps
                'wr': round((vals > 0).mean(), 4),
                'n': len(vals),
            }
    timing_results[tier] = tier_timing

results['timing_by_tier'] = {
    'type': 'timing',
    'desc': 'Average return (bps) and WR from each time to close, by GEX tier (OOS)',
    'data': timing_results,
}

# 6b. Close prediction from midday GEX
test_continuous('mm_1200_to_close_ret', 'mm_1200', 'ret_1200_to_close', daily,
                desc="MM GEX at 12:00 vs 12:00→close return")

test_continuous('mm_1300_to_close_ret', 'mm_1300', 'ret_1300_to_close', daily,
                desc="MM GEX at 13:00 vs 13:00→close return")

# ============================================================
# 7. VOLATILITY PATTERNS
# ============================================================
print("\n--- 7. Volatility Patterns ---")

# 7a. Hourly range by GEX tier
vol_by_tier = {}
for tier in ['DEEP_NEG', 'NEG', 'LOW_POS', 'MID_POS', 'HIGH_POS', 'EXTREME_POS']:
    tier_data = daily[(daily['gex_tier_open'] == tier) & daily['is_oos']]
    if len(tier_data) < 10:
        continue
    hourly = {}
    for h in range(10, 16):
        col = f'range_h{h}'
        if col in tier_data.columns:
            vals = tier_data[col].dropna()
            if len(vals) > 5:
                hourly[f'h{h}'] = {
                    'mean_range_bps': round(vals.mean() * 10000, 2),
                    'n': len(vals),
                }
    vol_by_tier[tier] = hourly

results['vol_profile_by_tier'] = {
    'type': 'vol_profile',
    'desc': 'Average hourly range (bps) by GEX tier (OOS)',
    'data': vol_by_tier,
}

# 7b. GEX magnitude vs hourly ranges
for h in range(10, 16):
    col = f'range_h{h}'
    if col in daily.columns:
        test_continuous(f'mm_gex_0930_to_range_h{h}', 'mm_0930', col, daily,
                        desc=f"MM GEX at 9:30 vs hour {h} range")

# 7c. Explosion detection: High GEX but high range
if 'range_pct' in daily.columns:
    daily['range_pct_rank'] = daily.groupby('is_is')['range_pct'].rank(pct=True)
    daily['gex_rank'] = daily.groupby('is_is')['mm_0930'].rank(pct=True)
    
    # Days with high GEX (>P60) but high range (>P80) - "explosions"
    def is_explosion(r):
        return r.get('gex_rank', 0) > 0.6 and r.get('range_pct_rank', 0) > 0.8
    
    # What predicts these? Check morning velocity, 0DTE ratio, etc.
    daily['is_explosion'] = daily.apply(is_explosion, axis=1)

# ============================================================
# 8. ADDITIONAL PATTERNS
# ============================================================
print("\n--- 8. Additional Patterns ---")

# 8a. GEX mean (full day average) vs close direction
test_continuous('gex_mean_to_ret', 'mm_mean', 'ret_oc', daily,
                desc="Average MM GEX (all RTH snapshots) vs O→C return")

# 8b. MM gamma at open vs range — quintile analysis
test_quintile('mm_gex_0930_quintile_ret', 'mm_0930', 'ret_oc', daily,
              desc="MM GEX at 9:30 quintiles vs O→C return")

test_quintile('mm_gex_0930_quintile_range', 'mm_0930', 'range_pct', daily,
              desc="MM GEX at 9:30 quintiles vs daily range")

# 8c. Overnight gap analysis
daily['overnight_gap'] = daily['open_price'] / daily['close'].shift(1) - 1
daily['gap_faded'] = np.sign(daily['overnight_gap']) != np.sign(daily['ret_oc'])

# Does GEX predict gap fade?
daily['gex_x_gap'] = daily['mm_0930'] * daily['overnight_gap']
test_continuous('gex_x_gap_to_ret', 'gex_x_gap', 'ret_oc', daily,
                desc="GEX × overnight gap interaction vs O→C return")

# High GEX + gap fade
def high_gex_gap_up(r):
    return r.get('gex_tier_open', '') in ('HIGH_POS', 'EXTREME_POS') and r.get('overnight_gap', 0) > 0.002

test_conditional('high_gex_gap_up_fade', high_gex_gap_up, 'ret_oc', daily, min_n=20,
                 desc="HIGH/EXTREME GEX + gap up >0.2% → O→C return direction (fade expected)")

# 8d. 0DTE dominance patterns
for time_label in ['0930', '1030', '1200', '1400']:
    col = f'dte0_ratio_{time_label}'
    if col in daily.columns:
        test_continuous(f'dte0_ratio_{time_label}_to_range', col, 'range_pct', daily,
                        desc=f"0DTE ratio at {time_label} vs daily range")

# 8e. Absolute GEX magnitude (ignoring sign) vs range
daily['mm_abs_0930'] = daily['mm_0930'].abs()
test_continuous('mm_abs_gex_to_range', 'mm_abs_0930', 'range_pct', daily,
                desc="Absolute MM GEX at 9:30 vs daily range (vol suppression)")

# 8f. GEX percentile rank → return
daily['mm_0930_pctile'] = daily['mm_0930'].rank(pct=True)
test_quintile('mm_gex_percentile_quintile', 'mm_0930_pctile', 'ret_oc', daily,
              desc="MM GEX percentile rank quintiles vs O→C return")

# ============================================================
# 9. INTRADAY TIME-SPECIFIC GEX SIGNALS
# ============================================================
print("\n--- 9. Time-Specific GEX Signals ---")

# Test GEX at each time → forward return to close
# Which time's GEX reading is most predictive?
for time_label in ['0930', '1000', '1030', '1100', '1130', '1200', '1300', '1400', '1500']:
    ret_col = f'ret_{time_label}_to_close' if time_label != '0930' else 'ret_oc'
    if ret_col in daily.columns:
        test_continuous(f'mm_gex_{time_label}_predictive', f'mm_{time_label}', ret_col, daily,
                        desc=f"MM GEX at {time_label} vs {time_label}→close return (predictiveness)")

# 10. GEX changes between snapshots
print("\n--- 10. GEX Change Patterns ---")
for from_t, to_t in [('0930', '1000'), ('1000', '1030'), ('1030', '1100'), ('1100', '1130'),
                      ('1130', '1200'), ('1200', '1300'), ('1300', '1400'), ('1400', '1500')]:
    from_col = f'mm_{from_t}'
    to_col = f'mm_{to_t}'
    if from_col in daily.columns and to_col in daily.columns:
        daily[f'mm_chg_{from_t}_{to_t}'] = daily[to_col] - daily[from_col]
        ret_col = f'ret_{to_t}_to_close' if to_t != '0930' else 'ret_oc'
        if ret_col in daily.columns:
            test_continuous(f'mm_chg_{from_t}_{to_t}_to_ret', f'mm_chg_{from_t}_{to_t}', ret_col, daily,
                            desc=f"MM GEX change {from_t}→{to_t} vs {to_t}→close return")

# ============================================================
# 11. PARTICIPANT-SPECIFIC TIME PATTERNS  
# ============================================================
print("\n--- 11. Participant-Specific Patterns ---")

for participant in ['cust', 'firm', 'procust', 'bd']:
    for time_label in ['0930', '1030', '1200']:
        col = f'{participant}_{time_label}'
        if col in daily.columns:
            test_continuous(f'{participant}_{time_label}_to_ret', col, 'ret_oc', daily,
                            desc=f"{participant} gamma at {time_label} vs O→C return")

# ============================================================
# COMPILE RESULTS
# ============================================================
print("\n" + "="*60)
print("COMPILING RESULTS")
print("="*60)

# Separate signals from noise
signals = {}
non_signals = {}
for name, r in results.items():
    if r.get('signal', False):
        signals[name] = r
    elif r.get('status') != 'INSUFFICIENT_N':
        non_signals[name] = r

print(f"\nTotal tests run: {len(results)}")
print(f"Signals found (OOS): {len(signals)}")
print(f"Non-signals: {len(non_signals)}")

# Save raw results
with open(DATA_DIR / 'trace_pattern_mining.json', 'w') as f:
    # Convert numpy types for JSON serialization
    def convert(o):
        if isinstance(o, (np.integer,)): return int(o)
        if isinstance(o, (np.floating,)): return float(o)
        if isinstance(o, (np.bool_,)): return bool(o)
        if isinstance(o, np.ndarray): return o.tolist()
        if pd.isna(o): return None
        return o
    
    json.dump({k: {kk: convert(vv) for kk, vv in v.items()} if isinstance(v, dict) else v 
               for k, v in results.items()}, f, indent=2, default=convert)

print("Saved trace_pattern_mining.json")

# Generate report
report_lines = []
report_lines.append("# TRACE Pattern Mining Report")
report_lines.append(f"*Generated: 2026-03-15*\n")
report_lines.append(f"**Dataset:** {len(all_dates)} trading days (FOMC excluded)")
report_lines.append(f"**IS period:** {all_dates[0]} to {all_dates[n_is-1]} ({len(is_dates)} days)")
report_lines.append(f"**OOS period:** {all_dates[n_is]} to {all_dates[-1]} ({len(oos_dates)} days)")
report_lines.append(f"**Tests run:** {len(results)}")
report_lines.append(f"**OOS signals:** {len(signals)}\n")

report_lines.append("---\n")
report_lines.append("## 🎯 Tradeable Findings (OOS Confirmed)\n")

if len(signals) == 0:
    report_lines.append("*No new signals survived OOS testing at IC ≥ 0.08 or WR spread ≥ 8pp.*\n")
else:
    for name, r in sorted(signals.items(), key=lambda x: abs(x[1].get('oos_ic', x[1].get('oos_wr', 0) - 0.5)), reverse=True):
        report_lines.append(f"### {name}")
        report_lines.append(f"**Description:** {r.get('desc', '')}")
        if r.get('type') == 'continuous':
            report_lines.append(f"- OOS IC: {r['oos_ic']}, p={r['oos_p']}, N={r['oos_n']}")
            report_lines.append(f"- IS IC: {r['is_ic']}, p={r['is_p']}, N={r['is_n']}")
        elif r.get('type') == 'conditional':
            report_lines.append(f"- OOS WR: {r['oos_wr']:.1%}, N={r['oos_n']}, p={r['p_value']}")
            report_lines.append(f"- Base WR: {r['oos_base_wr']:.1%}, Spread: {r['oos_wr_spread']:.1%}")
            report_lines.append(f"- IS WR: {r['is_wr']:.1%}, N={r['is_n']}")
        elif r.get('type') == 'quintile':
            report_lines.append(f"- OOS IC: {r['oos_ic']}, p={r['oos_p']}, N={r['oos_n']}")
            report_lines.append(f"- Q1 WR: {r.get('oos_q1_wr', 'N/A')}, Q5 WR: {r.get('oos_q5_wr', 'N/A')}, Spread: {r.get('oos_wr_spread', 'N/A')}")
        report_lines.append("")

report_lines.append("---\n")
report_lines.append("## 📊 All Test Results\n")

# Group by category
categories = {
    '1. Intraday GEX Dynamics': [k for k in results if k.startswith(('gex_velocity', 'dte0_ratio_1030', 'afternoon_erosion', 'mm_gex_0930_to_ret', 'mm_gex_1', 'gex_flip'))],
    '2. Participant Flow Patterns': [k for k in results if k.startswith(('mm_cust', 'procust', 'agreement', 'firm'))],
    '3. Strike-Level Patterns': [k for k in results if k.startswith(('gamma_hhi', 'gamma_spread', 'gamma_tilt'))],
    '4. Regime Transitions': [k for k in results if k.startswith(('gex_dod', 'tier_', 'gex_intraday'))],
    '5. Conditional Setups': [k for k in results if k.startswith(('high_gex', 'big_gex', 'gex_x'))],
    '6. Timing Patterns': [k for k in results if k in ('timing_by_tier', 'mm_1200_to_close_ret', 'mm_1300_to_close_ret')],
    '7. Volatility Patterns': [k for k in results if k.startswith(('vol_profile', 'mm_gex_0930_to_range_h', 'mm_abs_gex'))],
    '8. Additional Patterns': [k for k in results if k.startswith(('gex_mean', 'mm_gex_0930_quintile', 'mm_gex_percentile', 'overnight'))],
    '9. Time-Specific Signals': [k for k in results if k.startswith('mm_gex_') and 'predictive' in k],
    '10. GEX Change Patterns': [k for k in results if k.startswith('mm_chg_')],
    '11. Participant Time Patterns': [k for k in results if any(k.startswith(p) for p in ['cust_', 'firm_', 'procust_', 'bd_'])],
}

# Also catch anything not categorized
all_categorized = set()
for v in categories.values():
    all_categorized.update(v)
uncategorized = [k for k in results if k not in all_categorized]
if uncategorized:
    categories['Other'] = uncategorized

for cat_name, keys in categories.items():
    report_lines.append(f"### {cat_name}\n")
    for k in sorted(keys):
        r = results[k]
        status = '✅' if r.get('signal', False) else '❌'
        
        if r.get('type') == 'continuous':
            report_lines.append(f"- {status} **{k}**: OOS IC={r['oos_ic']}, p={r['oos_p']}, N={r['oos_n']} | IS IC={r['is_ic']}")
        elif r.get('type') == 'conditional':
            report_lines.append(f"- {status} **{k}**: OOS WR={r['oos_wr']:.1%} (N={r['oos_n']}), p={r.get('p_value', 'N/A')} | IS WR={r['is_wr']:.1%}")
        elif r.get('type') == 'quintile':
            report_lines.append(f"- {status} **{k}**: OOS IC={r['oos_ic']}, Q1→Q5 WR spread={r.get('oos_wr_spread', 'N/A')}")
        elif r.get('type') == 'timing':
            report_lines.append(f"- 📊 **{k}**: See timing table below")
        elif r.get('type') == 'vol_profile':
            report_lines.append(f"- 📊 **{k}**: See vol profile table below")
        elif r.get('status') == 'INSUFFICIENT_N':
            report_lines.append(f"- ⚠️ **{k}**: Insufficient N ({r.get('n_triggered', 'N/A')})")
        else:
            report_lines.append(f"- {status} **{k}**: {r}")
        
        report_lines.append(f"  *{r.get('desc', '')}*")
    report_lines.append("")

# Add timing table
report_lines.append("---\n")
report_lines.append("## ⏰ Timing by GEX Tier (OOS)\n")
report_lines.append("Average return (bps) and WR from entry time to close:\n")

if timing_results:
    # Table header
    times = ['0930', '1000', '1030', '1100', '1200', '1300', '1400', '1500']
    report_lines.append("| Tier | " + " | ".join(times) + " |")
    report_lines.append("|" + "---|" * (len(times) + 1))
    
    for tier in ['DEEP_NEG', 'NEG', 'LOW_POS', 'MID_POS', 'HIGH_POS', 'EXTREME_POS']:
        if tier not in timing_results:
            continue
        row = f"| {tier} |"
        for t in times:
            if t in timing_results[tier]:
                d = timing_results[tier][t]
                row += f" {d['mean_ret']:+.1f}bps ({d['wr']:.0%}, n={d['n']}) |"
            else:
                row += " — |"
        report_lines.append(row)
    report_lines.append("")

# Add vol profile table
report_lines.append("## 📈 Hourly Vol Profile by GEX Tier (OOS)\n")
report_lines.append("Average hourly range in bps:\n")

if vol_by_tier:
    hours = [f'h{h}' for h in range(10, 16)]
    report_lines.append("| Tier | " + " | ".join(hours) + " |")
    report_lines.append("|" + "---|" * (len(hours) + 1))
    
    for tier in ['DEEP_NEG', 'NEG', 'LOW_POS', 'MID_POS', 'HIGH_POS', 'EXTREME_POS']:
        if tier not in vol_by_tier:
            continue
        row = f"| {tier} |"
        for h in hours:
            if h in vol_by_tier[tier]:
                d = vol_by_tier[tier][h]
                row += f" {d['mean_range_bps']:.1f} |"
            else:
                row += " — |"
        report_lines.append(row)

report_lines.append("\n---\n")
report_lines.append("## Methodology Notes\n")
report_lines.append("- IS/OOS split: 60/40 chronological")
report_lines.append("- Signal threshold: |IC| ≥ 0.08 with p < 0.05, or WR spread ≥ 8pp")
report_lines.append("- Minimum N = 50 for continuous tests, 25 for conditional")
report_lines.append("- FOMC dates excluded")
report_lines.append("- All returns are SPX-based (open→close or time→close)")
report_lines.append("- GEX = net MM gamma across all strikes at specified timestamp")
report_lines.append("- Tilt = ratio of positive gamma above median strike to total absolute gamma")

report = '\n'.join(report_lines)

with open(DATA_DIR / 'trace_pattern_mining_report.md', 'w') as f:
    f.write(report)
print("\nSaved trace_pattern_mining_report.md")

# Print summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"\nSignals that survived OOS ({len(signals)}):")
for name, r in sorted(signals.items()):
    print(f"  ✅ {name}: {r.get('desc', '')}")
    if r.get('type') == 'continuous':
        print(f"     OOS IC={r['oos_ic']}, p={r['oos_p']}")
    elif r.get('type') == 'conditional':
        print(f"     OOS WR={r['oos_wr']:.1%}, N={r['oos_n']}")
    elif r.get('type') == 'quintile':
        print(f"     OOS Q1→Q5 spread={r.get('oos_wr_spread', 'N/A')}")

print(f"\nNon-signals ({len(non_signals)}):")
for name, r in sorted(non_signals.items()):
    if r.get('type') == 'continuous':
        print(f"  ❌ {name}: IC={r.get('oos_ic', 'N/A')}, p={r.get('oos_p', 'N/A')}")
    elif r.get('type') == 'conditional':
        print(f"  ❌ {name}: WR={r.get('oos_wr', 'N/A')}")

print("\nDone!")
