#!/usr/bin/env python3
"""
Gamma Indicators Backtest: GWB, ZDS, SFA
Using SpotGamma TRACE intradayStrikeGEX data for ES futures.
V2: Fixed spot price validation and forward return computation.
"""

import pandas as pd
import numpy as np
import json
import glob
import os
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore')

DATA_DIR = '/Users/daniel/.openclaw/workspace/data'
TRACE_DIR = os.path.join(DATA_DIR, 'trace_api')
OUTPUT_CSV = os.path.join(DATA_DIR, 'gamma_indicators_backtest.csv')
OUTPUT_JSON = os.path.join(DATA_DIR, 'gamma_indicators_results.json')
OUTPUT_MD = os.path.join(DATA_DIR, 'gamma_indicators_results.md')

TX_COST_PTS = 0.25
MIN_ZSCORE_DAYS = 20

###############################################################################
# 1. Load ES price data
###############################################################################
print("Loading ES price data...")

# Primary: intraday_gex_5min (aligned spot prices)
gex_5min = pd.read_csv(os.path.join(DATA_DIR, 'intraday_gex_5min.csv'),
                       usecols=['date', 'timestamp', 'spot'])
gex_5min['timestamp'] = pd.to_datetime(gex_5min['timestamp'], utc=True).dt.tz_convert('US/Eastern')
gex_5min['spot'] = pd.to_numeric(gex_5min['spot'], errors='coerce')
gex_5min = gex_5min.dropna(subset=['spot'])
gex_5min = gex_5min.set_index('timestamp').sort_index()
# Remove duplicates
gex_5min = gex_5min[~gex_5min.index.duplicated(keep='first')]

# Secondary: es_1min_bars - filter to ES only (price > 1000)
es_1min = pd.read_csv(os.path.join(DATA_DIR, 'es_1min_bars.csv'),
                      usecols=['ts_event', 'close', 'symbol'])
es_1min['close'] = pd.to_numeric(es_1min['close'], errors='coerce')
# Filter to ES contracts only (price should be > 4000)
es_1min = es_1min[es_1min['close'] > 4000].copy()
es_1min['timestamp'] = pd.to_datetime(es_1min['ts_event'], utc=True).dt.tz_convert('US/Eastern')
es_1min = es_1min.set_index('timestamp').sort_index()
es_1min = es_1min[~es_1min.index.duplicated(keep='first')]

# Build 5min from 1min
es_5min_fallback = es_1min['close'].resample('5min').last().dropna()

print(f"  gex_5min: {gex_5min.index.min()} to {gex_5min.index.max()}, {len(gex_5min)} rows")
print(f"  es_1min (filtered): {es_1min.index.min()} to {es_1min.index.max()}, {len(es_1min)} rows")

def get_spot(ts):
    """Get spot price for a timestamp, with validation."""
    ts_pd = pd.Timestamp(ts)
    
    # Try gex_5min first
    if ts_pd in gex_5min.index:
        val = gex_5min.loc[ts_pd, 'spot']
        if isinstance(val, pd.Series):
            val = val.iloc[0]
        if not pd.isna(val) and val > 4000:
            return float(val)
    
    # Fall back to es_5min
    try:
        idx = es_5min_fallback.index.get_indexer([ts_pd], method='nearest')[0]
        if 0 <= idx < len(es_5min_fallback):
            nearest_ts = es_5min_fallback.index[idx]
            if abs((nearest_ts - ts_pd).total_seconds()) < 600:
                val = es_5min_fallback.iloc[idx]
                if val > 4000:
                    return float(val)
    except:
        pass
    
    return None

###############################################################################
# 2. Load TRACE files and compute indicators
###############################################################################
print("Loading TRACE files and computing indicators...")

trace_files = sorted(glob.glob(os.path.join(TRACE_DIR, 'intradayStrikeGEX_*.csv')))
print(f"  Found {len(trace_files)} TRACE files")

results = []
file_count = 0

