#!/usr/bin/env python3
"""
TRACE Intraday Signal Analysis
Tests 7 directional signals using SpotGamma TRACE data for ES/SPX futures.
"""

import pandas as pd
import numpy as np
import json
import os
import sys
import warnings
from pathlib import Path
from scipy import stats
from datetime import timedelta
warnings.filterwarnings('ignore')

DATA_DIR = Path('/Users/daniel/.openclaw/workspace/data')
TRACE_DIR = DATA_DIR / 'trace_api'
OUTPUT_JSON = DATA_DIR / 'trace_intraday_signals.json'
OUTPUT_MD = DATA_DIR / 'trace_intraday_signals.md'

# ============================================================
# STEP 1: Load ES price data (continuous front-month)
# ============================================================
print("Loading ES 1-min bars...")
es_raw = pd.read_csv(DATA_DIR / 'es_1min_bars.csv', usecols=['ts_event', 'open', 'high', 'low', 'close', 'volume', 'symbol'])
# Filter to front-month contracts only (no spreads)
es_raw = es_raw[~es_raw['symbol'].str.contains('-')]
es_raw['ts'] = pd.to_datetime(es_raw['ts_event'])
es_raw = es_raw.sort_values('ts')

# Build continuous series: for each timestamp, use the contract with highest volume that day
es_raw['date'] = es_raw['ts'].dt.date
daily_vol = es_raw.groupby(['date', 'symbol'])['volume'].sum().reset_index()
front_month = daily_vol.loc[daily_vol.groupby('date')['volume'].idxmax()][['date', 'symbol']]
front_month = front_month.rename(columns={'symbol': 'front_symbol'})
es_raw['date_key'] = es_raw['date']
es = es_raw.merge(front_month, left_on='date_key', right_on='date', suffixes=('', '_fm'))
es = es[es['symbol'] == es['front_symbol']][['ts', 'open', 'high', 'low', 'close', 'volume']].copy()
es = es.sort_values('ts').reset_index(drop=True)
es = es.set_index('ts')
# Resample to ensure no gaps in 1-min grid
es = es[~es.index.duplicated(keep='first')]
print(f"  ES bars: {len(es)} rows, {es.index.min()} to {es.index.max()}")

# ============================================================
# STEP 2: Get available TRACE dates
# ============================================================
gamma_files = sorted(TRACE_DIR.glob('intradayGamma_*.parquet'))
trace_dates = sorted(set(f.stem.split('_')[1] for f in gamma_files))
print(f"  TRACE dates: {len(trace_dates)} days, {trace_dates[0]} to {trace_dates[-1]}")

# IS/OOS split: first 120 days IS, last 60 OOS
IS_DATES = set(trace_dates[:120])
OOS_DATES = set(trace_dates[120:])
print(f"  IS: {len(IS_DATES)} days, OOS: {len(OOS_DATES)} days")

# ============================================================
# STEP 3: Helper functions
# ============================================================

def get_es_price_at(ts, tolerance_minutes=3):
    """Get ES close price closest to given timestamp."""
    if ts in es.index:
        return es.loc[ts, 'close']
    # Find nearest
    idx = es.index.get_indexer([ts], method='nearest')[0]
    if idx < 0 or idx >= len(es):
        return None
    nearest_ts = es.index[idx]
    if abs((nearest_ts - ts).total_seconds()) > tolerance_minutes * 60:
        return None
    return es.iloc[idx]['close']

def get_forward_return(ts, horizon_minutes):
    """Get forward return from ts to ts + horizon_minutes."""
    target_ts = ts + timedelta(minutes=horizon_minutes)
    p0 = get_es_price_at(ts, tolerance_minutes=3)
    p1 = get_es_price_at(target_ts, tolerance_minutes=5)
    if p0 is None or p1 is None or p0 == 0:
        return None
    return (p1 - p0) / p0

def get_eod_return(ts, eod_hour=16, eod_tz='US/Eastern'):
    """Get return from ts to end of day (4pm ET)."""
    # EOD is 4pm ET on same date
    dt_et = ts
    if hasattr(ts, 'tz') and ts.tz is not None:
        dt_et = ts.tz_convert('US/Eastern')
    eod = dt_et.normalize() + timedelta(hours=16)
    if hasattr(eod, 'tz_localize'):
        pass  # already tz-aware
    p0 = get_es_price_at(ts, tolerance_minutes=3)
    p1 = get_es_price_at(eod, tolerance_minutes=10)
    if p0 is None or p1 is None or p0 == 0:
        return None
    return (p1 - p0) / p0

