#!/usr/bin/env python3
"""
TRACE Participant Positioning vs Market Direction Backtest
Comprehensive analysis of who's directionally correct and who makes money.
"""

import pandas as pd
import numpy as np
import json
import os
import warnings
from scipy import stats
from datetime import datetime, timedelta

warnings.filterwarnings('ignore')

# ─── Configuration ───
DATA_DIR = '/Users/lutherbot/.openclaw/workspace/data'
TRACE_HIST_DIR = f'{DATA_DIR}/trace_uncorrupted'
TRACE_LIVE_DIR = f'{DATA_DIR}/trace_live/daily'
SPX_FILE = f'{DATA_DIR}/spx_5min_polygon.csv'
FOMC_FILE = f'{DATA_DIR}/fomc_dates.json'
OUTPUT_MD = f'{DATA_DIR}/participant_backtest_report.md'
OUTPUT_JSON = f'{DATA_DIR}/participant_backtest_results.json'

PARTICIPANTS = ['bd', 'cust', 'firm', 'mm', 'procust']
PART_NAMES = {'bd': 'Broker-Dealer', 'cust': 'Customer', 'firm': 'Firm', 
              'mm': 'Market Maker', 'procust': 'Pro Customer'}

# Corruption period
CORRUPT_START = '2025-10-27'
CORRUPT_END = '2026-02-17'

# ─── Load FOMC dates ───
with open(FOMC_FILE) as f:
    fomc_dates = set(json.load(f)['dates'])

print("Loading SPX 5-min data...")
spx = pd.read_csv(SPX_FILE)
spx['datetime'] = pd.to_datetime(spx['datetime'])
spx['date'] = spx['datetime'].dt.strftime('%Y-%m-%d')
spx = spx.sort_values('datetime').reset_index(drop=True)

# Precompute daily reference prices from SPX
print("Computing daily price references...")
daily_prices = {}
for date, group in spx.groupby('date'):
    g = group.sort_values('datetime')
    # Times in UTC from the data - convert to ET
    # Data appears to be UTC: 13:30 UTC = 9:30 ET
    daily_prices[date] = {
        'open': g.iloc[0]['open'],
        'close': g.iloc[-1]['close'],
        'bars': g[['datetime', 'open', 'high', 'low', 'close']].reset_index(drop=True)
    }
    # Find price at specific offsets from open
    open_time = g.iloc[0]['datetime']
    for offset_name, offset_min in [('1h', 60), ('3h', 180)]:
        target_time = open_time + timedelta(minutes=offset_min)
        mask = g['datetime'] <= target_time
        if mask.any():
            daily_prices[date][f'price_{offset_name}'] = g[mask].iloc[-1]['close']
        else:
            daily_prices[date][f'price_{offset_name}'] = None

# Build next-day map
sorted_dates = sorted(daily_prices.keys())
next_day_map = {}
for i, d in enumerate(sorted_dates):
    if i + 1 < len(sorted_dates):
        next_day_map[d] = sorted_dates[i + 1]

print(f"SPX: {len(daily_prices)} trading days, {sorted_dates[0]} to {sorted_dates[-1]}")

# ─── Load TRACE data ───
print("Loading TRACE historical data...")

def load_trace_day_hist(filepath, date_str):
    """Load a single historical TRACE GEX file and extract morning snapshot features."""
    df = pd.read_parquet(filepath)
    if df.empty:
        return None
    
    # Get earliest timestamp (morning snapshot)
    timestamps = sorted(df['timestamp'].unique())
    morning_ts = timestamps[0]
    eod_ts = timestamps[-1]
    
    morning = df[df['timestamp'] == morning_ts].copy()
    eod = df[df['timestamp'] == eod_ts].copy()
    
    result = {'date': date_str}
    
    for snapshot_name, snap_df in [('morning', morning), ('eod', eod)]:
        if snap_df.empty:
            continue
            
        strikes = snap_df['strike_price'].values
        
        # Need spot price - use SPX open price for the date
        if date_str in daily_prices:
            spot = daily_prices[date_str]['open']
        else:
            # Approximate from strike with max total gamma
            total_gamma = sum(snap_df[f'{p}_gamma'].values + snap_df[f'{p}_gamma_0'].values for p in PARTICIPANTS)
            spot = strikes[np.argmax(np.abs(total_gamma))]
        
        result[f'{snapshot_name}_spot'] = spot
        
        for p in PARTICIPANTS:
            # Total gamma = non-0DTE + 0DTE
            gamma_total = snap_df[f'{p}_gamma'].values + snap_df[f'{p}_gamma_0'].values
            gamma_non0 = snap_df[f'{p}_gamma'].values
            gamma_0dte = snap_df[f'{p}_gamma_0'].values
            
            # Net gamma (sum across all strikes)
            result[f'{snapshot_name}_{p}_net_gamma'] = gamma_total.sum()
            result[f'{snapshot_name}_{p}_net_gamma_non0'] = gamma_non0.sum()
            result[f'{snapshot_name}_{p}_net_gamma_0dte'] = gamma_0dte.sum()
            
            # Gamma tilt: % of gamma above spot vs below
            above_mask = strikes > spot
            below_mask = strikes < spot
            
            gamma_above = np.abs(gamma_total[above_mask]).sum()
            gamma_below = np.abs(gamma_total[below_mask]).sum()
            total_abs = gamma_above + gamma_below
            
            if total_abs > 0:
                result[f'{snapshot_name}_{p}_tilt'] = gamma_above / total_abs  # >0.5 = more above
            else:
                result[f'{snapshot_name}_{p}_tilt'] = 0.5
            
            # Signed tilt: positive gamma above minus below (directional lean)
            gamma_above_signed = gamma_total[above_mask].sum()
            gamma_below_signed = gamma_total[below_mask].sum()
            result[f'{snapshot_name}_{p}_signed_tilt'] = gamma_above_signed - gamma_below_signed
    
    return result