for fpath in trace_files:
    date_str = os.path.basename(fpath).replace('intradayStrikeGEX_', '').replace('.csv', '')
    
    try:
        df = pd.read_csv(fpath)
    except Exception as e:
        print(f"  Error reading {date_str}: {e}")
        continue
    
    df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True).dt.tz_convert('US/Eastern')
    
    num_cols = ['strike_price', 'mm_gamma', 'mm_gamma_0', 'firm_gamma', 'procust_gamma']
    for c in num_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors='coerce')
    
    # Filter to RTH (9:30 - 15:55 ET) — exclude 16:00 which often has bad data
    time_min = df['timestamp'].dt.hour * 60 + df['timestamp'].dt.minute
    df = df[(time_min >= 9*60+30) & (time_min <= 15*60+55)]
    
    if df.empty:
        continue
    
    timestamps = sorted(df['timestamp'].unique())
    
    for ts in timestamps:
        snap = df[df['timestamp'] == ts].copy()
        if snap.empty:
            continue
        
        spot = get_spot(ts)
        if spot is None:
            continue
        
        strikes = snap['strike_price'].values
        mm_gamma = snap['mm_gamma'].fillna(0).values
        mm_gamma_0 = snap['mm_gamma_0'].fillna(0).values
        firm_gamma = snap['firm_gamma'].fillna(0).values
        procust_gamma = snap['procust_gamma'].fillna(0).values
        
        # === GWB: Gamma Wall Bias ===
        above_mask = strikes > spot
        below_mask = strikes < spot
        
        neg_mm = np.where(mm_gamma < 0, mm_gamma, 0)
        neg_above = abs(neg_mm[above_mask].sum())  # magnitude of negative gamma above
        neg_below = abs(neg_mm[below_mask].sum())   # magnitude of negative gamma below
        
        denom_gwb = neg_above + neg_below
        gwb_raw = (neg_above - neg_below) / denom_gwb if denom_gwb > 0 else 0.0
        # GWB > 0 → bigger negative gamma wall above → BEARISH
        
        # === ZDS: 0DTE Dominance Score ===
        total_mm_abs = np.sum(np.abs(mm_gamma))
        total_0dte_abs = np.sum(np.abs(mm_gamma_0))
        
        dte0_share = total_0dte_abs / total_mm_abs if total_mm_abs > 0 else 0.0
        
        abs_0dte = np.abs(mm_gamma_0)
        if abs_0dte.sum() > 0:
            dte0_center = np.average(strikes, weights=abs_0dte)
            zds_direction = -1.0 * (spot - dte0_center) / spot * 100
            zds_raw = dte0_share * zds_direction
        else:
            zds_raw = 0.0
        
        # === SFA: Smart Flow Asymmetry ===
        firm_above = firm_gamma[above_mask].sum()
        firm_below = firm_gamma[below_mask].sum()
        firm_denom = abs(firm_above) + abs(firm_below)
        firm_asym = (firm_above - firm_below) / firm_denom if firm_denom > 0 else 0.0
        
        pc_above = procust_gamma[above_mask].sum()
        pc_below = procust_gamma[below_mask].sum()
        pc_denom = abs(pc_above) + abs(pc_below)
        pc_asym = (pc_above - pc_below) / pc_denom if pc_denom > 0 else 0.0
        
        sfa_raw = 0.6 * firm_asym + 0.4 * pc_asym
        
        results.append({
            'date': date_str,
            'timestamp': pd.Timestamp(ts),
            'es_price': spot,
            'gwb_raw': gwb_raw,
            'zds_raw': zds_raw,
            'sfa_raw': sfa_raw,
        })
    
    file_count += 1
    if file_count % 30 == 0:
        print(f"  Processed {file_count}/{len(trace_files)} files...")

print(f"  Processed {file_count} files, {len(results)} snapshots")

###############################################################################
# 3. Build DataFrame, z-scores, forward returns
###############################################################################
print("Computing z-scores...")

ind_df = pd.DataFrame(results)
ind_df = ind_df.sort_values('timestamp').reset_index(drop=True)
ind_df['date'] = pd.to_datetime(ind_df['date'])