def time_bucket(ts):
    """Classify timestamp into morning/midday/afternoon (ET)."""
    if hasattr(ts, 'tz') and ts.tz is not None:
        h = ts.hour
        m = ts.minute
    else:
        h = ts.hour
        m = ts.minute
    hm = h * 60 + m
    if hm < 11 * 60:  # before 11:00
        return 'morning'
    elif hm < 13 * 60:  # before 13:00
        return 'midday'
    else:
        return 'afternoon'

def spearman_ic(signal, returns):
    """Compute Spearman IC with t-stat."""
    mask = np.isfinite(signal) & np.isfinite(returns)
    if mask.sum() < 20:
        return None, None, 0
    ic, pval = stats.spearmanr(signal[mask], returns[mask])
    n = mask.sum()
    return ic, pval, n

# ============================================================
# STEP 4: Process each day - build signal matrix
# ============================================================

print("\nProcessing TRACE days...")

# Pre-vectorize ES price lookup
print("  Building ES price lookup...")
es_et = es.copy()
if es_et.index.tz is None:
    es_et.index = es_et.index.tz_localize('UTC')
es_et.index = es_et.index.tz_convert('US/Eastern')
es_prices = es_et['close']

def fast_es_price(ts_et, tol_min=3):
    """Fast ES price lookup using pre-indexed data."""
    try:
        idx = es_prices.index.get_indexer([ts_et], method='nearest')[0]
        if idx < 0 or idx >= len(es_prices):
            return None
        nearest = es_prices.index[idx]
        if abs((nearest - ts_et).total_seconds()) > tol_min * 60:
            return None
        return es_prices.iloc[idx]
    except:
        return None

def fast_forward_price(ts_et, horizon_min):
    """Get ES price at ts + horizon."""
    target = ts_et + timedelta(minutes=horizon_min)
    return fast_es_price(target, tol_min=5)

all_signals = []
processed_days = 0
skipped_days = 0

