#!/usr/bin/env python3
"""
Intraday Signal Research for ES Futures
Tests multiple signal candidates across time-of-day windows.
"""

import pandas as pd
import numpy as np
from scipy import stats
import json
import os
import glob
from datetime import datetime, timedelta, time as dtime
import warnings
warnings.filterwarnings('ignore')

DATA = '/Users/daniel/.openclaw/workspace/data'

# ============================================================
# STEP 1: Load ES price data — use 1-min bars, resample to 5-min
# ============================================================
print("=" * 60)
print("LOADING ES PRICE DATA (1-min bars → 5-min RTH)")
print("=" * 60)

es1 = pd.read_csv(f'{DATA}/es_1min_bars.csv')
# Filter to outright contracts only (no spreads with '-')
es1 = es1[~es1['symbol'].str.contains('-', na=False)].copy()
es1['ts'] = pd.to_datetime(es1['ts_event'], utc=True).dt.tz_convert('US/Eastern')
es1 = es1.sort_values('ts').reset_index(drop=True)
es1['date'] = es1['ts'].dt.date
es1['time'] = es1['ts'].dt.time

# Keep only the most-traded symbol per day (front month)
daily_vol = es1.groupby(['date', 'symbol'])['volume'].sum().reset_index()
front_month = daily_vol.loc[daily_vol.groupby('date')['volume'].idxmax()][['date', 'symbol']]
es1 = es1.merge(front_month, on=['date', 'symbol'])

# Filter RTH: 9:30 - 16:00 ET
es1_rth = es1[(es1['time'] >= dtime(9, 30)) & (es1['time'] < dtime(16, 0))].copy()

# Resample to 5-min bars
es5 = es1_rth.set_index('ts').groupby('date').resample('5min').agg({
    'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'
}).dropna(subset=['close']).reset_index()

# Rename level_1 if needed
if 'ts' not in es5.columns and 'level_1' in es5.columns:
    es5 = es5.rename(columns={'level_1': 'ts'})

es5['time'] = es5['ts'].dt.time
es5 = es5[(es5['time'] >= dtime(9, 30)) & (es5['time'] < dtime(16, 0))].copy()
es5 = es5.sort_values('ts').reset_index(drop=True)

print(f"ES 5min RTH: {len(es5)} bars, {es5['date'].nunique()} days")
print(f"Date range: {es5['date'].min()} to {es5['date'].max()}")

# ============================================================
# Compute forward returns
# ============================================================
print("Computing forward returns...")

def compute_forward_returns(df):
    df = df.copy()
    df['fwd_1h'] = np.nan
    df['fwd_3h'] = np.nan
    df['fwd_eod'] = np.nan
    
    for date, group in df.groupby('date'):
        idx = group.index.values
        closes = group['close'].values
        n = len(closes)
        eod_price = closes[-1]
        
        for i in range(n):
            if i + 12 < n:
                df.at[idx[i], 'fwd_1h'] = closes[i + 12] / closes[i] - 1
            if i + 36 < n:
                df.at[idx[i], 'fwd_3h'] = closes[i + 36] / closes[i] - 1
            if i < n - 1:
                df.at[idx[i], 'fwd_eod'] = eod_price / closes[i] - 1
    
    return df

es5 = compute_forward_returns(es5)
print(f"Non-null: 1H={es5['fwd_1h'].notna().sum()}, 3H={es5['fwd_3h'].notna().sum()}, EOD={es5['fwd_eod'].notna().sum()}")

# Time windows
def get_window(t):
    if t >= dtime(9, 30) and t < dtime(11, 0):
        return 'morning'
    elif t >= dtime(11, 0) and t < dtime(13, 0):
        return 'midday'
    elif t >= dtime(13, 0) and t < dtime(16, 0):
        return 'afternoon'
    return None

es5['window'] = es5['time'].apply(get_window)

# IS/OOS split — 67/33
all_dates = sorted(es5['date'].unique())
n_dates = len(all_dates)
split_idx = int(n_dates * 0.67)
is_dates = set(all_dates[:split_idx])
oos_dates = set(all_dates[split_idx:])
es5['is_oos'] = es5['date'].apply(lambda d: 'IS' if d in is_dates else 'OOS')

print(f"IS: {len(is_dates)} days ({min(is_dates)} → {max(is_dates)})")
print(f"OOS: {len(oos_dates)} days ({min(oos_dates)} → {max(oos_dates)})")

# ============================================================
# EVALUATION HELPERS
# ============================================================
def compute_ic(signal, returns):
    mask = signal.notna() & returns.notna()
    if mask.sum() < 30:
        return np.nan
    return stats.spearmanr(signal[mask], returns[mask])[0]