# Validate prices — should be continuous (no jumps > 5%)
# Remove rows where price changed > 5% from previous within same day
for date_val, grp in ind_df.groupby('date'):
    idx = grp.index
    prices = grp['es_price'].values
    for i in range(1, len(prices)):
        if abs(prices[i] - prices[i-1]) / prices[i-1] > 0.05:
            ind_df.loc[idx[i], 'es_price'] = np.nan

ind_df = ind_df.dropna(subset=['es_price']).reset_index(drop=True)
print(f"  After price validation: {len(ind_df)} rows")

# Trading day number
ind_df['trade_day_num'] = ind_df['date'].rank(method='dense').astype(int)

# Z-scores with expanding window (shifted to avoid lookahead)
for col in ['gwb', 'zds', 'sfa']:
    raw_col = f'{col}_raw'
    z_col = f'{col}_zscore'
    
    # Use expanding window shifted by 1 to exclude current
    mean_shift = ind_df[raw_col].expanding(min_periods=1).mean().shift(1)
    std_shift = ind_df[raw_col].expanding(min_periods=2).std().shift(1)
    
    z = (ind_df[raw_col] - mean_shift) / std_shift
    z = z.clip(-5, 5)
    
    min_day = ind_df['trade_day_num'].min() + MIN_ZSCORE_DAYS - 1
    z[ind_df['trade_day_num'] <= min_day] = np.nan
    
    ind_df[z_col] = z

ind_df['combined_zscore'] = (ind_df['gwb_zscore'] + ind_df['zds_zscore'] + ind_df['sfa_zscore']) / 3

###############################################################################
# 4. Forward returns (with proper validation)
###############################################################################
print("Computing forward returns...")

ind_df['fwd_1h'] = np.nan
ind_df['fwd_3h'] = np.nan
ind_df['fwd_eod'] = np.nan

for date_val, group in ind_df.groupby('date'):
    idx = group.index
    prices = group['es_price'].values
    timestamps = group['timestamp'].values
    
    # EOD = last valid price of the day (not 16:00 if it's bad)
    eod_price = prices[-1]
    
    for i, row_idx in enumerate(idx):
        current_price = prices[i]
        current_ts = timestamps[i]
        
        # Forward 1H
        target_1h = current_ts + np.timedelta64(60, 'm')
        diffs_1h = np.abs(timestamps - target_1h)
        best_1h = np.argmin(diffs_1h)
        if diffs_1h[best_1h] <= np.timedelta64(5, 'm') and best_1h > i:
            ret_1h = (prices[best_1h] - current_price) / current_price * 10000
            if abs(ret_1h) < 500:  # sanity: < 5% in 1 hour
                ind_df.loc[row_idx, 'fwd_1h'] = ret_1h
        
        # Forward 3H
        target_3h = current_ts + np.timedelta64(180, 'm')
        diffs_3h = np.abs(timestamps - target_3h)
        best_3h = np.argmin(diffs_3h)
        if diffs_3h[best_3h] <= np.timedelta64(5, 'm') and best_3h > i:
            ret_3h = (prices[best_3h] - current_price) / current_price * 10000
            if abs(ret_3h) < 1000:  # sanity: < 10% in 3 hours
                ind_df.loc[row_idx, 'fwd_3h'] = ret_3h
        
        # Forward EOD
        if i < len(idx) - 1:
            ret_eod = (eod_price - current_price) / current_price * 10000
            if abs(ret_eod) < 1000:
                ind_df.loc[row_idx, 'fwd_eod'] = ret_eod

print(f"  fwd_1h: {ind_df['fwd_1h'].notna().sum()}, fwd_3h: {ind_df['fwd_3h'].notna().sum()}, fwd_eod: {ind_df['fwd_eod'].notna().sum()}")

