#!/usr/bin/env python3
"""ORB Decision Tree Clean Retest - All Phases"""

import pandas as pd
import numpy as np
import json
import os
import glob
from datetime import time, timedelta
import warnings
warnings.filterwarnings('ignore')

BASE = '/Users/lutherbot/.openclaw/workspace/data'

# ─── Load FOMC dates ───
with open(f'{BASE}/fomc_dates.json') as f:
    FOMC_DATES = set(json.load(f)['dates'])

# ─── Load SPX 5-min ───
print("Loading SPX 5-min...")
spx = pd.read_csv(f'{BASE}/spx_5min_polygon.csv', parse_dates=['datetime'])
spx['datetime_et'] = spx['datetime'].dt.tz_localize('UTC').dt.tz_convert('America/New_York')
spx['date'] = spx['datetime_et'].dt.date.astype(str)
spx = spx[~spx['date'].isin(FOMC_DATES)]
# Filter RTH only (9:30-16:00 ET)
spx['time'] = spx['datetime_et'].dt.time
spx_rth = spx[(spx['time'] >= time(9,30)) & (spx['time'] < time(16,0))].copy()
print(f"  SPX RTH bars: {len(spx_rth)}, dates: {spx_rth['date'].nunique()}")

# ─── Load ES 1-min ───
print("Loading ES 1-min...")
es = pd.read_csv(f'{BASE}/es_1min_bars.csv', parse_dates=['ts_event'])
es['datetime_et'] = es['ts_event'].dt.tz_convert('America/New_York')
es['date'] = es['datetime_et'].dt.date.astype(str)
es = es[~es['date'].isin(FOMC_DATES)]
es['time'] = es['datetime_et'].dt.time
es_rth = es[(es['time'] >= time(9,30)) & (es['time'] < time(16,0))].copy()
print(f"  ES RTH bars: {len(es_rth)}, dates: {es_rth['date'].nunique()}")

# ─── Load TRACE GEX snapshots ───
print("Loading TRACE normalized data...")
trace_files = sorted(glob.glob(f'{BASE}/trace_normalized/intradayStrikeGEX_*.parquet'))
trace_dates = {}
for f in trace_files:
    date_str = f.split('_')[-1].replace('.parquet','')
    if date_str in FOMC_DATES:
        continue
    try:
        import pyarrow.parquet as pq
        df = pq.read_table(f).to_pandas()
        # Get 9:30 ET snapshot (closest to market open)
        # Timestamps are already ET
        ts_unique = sorted(df['timestamp'].unique())
        # Find closest to 9:30
        open_time = pd.Timestamp(f'{date_str} 09:30:00', tz='America/New_York')
        closest_idx = min(range(len(ts_unique)), key=lambda i: abs(ts_unique[i] - open_time))
        snap = df[df['timestamp'] == ts_unique[closest_idx]]
        
        net_gex = snap['mm_gamma'].sum()  # Net MM gamma = total GEX
        net_gex_0dte = snap['mm_gamma_0'].sum()  # 0DTE component
        
        # Compute tilt: % of positive gamma strikes vs total
        pos_gamma_strikes = (snap['mm_gamma'] > 0).sum()
        total_strikes = len(snap)
        tilt = pos_gamma_strikes / total_strikes if total_strikes > 0 else 0.5
        
        trace_dates[date_str] = {
            'net_gex': net_gex,
            'net_gex_0dte': net_gex_0dte,
            'tilt': tilt,
            'snap_time': str(ts_unique[closest_idx])
        }
    except Exception as e:
        pass

print(f"  TRACE dates loaded: {len(trace_dates)}")

# ─── Build daily ORB events ───
print("\nComputing ORB events...")