for date_str in trace_dates:
    gamma_file = TRACE_DIR / f'intradayGamma_{date_str}.parquet'
    gex_file = TRACE_DIR / f'intradayStrikeGEX_{date_str}.parquet'
    delta_file = TRACE_DIR / f'intradayDelta_{date_str}.parquet'
    
    if not gamma_file.exists() or not gex_file.exists() or not delta_file.exists():
        skipped_days += 1
        continue
    
    try:
        gamma_df = pd.read_parquet(gamma_file)
        gex_df = pd.read_parquet(gex_file)
        delta_df = pd.read_parquet(delta_file)
    except Exception as e:
        skipped_days += 1
        continue
    
    # Convert timestamps to ET
    for df in [gamma_df, gex_df, delta_df]:
        if df['timestamp'].dtype == 'object':
            df['timestamp'] = pd.to_datetime(df['timestamp'])
        if df['timestamp'].dt.tz is None:
            df['timestamp'] = df['timestamp'].dt.tz_localize('US/Eastern')
        else:
            df['timestamp'] = df['timestamp'].dt.tz_convert('US/Eastern')
    
    # Get unique timestamps during RTH (9:30-16:00 ET)
    timestamps = sorted(gamma_df['timestamp'].unique())
    rth_timestamps = [t for t in timestamps 
                      if pd.Timestamp(t).hour * 60 + pd.Timestamp(t).minute >= 9*60+30 
                      and pd.Timestamp(t).hour * 60 + pd.Timestamp(t).minute <= 15*60+30]  # leave room for 3h forward
    
    if len(rth_timestamps) == 0:
        skipped_days += 1
        continue
    
    is_day = date_str in IS_DATES
    
    for ts in rth_timestamps:
        ts_pd = pd.Timestamp(ts)
        
        # Get current ES price
        es_price = fast_es_price(ts_pd)
        if es_price is None:
            continue
        
        # Forward returns
        fwd_1h = fast_forward_price(ts_pd, 60)
        fwd_3h = fast_forward_price(ts_pd, 180)
        eod_price = fast_es_price(ts_pd.normalize() + timedelta(hours=15, minutes=55), tol_min=10)
        
        if fwd_1h is None and fwd_3h is None:
            continue
        
        ret_1h = (fwd_1h - es_price) / es_price if fwd_1h else None
        ret_3h = (fwd_3h - es_price) / es_price if fwd_3h else None
        ret_eod = (eod_price - es_price) / es_price if eod_price else None
        
        bucket = time_bucket(ts_pd)
        
        row = {
            'date': date_str,
            'timestamp': str(ts),
            'es_price': es_price,
            'ret_1h': ret_1h,
            'ret_3h': ret_3h,
            'ret_eod': ret_eod,
            'bucket': bucket,
            'is_sample': is_day,
        }
        
        # ---- SIGNAL 1: Call vs Put GEX Imbalance ----
        gex_snap = gex_df[gex_df['timestamp'] == ts]
        if len(gex_snap) > 0:
            # Total GEX per participant summed across strikes
            total_gex = gex_snap[['bd_gamma', 'cust_gamma', 'firm_gamma', 'mm_gamma', 'procust_gamma']].sum()
            net_gex = total_gex.sum()
            # Call vs put: positive GEX = call-dominated, negative = put-dominated
            # Each column has both calls and puts mixed. We need to separate.
            # Actually in StrikeGEX, the gamma values are net (call + put) at each strike.
            # Positive values = call gamma dominates, negative = put gamma dominates
            strike_gex = gex_snap.copy()
            strike_gex['total_gex'] = strike_gex[['bd_gamma', 'cust_gamma', 'firm_gamma', 'mm_gamma', 'procust_gamma']].sum(axis=1)
            call_gex = strike_gex.loc[strike_gex['total_gex'] > 0, 'total_gex'].sum()
            put_gex = strike_gex.loc[strike_gex['total_gex'] < 0, 'total_gex'].sum()
            
            row['sig1_net_gex'] = net_gex
            row['sig1_call_put_ratio'] = call_gex / abs(put_gex) if put_gex != 0 else 0
            row['sig1_call_dominance'] = 1 if call_gex > abs(put_gex) else -1
        
        # ---- SIGNAL 2: Negative GEX Cluster Magnet ----
        if len(gex_snap) > 0:
            strike_gex = gex_snap.copy()
            strike_gex['total_gex'] = strike_gex[['bd_gamma', 'cust_gamma', 'firm_gamma', 'mm_gamma', 'procust_gamma']].sum(axis=1)
            # Find strike with most negative GEX
            min_idx = strike_gex['total_gex'].idxmin()
            neg_gex_strike = strike_gex.loc[min_idx, 'strike_price']
            neg_gex_value = strike_gex.loc[min_idx, 'total_gex']
            
            # Direction: is neg GEX cluster above or below price?
            row['sig2_neg_cluster_direction'] = 1 if neg_gex_strike > es_price else -1  # 1=above, -1=below
            row['sig2_neg_cluster_distance'] = (neg_gex_strike - es_price) / es_price
            row['sig2_neg_cluster_magnitude'] = neg_gex_value
        
        # ---- SIGNAL 3: Smart Money (Firm + ProCust) Gamma & Delta ----
        gamma_snap = gamma_df[gamma_df['timestamp'] == ts]
        delta_snap = delta_df[delta_df['timestamp'] == ts]
        
        if len(gamma_snap) > 0:
            # Find spot level closest to ES price
            spots = gamma_snap['spot'].values
            closest_idx = np.argmin(np.abs(spots - es_price))
            closest_row = gamma_snap.iloc[closest_idx]
            
            row['sig3_smart_gamma'] = closest_row['firm_gamma'] + closest_row['procust_gamma']
            row['sig3_firm_gamma'] = closest_row['firm_gamma']
            row['sig3_procust_gamma'] = closest_row['procust_gamma']
            row['sig3_cust_gamma'] = closest_row['cust_gamma']
            row['sig3_mm_gamma'] = closest_row['mm_gamma']
            row['sig3_bd_gamma'] = closest_row['bd_gamma']
        
        if len(delta_snap) > 0:
            spots = delta_snap['spot'].values
            closest_idx = np.argmin(np.abs(spots - es_price))
            closest_row = delta_snap.iloc[closest_idx]
            
            row['sig3_smart_delta'] = closest_row['firm_delta'] + closest_row['procust_delta']
            row['sig3_firm_delta'] = closest_row['firm_delta']
            row['sig3_procust_delta'] = closest_row['procust_delta']
            row['sig3_cust_delta'] = closest_row['cust_delta']
            row['sig3_mm_delta'] = closest_row['mm_delta']
            row['sig3_bd_delta'] = closest_row['bd_delta']
        
        # ---- SIGNAL 5: Gamma Asymmetry by Participant ----
        if len(gamma_snap) > 0:
            above = gamma_snap[gamma_snap['spot'] > es_price]
            below = gamma_snap[gamma_snap['spot'] < es_price]
            
            for part in ['bd_gamma', 'cust_gamma', 'firm_gamma', 'mm_gamma', 'procust_gamma']:
                above_sum = above[part].sum()
                below_sum = below[part].sum()
                denom = abs(above_sum) + abs(below_sum)
                asym = (above_sum - below_sum) / denom if denom > 0 else 0
                row[f'sig5_{part}_asym'] = asym
        
        # ---- SIGNAL 6: 0DTE vs Non-0DTE Gamma ----
        if len(gex_snap) > 0 and 'mm_gamma_0' in gex_snap.columns:
            mm_total = gex_snap['mm_gamma'].sum()
            mm_0dte = gex_snap['mm_gamma_0'].sum()
            mm_non0dte = mm_total - mm_0dte
            
            row['sig6_0dte_ratio'] = mm_0dte / mm_total if mm_total != 0 else 0
            row['sig6_0dte_gamma'] = mm_0dte
            row['sig6_non0dte_gamma'] = mm_non0dte
            row['sig6_0dte_divergence'] = mm_0dte - mm_non0dte
            
            # 0DTE concentration: find strike with largest absolute 0DTE gamma
            gex_snap_copy = gex_snap.copy()
            gex_snap_copy['total_0dte'] = gex_snap_copy[['bd_gamma_0', 'cust_gamma_0', 'firm_gamma_0', 'mm_gamma_0', 'procust_gamma_0']].sum(axis=1)
            max_0dte_idx = gex_snap_copy['total_0dte'].abs().idxmax()
            max_0dte_strike = gex_snap_copy.loc[max_0dte_idx, 'strike_price']
            row['sig6_0dte_magnet_dir'] = 1 if max_0dte_strike > es_price else -1
            row['sig6_0dte_magnet_dist'] = (max_0dte_strike - es_price) / es_price
        
        all_signals.append(row)
    
    processed_days += 1
    if processed_days % 20 == 0:
        print(f"  Processed {processed_days}/{len(trace_dates)} days, {len(all_signals)} observations so far")