###############################################################################
# 5. Save indicator CSV
###############################################################################
out_cols = ['date', 'timestamp', 'es_price', 'gwb_raw', 'gwb_zscore', 
            'zds_raw', 'zds_zscore', 'sfa_raw', 'sfa_zscore', 
            'combined_zscore', 'fwd_1h', 'fwd_3h', 'fwd_eod']
ind_df[out_cols].to_csv(OUTPUT_CSV, index=False)
print(f"Saved: {OUTPUT_CSV} ({len(ind_df)} rows)")

# Quick stats
print(f"\n  fwd_1h: mean={ind_df['fwd_1h'].mean():.2f}, std={ind_df['fwd_1h'].std():.2f}, median={ind_df['fwd_1h'].median():.2f}")
print(f"  fwd_3h: mean={ind_df['fwd_3h'].mean():.2f}, std={ind_df['fwd_3h'].std():.2f}, median={ind_df['fwd_3h'].median():.2f}")
print(f"  fwd_eod: mean={ind_df['fwd_eod'].mean():.2f}, std={ind_df['fwd_eod'].std():.2f}, median={ind_df['fwd_eod'].median():.2f}")

###############################################################################
# 6. Backtest
###############################################################################
print("\nRunning backtests...")

bt = ind_df.dropna(subset=['gwb_zscore', 'zds_zscore', 'sfa_zscore']).copy()
print(f"  Valid backtest rows: {len(bt)}")

# Time-of-day
bt['time_minutes'] = bt['timestamp'].dt.hour * 60 + bt['timestamp'].dt.minute
bt['tod'] = 'midday'
bt.loc[bt['time_minutes'] < 11*60, 'tod'] = 'morning'
bt.loc[bt['time_minutes'] >= 13*60, 'tod'] = 'afternoon'

# IS/OOS
unique_days = sorted(bt['date'].unique())
n_days = len(unique_days)
is_cutoff = min(120, int(n_days * 0.67))
is_days = unique_days[:is_cutoff]
oos_days = unique_days[is_cutoff:]
print(f"  IS: {len(is_days)} days ({is_days[0].date()} to {is_days[-1].date()})")
print(f"  OOS: {len(oos_days)} days ({oos_days[0].date()} to {oos_days[-1].date()})")

bt['is_is'] = bt['date'].isin(is_days)

tx_cost_bps = TX_COST_PTS / bt['es_price'].median() * 10000
print(f"  Tx cost: {tx_cost_bps:.2f} bps")

def compute_metrics(returns_bps, tx_cost=tx_cost_bps, n_days_total=1):
    """Compute metrics from per-trade return series in bps."""
    if len(returns_bps) < 5:
        return {'sharpe': None, 'win_rate': None, 'avg_return_bps': None,
                'n_trades': 0, 'max_dd': None, 'profit_factor': None}
    
    net = returns_bps - tx_cost
    mean_ret = net.mean()
    std_ret = net.std()
    
    # Annualized Sharpe: assume 1 trade per signal-bar
    # Scale: sqrt(252 * avg_trades_per_day)
    trades_per_day = len(net) / max(n_days_total, 1)
    if std_ret > 0:
        sharpe = mean_ret / std_ret * np.sqrt(252 * max(trades_per_day, 1))
    else:
        sharpe = 0.0
    
    win_rate = (net > 0).mean()
    
    cum = net.cumsum()
    max_dd = (cum - cum.cummax()).min()
    
    gross_wins = net[net > 0].sum()
    gross_losses = abs(net[net < 0].sum())
    pf = gross_wins / gross_losses if gross_losses > 0 else np.inf
    
    return {
        'sharpe': round(float(sharpe), 3),
        'win_rate': round(float(win_rate), 4),
        'avg_return_bps': round(float(mean_ret), 3),
        'n_trades': int(len(net)),
        'max_dd': round(float(max_dd), 2),
        'profit_factor': round(float(pf), 3)
    }