def compute_orb_for_date(date_str, bars_1min=None, bars_5min=None):
    """Compute ORB metrics for a single date. Prefer 1-min bars."""
    result = {'date': date_str}
    
    if bars_1min is not None and len(bars_1min) > 0:
        bars = bars_1min.copy()
        bar_freq = '1min'
    elif bars_5min is not None and len(bars_5min) > 0:
        bars = bars_5min.copy()
        bar_freq = '5min'
    else:
        return None
    
    bars = bars.sort_values('datetime_et')
    
    # OR15: 9:30-9:45
    or15_bars = bars[(bars['time'] >= time(9,30)) & (bars['time'] < time(9,45))]
    if len(or15_bars) == 0:
        return None
    or15_high = or15_bars['high'].max()
    or15_low = or15_bars['low'].min()
    or15_open = or15_bars.iloc[0]['open']
    
    # OR30: 9:30-10:00
    or30_bars = bars[(bars['time'] >= time(9,30)) & (bars['time'] < time(10,0))]
    or30_high = or30_bars['high'].max()
    or30_low = or30_bars['low'].min()
    
    # ORB volume (for RVOL)
    orb_volume = or15_bars['volume'].sum() if 'volume' in or15_bars.columns else np.nan
    
    # Post-ORB bars (after 9:45)
    post_bars = bars[bars['time'] >= time(9,45)]
    if len(post_bars) == 0:
        return None
    
    # Day close
    day_close = bars.iloc[-1]['close']
    day_open = or15_open
    
    # Detect breaks using 1-min/5-min closes
    first_break_up = None
    first_break_down = None
    break_up_confirmed = False
    break_down_confirmed = False
    
    for idx, row in post_bars.iterrows():
        bar_time = row['time']
        
        # Break UP detection
        if first_break_up is None and row['close'] > or15_high:
            first_break_up = bar_time
        
        # Break DOWN detection
        if first_break_down is None and row['close'] < or15_low:
            first_break_down = bar_time
    
    # 15-min confirmation: price stays above/below for 15 consecutive minutes
    if first_break_up is not None:
        # Find bars in the 15 min after break
        break_up_dt = pd.Timestamp(f'{date_str} {first_break_up}', tz='America/New_York')
        confirm_end = break_up_dt + timedelta(minutes=15)
        confirm_bars = bars[(bars['datetime_et'] > break_up_dt) & 
                           (bars['datetime_et'] <= confirm_end)]
        if len(confirm_bars) > 0:
            break_up_confirmed = all(confirm_bars['close'] > or15_high)
        else:
            break_up_confirmed = False
    
    if first_break_down is not None:
        break_down_dt = pd.Timestamp(f'{date_str} {first_break_down}', tz='America/New_York')
        confirm_end = break_down_dt + timedelta(minutes=15)
        confirm_bars = bars[(bars['datetime_et'] > break_down_dt) & 
                           (bars['datetime_et'] <= confirm_end)]
        if len(confirm_bars) > 0:
            break_down_confirmed = all(confirm_bars['close'] < or15_low)
        else:
            break_down_confirmed = False
    
    # Break timing (minutes after 9:45)
    break_time_min = None
    if first_break_up:
        dt = pd.Timestamp(f'2000-01-01 {first_break_up}')
        ref = pd.Timestamp('2000-01-01 09:45:00')
        break_time_min = (dt - ref).total_seconds() / 60
    if first_break_down:
        dt = pd.Timestamp(f'2000-01-01 {first_break_down}')
        ref = pd.Timestamp('2000-01-01 09:45:00')
        down_min = (dt - ref).total_seconds() / 60
        if break_time_min is None or down_min < break_time_min:
            break_time_min = down_min
    
    # Classify break type
    has_break_up = first_break_up is not None
    has_break_down = first_break_down is not None
    
    if has_break_up and has_break_down:
        break_type = 'double'
    elif has_break_up:
        break_type = 'up'
    elif has_break_down:
        break_type = 'down'
    else:
        break_type = 'none'
    
    # Failed break: broke out then came back inside
    failed_up = has_break_up and day_close < or15_high
    failed_down = has_break_down and day_close > or15_low
    
    # Gap context (prev close vs open)
    result.update({
        'or15_high': or15_high,
        'or15_low': or15_low,
        'or30_high': or30_high,
        'or30_low': or30_low,
        'day_open': day_open,
        'day_close': day_close,
        'break_type': break_type,
        'break_up': has_break_up,
        'break_down': has_break_down,
        'break_up_confirmed': break_up_confirmed,
        'break_down_confirmed': break_down_confirmed,
        'first_break_up_time': str(first_break_up) if first_break_up else None,
        'first_break_down_time': str(first_break_down) if first_break_down else None,
        'break_time_min': break_time_min,
        'failed_up': failed_up,
        'failed_down': failed_down,
        'orb_volume': orb_volume,
        'close_above_open': day_close > day_open,
        'bar_freq': bar_freq,
        'dow': pd.Timestamp(date_str).day_name(),
    })
    
    return result

# Get all unique dates
spx_dates = set(spx_rth['date'].unique())
es_dates = set(es_rth['date'].unique())
all_dates = sorted(spx_dates | es_dates)
print(f"Total trading dates: {len(all_dates)}")

orb_events = []
prev_close = {}

for date_str in all_dates:
    # Get bars for this date
    es_day = es_rth[es_rth['date'] == date_str].copy() if date_str in es_dates else pd.DataFrame()
    spx_day = spx_rth[spx_rth['date'] == date_str].copy() if date_str in spx_dates else pd.DataFrame()
    
    # Use ES 1-min when available (better granularity), SPX 5-min as fallback
    if len(es_day) > 0:
        es_day_use = es_day.rename(columns={'ts_event': 'dt_orig'})
        result = compute_orb_for_date(date_str, bars_1min=es_day_use)
    elif len(spx_day) > 0:
        result = compute_orb_for_date(date_str, bars_5min=spx_day)
    else:
        continue
    
    if result is None:
        continue
    
    # Gap calculation
    if date_str in prev_close:
        gap_pct = (result['day_open'] - prev_close[date_str]) / prev_close[date_str] * 100
        result['gap_pct'] = gap_pct
    else:
        result['gap_pct'] = np.nan
    
    # Store prev close for next day
    # Find next date's prev close
    all_dates_list = sorted(all_dates)
    idx_pos = all_dates_list.index(date_str)
    if idx_pos + 1 < len(all_dates_list):
        prev_close[all_dates_list[idx_pos + 1]] = result['day_close']
    
    # Match TRACE data
    if date_str in trace_dates:
        result['net_gex'] = trace_dates[date_str]['net_gex']
        result['gex_tilt'] = trace_dates[date_str]['tilt']
    else:
        result['net_gex'] = np.nan
        result['gex_tilt'] = np.nan
    
    orb_events.append(result)