print(f"  Done: {processed_days} days processed, {skipped_days} skipped, {len(all_signals)} total observations")

# ============================================================
# STEP 5: Compute delta flow changes (Signal 7) 
# ============================================================
print("\nComputing delta flow changes (Signal 4 changes + Signal 7)...")

df = pd.DataFrame(all_signals)
df['ts_parsed'] = pd.to_datetime(df['timestamp'])

# Sort and compute 30-min changes within each day
for col_base in ['sig3_firm_gamma', 'sig3_procust_gamma', 'sig3_cust_gamma', 'sig3_mm_gamma', 'sig3_bd_gamma',
                  'sig3_firm_delta', 'sig3_procust_delta', 'sig3_cust_delta', 'sig3_mm_delta', 'sig3_bd_delta']:
    if col_base in df.columns:
        # Within-day differences (approx 30 min = ~6 observations at 5-min intervals)
        df[f'{col_base}_chg'] = df.groupby('date')[col_base].diff(6)

# Smart money delta change
if 'sig3_smart_delta' in df.columns:
    df['sig7_smart_delta_chg'] = df.groupby('date')['sig3_smart_delta'].diff(6)
if 'sig3_cust_delta' in df.columns:
    df['sig7_cust_delta_chg'] = df.groupby('date')['sig3_cust_delta'].diff(6)
if 'sig3_firm_delta' in df.columns:
    df['sig7_firm_delta_chg'] = df.groupby('date')['sig3_firm_delta'].diff(6)
if 'sig3_procust_delta' in df.columns:
    df['sig7_procust_delta_chg'] = df.groupby('date')['sig3_procust_delta'].diff(6)

print(f"  DataFrame shape: {df.shape}")
print(f"  IS obs: {df['is_sample'].sum()}, OOS obs: {(~df['is_sample']).sum()}")

# ============================================================
# STEP 6: Compute IC metrics for all signals
# ============================================================
print("\nComputing signal metrics...")