def run_quintile(indicator_col, return_col, data):
    """Quintile analysis."""
    valid = data.dropna(subset=[indicator_col, return_col]).copy()
    if len(valid) < 100:
        return {}
    
    try:
        valid['quintile'] = pd.qcut(valid[indicator_col], 5, 
                                     labels=['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull'],
                                     duplicates='drop')
    except:
        return {}
    
    result = {}
    for q in ['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull']:
        qdata = valid[valid['quintile'] == q][return_col]
        if len(qdata) > 0:
            result[q] = {
                'mean_bps': round(float(qdata.mean()), 2),
                'median_bps': round(float(qdata.median()), 2),
                'n': int(len(qdata)),
                'win_rate': round(float((qdata > 0).mean()), 4)
            }
    
    means = [result.get(q, {}).get('mean_bps', 0) for q in ['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull']]
    if len(means) >= 5:
        result['Q1_Q5_spread_bps'] = round(means[0] - means[-1], 2)
        # For bearish indicator: Q1 (most bearish) should have lowest returns (short works)
        # Q5 (most bullish) should have highest returns
        # So we expect means to INCREASE from Q1 to Q5 if sign convention is right
        # Actually: higher indicator = more bearish, so Q5 = most bullish
        # Q5 mean should be HIGHEST, Q1 should be LOWEST
        result['monotonic_increasing'] = all(means[i] <= means[i+1] for i in range(4))
    
    return result


def run_threshold(indicator_col, return_col, data, threshold=1.0):
    """Threshold entry: short when > threshold, long when < -threshold."""
    valid = data.dropna(subset=[indicator_col, return_col]).copy()
    if len(valid) < 20:
        return compute_metrics(pd.Series(dtype=float))
    
    # For our indicators: positive = bearish → short, negative = bullish → long
    short_mask = valid[indicator_col] > threshold
    long_mask = valid[indicator_col] < -threshold
    
    short_returns = -valid.loc[short_mask, return_col]
    long_returns = valid.loc[long_mask, return_col]
    
    all_returns = pd.concat([short_returns, long_returns])
    if len(all_returns) == 0:
        return compute_metrics(pd.Series(dtype=float))
    
    n_days = valid['date'].nunique() if 'date' in valid.columns else 1
    metrics = compute_metrics(all_returns, n_days_total=n_days)
    metrics['n_short'] = int(short_mask.sum())
    metrics['n_long'] = int(long_mask.sum())
    return metrics


def run_full_backtest(indicator_col, data):
    """Full backtest suite for one indicator."""
    horizons = {'1h': 'fwd_1h', '3h': 'fwd_3h', 'eod': 'fwd_eod'}
    tods = {'morning': 'morning', 'midday': 'midday', 'afternoon': 'afternoon', 'all_day': None}
    
    threshold_results = {}
    quintile_results = {}
    
    for tod_name, tod_filter in tods.items():
        d = data[data['tod'] == tod_filter] if tod_filter else data
        threshold_results[tod_name] = {}
        quintile_results[tod_name] = {}
        
        for h_name, h_col in horizons.items():
            threshold_results[tod_name][h_name] = run_threshold(indicator_col, h_col, d)
            quintile_results[tod_name][h_name] = run_quintile(indicator_col, h_col, d)
    
    # IS/OOS
    is_data = data[data['is_is']]
    oos_data = data[~data['is_is']]
    
    is_oos = {}
    for h_name, h_col in horizons.items():
        is_m = run_threshold(indicator_col, h_col, is_data)
        oos_m = run_threshold(indicator_col, h_col, oos_data)
        
        is_s = is_m.get('sharpe')
        oos_s = oos_m.get('sharpe')
        
        consistent = False
        if is_s is not None and oos_s is not None:
            consistent = bool((is_s * oos_s > 0) and (abs(oos_s) > 0.3 * abs(is_s)))
        
        is_oos[h_name] = {
            'is': is_m, 'oos': oos_m,
            'is_sharpe': is_s, 'oos_sharpe': oos_s,
            'consistent': consistent
        }
    
    return {'threshold': threshold_results, 'quintile': quintile_results, 'is_oos': is_oos}