df_orb = pd.DataFrame(orb_events)
print(f"ORB events computed: {len(df_orb)}")
print(f"  Break types: {df_orb['break_type'].value_counts().to_dict()}")

# ─── RVOL calculation ───
print("\nComputing RVOL...")
df_orb['orb_volume'] = pd.to_numeric(df_orb['orb_volume'], errors='coerce')
df_orb = df_orb.sort_values('date')
df_orb['avg_orb_vol_20'] = df_orb['orb_volume'].rolling(20, min_periods=10).mean().shift(1)
df_orb['rvol'] = df_orb['orb_volume'] / df_orb['avg_orb_vol_20']

# ─── GEX Tiers ───
print("Computing GEX tiers...")
gex_vals = df_orb['net_gex'].dropna()
if len(gex_vals) > 0:
    pcts = gex_vals.quantile([0.1, 0.25, 0.5, 0.75, 0.9])
    def gex_tier(v):
        if pd.isna(v): return np.nan
        if v < pcts[0.1]: return 'DEEP_NEG'
        elif v < pcts[0.25]: return 'NEG'
        elif v < pcts[0.5]: return 'LOW_POS'
        elif v < pcts[0.75]: return 'MID_POS'
        elif v < pcts[0.9]: return 'HIGH_POS'
        else: return 'EXTREME_POS'
    df_orb['gex_tier'] = df_orb['net_gex'].apply(gex_tier)
    df_orb['gex_binary'] = df_orb['net_gex'].apply(lambda x: 'POS' if x >= 0 else 'NEG' if pd.notna(x) else np.nan)
else:
    df_orb['gex_tier'] = np.nan
    df_orb['gex_binary'] = np.nan

# ─── RVOL tiers ───
def rvol_tier(v):
    if pd.isna(v): return np.nan
    if v > 1.3: return 'HIGH'
    elif v >= 0.7: return 'NORMAL'
    else: return 'LOW'
df_orb['rvol_tier'] = df_orb['rvol'].apply(rvol_tier)

# ─── Tilt tiers ───
def tilt_label(v):
    if pd.isna(v): return np.nan
    return 'BULLISH' if v >= 0.75 else 'BEARISH'
df_orb['tilt_label'] = df_orb['gex_tilt'].apply(tilt_label)

# ─── Gap tiers ───
def gap_label(v):
    if pd.isna(v): return np.nan
    abs_v = abs(v)
    if abs_v < 0.05: return 'FLAT'
    elif v > 0:
        if abs_v < 0.3: return 'GAP_UP_SM'
        elif abs_v < 0.5: return 'GAP_UP_MD'
        else: return 'GAP_UP_LG'
    else:
        if abs_v < 0.3: return 'GAP_DN_SM'
        elif abs_v < 0.5: return 'GAP_DN_MD'
        else: return 'GAP_DN_LG'

def gap_dir(v):
    if pd.isna(v): return np.nan
    if v > 0.05: return 'UP'
    elif v < -0.05: return 'DOWN'
    else: return 'FLAT'

df_orb['gap_label'] = df_orb['gap_pct'].apply(gap_label)
df_orb['gap_dir'] = df_orb['gap_pct'].apply(gap_dir)

# ─── Break timing tiers ───
def timing_tier(v):
    if pd.isna(v): return np.nan
    if v <= 0: return 'EARLY'  # breaks within OR15 period or at 9:45
    elif v <= 30: return 'NORMAL'  # 9:45-10:15
    else: return 'LATE'  # after 10:15
df_orb['timing_tier'] = df_orb['break_time_min'].apply(timing_tier)

# ─── IS/OOS Split (60/40 by date) ───
dates_sorted = sorted(df_orb['date'].unique())
split_idx = int(len(dates_sorted) * 0.6)
is_dates = set(dates_sorted[:split_idx])
oos_dates = set(dates_sorted[split_idx:])
df_orb['sample'] = df_orb['date'].apply(lambda d: 'IS' if d in is_dates else 'OOS')
print(f"IS dates: {len(is_dates)}, OOS dates: {len(oos_dates)}")

# ═══════════════════════════════════════════════════════════════
# PHASE 1: Core ORB Metrics
# ═══════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("PHASE 1: Core ORB Metrics")
print("="*60)

results = {}

def wr_stats(mask, outcome_col='close_above_open', label=''):
    """Compute WR for IS and OOS"""
    sub = df_orb[mask]
    is_sub = sub[sub['sample'] == 'IS']
    oos_sub = sub[sub['sample'] == 'OOS']
    
    def _wr(d):
        if len(d) == 0: return {'wr': np.nan, 'n': 0}
        return {'wr': round(d[outcome_col].mean() * 100, 1), 'n': len(d)}
    
    r = {
        'all': _wr(sub),
        'IS': _wr(is_sub),
        'OOS': _wr(oos_sub),
    }
    if label:
        print(f"  {label}: ALL={r['all']['wr']}% (N={r['all']['n']}), IS={r['IS']['wr']}% (N={r['IS']['n']}), OOS={r['OOS']['wr']}% (N={r['OOS']['n']})")
    return r