def compute_signal_metrics(df, signal_col, return_cols=['ret_1h', 'ret_3h', 'ret_eod']):
    """Compute IC metrics for a signal across time buckets and IS/OOS."""
    results = {}
    
    for ret_col in return_cols:
        horizon = ret_col.replace('ret_', '')
        results[horizon] = {}
        
        for bucket in ['morning', 'midday', 'afternoon']:
            results[horizon][bucket] = {}
            
            for sample, label in [(True, 'is'), (False, 'oos')]:
                mask = (df['bucket'] == bucket) & (df['is_sample'] == sample)
                sub = df.loc[mask]
                
                if signal_col not in sub.columns or ret_col not in sub.columns:
                    results[horizon][bucket][label] = None
                    continue
                
                sig = sub[signal_col].values.astype(float)
                ret = sub[ret_col].values.astype(float)
                
                ic, pval, n = spearman_ic(sig, ret)
                results[horizon][bucket][label] = round(ic, 4) if ic is not None else None
                results[horizon][bucket][f'{label}_n'] = int(n)
                results[horizon][bucket][f'{label}_pval'] = round(pval, 4) if pval is not None else None
    
    # Overall OOS metrics
    oos = df[~df['is_sample']]
    if signal_col in oos.columns and 'ret_1h' in oos.columns:
        sig = oos[signal_col].values.astype(float)
        ret = oos['ret_1h'].values.astype(float)
        valid = np.isfinite(sig) & np.isfinite(ret)
        if valid.sum() > 50:
            # Quintile analysis
            valid_df = oos.loc[valid.values if hasattr(valid, 'values') else valid, [signal_col, 'ret_1h']].dropna()
            if len(valid_df) > 50:
                valid_df['q'] = pd.qcut(valid_df[signal_col], 5, labels=False, duplicates='drop')
                q_means = valid_df.groupby('q')['ret_1h'].mean()
                quintile_spread = (q_means.iloc[-1] - q_means.iloc[0]) if len(q_means) >= 2 else 0
                hit_q5 = (valid_df[valid_df['q'] == valid_df['q'].max()]['ret_1h'] > 0).mean() if len(valid_df[valid_df['q'] == valid_df['q'].max()]) > 0 else 0
            else:
                quintile_spread = 0
                hit_q5 = 0
        else:
            quintile_spread = 0
            hit_q5 = 0
    else:
        quintile_spread = 0
        hit_q5 = 0
    
    # Check IS/OOS consistency
    consistent = True
    for horizon in ['1h', '3h', 'eod']:
        for bucket in ['morning', 'midday', 'afternoon']:
            is_ic = results.get(horizon, {}).get(bucket, {}).get('is')
            oos_ic = results.get(horizon, {}).get(bucket, {}).get('oos')
            if is_ic is not None and oos_ic is not None:
                if (is_ic > 0.02 and oos_ic < -0.02) or (is_ic < -0.02 and oos_ic > 0.02):
                    consistent = False
    
    n_obs = len(df[signal_col].dropna()) if signal_col in df.columns else 0
    
    return results, quintile_spread, hit_q5, n_obs, consistent