indicators_info = {
    'GWB': {
        'col': 'gwb_zscore',
        'description': 'Gamma Wall Bias - asymmetry of negative MM gamma walls above vs below spot',
        'calculation': 'GWB = (|neg_mm_gamma_above| - |neg_mm_gamma_below|) / total. Positive = bigger wall above = BEARISH. Z-scored with expanding window.'
    },
    'ZDS': {
        'col': 'zds_zscore',
        'description': '0DTE Dominance Score - 0DTE gamma share × directional magnet pull',
        'calculation': 'ZDS = 0dte_share × (-1 × (spot - 0dte_center) / spot × 100). Positive = magnet above = BULLISH pull. Z-scored.'
    },
    'SFA': {
        'col': 'sfa_zscore',
        'description': 'Smart Flow Asymmetry - firm + pro-customer gamma above vs below spot',
        'calculation': 'SFA = 0.6 × firm_asym + 0.4 × procust_asym. Positive = smart money above = BEARISH. Z-scored.'
    },
    'Combined': {
        'col': 'combined_zscore',
        'description': 'Equal-weight average of GWB + ZDS + SFA z-scores',
        'calculation': '(GWB_z + ZDS_z + SFA_z) / 3'
    }
}

all_results = {}
for name, info in indicators_info.items():
    print(f"\n  Backtesting {name}...")
    r = run_full_backtest(info['col'], bt)
    all_results[name] = {
        'description': info['description'],
        'calculation': info['calculation'],
        'backtest': r['threshold'],
        'quintile_returns': r['quintile'],
        'is_oos_split': r['is_oos']
    }

###############################################################################
# 7. Determine best indicator
###############################################################################
print("\n\nIS/OOS Summary:")
scores = {}
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    r = all_results[name]
    print(f"\n  {name}:")
    for h in ['1h', '3h', 'eod']:
        iso = r['is_oos_split'][h]
        print(f"    {h}: IS={iso['is_sharpe']}, OOS={iso['oos_sharpe']}, Consistent={iso['consistent']}")
    
    # Score: prioritize OOS consistency across horizons
    score = 0
    for h in ['1h', '3h', 'eod']:
        iso = r['is_oos_split'][h]
        oos_s = iso.get('oos_sharpe')
        if oos_s is not None:
            score += oos_s * (0.4 if h == '1h' else 0.3 if h == '3h' else 0.3)
        if iso.get('consistent'):
            score += 0.5
    scores[name] = score

best = max(scores, key=scores.get)
print(f"\n  Best: {best} (score={scores[best]:.3f})")

# Also check quintile monotonicity
print("\nQuintile Analysis (all_day, 1h):")
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    q = all_results[name]['quintile_returns'].get('all_day', {}).get('1h', {})
    if q:
        means = [q.get(qn, {}).get('mean_bps', 'N/A') for qn in ['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull']]
        spread = q.get('Q1_Q5_spread_bps', 'N/A')
        mono = q.get('monotonic_increasing', False)
        print(f"  {name}: Q1→Q5 means = {means}, spread={spread}, monotonic={mono}")

print("\nQuintile Analysis (all_day, eod):")
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    q = all_results[name]['quintile_returns'].get('all_day', {}).get('eod', {})
    if q:
        means = [q.get(qn, {}).get('mean_bps', 'N/A') for qn in ['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull']]
        spread = q.get('Q1_Q5_spread_bps', 'N/A')
        mono = q.get('monotonic_increasing', False)
        print(f"  {name}: Q1→Q5 means = {means}, spread={spread}, monotonic={mono}")

###############################################################################
# 8. Trading rules
###############################################################################
# Determine best horizon for best indicator
best_horizon = '1h'
best_oos_sharpe = -999
for h in ['1h', '3h', 'eod']:
    s = all_results[best]['is_oos_split'][h].get('oos_sharpe')
    if s is not None and s > best_oos_sharpe:
        best_oos_sharpe = s
        best_horizon = h

# Also find best morning indicator
morning_scores = {}
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    m = all_results[name]['backtest'].get('morning', {}).get('1h', {})
    if m and m.get('sharpe') is not None:
        morning_scores[name] = m['sharpe']