# For break down, outcome = close below open
df_orb['close_below_open'] = ~df_orb['close_above_open']

print("\n--- Break Up ---")
r1 = wr_stats(df_orb['break_up'] & ~df_orb['break_down'], 'close_above_open', 'Break Up → close > open')
results['break_up_wr'] = r1

print("\n--- Break Down ---")
r2 = wr_stats(df_orb['break_down'] & ~df_orb['break_up'], 'close_below_open', 'Break Down → close < open')
results['break_down_wr'] = r2

print("\n--- Break Up + 15m Confirm ---")
r3 = wr_stats(df_orb['break_up_confirmed'] & ~df_orb['break_down'], 'close_above_open', 'Break Up + 15m confirm → close > open')
results['break_up_confirmed_wr'] = r3

print("\n--- Break Down + 15m Confirm ---")
r4 = wr_stats(df_orb['break_down_confirmed'] & ~df_orb['break_up'], 'close_below_open', 'Break Down + 15m confirm → close < open')
results['break_down_confirmed_wr'] = r4

print("\n--- Double Break ---")
double = df_orb[df_orb['break_type'] == 'double']
print(f"  Double breaks: {len(double)}")
if len(double) > 0:
    print(f"  Close above open: {double['close_above_open'].mean()*100:.1f}%")

print("\n--- Failed Breaks ---")
# Failed up: broke above OR but closed below
failed_up_mask = df_orb['failed_up'] & df_orb['break_up']
failed_down_mask = df_orb['failed_down'] & df_orb['break_down']
r_fail_up = wr_stats(failed_up_mask, 'close_below_open', 'Failed Up → fade (close < open)')
r_fail_down = wr_stats(failed_down_mask, 'close_above_open', 'Failed Down → fade (close > open)')
results['failed_up_fade_wr'] = r_fail_up
results['failed_down_fade_wr'] = r_fail_down

# ═══════════════════════════════════════════════════════════════
# PHASE 2: Factor Analysis
# ═══════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("PHASE 2: Factor Analysis")
print("="*60)

def cross_tab(break_mask, outcome_col, factor_col, label):
    """Cross-tabulate break WR by factor level"""
    sub = df_orb[break_mask].copy()
    print(f"\n--- {label} ---")
    r = {}
    for level in sorted(sub[factor_col].dropna().unique()):
        mask = sub[factor_col] == level
        s = sub[mask]
        is_s = s[s['sample']=='IS']
        oos_s = s[s['sample']=='OOS']
        wr_all = s[outcome_col].mean()*100 if len(s)>0 else np.nan
        wr_is = is_s[outcome_col].mean()*100 if len(is_s)>0 else np.nan
        wr_oos = oos_s[outcome_col].mean()*100 if len(oos_s)>0 else np.nan
        r[str(level)] = {'wr_all': round(wr_all,1), 'n': len(s), 
                         'wr_is': round(wr_is,1) if not np.isnan(wr_is) else None, 'n_is': len(is_s),
                         'wr_oos': round(wr_oos,1) if not np.isnan(wr_oos) else None, 'n_oos': len(oos_s)}
        print(f"  {level}: ALL={wr_all:.1f}% (N={len(s)}), IS={wr_is:.1f}% (N={len(is_s)}), OOS={wr_oos:.1f}% (N={len(oos_s)})")
    return r

# Factor A: GEX Magnitude Tiers
print("\n=== Factor A: GEX Magnitude Tiers ===")
up_mask = df_orb['break_up'] & ~df_orb['break_down']
down_mask = df_orb['break_down'] & ~df_orb['break_up']

results['gex_tier_break_up'] = cross_tab(up_mask, 'close_above_open', 'gex_tier', 'Break UP WR by GEX Tier')
results['gex_tier_break_down'] = cross_tab(down_mask, 'close_below_open', 'gex_tier', 'Break DOWN WR by GEX Tier')

# Also confirmed breaks
up_conf = df_orb['break_up_confirmed'] & ~df_orb['break_down']
down_conf = df_orb['break_down_confirmed'] & ~df_orb['break_up']
results['gex_tier_break_up_conf'] = cross_tab(up_conf, 'close_above_open', 'gex_tier', 'Break UP+Confirm WR by GEX Tier')
results['gex_tier_break_down_conf'] = cross_tab(down_conf, 'close_below_open', 'gex_tier', 'Break DOWN+Confirm WR by GEX Tier')

# Binary GEX
results['gex_binary_break_up'] = cross_tab(up_mask, 'close_above_open', 'gex_binary', 'Break UP WR by GEX Binary')
results['gex_binary_break_down'] = cross_tab(down_mask, 'close_below_open', 'gex_binary', 'Break DOWN WR by GEX Binary')

# Factor B: RVOL
print("\n=== Factor B: RVOL ===")
results['rvol_break_up'] = cross_tab(up_mask, 'close_above_open', 'rvol_tier', 'Break UP WR by RVOL')
results['rvol_break_down'] = cross_tab(down_mask, 'close_below_open', 'rvol_tier', 'Break DOWN WR by RVOL')
results['rvol_break_up_conf'] = cross_tab(up_conf, 'close_above_open', 'rvol_tier', 'Break UP+Confirm WR by RVOL')