# Define all signals to test
signal_defs = [
    # Signal 1: Call vs Put GEX
    ('sig1_net_gex', 'Signal 1a: Net GEX (call+put sum)', 'Positive = call-dominated GEX, negative = put-dominated'),
    ('sig1_call_put_ratio', 'Signal 1b: Call/Put GEX Ratio', 'Ratio of positive GEX strikes to negative GEX strikes'),
    
    # Signal 2: Negative GEX Cluster
    ('sig2_neg_cluster_direction', 'Signal 2a: Neg GEX Cluster Direction', '1 if largest neg GEX is above price, -1 if below'),
    ('sig2_neg_cluster_distance', 'Signal 2b: Neg GEX Cluster Distance', 'Signed distance from price to largest neg GEX cluster'),
    
    # Signal 3: Smart Money
    ('sig3_smart_gamma', 'Signal 3a: Smart Money Gamma (Firm+ProCust)', 'Combined firm + pro-customer gamma at current spot'),
    ('sig3_smart_delta', 'Signal 3b: Smart Money Delta (Firm+ProCust)', 'Combined firm + pro-customer delta at current spot'),
    ('sig3_firm_gamma', 'Signal 3c: Firm Gamma', 'Prop firm gamma at current spot'),
    ('sig3_procust_gamma', 'Signal 3d: ProCust Gamma', 'Professional customer gamma at current spot'),
    ('sig3_cust_gamma', 'Signal 3e: Retail Customer Gamma', 'Retail gamma at current spot (test as contrarian)'),
    
    # Signal 4: Per-Participant Gamma/Delta
    ('sig3_mm_gamma', 'Signal 4a: MM Gamma at Spot', 'Market maker gamma at current spot level'),
    ('sig3_bd_gamma', 'Signal 4b: Broker-Dealer Gamma at Spot', 'BD gamma at current spot'),
    ('sig3_mm_delta', 'Signal 4c: MM Delta at Spot', 'Market maker delta at current spot'),
    ('sig3_firm_delta', 'Signal 4d: Firm Delta at Spot', 'Prop firm delta at current spot'),
    ('sig3_procust_delta', 'Signal 4e: ProCust Delta at Spot', 'Pro-customer delta at current spot'),
    ('sig3_cust_delta', 'Signal 4f: Retail Customer Delta at Spot', 'Retail delta at current spot'),
    ('sig3_bd_delta', 'Signal 4g: BD Delta at Spot', 'Broker-dealer delta at current spot'),
    
    # Signal 4 changes (30-min momentum)
    ('sig3_firm_gamma_chg', 'Signal 4h: Firm Gamma 30m Change', '30-min change in firm gamma at spot'),
    ('sig3_procust_gamma_chg', 'Signal 4i: ProCust Gamma 30m Change', '30-min change in procust gamma'),
    ('sig3_mm_gamma_chg', 'Signal 4j: MM Gamma 30m Change', '30-min change in MM gamma at spot'),
    ('sig3_cust_gamma_chg', 'Signal 4k: Retail Gamma 30m Change', '30-min change in retail gamma'),
    
    # Signal 5: Gamma Asymmetry
    ('sig5_mm_gamma_asym', 'Signal 5a: MM Gamma Asymmetry', 'MM gamma above vs below spot. Positive = more gamma above'),
    ('sig5_firm_gamma_asym', 'Signal 5b: Firm Gamma Asymmetry', 'Firm gamma above vs below spot'),
    ('sig5_procust_gamma_asym', 'Signal 5c: ProCust Gamma Asymmetry', 'ProCust gamma above vs below spot'),
    ('sig5_cust_gamma_asym', 'Signal 5d: Retail Gamma Asymmetry', 'Retail gamma above vs below spot'),
    ('sig5_bd_gamma_asym', 'Signal 5e: BD Gamma Asymmetry', 'BD gamma above vs below spot'),
    
    # Signal 6: 0DTE
    ('sig6_0dte_ratio', 'Signal 6a: 0DTE MM Gamma Ratio', '0DTE share of total MM gamma'),
    ('sig6_0dte_divergence', 'Signal 6b: 0DTE vs Non-0DTE Divergence', '0DTE gamma minus non-0DTE gamma'),
    ('sig6_0dte_magnet_dir', 'Signal 6c: 0DTE Magnet Direction', 'Direction to largest 0DTE gamma concentration'),
    ('sig6_0dte_magnet_dist', 'Signal 6d: 0DTE Magnet Distance', 'Signed distance to largest 0DTE gamma strike'),
    
    # Signal 7: Delta Flow
    ('sig7_smart_delta_chg', 'Signal 7a: Smart Money Delta Flow', '30-min change in firm+procust delta'),
    ('sig7_cust_delta_chg', 'Signal 7b: Retail Delta Flow', '30-min change in retail delta'),
    ('sig7_firm_delta_chg', 'Signal 7c: Firm Delta Flow', '30-min change in firm delta'),
    ('sig7_procust_delta_chg', 'Signal 7d: ProCust Delta Flow', '30-min change in procust delta'),
]

results_list = []
for sig_col, sig_name, sig_desc in signal_defs:
    if sig_col not in df.columns:
        print(f"  SKIP {sig_name} - column not found")
        continue
    
    metrics, q_spread, hit_q5, n_obs, consistent = compute_signal_metrics(df, sig_col)
    
    result = {
        'name': sig_name,
        'column': sig_col,
        'description': sig_desc,
        'ic_1h': metrics.get('1h', {}),
        'ic_3h': metrics.get('3h', {}),
        'ic_eod': metrics.get('eod', {}),
        'quintile_spread_oos': round(q_spread * 10000, 2),  # in bps
        'hit_rate_q5_oos': round(hit_q5, 4),
        'n_obs': n_obs,
        'is_oos_consistent': consistent,
    }
    results_list.append(result)
    
    # Print summary
    oos_1h_ics = [metrics.get('1h', {}).get(b, {}).get('oos') for b in ['morning', 'midday', 'afternoon']]
    oos_1h_ics = [x for x in oos_1h_ics if x is not None]
    avg_oos_ic = np.mean(oos_1h_ics) if oos_1h_ics else 0
    status = "✓" if consistent else "✗"
    print(f"  {status} {sig_name}: OOS 1H IC avg={avg_oos_ic:.4f}, Q-spread={q_spread*10000:.1f}bps, n={n_obs}")

# ============================================================
# STEP 7: Participant Ranking
# ============================================================
print("\nParticipant Rankings...")