def compute_hit_rate(signal, returns, quantile=0.8):
    mask = signal.notna() & returns.notna()
    s, r = signal[mask], returns[mask]
    if len(s) < 50:
        return np.nan
    threshold = s.quantile(quantile)
    top = r[s >= threshold]
    return (top > 0).mean() if len(top) > 0 else np.nan

def evaluate_signal(df, signal_col, name, description):
    result = {
        'name': name, 'description': description,
        'morning_ic': {}, 'midday_ic': {}, 'afternoon_ic': {},
        'best_horizon': None, 'best_time': None,
        'hit_rate_top_quintile': None, 'n_obs': 0, 'notes': ''
    }
    
    mask = df[signal_col].notna()
    result['n_obs'] = int(mask.sum())
    if result['n_obs'] < 100:
        result['notes'] = f'Insufficient data: {result["n_obs"]} obs'
        return result
    
    best_ic = 0
    best_combo = (None, None)
    
    for window in ['morning', 'midday', 'afternoon']:
        wk = f'{window}_ic'
        for horizon in ['1h', '3h', 'eod']:
            fwd_col = f'fwd_{horizon}'
            
            is_mask = mask & (df['window'] == window) & (df['is_oos'] == 'IS')
            oos_mask = mask & (df['window'] == window) & (df['is_oos'] == 'OOS')
            
            ic_is = compute_ic(df.loc[is_mask, signal_col], df.loc[is_mask, fwd_col])
            ic_oos = compute_ic(df.loc[oos_mask, signal_col], df.loc[oos_mask, fwd_col])
            
            result[wk][horizon] = {
                'is': round(ic_is, 4) if pd.notna(ic_is) else None,
                'oos': round(ic_oos, 4) if pd.notna(ic_oos) else None,
                'n_is': int(is_mask.sum()), 'n_oos': int(oos_mask.sum())
            }
            
            if pd.notna(ic_is) and pd.notna(ic_oos):
                if np.sign(ic_is) == np.sign(ic_oos) and abs(ic_oos) > abs(best_ic):
                    best_ic = ic_oos
                    best_combo = (window, horizon)
    
    if best_combo[0]:
        result['best_time'] = best_combo[0]
        result['best_horizon'] = best_combo[1]
        combo_mask = mask & (df['window'] == best_combo[0]) & (df['is_oos'] == 'OOS')
        hr = compute_hit_rate(df.loc[combo_mask, signal_col], df.loc[combo_mask, f'fwd_{best_combo[1]}'])
        result['hit_rate_top_quintile'] = round(hr, 4) if pd.notna(hr) else None
    
    consistent = any(
        pd.notna(result[f'{w}_ic'].get(h, {}).get('is')) and
        pd.notna(result[f'{w}_ic'].get(h, {}).get('oos')) and
        np.sign(result[f'{w}_ic'][h]['is']) == np.sign(result[f'{w}_ic'][h]['oos']) and
        abs(result[f'{w}_ic'][h]['oos']) > 0.02
        for w in ['morning', 'midday', 'afternoon']
        for h in ['1h', '3h', 'eod']
    )
    
    if not consistent:
        result['notes'] = 'FAILED: No consistent IS/OOS signal above |IC| > 0.02'
    
    return result

def print_signal(r):
    print(f"\n  {r['name']}: n={r['n_obs']}, best={r['best_time']}/{r['best_horizon']}, HR={r.get('hit_rate_top_quintile')}")
    for w in ['morning', 'midday', 'afternoon']:
        for h in ['1h', '3h', 'eod']:
            entry = r[f'{w}_ic'].get(h, {})
            ic_is = entry.get('is', '-')
            ic_oos = entry.get('oos', '-')
            if ic_is != '-' or ic_oos != '-':
                ic_is_s = f"{ic_is:>8.4f}" if isinstance(ic_is, float) else f"{'N/A':>8}"
                ic_oos_s = f"{ic_oos:>8.4f}" if isinstance(ic_oos, float) else f"{'N/A':>8}"
                marker = '✓' if (isinstance(ic_is, float) and isinstance(ic_oos, float) and np.sign(ic_is) == np.sign(ic_oos) and abs(ic_oos) > 0.02) else ' '
                print(f"    {marker} {w:10s}/{h:3s}: IS={ic_is_s}, OOS={ic_oos_s}")
    if r['notes']:
        print(f"    → {r['notes']}")

# ============================================================
# Helper: parse TradingView 5-min CSVs
# ============================================================
def load_tv_5min(filename, col_prefix):
    """Load a TradingView 5-min CSV and return with ET timezone."""
    df = pd.read_csv(f'{DATA}/{filename}')
    df['ts'] = pd.to_datetime(df['datetime'])
    if df['ts'].dt.tz is None:
        df['ts'] = df['ts'].dt.tz_localize('US/Eastern')
    else:
        df['ts'] = df['ts'].dt.tz_convert('US/Eastern')
    df = df.sort_values('ts')
    return df[['ts', 'close']].rename(columns={'close': f'{col_prefix}_close'})