best_morning = max(morning_scores, key=morning_scores.get) if morning_scores else best

trading_rules = f"""
Trading Rules for Gamma Indicators (ES Futures):

BEST OVERALL: {best} (best OOS horizon: {best_horizon}, OOS Sharpe: {best_oos_sharpe:.2f})
BEST MORNING: {best_morning}

1. SIGNAL GENERATION (at each 5-min TRACE snapshot):
   - Compute {best} z-score using expanding mean/std (no lookahead)
   - When z-score > +1.0: SHORT bias (bearish gamma positioning)
   - When z-score < -1.0: LONG bias (bullish gamma positioning)  
   - Between -1.0 and +1.0: FLAT / no signal

2. ENTRY TIMING:
   - Morning signals (9:30-11:00 ET) — check {best_morning} for 1-3H trades
   - Midday/afternoon — use Combined z-score for confirmation

3. EXIT:
   - Primary: Hold {best_horizon} from entry
   - Stop: Exit if indicator flips sign (crosses 0)
   - Hard stop: 3x ATR or 50 bps

4. POSITION: 1 ES contract per signal
   - Transaction cost: {TX_COST_PTS} pts ($12.50) per round trip
"""

###############################################################################
# 9. Save JSON
###############################################################################
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer,)): return int(obj)
        if isinstance(obj, (np.floating,)):
            return float(obj) if not np.isnan(obj) else None
        if isinstance(obj, (np.ndarray,)): return obj.tolist()
        if isinstance(obj, (pd.Timestamp,)): return str(obj)
        if isinstance(obj, (np.bool_,)): return bool(obj)
        return super().default(obj)

output = {
    'indicators': all_results,
    'best_indicator': best,
    'trading_rules': trading_rules.strip(),
    'metadata': {
        'n_trace_files': file_count,
        'n_snapshots': len(ind_df),
        'n_valid_backtest': len(bt),
        'is_days': len(is_days),
        'oos_days': len(oos_days),
        'date_range': f"{unique_days[0].date()} to {unique_days[-1].date()}",
        'tx_cost_bps': round(tx_cost_bps, 2)
    }
}

with open(OUTPUT_JSON, 'w') as f:
    json.dump(output, f, indent=2, cls=NpEncoder)
print(f"\nSaved: {OUTPUT_JSON}")

###############################################################################
# 10. Generate markdown report
###############################################################################
md = []
md.append("# Gamma Indicators Backtest Report\n")
md.append(f"**Date Range:** {unique_days[0].date()} to {unique_days[-1].date()}")
md.append(f"**Data:** {file_count} TRACE files | {len(ind_df)} snapshots | {len(bt)} valid backtest rows")
md.append(f"**IS:** {len(is_days)} days ({is_days[0].date()} → {is_days[-1].date()}) | **OOS:** {len(oos_days)} days ({oos_days[0].date()} → {oos_days[-1].date()})")
md.append(f"**Transaction Cost:** {TX_COST_PTS} pts ({tx_cost_bps:.2f} bps) round trip\n")
md.append(f"## 🏆 Best Indicator: **{best}** (OOS {best_horizon} Sharpe: {best_oos_sharpe:.2f})\n")