participant_signals = {
    'mm': ('sig3_mm_gamma', 'sig3_mm_delta'),
    'firm': ('sig3_firm_gamma', 'sig3_firm_delta'),
    'procust': ('sig3_procust_gamma', 'sig3_procust_delta'),
    'cust': ('sig3_cust_gamma', 'sig3_cust_delta'),
    'bd': ('sig3_bd_gamma', 'sig3_bd_delta'),
}

participant_ranking = {}
for horizon in ['1h', '3h', 'eod']:
    ret_col = f'ret_{horizon}'
    rankings = []
    oos_data = df[~df['is_sample']]
    
    for part, (gamma_col, delta_col) in participant_signals.items():
        # Use absolute IC as ranking metric
        best_ic = 0
        for col in [gamma_col, delta_col]:
            if col in oos_data.columns and ret_col in oos_data.columns:
                sig = oos_data[col].values.astype(float)
                ret = oos_data[ret_col].values.astype(float)
                ic, _, n = spearman_ic(sig, ret)
                if ic is not None:
                    best_ic = max(best_ic, abs(ic))
        rankings.append((part, best_ic))
    
    rankings.sort(key=lambda x: x[1], reverse=True)
    participant_ranking[horizon] = [{'participant': p, 'abs_ic': round(ic, 4)} for p, ic in rankings]
    print(f"  {horizon}: {' > '.join(f'{p}({ic:.4f})' for p, ic in rankings)}")

# ============================================================
# STEP 8: Generate Summary
# ============================================================

# Find best signals
best_signals = sorted(results_list, key=lambda x: abs(x.get('quintile_spread_oos', 0)), reverse=True)
consistent_signals = [s for s in best_signals if s['is_oos_consistent']]

summary_lines = []
summary_lines.append("## Key Findings\n")

if consistent_signals:
    summary_lines.append(f"Of {len(results_list)} signals tested, {len(consistent_signals)} show IS/OOS consistency.\n")
    summary_lines.append("### Top Consistent Signals (by OOS quintile spread):")
    for s in consistent_signals[:5]:
        summary_lines.append(f"- **{s['name']}**: Q-spread={s['quintile_spread_oos']}bps, Hit Q5={s['hit_rate_q5_oos']:.1%}")
else:
    summary_lines.append("No signals showed IS/OOS consistency. This means the tested signals are likely noise for intraday direction prediction.")

summary_lines.append(f"\n### Participant Ranking (best predictor of direction):")
for h in ['1h', '3h', 'eod']:
    top = participant_ranking[h][0]
    summary_lines.append(f"- **{h}**: {top['participant']} (abs IC = {top['abs_ic']:.4f})")

summary_text = '\n'.join(summary_lines)

# ============================================================
# STEP 9: Save outputs
# ============================================================
print("\nSaving results...")

output = {
    'signals': results_list,
    'participant_ranking': participant_ranking,
    'summary': summary_text,
    'metadata': {
        'n_days': processed_days,
        'n_observations': len(df),
        'is_days': len(IS_DATES),
        'oos_days': len(OOS_DATES),
        'date_range': f"{trace_dates[0]} to {trace_dates[-1]}",
    }
}

with open(OUTPUT_JSON, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"  Saved JSON: {OUTPUT_JSON}")

# Generate Markdown report
md_lines = [
    "# TRACE Intraday Signal Analysis Report",
    f"\n**Date Range:** {trace_dates[0]} to {trace_dates[-1]}",
    f"**Days:** {processed_days} ({len(IS_DATES)} IS / {len(OOS_DATES)} OOS)",
    f"**Observations:** {len(df):,}",
    f"\n{summary_text}",
    "\n---\n",
    "## Detailed Signal Results\n",
    "### Legend",
    "- **IC**: Spearman rank correlation between signal and forward return",
    "- **Q-spread**: Return difference between top and bottom signal quintiles (bps)",
    "- **Hit Q5**: Win rate of top quintile observations",
    "- **IS/OOS**: ✓ = signs agree across in-sample and out-of-sample\n",
]