# ============================================================
# Helper: merge auxiliary data with ES base
# ============================================================
def merge_aux(es_df, aux_df, on='ts'):
    return pd.merge_asof(
        es_df.sort_values(on), aux_df.sort_values(on),
        on=on, tolerance=pd.Timedelta('5min'), direction='backward'
    )

signals_results = []

# ============================================================
# SIGNAL 1: Intraday GEX
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 1: Intraday GEX")
print("=" * 60)

gex = pd.read_csv(f'{DATA}/intraday_gex_5min.csv')
gex['ts'] = pd.to_datetime(gex['timestamp'], utc=True).dt.tz_convert('US/Eastern')
gex = gex.sort_values('ts')

es_gex = merge_aux(es5.copy(), gex[['ts', 'net_gex', 'gex_quintile', 'regime']])
print(f"GEX merged: {es_gex['net_gex'].notna().sum()} / {len(es_gex)}")

es_gex['sig_gex_quintile'] = es_gex['gex_quintile']
es_gex['sig_gex_change_30m'] = es_gex.groupby('date')['net_gex'].diff(6)
es_gex['sig_gex_regime'] = (es_gex['regime'] == 'POS').astype(float)
es_gex.loc[es_gex['regime'].isna(), 'sig_gex_regime'] = np.nan