def load_trace_day_live(day_dir, date_str):
    """Load a live TRACE day (multiple files) and extract morning/EOD features."""
    gex_files = sorted([f for f in os.listdir(day_dir) if f.startswith('intradayStrikeGEX_')])
    delta_files = sorted([f for f in os.listdir(day_dir) if f.startswith('intradayDelta_')])
    
    if not gex_files:
        return None
    
    morning_gex = pd.read_parquet(os.path.join(day_dir, gex_files[0]))
    eod_gex = pd.read_parquet(os.path.join(day_dir, gex_files[-1]))
    
    result = {'date': date_str}
    
    # Get spot from delta files if available
    morning_spot = None
    if delta_files:
        morning_delta = pd.read_parquet(os.path.join(day_dir, delta_files[0]))
        eod_delta = pd.read_parquet(os.path.join(day_dir, delta_files[-1]))
        if 'spot' in morning_delta.columns:
            morning_spot = morning_delta['spot'].median()
            result['morning_spot'] = morning_spot
            result['eod_spot'] = eod_delta['spot'].median() if not eod_delta.empty else morning_spot
            
            # Extract delta at spot for each participant
            for snapshot_name, delta_df in [('morning', morning_delta), ('eod', eod_delta)]:
                spot_val = result[f'{snapshot_name}_spot']
                # Find closest spot row
                closest_idx = (delta_df['spot'] - spot_val).abs().idxmin()
                row = delta_df.loc[closest_idx]
                for p in PARTICIPANTS:
                    col = f'{p}_delta'
                    if col in delta_df.columns:
                        result[f'{snapshot_name}_{p}_net_delta'] = row[col]
    
    for snapshot_name, snap_df in [('morning', morning_gex), ('eod', eod_gex)]:
        if snap_df.empty:
            continue
        
        # Use single timestamp
        ts = snap_df['timestamp'].iloc[0]
        snap = snap_df[snap_df['timestamp'] == ts]
        strikes = snap['strike_price'].values
        
        if f'{snapshot_name}_spot' not in result:
            if date_str in daily_prices:
                result[f'{snapshot_name}_spot'] = daily_prices[date_str]['open']
            else:
                result[f'{snapshot_name}_spot'] = strikes[len(strikes)//2]
        
        spot = result[f'{snapshot_name}_spot']
        
        for p in PARTICIPANTS:
            gamma_total = snap[f'{p}_gamma'].values + snap[f'{p}_gamma_0'].values
            gamma_non0 = snap[f'{p}_gamma'].values
            gamma_0dte = snap[f'{p}_gamma_0'].values
            
            result[f'{snapshot_name}_{p}_net_gamma'] = gamma_total.sum()
            result[f'{snapshot_name}_{p}_net_gamma_non0'] = gamma_non0.sum()
            result[f'{snapshot_name}_{p}_net_gamma_0dte'] = gamma_0dte.sum()
            
            above_mask = strikes > spot
            below_mask = strikes < spot
            
            gamma_above = np.abs(gamma_total[above_mask]).sum()
            gamma_below = np.abs(gamma_total[below_mask]).sum()
            total_abs = gamma_above + gamma_below
            
            if total_abs > 0:
                result[f'{snapshot_name}_{p}_tilt'] = gamma_above / total_abs
            else:
                result[f'{snapshot_name}_{p}_tilt'] = 0.5
            
            gamma_above_signed = gamma_total[above_mask].sum()
            gamma_below_signed = gamma_total[below_mask].sum()
            result[f'{snapshot_name}_{p}_signed_tilt'] = gamma_above_signed - gamma_below_signed
    
    return result


# Load all historical data
all_days = []
hist_files = sorted([f for f in os.listdir(TRACE_HIST_DIR) if f.endswith('.parquet')])
print(f"Processing {len(hist_files)} historical TRACE files...")

for i, fname in enumerate(hist_files):
    date_str = fname.replace('intradayStrikeGEX_', '').replace('.parquet', '')
    
    if date_str in fomc_dates:
        continue
    
    filepath = os.path.join(TRACE_HIST_DIR, fname)
    try:
        result = load_trace_day_hist(filepath, date_str)
        if result:
            all_days.append(result)
    except Exception as e:
        print(f"  Error on {date_str}: {e}")
    
    if (i + 1) % 50 == 0:
        print(f"  Processed {i+1}/{len(hist_files)} historical files...")

# Load live data
live_dates = sorted([d for d in os.listdir(TRACE_LIVE_DIR) if os.path.isdir(os.path.join(TRACE_LIVE_DIR, d))])
print(f"Processing {len(live_dates)} live TRACE directories...")
for date_str in live_dates:
    if date_str in fomc_dates:
        continue
    day_dir = os.path.join(TRACE_LIVE_DIR, date_str)
    try:
        result = load_trace_day_live(day_dir, date_str)
        if result:
            all_days.append(result)
    except Exception as e:
        print(f"  Error on {date_str}: {e}")

df = pd.DataFrame(all_days)
df = df.sort_values('date').reset_index(drop=True)
print(f"\nTotal days loaded: {len(df)}")
print(f"Date range: {df['date'].min()} to {df['date'].max()}")

# ─── Compute forward returns ───
print("Computing forward returns...")

returns = {}
for _, row in df.iterrows():
    date = row['date']
    if date not in daily_prices:
        continue
    
    dp = daily_prices[date]
    spot = row.get('morning_spot', dp['open'])
    
    ret = {'date': date}
    
    # Returns from open
    if dp.get('price_1h') is not None:
        ret['ret_1h'] = (dp['price_1h'] - dp['open']) / dp['open']
    if dp.get('price_3h') is not None:
        ret['ret_3h'] = (dp['price_3h'] - dp['open']) / dp['open']
    ret['ret_eod'] = (dp['close'] - dp['open']) / dp['open']
    
    # Next day returns
    if date in next_day_map:
        nd = next_day_map[date]
        if nd in daily_prices:
            ret['ret_next_open'] = (daily_prices[nd]['open'] - dp['close']) / dp['close']
            ret['ret_next_close'] = (daily_prices[nd]['close'] - dp['close']) / dp['close']
    
    # Daily move magnitude (for P&L calc)
    ret['daily_move'] = dp['close'] - dp['open']
    ret['daily_move_sq'] = (dp['close'] - dp['open']) ** 2
    ret['daily_range'] = dp['bars']['high'].max() - dp['bars']['low'].min()
    
    # Gap direction
    if date in next_day_map:
        nd = next_day_map[date]
        if nd in daily_prices:
            gap = daily_prices[nd]['open'] - dp['close']
            ret['next_gap_pct'] = gap / dp['close']
    
    returns[date] = ret

returns_df = pd.DataFrame(returns.values())
print(f"Returns computed for {len(returns_df)} days")

# Merge
df = df.merge(returns_df, on='date', how='inner')
print(f"Merged dataset: {len(df)} days")

# Flag corruption period
df['is_corrupt'] = (df['date'] >= CORRUPT_START) & (df['date'] <= CORRUPT_END)
print(f"Corrupt period days: {df['is_corrupt'].sum()}")

# Mark day of week
df['dow'] = pd.to_datetime(df['date']).dt.dayofweek  # 0=Mon
df['dow_name'] = pd.to_datetime(df['date']).dt.day_name()

# ─── Exclude corrupt period for main analysis ───
df_clean = df[~df['is_corrupt']].copy()
print(f"Clean dataset (excl corrupt): {len(df_clean)} days")

# IS/OOS split (60/40)
n = len(df_clean)
is_cutoff = int(n * 0.6)
df_clean['is_oos'] = ['IS'] * is_cutoff + ['OOS'] * (n - is_cutoff)
is_date = df_clean.iloc[is_cutoff-1]['date']
oos_start = df_clean.iloc[is_cutoff]['date']
print(f"IS: {is_cutoff} days ({df_clean.iloc[0]['date']} to {is_date})")
print(f"OOS: {n - is_cutoff} days ({oos_start} to {df_clean.iloc[-1]['date']})")

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 1: Directional Accuracy by Participant (IC)
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 1: Directional Accuracy (IC)")
print("="*60)

horizons = ['ret_1h', 'ret_3h', 'ret_eod', 'ret_next_close']
horizon_names = {'ret_1h': '1H', 'ret_3h': '3H', 'ret_eod': 'EOD', 'ret_next_close': 'NextDay'}

metrics = ['net_gamma', 'tilt', 'signed_tilt']
metric_labels = {'net_gamma': 'Net Gamma', 'tilt': 'Abs Tilt', 'signed_tilt': 'Signed Tilt'}

ic_results = {}

for sample_name, sample_df in [('IS', df_clean[df_clean['is_oos']=='IS']), 
                                 ('OOS', df_clean[df_clean['is_oos']=='OOS']),
                                 ('ALL', df_clean)]:
    ic_results[sample_name] = {}
    for p in PARTICIPANTS:
        ic_results[sample_name][p] = {}
        for metric in metrics:
            ic_results[sample_name][p][metric] = {}
            col = f'morning_{p}_{metric}'
            if col not in sample_df.columns:
                continue
            for horizon in horizons:
                if horizon not in sample_df.columns:
                    continue
                valid = sample_df[[col, horizon]].dropna()
                if len(valid) < 20:
                    continue
                ic, pval = stats.spearmanr(valid[col], valid[horizon])
                ic_results[sample_name][p][metric][horizon] = {
                    'ic': round(ic, 4),
                    'pval': round(pval, 4),
                    'n': len(valid),
                    'tstat': round(ic * np.sqrt(len(valid) - 2) / np.sqrt(1 - ic**2), 2) if abs(ic) < 1 else 0
                }

# Print summary table
print("\nIC Summary (Signed Tilt → Forward Returns) — ALL sample:")
print(f"{'Participant':<20} {'1H':>8} {'3H':>8} {'EOD':>8} {'NextDay':>8}")
print("-" * 56)
for p in PARTICIPANTS:
    row = f"{PART_NAMES[p]:<20}"
    for h in horizons:
        val = ic_results['ALL'].get(p, {}).get('signed_tilt', {}).get(h, {})
        if val:
            ic_val = val['ic']
            sig = '*' if val['pval'] < 0.05 else ''
            row += f" {ic_val:>7.3f}{sig}"
        else:
            row += f" {'N/A':>8}"
    print(row)

print("\nIC Summary (Net Gamma → Forward Returns) — ALL sample:")
print(f"{'Participant':<20} {'1H':>8} {'3H':>8} {'EOD':>8} {'NextDay':>8}")
print("-" * 56)
for p in PARTICIPANTS:
    row = f"{PART_NAMES[p]:<20}"
    for h in horizons:
        val = ic_results['ALL'].get(p, {}).get('net_gamma', {}).get(h, {})
        if val:
            ic_val = val['ic']
            sig = '*' if val['pval'] < 0.05 else ''
            row += f" {ic_val:>7.3f}{sig}"
        else:
            row += f" {'N/A':>8}"
    print(row)

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 2: Contrarian vs Directional Classification
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 2: Contrarian vs Directional")
print("="*60)

for p in PARTICIPANTS:
    print(f"\n{PART_NAMES[p]}:")
    for metric in ['signed_tilt', 'net_gamma']:
        col = f'morning_{p}_{metric}'
        if col not in df_clean.columns:
            continue
        for h in horizons:
            if h not in df_clean.columns:
                continue
            valid = df_clean[[col, h, 'is_oos']].dropna()
            is_data = valid[valid['is_oos'] == 'IS']
            oos_data = valid[valid['is_oos'] == 'OOS']
            
            is_ic = stats.spearmanr(is_data[col], is_data[h])[0] if len(is_data) > 20 else np.nan
            oos_ic = stats.spearmanr(oos_data[col], oos_data[h])[0] if len(oos_data) > 20 else np.nan
            
            if not np.isnan(is_ic) and not np.isnan(oos_ic):
                direction = "DIRECTIONAL" if is_ic > 0 and oos_ic > 0 else \
                           "CONTRARIAN" if is_ic < 0 and oos_ic < 0 else "MIXED"
                print(f"  {metric_labels[metric]:>15} → {horizon_names[h]:>7}: IS={is_ic:+.3f} OOS={oos_ic:+.3f} [{direction}]")

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 3: Participant Agreement/Divergence
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 3: Participant Agreement/Divergence")
print("="*60)

# Define directional lean using signed tilt
for p in PARTICIPANTS:
    col = f'morning_{p}_signed_tilt'
    if col in df_clean.columns:
        df_clean[f'{p}_bullish'] = df_clean[col] > 0

# Pro + Firm agreement
agree_mask = df_clean['procust_bullish'] == df_clean['firm_bullish']
disagree_mask = ~agree_mask

print(f"\nPro Customer + Firm AGREE: {agree_mask.sum()} days")
print(f"Pro Customer + Firm DISAGREE: {disagree_mask.sum()} days")

for label, mask in [('AGREE (both bullish)', df_clean['procust_bullish'] & df_clean['firm_bullish']),
                     ('AGREE (both bearish)', ~df_clean['procust_bullish'] & ~df_clean['firm_bullish']),
                     ('DISAGREE', disagree_mask)]:
    subset = df_clean[mask]
    if len(subset) < 10:
        continue
    print(f"\n  {label} ({len(subset)} days):")
    for h in horizons:
        if h in subset.columns:
            valid = subset[h].dropna()
            if len(valid) > 5:
                mean_ret = valid.mean() * 100
                wr = (valid > 0).mean() * 100
                print(f"    {horizon_names[h]:>7}: mean={mean_ret:+.3f}% WR={wr:.1f}%")

# MM + Customer agreement
print("\n\nMM + Customer agreement:")
mm_cust_agree = df_clean['mm_bullish'] == df_clean['cust_bullish']
for label, mask in [('MM+Cust AGREE bullish', df_clean['mm_bullish'] & df_clean['cust_bullish']),
                     ('MM+Cust AGREE bearish', ~df_clean['mm_bullish'] & ~df_clean['cust_bullish']),
                     ('MM+Cust DISAGREE', ~mm_cust_agree)]:
    subset = df_clean[mask]
    if len(subset) < 10:
        continue
    print(f"\n  {label} ({len(subset)} days):")
    for h in horizons:
        if h in subset.columns:
            valid = subset[h].dropna()
            if len(valid) > 5:
                mean_ret = valid.mean() * 100
                wr = (valid > 0).mean() * 100
                print(f"    {horizon_names[h]:>7}: mean={mean_ret:+.3f}% WR={wr:.1f}%")

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 4: 0DTE vs Non-0DTE
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 4: 0DTE vs Non-0DTE Positioning")
print("="*60)

for p in PARTICIPANTS:
    print(f"\n{PART_NAMES[p]}:")
    for gamma_type, gamma_label in [('net_gamma_0dte', '0DTE'), ('net_gamma_non0', 'Non-0DTE'), ('net_gamma', 'Total')]:
        col = f'morning_{p}_{gamma_type}'
        if col not in df_clean.columns:
            continue
        ics = []
        for h in horizons:
            if h not in df_clean.columns:
                continue
            valid = df_clean[[col, h]].dropna()
            if len(valid) > 20:
                ic = stats.spearmanr(valid[col], valid[h])[0]
                ics.append(f"{horizon_names[h]}={ic:+.3f}")
        if ics:
            print(f"  {gamma_label:>8}: {', '.join(ics)}")

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 5: Profitability Estimation
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 5: Profitability Estimation")
print("="*60)

# P&L = 0.5 * net_gamma * (daily_move)^2 + net_delta * daily_move
# Since we don't have delta for historical data, use gamma-based P&L
# Normalize gamma to make P&L comparable

pnl_results = {}
for p in PARTICIPANTS:
    gamma_col = f'morning_{p}_net_gamma'
    if gamma_col not in df_clean.columns:
        continue
    
    valid = df_clean[[gamma_col, 'daily_move', 'daily_move_sq', 'date']].dropna()
    
    # Gamma P&L: long gamma profits from big moves
    # Normalize gamma by its std to make P&L interpretable
    gamma_std = valid[gamma_col].std()
    if gamma_std == 0:
        continue
    
    gamma_norm = valid[gamma_col] / gamma_std
    
    # P&L components
    gamma_pnl = 0.5 * gamma_norm * valid['daily_move_sq']
    
    # Also compute directional P&L from signed tilt
    tilt_col = f'morning_{p}_signed_tilt'
    if tilt_col in valid.columns:
        tilt_norm = df_clean.loc[valid.index, tilt_col] / df_clean.loc[valid.index, tilt_col].std()
        directional_pnl = tilt_norm * valid['daily_move']
    else:
        directional_pnl = pd.Series(0, index=valid.index)
    
    total_pnl = gamma_pnl + directional_pnl
    
    pnl_results[p] = {
        'name': PART_NAMES[p],
        'gamma_pnl_total': float(gamma_pnl.sum()),
        'gamma_pnl_mean': float(gamma_pnl.mean()),
        'directional_pnl_total': float(directional_pnl.sum()),
        'directional_pnl_mean': float(directional_pnl.mean()),
        'total_pnl': float(total_pnl.sum()),
        'total_pnl_mean': float(total_pnl.mean()),
        'sharpe': float(total_pnl.mean() / total_pnl.std()) * np.sqrt(252) if total_pnl.std() > 0 else 0,
        'n_days': len(valid),
        'avg_gamma_sign': float(np.sign(valid[gamma_col]).mean()),
    }

print(f"\n{'Participant':<20} {'Gamma P&L':>12} {'Dir P&L':>12} {'Total P&L':>12} {'Sharpe':>8} {'Avg Sign':>10}")
print("-" * 80)
for p in sorted(pnl_results.keys(), key=lambda x: pnl_results[x]['total_pnl'], reverse=True):
    r = pnl_results[p]
    print(f"{r['name']:<20} {r['gamma_pnl_total']:>12.1f} {r['directional_pnl_total']:>12.1f} {r['total_pnl']:>12.1f} {r['sharpe']:>8.2f} {r['avg_gamma_sign']:>10.3f}")

# ═══════════════════════════════════════════════════════════════════
# ANALYSIS 6: Conditional Analysis
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("ANALYSIS 6: Conditional Analysis")
print("="*60)

# 6a: By GEX regime (positive vs negative total net gamma)
total_net_gamma = sum(df_clean[f'morning_{p}_net_gamma'] for p in PARTICIPANTS if f'morning_{p}_net_gamma' in df_clean.columns)
df_clean['total_net_gamma'] = total_net_gamma
df_clean['gex_regime'] = np.where(df_clean['total_net_gamma'] > 0, 'Positive GEX', 'Negative GEX')

conditional_results = {}

print("\n--- By GEX Regime ---")
for regime in ['Positive GEX', 'Negative GEX']:
    subset = df_clean[df_clean['gex_regime'] == regime]
    print(f"\n{regime} ({len(subset)} days):")
    conditional_results[regime] = {}
    for p in PARTICIPANTS:
        col = f'morning_{p}_signed_tilt'
        if col not in subset.columns:
            continue
        for h in ['ret_1h', 'ret_eod']:
            valid = subset[[col, h]].dropna()
            if len(valid) > 15:
                ic = stats.spearmanr(valid[col], valid[h])[0]
                conditional_results[regime][f'{p}_{h}'] = round(ic, 4)
                print(f"  {PART_NAMES[p]:>15} → {horizon_names[h]:>5}: IC={ic:+.3f} (n={len(valid)})")

# 6b: By day of week
print("\n--- By Day of Week ---")
dow_results = {}
for dow_name in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday']:
    subset = df_clean[df_clean['dow_name'] == dow_name]
    if len(subset) < 15:
        continue
    print(f"\n{dow_name} ({len(subset)} days):")
    dow_results[dow_name] = {}
    for p in ['mm', 'procust', 'firm']:
        col = f'morning_{p}_signed_tilt'
        if col not in subset.columns:
            continue
        for h in ['ret_1h', 'ret_eod']:
            valid = subset[[col, h]].dropna()
            if len(valid) > 10:
                ic = stats.spearmanr(valid[col], valid[h])[0]
                dow_results[dow_name][f'{p}_{h}'] = round(ic, 4)
                print(f"  {PART_NAMES[p]:>15} → {horizon_names[h]:>5}: IC={ic:+.3f}")

# 6c: By gap direction
df_clean['gap_direction'] = np.where(df_clean['ret_1h'] > 0.001, 'Gap Up', 
                              np.where(df_clean['ret_1h'] < -0.001, 'Gap Down', 'Flat'))

print("\n--- By Gap Direction (using 1H return as proxy) ---")
gap_results = {}
for gap in ['Gap Up', 'Gap Down', 'Flat']:
    subset = df_clean[df_clean['gap_direction'] == gap]
    if len(subset) < 15:
        continue
    print(f"\n{gap} ({len(subset)} days):")
    gap_results[gap] = {}
    for p in ['mm', 'procust', 'firm']:
        col = f'morning_{p}_signed_tilt'
        if col not in subset.columns:
            continue
        for h in ['ret_eod', 'ret_next_close']:
            valid = subset[[col, h]].dropna()
            if len(valid) > 10:
                ic = stats.spearmanr(valid[col], valid[h])[0]
                gap_results[gap][f'{p}_{h}'] = round(ic, 4)
                print(f"  {PART_NAMES[p]:>15} → {horizon_names[h]:>7}: IC={ic:+.3f}")

# ═══════════════════════════════════════════════════════════════════
# Additional: Win Rate by Quintile
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("QUINTILE ANALYSIS: MM Signed Tilt")
print("="*60)

for sample_name, sample_df in [('IS', df_clean[df_clean['is_oos']=='IS']), 
                                 ('OOS', df_clean[df_clean['is_oos']=='OOS'])]:
    print(f"\n{sample_name}:")
    col = 'morning_mm_signed_tilt'
    if col not in sample_df.columns:
        continue
    valid = sample_df[[col, 'ret_1h', 'ret_3h', 'ret_eod']].dropna()
    if len(valid) < 25:
        continue
    valid['quintile'] = pd.qcut(valid[col], 5, labels=['Q1(bear)', 'Q2', 'Q3', 'Q4', 'Q5(bull)'])
    
    print(f"  {'Quintile':<12} {'N':>4} {'1H mean':>10} {'1H WR':>8} {'EOD mean':>10} {'EOD WR':>8}")
    for q in ['Q1(bear)', 'Q2', 'Q3', 'Q4', 'Q5(bull)']:
        qd = valid[valid['quintile'] == q]
        if len(qd) > 0:
            print(f"  {q:<12} {len(qd):>4} {qd['ret_1h'].mean()*100:>9.3f}% {(qd['ret_1h']>0).mean()*100:>7.1f}% {qd['ret_eod'].mean()*100:>9.3f}% {(qd['ret_eod']>0).mean()*100:>7.1f}%")

# Same for ProCust
print("\nQUINTILE ANALYSIS: ProCust Signed Tilt")
for sample_name, sample_df in [('IS', df_clean[df_clean['is_oos']=='IS']), 
                                 ('OOS', df_clean[df_clean['is_oos']=='OOS'])]:
    print(f"\n{sample_name}:")
    col = 'morning_procust_signed_tilt'
    if col not in sample_df.columns:
        continue
    valid = sample_df[[col, 'ret_1h', 'ret_3h', 'ret_eod']].dropna()
    if len(valid) < 25:
        continue
    valid['quintile'] = pd.qcut(valid[col], 5, labels=['Q1(bear)', 'Q2', 'Q3', 'Q4', 'Q5(bull)'])
    
    print(f"  {'Quintile':<12} {'N':>4} {'1H mean':>10} {'1H WR':>8} {'EOD mean':>10} {'EOD WR':>8}")
    for q in ['Q1(bear)', 'Q2', 'Q3', 'Q4', 'Q5(bull)']:
        qd = valid[valid['quintile'] == q]
        if len(qd) > 0:
            print(f"  {q:<12} {len(qd):>4} {qd['ret_1h'].mean()*100:>9.3f}% {(qd['ret_1h']>0).mean()*100:>7.1f}% {qd['ret_eod'].mean()*100:>9.3f}% {(qd['ret_eod']>0).mean()*100:>7.1f}%")


# ═══════════════════════════════════════════════════════════════════
# SAVE RESULTS
# ═══════════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("SAVING RESULTS")
print("="*60)

# JSON results
json_output = {
    'metadata': {
        'total_days': len(df),
        'clean_days': len(df_clean),
        'is_days': int((df_clean['is_oos'] == 'IS').sum()),
        'oos_days': int((df_clean['is_oos'] == 'OOS').sum()),
        'is_range': f"{df_clean[df_clean['is_oos']=='IS']['date'].min()} to {df_clean[df_clean['is_oos']=='IS']['date'].max()}",
        'oos_range': f"{df_clean[df_clean['is_oos']=='OOS']['date'].min()} to {df_clean[df_clean['is_oos']=='OOS']['date'].max()}",
        'corrupt_excluded': int(df['is_corrupt'].sum()),
        'fomc_excluded': 'per fomc_dates.json',
        'spx_range': f"{sorted_dates[0]} to {sorted_dates[-1]}",
        'trace_range': f"{df['date'].min()} to {df['date'].max()}",
    },
    'ic_results': {},
    'pnl_results': pnl_results,
    'conditional': {
        'gex_regime': conditional_results,
        'day_of_week': dow_results,
        'gap_direction': gap_results,
    }
}

# Flatten IC results for JSON
for sample in ic_results:
    json_output['ic_results'][sample] = {}
    for p in ic_results[sample]:
        json_output['ic_results'][sample][p] = {}
        for metric in ic_results[sample][p]:
            json_output['ic_results'][sample][p][metric] = {}
            for h in ic_results[sample][p][metric]:
                json_output['ic_results'][sample][p][metric][h] = ic_results[sample][p][metric][h]

with open(OUTPUT_JSON, 'w') as f:
    json.dump(json_output, f, indent=2, default=str)
print(f"Saved JSON to {OUTPUT_JSON}")

# ─── Generate Markdown Report ───
md = []
md.append("# TRACE Participant Positioning vs Market Direction — Comprehensive Backtest")
md.append(f"\n*Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}*\n")

md.append("## Data Quality & Coverage")
md.append(f"- **Total TRACE days loaded:** {len(df)}")
md.append(f"- **Clean days (excl corrupt period):** {len(df_clean)}")
md.append(f"- **IS period:** {json_output['metadata']['is_days']} days ({json_output['metadata']['is_range']})")
md.append(f"- **OOS period:** {json_output['metadata']['oos_days']} days ({json_output['metadata']['oos_range']})")
md.append(f"- **Corruption period excluded:** {json_output['metadata']['corrupt_excluded']} days (Oct 27 2025 – Feb 17 2026)")
md.append(f"- **FOMC dates excluded:** per fomc_dates.json")
md.append(f"- **SPX data range:** {json_output['metadata']['spx_range']}")
md.append(f"- **TRACE source:** trace_uncorrupted/ (366 historical files) + trace_live/daily/ (7 live days)")
md.append(f"- **Historical TRACE = single EOD-ish snapshots with intraday timestamps; live = multi-snapshot**")
md.append(f"- **Note:** Historical data has multiple intraday timestamps per file — earliest used as 'morning', latest as 'EOD'")
md.append("")

# IC Table
md.append("## 1. Information Coefficient (IC): Signed Tilt → Forward Returns")
md.append("")
md.append("Rank correlation of each participant's gamma tilt (gamma above spot minus below) with subsequent SPX returns.")
md.append("Positive IC = positioning predicts direction correctly. Negative IC = fade their positioning.")
md.append("")

for sample in ['ALL', 'IS', 'OOS']:
    md.append(f"### {sample} Sample")
    md.append(f"| Participant | 1H | 3H | EOD | Next Day |")
    md.append(f"|---|---|---|---|---|")
    for p in PARTICIPANTS:
        row = f"| {PART_NAMES[p]} "
        for h in horizons:
            val = ic_results[sample].get(p, {}).get('signed_tilt', {}).get(h, {})
            if val:
                ic_val = val['ic']
                sig = '**' if val['pval'] < 0.05 else ''
                row += f"| {sig}{ic_val:+.4f}{sig} (t={val['tstat']:.1f}) "
            else:
                row += "| N/A "
        row += "|"
        md.append(row)
    md.append("")

# Net Gamma IC
md.append("## 2. IC: Net Gamma → Forward Returns")
md.append("")
for sample in ['ALL', 'IS', 'OOS']:
    md.append(f"### {sample} Sample")
    md.append(f"| Participant | 1H | 3H | EOD | Next Day |")
    md.append(f"|---|---|---|---|---|")
    for p in PARTICIPANTS:
        row = f"| {PART_NAMES[p]} "
        for h in horizons:
            val = ic_results[sample].get(p, {}).get('net_gamma', {}).get(h, {})
            if val:
                ic_val = val['ic']
                sig = '**' if val['pval'] < 0.05 else ''
                row += f"| {sig}{ic_val:+.4f}{sig} (t={val['tstat']:.1f}) "
            else:
                row += "| N/A "
        row += "|"
        md.append(row)
    md.append("")

# Contrarian vs Directional
md.append("## 3. Contrarian vs Directional Classification")
md.append("")
md.append("| Participant | Metric | Horizon | IS IC | OOS IC | Classification |")
md.append("|---|---|---|---|---|---|")
for p in PARTICIPANTS:
    for metric in ['signed_tilt', 'net_gamma']:
        col = f'morning_{p}_{metric}'
        if col not in df_clean.columns:
            continue
        for h in horizons:
            if h not in df_clean.columns:
                continue
            is_data = df_clean[df_clean['is_oos']=='IS'][[col, h]].dropna()
            oos_data = df_clean[df_clean['is_oos']=='OOS'][[col, h]].dropna()
            if len(is_data) < 20 or len(oos_data) < 20:
                continue
            is_ic = stats.spearmanr(is_data[col], is_data[h])[0]
            oos_ic = stats.spearmanr(oos_data[col], oos_data[h])[0]
            direction = "✅ DIRECTIONAL" if is_ic > 0 and oos_ic > 0 else \
                       "🔄 CONTRARIAN" if is_ic < 0 and oos_ic < 0 else "⚠️ MIXED"
            md.append(f"| {PART_NAMES[p]} | {metric_labels[metric]} | {horizon_names[h]} | {is_ic:+.4f} | {oos_ic:+.4f} | {direction} |")
md.append("")

# P&L Leaderboard
md.append("## 4. Profitability Leaderboard")
md.append("")
md.append("Estimated P&L from gamma and directional positioning (normalized units).")
md.append("Gamma P&L = 0.5 × normalized_net_gamma × daily_move². Dir P&L = normalized_signed_tilt × daily_move.")
md.append("")
md.append("| Rank | Participant | Gamma P&L | Dir P&L | Total P&L | Sharpe | Avg Gamma Sign |")
md.append("|---|---|---|---|---|---|---|")
for rank, p in enumerate(sorted(pnl_results.keys(), key=lambda x: pnl_results[x]['total_pnl'], reverse=True), 1):
    r = pnl_results[p]
    md.append(f"| {rank} | {r['name']} | {r['gamma_pnl_total']:.0f} | {r['directional_pnl_total']:.0f} | {r['total_pnl']:.0f} | {r['sharpe']:.2f} | {r['avg_gamma_sign']:+.3f} |")
md.append("")

# 0DTE vs Non-0DTE
md.append("## 5. 0DTE vs Non-0DTE Positioning Predictiveness")
md.append("")
md.append("| Participant | Type | 1H IC | 3H IC | EOD IC | Next Day IC |")
md.append("|---|---|---|---|---|---|")
for p in PARTICIPANTS:
    for gamma_type, gamma_label in [('net_gamma_0dte', '0DTE'), ('net_gamma_non0', 'Non-0DTE')]:
        col = f'morning_{p}_{gamma_type}'
        if col not in df_clean.columns:
            continue
        row = f"| {PART_NAMES[p]} | {gamma_label} "
        for h in horizons:
            valid = df_clean[[col, h]].dropna()
            if len(valid) > 20:
                ic = stats.spearmanr(valid[col], valid[h])[0]
                row += f"| {ic:+.4f} "
            else:
                row += "| N/A "
        row += "|"
        md.append(row)
md.append("")

# Agreement/Divergence
md.append("## 6. Participant Agreement/Divergence")
md.append("")
md.append("### Pro Customer + Firm")
md.append("| Condition | N days | 1H Mean | 1H WR | EOD Mean | EOD WR |")
md.append("|---|---|---|---|---|---|")
for label, mask in [('Both Bullish', df_clean['procust_bullish'] & df_clean['firm_bullish']),
                     ('Both Bearish', ~df_clean['procust_bullish'] & ~df_clean['firm_bullish']),
                     ('Disagree', df_clean['procust_bullish'] != df_clean['firm_bullish'])]:
    subset = df_clean[mask]
    if len(subset) < 5:
        continue
    ret1h = subset['ret_1h'].dropna()
    ret_eod = subset['ret_eod'].dropna()
    md.append(f"| {label} | {len(subset)} | {ret1h.mean()*100:+.3f}% | {(ret1h>0).mean()*100:.1f}% | {ret_eod.mean()*100:+.3f}% | {(ret_eod>0).mean()*100:.1f}% |")
md.append("")

md.append("### MM + Customer")
md.append("| Condition | N days | 1H Mean | 1H WR | EOD Mean | EOD WR |")
md.append("|---|---|---|---|---|---|")
for label, mask in [('Both Bullish', df_clean['mm_bullish'] & df_clean['cust_bullish']),
                     ('Both Bearish', ~df_clean['mm_bullish'] & ~df_clean['cust_bullish']),
                     ('Disagree', df_clean['mm_bullish'] != df_clean['cust_bullish'])]:
    subset = df_clean[mask]
    if len(subset) < 5:
        continue
    ret1h = subset['ret_1h'].dropna()
    ret_eod = subset['ret_eod'].dropna()
    md.append(f"| {label} | {len(subset)} | {ret1h.mean()*100:+.3f}% | {(ret1h>0).mean()*100:.1f}% | {ret_eod.mean()*100:+.3f}% | {(ret_eod>0).mean()*100:.1f}% |")
md.append("")

# Conditional Analysis
md.append("## 7. Conditional Analysis")
md.append("")

md.append("### By GEX Regime")
md.append("| Regime | Participant | 1H IC | EOD IC |")
md.append("|---|---|---|---|")
for regime in conditional_results:
    for p in ['mm', 'procust', 'firm']:
        ic_1h = conditional_results[regime].get(f'{p}_ret_1h', 'N/A')
        ic_eod = conditional_results[regime].get(f'{p}_ret_eod', 'N/A')
        ic_1h_str = f"{ic_1h:+.4f}" if isinstance(ic_1h, float) else ic_1h
        ic_eod_str = f"{ic_eod:+.4f}" if isinstance(ic_eod, float) else ic_eod
        md.append(f"| {regime} | {PART_NAMES[p]} | {ic_1h_str} | {ic_eod_str} |")
md.append("")

md.append("### By Day of Week (Key Participants)")
md.append("| Day | Participant | 1H IC | EOD IC |")
md.append("|---|---|---|---|")
for dow in dow_results:
    for p in ['mm', 'procust', 'firm']:
        ic_1h = dow_results[dow].get(f'{p}_ret_1h', 'N/A')
        ic_eod = dow_results[dow].get(f'{p}_ret_eod', 'N/A')
        ic_1h_str = f"{ic_1h:+.4f}" if isinstance(ic_1h, float) else ic_1h
        ic_eod_str = f"{ic_eod:+.4f}" if isinstance(ic_eod, float) else ic_eod
        md.append(f"| {dow} | {PART_NAMES[p]} | {ic_1h_str} | {ic_eod_str} |")
md.append("")

# Key Findings
md.append("## 8. Key Findings")
md.append("")

# Determine strongest signals
best_signals = []
for p in PARTICIPANTS:
    for h in horizons:
        all_ic = ic_results['ALL'].get(p, {}).get('signed_tilt', {}).get(h, {})
        is_ic = ic_results['IS'].get(p, {}).get('signed_tilt', {}).get(h, {})
        oos_ic = ic_results['OOS'].get(p, {}).get('signed_tilt', {}).get(h, {})
        if all_ic and is_ic and oos_ic:
            if is_ic['ic'] * oos_ic['ic'] > 0 and abs(all_ic['ic']) > 0.05:
                best_signals.append({
                    'participant': PART_NAMES[p],
                    'horizon': horizon_names[h],
                    'all_ic': all_ic['ic'],
                    'is_ic': is_ic['ic'],
                    'oos_ic': oos_ic['ic'],
                    'pval': all_ic['pval'],
                    'consistent': True
                })

best_signals.sort(key=lambda x: abs(x['all_ic']), reverse=True)

if best_signals:
    md.append("### Strongest Consistent Signals (IS+OOS same sign, |IC| > 0.05)")
    for s in best_signals[:10]:
        direction = "DIRECTIONAL" if s['all_ic'] > 0 else "CONTRARIAN"
        md.append(f"- **{s['participant']}** → {s['horizon']}: IC={s['all_ic']:+.4f} (IS={s['is_ic']:+.3f}, OOS={s['oos_ic']:+.3f}) — **{direction}** (p={s['pval']:.3f})")
    md.append("")

md.append("### Summary Bullets")
md.append("")

# Auto-generate bullets based on results
# Find who has the strongest signal
if best_signals:
    top = best_signals[0]
    md.append(f"1. **Strongest predictor:** {top['participant']} signed tilt → {top['horizon']} returns (IC={top['all_ic']:+.4f})")

# Who makes money
if pnl_results:
    top_pnl = max(pnl_results.keys(), key=lambda x: pnl_results[x]['total_pnl'])
    worst_pnl = min(pnl_results.keys(), key=lambda x: pnl_results[x]['total_pnl'])
    md.append(f"2. **Most profitable (gamma model):** {pnl_results[top_pnl]['name']} (Sharpe={pnl_results[top_pnl]['sharpe']:.2f})")
    md.append(f"3. **Biggest loser (gamma model):** {pnl_results[worst_pnl]['name']} (Sharpe={pnl_results[worst_pnl]['sharpe']:.2f})")

# Avg gamma sign tells us who's long/short gamma
for p in PARTICIPANTS:
    if p in pnl_results:
        sign = pnl_results[p]['avg_gamma_sign']
        stance = "long gamma" if sign > 0.1 else "short gamma" if sign < -0.1 else "mixed gamma"
        md.append(f"4. **{PART_NAMES[p]}:** typically {stance} (avg sign={sign:+.3f})")

md.append("")
md.append("### Caveats & Data Quality Notes")
md.append("")
md.append("- Historical TRACE data from `trace_uncorrupted/` — single files per day with multiple intraday timestamps")
md.append("- No separate delta profiles in historical data — directional analysis uses gamma tilt as proxy")
md.append("- Live TRACE data (7 days) has both GEX and delta files but minimal overlap with SPX price data")
md.append("- SPX price data ends Feb 27, 2026; live TRACE starts Mar 6, 2026 — minimal price overlap for live data")
md.append("- Corruption period (Oct 27 2025 – Feb 17 2026) entirely excluded from analysis")
md.append("- 'Morning snapshot' = earliest timestamp in daily file; actual time varies by day")
md.append("- P&L model is illustrative (normalized units) — not dollar P&L")
md.append("- **No VIX data available** — VIX conditional analysis omitted")
md.append("")

# Write markdown
with open(OUTPUT_MD, 'w') as f:
    f.write('\n'.join(md))
print(f"Saved report to {OUTPUT_MD}")

print("\n✅ ANALYSIS COMPLETE")