for s in results_list:
    status = "✓" if s['is_oos_consistent'] else "✗"
    md_lines.append(f"### {status} {s['name']}")
    md_lines.append(f"*{s['description']}*\n")
    md_lines.append(f"N={s['n_obs']:,} | Q-spread={s['quintile_spread_oos']}bps | Hit Q5={s['hit_rate_q5_oos']:.1%} | Consistent={s['is_oos_consistent']}\n")
    
    # IC table
    md_lines.append("| Horizon | Morning IS | Morning OOS | Midday IS | Midday OOS | Afternoon IS | Afternoon OOS |")
    md_lines.append("|---------|-----------|-------------|-----------|------------|--------------|---------------|")
    
    for horizon_key, horizon_label in [('ic_1h', '1H'), ('ic_3h', '3H'), ('ic_eod', 'EOD')]:
        h_data = s.get(horizon_key, {})
        cells = [horizon_label]
        for bucket in ['morning', 'midday', 'afternoon']:
            b_data = h_data.get(bucket, {})
            is_val = b_data.get('is', '-')
            oos_val = b_data.get('oos', '-')
            is_str = f"{is_val:.4f}" if isinstance(is_val, (int, float)) else str(is_val)
            oos_str = f"{oos_val:.4f}" if isinstance(oos_val, (int, float)) else str(oos_val)
            cells.extend([is_str, oos_str])
        md_lines.append(f"| {' | '.join(cells)} |")
    
    md_lines.append("")

# Participant ranking section
md_lines.append("\n## Participant Ranking\n")
md_lines.append("Which participant type's positioning best predicts direction?\n")
for h in ['1h', '3h', 'eod']:
    md_lines.append(f"### {h.upper()} Forward Return")
    for i, p in enumerate(participant_ranking[h]):
        md_lines.append(f"{i+1}. **{p['participant']}** — abs IC = {p['abs_ic']:.4f}")
    md_lines.append("")

# Answers to Daniel's questions
md_lines.append("\n## Answers to Daniel's Questions\n")
md_lines.append("### Q1: Does price follow the largest side (call vs put GEX)?")
sig1_results = [s for s in results_list if s['column'].startswith('sig1_')]
if sig1_results:
    best_sig1 = max(sig1_results, key=lambda x: abs(x.get('quintile_spread_oos', 0)))
    md_lines.append(f"**{best_sig1['name']}**: Q-spread={best_sig1['quintile_spread_oos']}bps, Consistent={best_sig1['is_oos_consistent']}")
    if abs(best_sig1['quintile_spread_oos']) < 5:
        md_lines.append("→ **Weak/No signal.** Call vs put GEX imbalance does not reliably predict intraday direction.\n")
    else:
        md_lines.append(f"→ **Signal detected.** The call/put GEX split has some predictive value.\n")

md_lines.append("### Q2: Does price follow clusters of high negative GEX?")
sig2_results = [s for s in results_list if s['column'].startswith('sig2_')]
if sig2_results:
    best_sig2 = max(sig2_results, key=lambda x: abs(x.get('quintile_spread_oos', 0)))
    md_lines.append(f"**{best_sig2['name']}**: Q-spread={best_sig2['quintile_spread_oos']}bps, Consistent={best_sig2['is_oos_consistent']}")
    # Check if direction matters (magnet vs repel)
    sig2_dir = next((s for s in sig2_results if s['column'] == 'sig2_neg_cluster_distance'), None)
    if sig2_dir:
        oos_1h_morning = sig2_dir.get('ic_1h', {}).get('morning', {}).get('oos')
        if oos_1h_morning and oos_1h_morning > 0.02:
            md_lines.append("→ **Magnet effect detected.** Price tends to move TOWARD negative GEX clusters.\n")
        elif oos_1h_morning and oos_1h_morning < -0.02:
            md_lines.append("→ **Repulsion detected.** Price tends to move AWAY from negative GEX clusters.\n")
        else:
            md_lines.append("→ **No clear magnet/repulsion effect intraday.**\n")

md_lines.append("### Q3: Does price follow smart money (firm + pro-customer) flow?")
sig3_results = [s for s in results_list if s['column'] in ['sig3_smart_gamma', 'sig3_smart_delta']]
if sig3_results:
    for sr in sig3_results:
        md_lines.append(f"**{sr['name']}**: Q-spread={sr['quintile_spread_oos']}bps, Consistent={sr['is_oos_consistent']}")
    best_sig3 = max(sig3_results, key=lambda x: abs(x.get('quintile_spread_oos', 0)))
    if abs(best_sig3['quintile_spread_oos']) < 5:
        md_lines.append("→ **Weak/No signal.** Smart money positioning is not a reliable intraday predictor.\n")
    else:
        md_lines.append(f"→ **Signal present** in smart money positioning.\n")

md_lines.append("### Q4: Which participant type best predicts price?")
for h in ['1h', '3h']:
    top = participant_ranking[h][0]
    md_lines.append(f"- **{h}**: {top['participant']} (abs IC = {top['abs_ic']:.4f})")
md_lines.append("")

with open(OUTPUT_MD, 'w') as f:
    f.write('\n'.join(md_lines))
print(f"  Saved MD: {OUTPUT_MD}")

print("\n✅ Analysis complete!")