for col, name, desc in [
    ('sig_gex_quintile', 'gex_quintile', 'GEX quintile (1-5) at current time'),
    ('sig_gex_change_30m', 'gex_change_30m', 'GEX change over last 30 minutes'),
    ('sig_gex_regime', 'gex_regime', 'Binary: positive vs negative GEX regime'),
]:
    r = evaluate_signal(es_gex, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 2: VIX Changes
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 2: VIX Intraday Changes")
print("=" * 60)

es_vix = es5.copy()
for fname, prefix in [
    ('vix_5min_tv.csv', 'vix'), ('vix9d_5min_tv.csv', 'vix9d'), ('vix1d_5min_tv.csv', 'vix1d')
]:
    aux = load_tv_5min(fname, prefix)
    es_vix = merge_aux(es_vix, aux)
    print(f"  {prefix}: {es_vix[f'{prefix}_close'].notna().sum()} merged")

# VIX 30-min change (inverted — VIX drop = bullish)
es_vix['sig_vix_30m_chg'] = -es_vix.groupby('date')['vix_close'].pct_change(6)

# VIX9D/VIX ratio change
es_vix['vix_ratio'] = es_vix['vix9d_close'] / es_vix['vix_close']
es_vix['sig_vix_ratio_chg'] = es_vix.groupby('date')['vix_ratio'].diff(6)

# VIX1D/VIX ratio change
es_vix['vix1d_ratio'] = es_vix['vix1d_close'] / es_vix['vix_close']
es_vix['sig_vix1d_ratio_chg'] = es_vix.groupby('date')['vix1d_ratio'].diff(6)

# Fast VIX spike contrarian
raw_vix_chg = es_vix.groupby('date')['vix_close'].pct_change(6)
es_vix['sig_vix_spike'] = np.where(raw_vix_chg.abs() > 0.03, -np.sign(raw_vix_chg), 0)

for col, name, desc in [
    ('sig_vix_30m_chg', 'vix_30m_inv', 'Inverted VIX 30-min change (VIX drop → bullish)'),
    ('sig_vix_ratio_chg', 'vix9d_vix_ratio_chg', 'VIX9D/VIX ratio 30-min change'),
    ('sig_vix1d_ratio_chg', 'vix1d_vix_ratio_chg', 'VIX1D/VIX ratio 30-min change'),
    ('sig_vix_spike', 'vix_spike_contrarian', 'Contrarian on VIX >3% 30m moves'),
]:
    r = evaluate_signal(es_vix, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 3: NYSE TICK
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 3: NYSE TICK")
print("=" * 60)

tick_aux = load_tv_5min('tick_nyse_5min_tv.csv', 'tick')
# Also need high/low for extremes
tick_raw = pd.read_csv(f'{DATA}/tick_nyse_5min_tv.csv')
tick_raw['ts'] = pd.to_datetime(tick_raw['datetime'])
if tick_raw['ts'].dt.tz is None:
    tick_raw['ts'] = tick_raw['ts'].dt.tz_localize('US/Eastern')
tick_raw = tick_raw.sort_values('ts')

es_tick = merge_aux(es5.copy(), tick_raw[['ts', 'close', 'high', 'low']].rename(
    columns={'close': 'tick_close', 'high': 'tick_high', 'low': 'tick_low'}
))
print(f"TICK merged: {es_tick['tick_close'].notna().sum()}")

es_tick['sig_tick_level'] = es_tick['tick_close']
es_tick['sig_tick_30m_ma'] = es_tick.groupby('date')['tick_close'].transform(
    lambda x: x.rolling(6, min_periods=1).mean()
)
es_tick['sig_tick_extreme'] = np.where(
    es_tick['tick_low'] < -1000, 1,
    np.where(es_tick['tick_high'] > 1000, -1, 0)
)

for col, name, desc in [
    ('sig_tick_level', 'tick_level', 'NYSE TICK close value'),
    ('sig_tick_30m_ma', 'tick_30m_ma', 'NYSE TICK 30-min moving average'),
    ('sig_tick_extreme', 'tick_extreme', 'Contrarian: buy TICK<-1000, sell TICK>+1000'),
]:
    r = evaluate_signal(es_tick, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 4: Cross-Asset
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 4: Cross-Asset")
print("=" * 60)

es_cross = es5.copy()
for fname, prefix in [
    ('dxy_5min_tv.csv', 'dxy'), ('tlt_5min_tv.csv', 'tlt'),
    ('usdjpy_5min_tv.csv', 'usdjpy'), ('hyg_5min_tv.csv', 'hyg'),
]:
    aux = load_tv_5min(fname, prefix)
    es_cross = merge_aux(es_cross, aux)
    print(f"  {prefix}: {es_cross[f'{prefix}_close'].notna().sum()} merged")

es_cross['sig_dxy_mom'] = -es_cross.groupby('date')['dxy_close'].pct_change(6)
es_cross['sig_tlt_mom'] = es_cross.groupby('date')['tlt_close'].pct_change(6)
es_cross['sig_usdjpy_mom'] = es_cross.groupby('date')['usdjpy_close'].pct_change(6)
es_cross['sig_hyg_mom'] = es_cross.groupby('date')['hyg_close'].pct_change(6)

# TLT/ES divergence
es_cross['es_30m_ret'] = es_cross.groupby('date')['close'].pct_change(6)
es_cross['tlt_30m_ret'] = es_cross.groupby('date')['tlt_close'].pct_change(6)
es_cross['sig_tlt_es_div'] = es_cross['tlt_30m_ret'] - es_cross['es_30m_ret']

for col, name, desc in [
    ('sig_dxy_mom', 'dxy_30m_inv', 'DXY 30m momentum inverted (DXY down = ES up)'),
    ('sig_tlt_mom', 'tlt_30m_mom', 'TLT 30m momentum'),
    ('sig_usdjpy_mom', 'usdjpy_30m_mom', 'USDJPY 30m momentum (risk-on proxy)'),
    ('sig_hyg_mom', 'hyg_30m_mom', 'HYG 30m momentum (credit risk proxy)'),
    ('sig_tlt_es_div', 'tlt_es_divergence', 'TLT outperforming ES 30m (mean-reversion)'),
]:
    r = evaluate_signal(es_cross, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 5: Cumulative Delta (Order Flow)
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 5: Cumulative Delta")
print("=" * 60)

delta_1m = pd.read_csv(f'{DATA}/es_1min_delta_bars.csv')
delta_1m['ts'] = pd.to_datetime(delta_1m['timestamp'], utc=True).dt.tz_convert('US/Eastern')
delta_1m['date'] = delta_1m['ts'].dt.date
delta_1m['time_val'] = delta_1m['ts'].dt.time

delta_rth = delta_1m[(delta_1m['time_val'] >= dtime(9, 30)) & (delta_1m['time_val'] < dtime(16, 0))].copy()
delta_rth = delta_rth.sort_values('ts')
delta_rth['cum_delta'] = delta_rth.groupby('date')['delta'].cumsum()
delta_rth['cum_volume'] = delta_rth.groupby('date')['volume'].cumsum()

print(f"Delta 1min RTH: {len(delta_rth)} bars, {delta_rth['date'].nunique()} days")
print(f"Date range: {delta_rth['date'].min()} to {delta_rth['date'].max()}")

# Resample to 5-min
delta_5m = delta_rth.set_index('ts').groupby('date').resample('5min').agg({
    'delta': 'sum', 'cum_delta': 'last', 'close': 'last', 'volume': 'sum', 'cum_volume': 'last'
}).dropna(subset=['close']).reset_index()
if 'ts' not in delta_5m.columns:
    delta_5m = delta_5m.rename(columns={'level_1': 'ts'})
delta_5m = delta_5m.sort_values('ts')

# Signals
delta_5m['sig_cum_delta_norm'] = delta_5m['cum_delta'] / delta_5m['cum_volume']
delta_5m['sig_delta_30m'] = delta_5m.groupby('date')['delta'].transform(lambda x: x.rolling(6, min_periods=1).sum())
delta_5m['sig_delta_price_corr'] = delta_5m.groupby('date').apply(
    lambda g: g['cum_delta'].rolling(6, min_periods=4).corr(g['close'])
).reset_index(level=0, drop=True)

# Merge with ES
es_delta = es5.copy()
for col in ['sig_cum_delta_norm', 'sig_delta_30m', 'sig_delta_price_corr']:
    sub = delta_5m[['ts', col]].dropna().sort_values('ts')
    es_delta = merge_aux(es_delta, sub)

print(f"Delta merged: cum_delta={es_delta['sig_cum_delta_norm'].notna().sum()}, delta_30m={es_delta['sig_delta_30m'].notna().sum()}")

for col, name, desc in [
    ('sig_cum_delta_norm', 'cum_delta_norm', 'Cumulative delta / cumulative volume (order flow bias)'),
    ('sig_delta_30m', 'delta_30m_sum', '30-min rolling delta sum'),
    ('sig_delta_price_corr', 'delta_price_corr', '30-min rolling correlation of cum_delta vs price (low = divergence)'),
]:
    r = evaluate_signal(es_delta, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 6: Time-of-Day Patterns
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 6: Time-of-Day Patterns")
print("=" * 60)

es_tod = es5.copy()

# First 30-min direction
first_30m = es_tod[es_tod['time'] <= dtime(10, 0)].groupby('date').apply(
    lambda g: g['close'].iloc[-1] / g['close'].iloc[0] - 1 if len(g) > 1 else np.nan
).to_dict()
es_tod['sig_first_30m'] = es_tod['date'].map(first_30m)
es_tod.loc[es_tod['time'] < dtime(10, 5), 'sig_first_30m'] = np.nan

# Morning range position (only after 11am)
morn_ranges = es_tod[es_tod['time'] <= dtime(11, 0)].groupby('date').agg(
    mh=('close', 'max'), ml=('close', 'min')
)
es_tod = es_tod.merge(morn_ranges, on='date', how='left')
rng = es_tod['mh'] - es_tod['ml']
es_tod['sig_morn_range_pos'] = np.where(rng > 0, (es_tod['close'] - es_tod['ml']) / rng, np.nan)
es_tod.loc[es_tod['time'] < dtime(11, 0), 'sig_morn_range_pos'] = np.nan

# Return from open
opens = es_tod.groupby('date')['close'].first().to_dict()
es_tod['sig_ret_from_open'] = es_tod['close'] / es_tod['date'].map(opens) - 1

for col, name, desc in [
    ('sig_first_30m', 'first_30m_dir', 'First 30-min return direction (continuation?)'),
    ('sig_morn_range_pos', 'morning_range_pos', 'Position in morning range (0=low, 1=high) after 11am'),
    ('sig_ret_from_open', 'return_from_open', 'Intraday return from open'),
]:
    r = evaluate_signal(es_tod, col, name, desc)
    signals_results.append(r)
    print_signal(r)

# ============================================================
# SIGNAL 7: HIRO (SPX)
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 7: HIRO (SPX)")
print("=" * 60)

hiro_files = sorted(glob.glob(f'{DATA}/trace_api/running_hiro_*.json'))
print(f"Found {len(hiro_files)} HIRO files")

hiro_recs = []
for f in hiro_files:
    try:
        with open(f) as fh:
            data = json.load(fh)
        for e in data:
            if e.get('symbol') == 'SPX':
                date_str = e.get('day', os.path.basename(f).split('_')[-1].replace('.json', ''))
                hiro_recs.append({
                    'date_str': date_str,
                    'hiro_signal': float(e.get('currentDaySignal', 0)),
                })
    except:
        pass

if hiro_recs:
    hiro_df = pd.DataFrame(hiro_recs)
    hiro_df['date'] = pd.to_datetime(hiro_df['date_str']).dt.date
    hiro_df = hiro_df.groupby('date').last()
    
    es_hiro = es5.copy()
    es_hiro = es_hiro.merge(hiro_df[['hiro_signal']], on='date', how='left')
    es_hiro['sig_hiro'] = es_hiro['hiro_signal']
    es_hiro['sig_hiro_z'] = (es_hiro['sig_hiro'] - es_hiro['sig_hiro'].expanding().mean()) / es_hiro['sig_hiro'].expanding().std()
    
    print(f"HIRO merged: {es_hiro['sig_hiro'].notna().sum()} bars")
    
    for col, name, desc in [
        ('sig_hiro', 'hiro_raw', 'SPX HIRO current-day signal (raw level)'),
        ('sig_hiro_z', 'hiro_zscore', 'SPX HIRO z-score'),
    ]:
        r = evaluate_signal(es_hiro, col, name, desc)
        signals_results.append(r)
        print_signal(r)

# ============================================================
# SIGNAL 8: MM Gamma from TRACE
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 8: MM Gamma (TRACE)")
print("=" * 60)

gamma_files = sorted(glob.glob(f'{DATA}/trace_api/intradayGamma_*.csv'))
print(f"Found {len(gamma_files)} gamma files, loading...")

gamma_summaries = []
for f in gamma_files:
    try:
        gdf = pd.read_csv(f)
        if len(gdf) == 0:
            continue
        gdf['ts'] = pd.to_datetime(gdf['time'], utc=True).dt.tz_convert('US/Eastern')
        
        # Group by snapshot time, compute total mm_gamma and cust_gamma
        summary = gdf.groupby('ts').agg(
            mm_gamma_total=('mm_gamma', 'sum'),
            cust_gamma_total=('cust_gamma', 'sum'),
        ).reset_index()
        gamma_summaries.append(summary)
    except:
        pass

if gamma_summaries:
    gamma_all = pd.concat(gamma_summaries, ignore_index=True)
    gamma_all['date'] = gamma_all['ts'].dt.date
    gamma_all['time_val'] = gamma_all['ts'].dt.time
    gamma_all = gamma_all[(gamma_all['time_val'] >= dtime(9, 30)) & (gamma_all['time_val'] < dtime(16, 0))]
    gamma_all = gamma_all.sort_values('ts')
    
    print(f"Gamma summaries: {len(gamma_all)} snapshots, {gamma_all['date'].nunique()} days")
    
    es_gamma = merge_aux(es5.copy(), gamma_all[['ts', 'mm_gamma_total', 'cust_gamma_total']])
    es_gamma['sig_mm_gamma'] = es_gamma['mm_gamma_total']
    es_gamma['sig_cust_gamma'] = es_gamma['cust_gamma_total']
    
    print(f"Gamma merged: {es_gamma['sig_mm_gamma'].notna().sum()}")
    
    for col, name, desc in [
        ('sig_mm_gamma', 'mm_gamma_total', 'Total MM gamma across all strikes'),
        ('sig_cust_gamma', 'cust_gamma_total', 'Total customer gamma across all strikes'),
    ]:
        r = evaluate_signal(es_gamma, col, name, desc)
        signals_results.append(r)
        print_signal(r)

# ============================================================
# SIGNAL 9: VVIX
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL 9: VVIX")
print("=" * 60)

try:
    vvix_aux = load_tv_5min('vvix_5min_tv.csv', 'vvix')
    es_vvix = merge_aux(es5.copy(), vvix_aux)
    es_vvix['sig_vvix_chg'] = es_vvix.groupby('date')['vvix_close'].pct_change(6)
    es_vvix['sig_vvix_level'] = es_vvix['vvix_close']
    
    print(f"VVIX merged: {es_vvix['vvix_close'].notna().sum()}")
    
    for col, name, desc in [
        ('sig_vvix_chg', 'vvix_30m_change', 'VVIX 30-min pct change'),
        ('sig_vvix_level', 'vvix_level', 'VVIX absolute level'),
    ]:
        r = evaluate_signal(es_vvix, col, name, desc)
        signals_results.append(r)
        print_signal(r)
except Exception as e:
    print(f"VVIX error: {e}")

# ============================================================
# SIGNAL COMBINATIONS
# ============================================================
print("\n" + "=" * 60)
print("SIGNAL COMBINATIONS")
print("=" * 60)

# Build combined dataframe
all_sigs = es5.copy()

# Merge all signals we've computed
sig_frames = {
    'gex': (es_gex, ['sig_gex_quintile', 'sig_gex_change_30m']),
    'vix': (es_vix, ['sig_vix_30m_chg', 'sig_vix_ratio_chg', 'sig_vix1d_ratio_chg']),
    'tick': (es_tick, ['sig_tick_level', 'sig_tick_30m_ma']),
    'cross': (es_cross, ['sig_dxy_mom', 'sig_hyg_mom', 'sig_tlt_mom', 'sig_usdjpy_mom', 'sig_tlt_es_div']),
    'delta': (es_delta, ['sig_cum_delta_norm', 'sig_delta_30m']),
    'tod': (es_tod, ['sig_first_30m', 'sig_ret_from_open']),
}

for key, (df, cols) in sig_frames.items():
    for c in cols:
        if c in df.columns and len(df) == len(all_sigs):
            all_sigs[c] = df[c].values

print(f"Combined frame: {len(all_sigs)} bars, signals: {[c for c in all_sigs.columns if c.startswith('sig_')]}")

def zscore(s):
    m = s.notna()
    out = s.copy()
    out[m] = (s[m] - s[m].mean()) / s[m].std()
    return out

combinations_results = []

combos = [
    ('tick_plus_vix', ['sig_tick_30m_ma', 'sig_vix_30m_chg']),
    ('gex_plus_delta', ['sig_gex_quintile', 'sig_cum_delta_norm']),
    ('tick_plus_hyg', ['sig_tick_30m_ma', 'sig_hyg_mom']),
    ('vix_plus_dxy', ['sig_vix_30m_chg', 'sig_dxy_mom']),
    ('tick_vix_dxy', ['sig_tick_30m_ma', 'sig_vix_30m_chg', 'sig_dxy_mom']),
    ('delta_plus_vix', ['sig_delta_30m', 'sig_vix_30m_chg']),
    ('open_ret_plus_tick', ['sig_ret_from_open', 'sig_tick_30m_ma']),
]

for combo_name, cols in combos:
    valid = all(c in all_sigs.columns for c in cols)
    if not valid:
        print(f"  Skipping {combo_name}: missing columns")
        continue
    
    combo = None
    for c in cols:
        z = zscore(all_sigs[c])
        combo = z if combo is None else combo + z
    
    all_sigs[f'combo_{combo_name}'] = combo
    
    r = evaluate_signal(all_sigs, f'combo_{combo_name}', f'combo_{combo_name}', f'Z-score sum: {" + ".join(cols)}')
    combinations_results.append(r)
    print_signal(r)

# ============================================================
# QUINTILE ANALYSIS for top signals
# ============================================================
print("\n" + "=" * 60)
print("QUINTILE ANALYSIS")
print("=" * 60)

passed_signals = [s for s in signals_results + combinations_results 
                  if 'FAILED' not in s.get('notes', '') and s['best_time'] is not None]
print(f"Signals with IS/OOS consistency: {len(passed_signals)}")

quintile_results = {}

# Build lookup of signal columns to dataframes
all_dfs = {
    'es_gex': es_gex, 'es_vix': es_vix, 'es_tick': es_tick,
    'es_cross': es_cross, 'es_delta': es_delta, 'es_tod': es_tod,
    'all_sigs': all_sigs,
}
if 'es_hiro' in dir():
    all_dfs['es_hiro'] = es_hiro
if 'es_gamma' in dir():
    all_dfs['es_gamma'] = es_gamma
if 'es_vvix' in dir():
    all_dfs['es_vvix'] = es_vvix

for s in passed_signals:
    name = s['name']
    bt, bh = s['best_time'], s['best_horizon']
    
    # Find the column in any dataframe
    for col_try in [name, f'sig_{name}', f'combo_{name}']:
        found = False
        for df_name, df in all_dfs.items():
            if col_try in df.columns and 'fwd_1h' in df.columns:
                mask = df[col_try].notna() & df[f'fwd_{bh}'].notna() & (df['window'] == bt) & (df['is_oos'] == 'OOS')
                sub = df[mask].copy()
                if len(sub) < 50:
                    continue
                try:
                    sub['q'] = pd.qcut(sub[col_try], 5, labels=[1,2,3,4,5], duplicates='drop')
                    qa = sub.groupby('q')[f'fwd_{bh}'].agg(['mean', 'std', 'count']).reset_index()
                    qa['mean_bps'] = qa['mean'] * 10000
                    quintile_results[name] = qa.to_dict('records')
                    print(f"\n{name} ({bt}/{bh}):")
                    for _, row in qa.iterrows():
                        print(f"  Q{row['q']}: {row['mean_bps']:+.1f} bps  (n={int(row['count'])})")
                    found = True
                except:
                    pass
                break
        if found:
            break

# ============================================================
# SAVE RESULTS
# ============================================================
print("\n" + "=" * 60)
print("SAVING")
print("=" * 60)

def clean(obj):
    if isinstance(obj, dict):
        return {k: clean(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [clean(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj) if not np.isnan(obj) else None
    elif isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, (pd.Timestamp, datetime)):
        return str(obj)
    return obj

# Build summary text
passed = [s for s in signals_results + combinations_results if 'FAILED' not in s.get('notes', '') and s['best_time'] is not None]
failed = [s for s in signals_results + combinations_results if 'FAILED' in s.get('notes', '')]

summary = f"Tested {len(signals_results)} signals + {len(combinations_results)} combos. "
summary += f"{len(passed)} passed IS/OOS consistency, {len(failed)} failed.\n\n"
summary += "TOP SIGNALS:\n"
for s in passed:
    bt, bh = s['best_time'], s['best_horizon']
    ic_oos = s.get(f'{bt}_ic', {}).get(bh, {}).get('oos', '?')
    summary += f"  {s['name']}: {bt}/{bh}, OOS IC={ic_oos}, HR={s.get('hit_rate_top_quintile')}\n"

output = {
    'signals': clean(signals_results),
    'combinations': clean(combinations_results),
    'quintile_analysis': clean(quintile_results),
    'metadata': {
        'n_bars': len(es5),
        'n_days': n_dates,
        'date_range': f"{min(all_dates)} to {max(all_dates)}",
        'is_range': f"{min(is_dates)} to {max(is_dates)} ({len(is_dates)} days)",
        'oos_range': f"{min(oos_dates)} to {max(oos_dates)} ({len(oos_dates)} days)",
    },
    'summary': summary,
}

with open(f'{DATA}/intraday_signal_research.json', 'w') as f:
    json.dump(output, f, indent=2)
print(f"Saved JSON: {DATA}/intraday_signal_research.json")

# Markdown report
md = []
md.append("# Intraday Signal Research — ES Futures")
md.append(f"\n*Generated {datetime.now().strftime('%Y-%m-%d %H:%M')}*\n")
md.append(f"## Data")
md.append(f"- {n_dates} trading days: {min(all_dates)} → {max(all_dates)}")
md.append(f"- IS: {min(is_dates)} → {max(is_dates)} ({len(is_dates)} days)")
md.append(f"- OOS: {min(oos_dates)} → {max(oos_dates)} ({len(oos_dates)} days)")
md.append(f"- {len(es5)} 5-min RTH bars\n")
md.append(f"## Results: {len(passed)} passed, {len(failed)} failed\n")

for s in sorted(passed, key=lambda x: abs(x.get(f"{x.get('best_time','')}_ic",{}).get(x.get('best_horizon',''),{}).get('oos',0) or 0), reverse=True):
    bt, bh = s['best_time'], s['best_horizon']
    ic_is = s.get(f'{bt}_ic',{}).get(bh,{}).get('is','?')
    ic_oos = s.get(f'{bt}_ic',{}).get(bh,{}).get('oos','?')
    hr = s.get('hit_rate_top_quintile','?')
    
    md.append(f"### {s['name']}")
    md.append(f"**{s['description']}**\n")
    md.append(f"Best: **{bt} / {bh}** | IS IC: {ic_is} | OOS IC: {ic_oos} | Hit Rate (Q5): {hr}\n")
    
    md.append("| Window | Horizon | IS IC | OOS IC | n_IS | n_OOS |")
    md.append("|--------|---------|-------|--------|------|-------|")
    for w in ['morning', 'midday', 'afternoon']:
        for h in ['1h', '3h', 'eod']:
            e = s.get(f'{w}_ic',{}).get(h,{})
            consistent = ''
            if e.get('is') and e.get('oos') and isinstance(e['is'], float) and isinstance(e['oos'], float):
                if np.sign(e['is']) == np.sign(e['oos']) and abs(e['oos']) > 0.02:
                    consistent = ' ✓'
            md.append(f"| {w} | {h} | {e.get('is','-')} | {e.get('oos','-')}{consistent} | {e.get('n_is','-')} | {e.get('n_oos','-')} |")
    md.append("")
    
    if s['name'] in quintile_results:
        md.append("**Quintile returns (OOS, bps):**\n")
        md.append("| Q | Mean bps | n |")
        md.append("|---|---------|---|")
        for q in quintile_results[s['name']]:
            md.append(f"| {q.get('q','?')} | {q.get('mean_bps',0):.1f} | {q.get('count','?')} |")
        md.append("")

md.append("## Combinations\n")
for s in combinations_results:
    bt, bh = s.get('best_time','?'), s.get('best_horizon','?')
    ic_oos = s.get(f'{bt}_ic',{}).get(bh,{}).get('oos','?') if bt != '?' else '?'
    status = "✅" if 'FAILED' not in s.get('notes','') else "❌"
    md.append(f"- {status} **{s['name']}**: {bt}/{bh}, OOS IC={ic_oos}")
    md.append(f"  {s['description']}")

md.append("\n## Failed Signals\n")
for s in failed:
    md.append(f"- **{s['name']}**: {s['description']}")

with open(f'{DATA}/intraday_signal_research.md', 'w') as f:
    f.write('\n'.join(md))
print(f"Saved MD: {DATA}/intraday_signal_research.md")

print("\n" + "=" * 60)
print("COMPLETE")
print("=" * 60)
print(summary)