# Factor C: Gamma Tilt
print("\n=== Factor C: Gamma Tilt ===")
results['tilt_break_up'] = cross_tab(up_mask, 'close_above_open', 'tilt_label', 'Break UP WR by Tilt')
results['tilt_break_down'] = cross_tab(down_mask, 'close_below_open', 'tilt_label', 'Break DOWN WR by Tilt')

# Factor D: Gap Context
print("\n=== Factor D: Gap Context ===")
results['gap_break_up'] = cross_tab(up_mask, 'close_above_open', 'gap_dir', 'Break UP WR by Gap Direction')
results['gap_break_down'] = cross_tab(down_mask, 'close_below_open', 'gap_dir', 'Break DOWN WR by Gap Direction')
results['gap_detail_break_up'] = cross_tab(up_mask, 'close_above_open', 'gap_label', 'Break UP WR by Gap Detail')

# Factor E: Day of Week
print("\n=== Factor E: Day of Week ===")
results['dow_break_up'] = cross_tab(up_mask, 'close_above_open', 'dow', 'Break UP WR by Day of Week')
results['dow_break_down'] = cross_tab(down_mask, 'close_below_open', 'dow', 'Break DOWN WR by Day of Week')

# Factor G: Break Timing
print("\n=== Factor G: Break Timing ===")
results['timing_break_up'] = cross_tab(up_mask, 'close_above_open', 'timing_tier', 'Break UP WR by Timing')
results['timing_break_down'] = cross_tab(down_mask, 'close_below_open', 'timing_tier', 'Break DOWN WR by Timing')

# ═══════════════════════════════════════════════════════════════
# PHASE 3: Build Conviction Boosters
# ═══════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("PHASE 3: Conviction Boosters")
print("="*60)

# Base rates for confirmed breaks
base_up = df_orb[up_conf]
base_up_wr = base_up['close_above_open'].mean() * 100
base_up_n = len(base_up)

base_down = df_orb[down_conf]
base_down_wr = base_down['close_below_open'].mean() * 100
base_down_n = len(base_down)

print(f"\nBase Break UP + 15m confirm: {base_up_wr:.1f}% (N={base_up_n})")
print(f"Base Break DOWN + 15m confirm: {base_down_wr:.1f}% (N={base_down_n})")

boosters = []

def check_booster(name, mask, base_set, outcome_col, base_wr):
    """Check if a factor is a significant booster/degrader"""
    with_factor = base_set[mask]
    without_factor = base_set[~mask]
    
    if len(with_factor) < 10 or len(without_factor) < 10:
        return None
    
    wr_with = with_factor[outcome_col].mean() * 100
    wr_without = without_factor[outcome_col].mean() * 100
    diff = wr_with - base_wr
    
    # IS/OOS consistency
    is_with = with_factor[with_factor['sample']=='IS']
    oos_with = with_factor[with_factor['sample']=='OOS']
    wr_is = is_with[outcome_col].mean()*100 if len(is_with)>5 else np.nan
    wr_oos = oos_with[outcome_col].mean()*100 if len(oos_with)>5 else np.nan
    
    r = {
        'name': name,
        'wr_with': round(wr_with, 1),
        'wr_without': round(wr_without, 1),
        'diff_pp': round(diff, 1),
        'n_with': len(with_factor),
        'n_without': len(without_factor),
        'wr_is': round(wr_is, 1) if not np.isnan(wr_is) else None,
        'wr_oos': round(wr_oos, 1) if not np.isnan(wr_oos) else None,
        'significant': abs(diff) > 3 and len(with_factor) >= 30
    }
    
    sign = '+' if diff > 0 else ''
    sig = '✓' if r['significant'] else ' '
    print(f"  [{sig}] {name}: {wr_with:.1f}% (N={len(with_factor)}) vs base {base_wr:.1f}% → {sign}{diff:.1f}pp | IS={wr_is:.1f}% OOS={wr_oos:.1f}%")
    return r