for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    r = all_results[name]
    md.append(f"\n---\n## {name}: {r['description']}\n")
    md.append(f"**Calculation:** {r['calculation']}\n")
    
    # Threshold backtest table
    md.append("### Threshold Backtest (±1σ entry, short above / long below)\n")
    md.append("| Period | Horizon | Sharpe | Win% | Avg (bps) | #Trades | MaxDD (bps) | PF |")
    md.append("|--------|---------|--------|------|-----------|---------|-------------|------|")
    
    for tod in ['morning', 'midday', 'afternoon', 'all_day']:
        for horizon in ['1h', '3h', 'eod']:
            m = r['backtest'].get(tod, {}).get(horizon, {})
            if m and m.get('n_trades', 0) > 0:
                md.append(f"| {tod} | {horizon} | {m.get('sharpe', '-')} | {m.get('win_rate', '-'):.1%} | {m.get('avg_return_bps', '-')} | {m.get('n_trades', 0)} | {m.get('max_dd', '-')} | {m.get('profit_factor', '-')} |")
    
    # Quintile table (all_day, multiple horizons)
    md.append("\n### Quintile Returns (all_day)\n")
    for h_name in ['1h', 'eod']:
        q = r['quintile_returns'].get('all_day', {}).get(h_name, {})
        if q:
            md.append(f"**{h_name.upper()} horizon:**")
            md.append("| Quintile | Mean (bps) | Median (bps) | Win% | N |")
            md.append("|----------|-----------|-------------|------|---|")
            for qn in ['Q1_bear', 'Q2', 'Q3', 'Q4', 'Q5_bull']:
                qd = q.get(qn, {})
                if qd:
                    md.append(f"| {qn} | {qd.get('mean_bps', '-')} | {qd.get('median_bps', '-')} | {qd.get('win_rate', 0):.1%} | {qd.get('n', 0)} |")
            spread = q.get('Q1_Q5_spread_bps', '-')
            mono = q.get('monotonic_increasing', False)
            md.append(f"\nQ1-Q5 spread: **{spread} bps** | Monotonic: **{mono}**\n")
    
    # IS/OOS
    md.append("### IS/OOS Comparison\n")
    md.append("| Horizon | IS Sharpe | OOS Sharpe | Consistent |")
    md.append("|---------|-----------|------------|------------|")
    for h in ['1h', '3h', 'eod']:
        iso = r['is_oos_split'].get(h, {})
        is_s = iso.get('is_sharpe', '-')
        oos_s = iso.get('oos_sharpe', '-')
        cons = iso.get('consistent', '-')
        md.append(f"| {h} | {is_s} | {oos_s} | {cons} |")

# Correlations
corr_df = bt[['gwb_zscore', 'zds_zscore', 'sfa_zscore']].corr()
md.append("\n---\n## Indicator Correlations\n")
md.append(f"- GWB vs ZDS: **{corr_df.loc['gwb_zscore', 'zds_zscore']:.3f}**")
md.append(f"- GWB vs SFA: **{corr_df.loc['gwb_zscore', 'sfa_zscore']:.3f}**")
md.append(f"- ZDS vs SFA: **{corr_df.loc['zds_zscore', 'sfa_zscore']:.3f}**")

# Raw indicator stats
md.append("\n## Raw Indicator Statistics\n")
for col in ['gwb_raw', 'zds_raw', 'sfa_raw']:
    s = ind_df[col]
    md.append(f"- **{col}:** mean={s.mean():.4f}, std={s.std():.4f}, min={s.min():.4f}, max={s.max():.4f}")

md.append(f"\n---\n## Trading Rules\n\n```\n{trading_rules.strip()}\n```")

# Honest assessment
md.append("\n---\n## Honest Assessment\n")
md.append("*Evaluating whether these indicators actually work:*\n")

# Check if any indicator has consistent IS/OOS
any_works = False
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    r = all_results[name]
    for h in ['1h', '3h', 'eod']:
        if r['is_oos_split'][h].get('consistent'):
            md.append(f"- ✅ **{name} ({h})**: IS/OOS consistent")
            any_works = True

if not any_works:
    md.append("- ⚠️ **No indicator showed consistent IS→OOS performance at ±1σ threshold**")

# Check quintile monotonicity
for name in ['GWB', 'ZDS', 'SFA', 'Combined']:
    for h in ['1h', 'eod']:
        q = all_results[name]['quintile_returns'].get('all_day', {}).get(h, {})
        if q.get('monotonic_increasing'):
            md.append(f"- ✅ **{name} ({h})**: Monotonic quintile returns (signal direction confirmed)")

with open(OUTPUT_MD, 'w') as f:
    f.write('\n'.join(md))
print(f"Saved: {OUTPUT_MD}")

print("\n✅ DONE!")