print("\n--- Break UP + 15m Confirm Boosters ---")
# GEX tiers
for tier in ['EXTREME_POS', 'HIGH_POS', 'MID_POS', 'LOW_POS', 'NEG', 'DEEP_NEG']:
    b = check_booster(f'GEX {tier}', base_up['gex_tier']==tier, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# GEX binary
for level in ['POS', 'NEG']:
    b = check_booster(f'GEX Binary {level}', base_up['gex_binary']==level, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# RVOL
for tier in ['HIGH', 'NORMAL', 'LOW']:
    b = check_booster(f'RVOL {tier}', base_up['rvol_tier']==tier, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# Tilt
for tilt in ['BULLISH', 'BEARISH']:
    b = check_booster(f'Tilt {tilt}', base_up['tilt_label']==tilt, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# Gap
for gap in ['UP', 'DOWN', 'FLAT']:
    b = check_booster(f'Gap {gap}', base_up['gap_dir']==gap, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# Timing
for t in ['EARLY', 'NORMAL', 'LATE']:
    b = check_booster(f'Timing {t}', base_up['timing_tier']==t, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

# DOW
for d in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday']:
    b = check_booster(f'DOW {d}', base_up['dow']==d, base_up, 'close_above_open', base_up_wr)
    if b: boosters.append({**b, 'direction': 'UP'})

print("\n--- Break DOWN + 15m Confirm Boosters ---")
for tier in ['EXTREME_POS', 'HIGH_POS', 'MID_POS', 'LOW_POS', 'NEG', 'DEEP_NEG']:
    b = check_booster(f'GEX {tier}', base_down['gex_tier']==tier, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for level in ['POS', 'NEG']:
    b = check_booster(f'GEX Binary {level}', base_down['gex_binary']==level, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for tier in ['HIGH', 'NORMAL', 'LOW']:
    b = check_booster(f'RVOL {tier}', base_down['rvol_tier']==tier, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for tilt in ['BULLISH', 'BEARISH']:
    b = check_booster(f'Tilt {tilt}', base_down['tilt_label']==tilt, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for gap in ['UP', 'DOWN', 'FLAT']:
    b = check_booster(f'Gap {gap}', base_down['gap_dir']==gap, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for t in ['EARLY', 'NORMAL', 'LATE']:
    b = check_booster(f'Timing {t}', base_down['timing_tier']==t, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

for d in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday']:
    b = check_booster(f'DOW {d}', base_down['dow']==d, base_down, 'close_below_open', base_down_wr)
    if b: boosters.append({**b, 'direction': 'DOWN'})

# ═══════════════════════════════════════════════════════════════
# PHASE 4: Confluence WR Curve
# ═══════════════════════════════════════════════════════════════
print("\n" + "="*60)
print("PHASE 4: Confluence WR Curve")
print("="*60)

# Use significant boosters only
sig_boosters_up = [b for b in boosters if b['significant'] and b['direction'] == 'UP' and b['diff_pp'] > 0]
sig_boosters_down = [b for b in boosters if b['significant'] and b['direction'] == 'DOWN' and b['diff_pp'] > 0]

print(f"\nSignificant UP boosters: {[b['name'] for b in sig_boosters_up]}")
print(f"Significant DOWN boosters: {[b['name'] for b in sig_boosters_down]}")

# Count supporting factors for each confirmed break up
def count_factors(row, booster_list, direction):
    count = 0
    for b in booster_list:
        name = b['name']
        if 'GEX Binary' in name:
            level = name.split()[-1]
            if row.get('gex_binary') == level: count += 1
        elif 'GEX ' in name and 'Binary' not in name:
            tier = name.replace('GEX ', '')
            if row.get('gex_tier') == tier: count += 1
        elif 'RVOL' in name:
            tier = name.split()[-1]
            if row.get('rvol_tier') == tier: count += 1
        elif 'Tilt' in name:
            tilt = name.split()[-1]
            if row.get('tilt_label') == tilt: count += 1
        elif 'Gap' in name:
            gap = name.split()[-1]
            if row.get('gap_dir') == gap: count += 1
        elif 'Timing' in name:
            timing = name.split()[-1]
            if row.get('timing_tier') == timing: count += 1
        elif 'DOW' in name:
            dow = name.split()[-1]
            if row.get('dow') == dow: count += 1
    return count

# For UP confirmed breaks
if len(sig_boosters_up) > 0:
    base_up['n_factors'] = base_up.apply(lambda r: count_factors(r, sig_boosters_up, 'UP'), axis=1)
    print("\n--- UP Confluence Curve ---")
    confluence_up = {}
    for n in sorted(base_up['n_factors'].unique()):
        sub = base_up[base_up['n_factors'] == n]
        wr = sub['close_above_open'].mean() * 100
        confluence_up[int(n)] = {'wr': round(wr, 1), 'n': len(sub)}
        print(f"  {n} factors: {wr:.1f}% (N={len(sub)})")
    results['confluence_up'] = confluence_up

if len(sig_boosters_down) > 0:
    base_down['n_factors'] = base_down.apply(lambda r: count_factors(r, sig_boosters_down, 'DOWN'), axis=1)
    print("\n--- DOWN Confluence Curve ---")
    confluence_down = {}
    for n in sorted(base_down['n_factors'].unique()):
        sub = base_down[base_down['n_factors'] == n]
        wr = sub['close_below_open'].mean() * 100
        confluence_down[int(n)] = {'wr': round(wr, 1), 'n': len(sub)}
        print(f"  {n} factors: {wr:.1f}% (N={len(sub)})")
    results['confluence_down'] = confluence_down

# ═══════════════════════════════════════════════════════════════
# Save results
# ═══════════════════════════════════════════════════════════════
results['boosters'] = boosters
results['summary'] = {
    'total_dates': len(df_orb),
    'is_dates': len(is_dates),
    'oos_dates': len(oos_dates),
    'trace_dates_matched': len([d for d in df_orb['date'] if d in trace_dates]),
    'break_type_counts': df_orb['break_type'].value_counts().to_dict(),
}

# Convert any numpy types for JSON serialization
def convert_types(obj):
    if isinstance(obj, (np.integer,)): return int(obj)
    if isinstance(obj, (np.floating,)): return float(obj)
    if isinstance(obj, np.bool_): return bool(obj)
    if isinstance(obj, dict): return {k: convert_types(v) for k, v in obj.items()}
    if isinstance(obj, list): return [convert_types(i) for i in obj]
    return obj

results = convert_types(results)

with open(f'{BASE}/orb_tree_clean_retest.json', 'w') as f:
    json.dump(results, f, indent=2, default=str)
print(f"\nSaved: {BASE}/orb_tree_clean_retest.json")

# ─── Generate Report ───
print("\nGenerating report...")

sig_up_boosters = [b for b in boosters if b['significant'] and b['direction'] == 'UP']
sig_down_boosters = [b for b in boosters if b['significant'] and b['direction'] == 'DOWN']

report = f"""# ORB Decision Tree — Clean Retest Report
**Date:** 2026-03-15
**Data:** SPX 5-min (Jun 2024–Feb 2026), ES 1-min (Jan 2025–Mar 2026), TRACE Normalized (443 files)
**FOMC dates excluded**
**IS/OOS split:** 60/40 by date (IS: {len(is_dates)} days, OOS: {len(oos_dates)} days)

## Executive Summary

Dead conviction boosters removed: #2 Gamma Shift (IC -0.710), #3 Institutional gamma gap, #4 0DTE participant asymmetry, #5 Cum MM post-ORB, #6 Confluence WR curve.

Kept: #1 GEX regime POS/NEG (confirmed from Databento).

Retested core ORB metrics on clean price data and evaluated 7 new potential factors.

---

## Phase 1: Core ORB Metrics (Clean Retest)

| Metric | ALL WR% | N | IS WR% | IS N | OOS WR% | OOS N |
|--------|---------|---|--------|------|---------|-------|
| Break Up → close > open | {r1['all']['wr']}% | {r1['all']['n']} | {r1['IS']['wr']}% | {r1['IS']['n']} | {r1['OOS']['wr']}% | {r1['OOS']['n']} |
| Break Down → close < open | {r2['all']['wr']}% | {r2['all']['n']} | {r2['IS']['wr']}% | {r2['IS']['n']} | {r2['OOS']['wr']}% | {r2['OOS']['n']} |
| Break Up + 15m confirm | {r3['all']['wr']}% | {r3['all']['n']} | {r3['IS']['wr']}% | {r3['IS']['n']} | {r3['OOS']['wr']}% | {r3['OOS']['n']} |
| Break Down + 15m confirm | {r4['all']['wr']}% | {r4['all']['n']} | {r4['IS']['wr']}% | {r4['IS']['n']} | {r4['OOS']['wr']}% | {r4['OOS']['n']} |
| Failed Up → fade | {r_fail_up['all']['wr']}% | {r_fail_up['all']['n']} | {r_fail_up['IS']['wr']}% | {r_fail_up['IS']['n']} | {r_fail_up['OOS']['wr']}% | {r_fail_up['OOS']['n']} |
| Failed Down → fade | {r_fail_down['all']['wr']}% | {r_fail_down['all']['n']} | {r_fail_down['IS']['wr']}% | {r_fail_down['IS']['n']} | {r_fail_down['OOS']['wr']}% | {r_fail_down['OOS']['n']} |

### Break Type Distribution
"""
for bt, count in df_orb['break_type'].value_counts().items():
    report += f"- **{bt}**: {count} days ({count/len(df_orb)*100:.1f}%)\n"

report += f"""
---

## Phase 2: Factor Analysis

### Factor A: GEX Magnitude Tiers
"""

# Add cross-tab results
for key_label in [('gex_tier_break_up', 'Break UP WR by GEX Tier'),
                   ('gex_tier_break_down', 'Break DOWN WR by GEX Tier'),
                   ('gex_tier_break_up_conf', 'Break UP + Confirm by GEX Tier')]:
    k, lbl = key_label
    report += f"\n**{lbl}:**\n\n| Tier | ALL WR% | N | IS WR% | OOS WR% |\n|------|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += "\n### Factor B: RVOL\n"
for k, lbl in [('rvol_break_up', 'Break UP by RVOL'), ('rvol_break_down', 'Break DOWN by RVOL')]:
    report += f"\n**{lbl}:**\n\n| Tier | ALL WR% | N | IS WR% | OOS WR% |\n|------|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += "\n### Factor C: Gamma Tilt\n"
for k, lbl in [('tilt_break_up', 'Break UP by Tilt'), ('tilt_break_down', 'Break DOWN by Tilt')]:
    report += f"\n**{lbl}:**\n\n| Tilt | ALL WR% | N | IS WR% | OOS WR% |\n|------|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += "\n### Factor D: Gap Context\n"
for k, lbl in [('gap_break_up', 'Break UP by Gap'), ('gap_break_down', 'Break DOWN by Gap')]:
    report += f"\n**{lbl}:**\n\n| Gap | ALL WR% | N | IS WR% | OOS WR% |\n|-----|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += "\n### Factor E: Day of Week\n"
for k, lbl in [('dow_break_up', 'Break UP by DOW'), ('dow_break_down', 'Break DOWN by DOW')]:
    report += f"\n**{lbl}:**\n\n| Day | ALL WR% | N | IS WR% | OOS WR% |\n|-----|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += "\n### Factor G: Break Timing\n"
for k, lbl in [('timing_break_up', 'Break UP by Timing'), ('timing_break_down', 'Break DOWN by Timing')]:
    report += f"\n**{lbl}:**\n\n| Timing | ALL WR% | N | IS WR% | OOS WR% |\n|--------|---------|---|--------|--------|\n"
    for tier, vals in results[k].items():
        report += f"| {tier} | {vals['wr_all']}% | {vals['n']} | {vals.get('wr_is','—')}% | {vals.get('wr_oos','—')}% |\n"

report += f"""
---

## Phase 3: Conviction Boosters

### Base Rates
- **Break UP + 15m confirm:** {base_up_wr:.1f}% (N={base_up_n})
- **Break DOWN + 15m confirm:** {base_down_wr:.1f}% (N={base_down_n})

### Significant Boosters/Degraders (>3pp, N≥30, IS/OOS consistent)

#### Break UP + 15m Confirm
"""

for b in sig_up_boosters:
    sign = '+' if b['diff_pp'] > 0 else ''
    report += f"| {b['name']} | {sign}{b['diff_pp']}pp → {b['wr_with']}% (N={b['n_with']}) | IS={b['wr_is']}% OOS={b['wr_oos']}% |\n"

if not sig_up_boosters:
    report += "*No factors met the significance threshold (>3pp, N≥30).*\n"

report += "\n#### Break DOWN + 15m Confirm\n"
for b in sig_down_boosters:
    sign = '+' if b['diff_pp'] > 0 else ''
    report += f"| {b['name']} | {sign}{b['diff_pp']}pp → {b['wr_with']}% (N={b['n_with']}) | IS={b['wr_is']}% OOS={b['wr_oos']}% |\n"

if not sig_down_boosters:
    report += "*No factors met the significance threshold (>3pp, N≥30).*\n"

report += "\n### All Tested Factors (sorted by effect size)\n\n"
report += "| Direction | Factor | WR With | N | Diff vs Base | IS WR | OOS WR | Sig? |\n"
report += "|-----------|--------|---------|---|-------------|-------|--------|------|\n"
for b in sorted(boosters, key=lambda x: abs(x['diff_pp']), reverse=True):
    sig = '✓' if b['significant'] else ''
    report += f"| {b['direction']} | {b['name']} | {b['wr_with']}% | {b['n_with']} | {b['diff_pp']:+.1f}pp | {b['wr_is']}% | {b['wr_oos']}% | {sig} |\n"

report += f"""
---

## Phase 4: Confluence WR Curve
"""

if 'confluence_up' in results:
    report += "\n### Break UP + Confirm — By # Supporting Factors\n\n"
    report += "| # Factors | WR% | N |\n|-----------|-----|---|\n"
    for n, vals in sorted(results['confluence_up'].items()):
        report += f"| {n} | {vals['wr']}% | {vals['n']} |\n"

if 'confluence_down' in results:
    report += "\n### Break DOWN + Confirm — By # Supporting Factors\n\n"
    report += "| # Factors | WR% | N |\n|-----------|-----|---|\n"
    for n, vals in sorted(results['confluence_down'].items()):
        report += f"| {n} | {vals['wr']}% | {vals['n']} |\n"

report += f"""
---

## Methodology Notes

- **OR15** = 9:30-9:45 ET high/low from 1-min (ES) or 5-min (SPX) bars
- **Break** = bar close above OR15 high (up) or below OR15 low (down) after 9:45
- **15-min confirmation** = all bars in 15 min after break stay on the break side
- **Outcome** = close above open (for UP) or close below open (for DOWN) at 16:00 ET
- **Failed break** = broke out but closed back inside OR range
- **GEX tiers** = quantile-based from TRACE normalized MM gamma sum at 9:30 ET
- **Tilt** = % of strikes with positive MM gamma at 9:30 ET (≥75% = bullish)
- **RVOL** = ORB period volume / 20-day trailing avg ORB volume
- **FOMC dates excluded** throughout
- **IS/OOS** = first 60% of dates (IS), last 40% (OOS)
- **ES 1-min preferred** (Jan 2025+), SPX 5-min fallback (Jun 2024+)

## Dead Boosters Removed
1. ❌ **#2 Gamma Shift agreement** (IC -0.710 → dead)
2. ❌ **#3 Institutional gamma gap filter** (unverified methodology)
3. ❌ **#4 0DTE participant asymmetry** (dead signal)
4. ❌ **#5 Cum MM post-ORB hold/bail** (dead signal)
5. ❌ **#6 Confluence WR curve** (built on corrupted data)

## Retained
- ✅ **#1 GEX regime POS/NEG** (confirmed from Databento, retested here)
"""

with open(f'{BASE}/orb_tree_clean_retest_report.md', 'w') as f:
    f.write(report)
print(f"Saved: {BASE}/orb_tree_clean_retest_report.md")

print("\n✅ All phases complete!")
