#!/usr/bin/env python3
"""HIRO Retail Flow vs Price — 5-Min Bar-by-Bar Fade Analysis"""

import json
import csv
import os
import requests
import numpy as np
from datetime import datetime, timezone, timedelta
from collections import defaultdict

DATA_DIR = "/Users/daniel/.openclaw/workspace/data"
HIRO_DIR = os.path.join(DATA_DIR, "hiro_history")
POLYGON_KEY = "LpewQUO2J2wIYlQVHQNkpLpxN1nuoB1R"

# Load FOMC dates
with open(os.path.join(DATA_DIR, "fomc_dates.json")) as f:
    FOMC_DATES = set(json.load(f)["dates"])

# Find available HIRO dates
hiro_files = sorted([f for f in os.listdir(HIRO_DIR) if f.endswith("_sp500.json")])
dates = [f.replace("_sp500.json", "") for f in hiro_files]
dates = [d for d in dates if d not in FOMC_DATES]
print(f"Available dates (excl FOMC): {dates}")

# ET timezone offset (EDT = UTC-4, EST = UTC-5)
# March 2026: DST starts Mar 8, so Mar 3-7 are EST (UTC-5), Mar 8+ are EDT (UTC-4)
def get_et_offset(date_str):
    """Return UTC offset hours for ET on given date"""
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    # DST 2026 starts Sun Mar 8
    if dt >= datetime(2026, 3, 8):
        return -4  # EDT
    return -5  # EST

def make_5min_floor(ts_ms, et_offset):
    """Floor timestamp to 5-min bar in ET"""
    ts = ts_ms / 1000
    # Convert to ET
    et_ts = ts + et_offset * 3600
    # Floor to 5 min
    et_dt = datetime.utcfromtimestamp(et_ts)
    minute_floor = (et_dt.minute // 5) * 5
    et_floored = et_dt.replace(minute=minute_floor, second=0, microsecond=0)
    # Convert back to UTC timestamp
    utc_floored_ts = et_floored.timestamp() - et_offset * 3600
    return utc_floored_ts

def is_rth(ts_utc, et_offset):
    """Check if UTC timestamp is during RTH (9:30-16:00 ET)"""
    et_ts = ts_utc + et_offset * 3600
    et_dt = datetime.utcfromtimestamp(et_ts)
    t = et_dt.hour * 60 + et_dt.minute
    return 570 <= t < 960  # 9:30=570, 16:00=960

# Process each date
all_bars = []

for date_str in dates:
    print(f"\nProcessing {date_str}...")
    et_offset = get_et_offset(date_str)
    
    # Load HIRO retail data
    with open(os.path.join(HIRO_DIR, f"{date_str}_sp500.json")) as f:
        hiro = json.load(f)
    
    retail = hiro["retail"]
    
    # Bucket retail flow into 5-min bars (sum mid_signal per bar)
    bar_flow = defaultdict(float)
    for entry in retail:
        ts_ms = entry["utc_time"]
        ts_utc = ts_ms / 1000
        if not is_rth(ts_utc, et_offset):
            continue
        # Floor to 5-min in ET, get UTC key
        et_ts = ts_utc + et_offset * 3600
        et_dt = datetime.utcfromtimestamp(et_ts)
        minute_floor = (et_dt.minute // 5) * 5
        bar_key = et_dt.replace(minute=minute_floor, second=0, microsecond=0)
        bar_key_str = bar_key.strftime("%H:%M")
        bar_flow[(date_str, bar_key_str)] += entry["mid_signal"]
    
    # Get SPY 5-min bars from Polygon
    url = f"https://api.polygon.io/v2/aggs/ticker/SPY/range/5/minute/{date_str}/{date_str}?adjusted=true&sort=asc&limit=5000&apiKey={POLYGON_KEY}"
    resp = requests.get(url)
    spy_data = resp.json()
    
    if spy_data.get("resultsCount", 0) == 0:
        print(f"  No SPY data for {date_str}, skipping")
        continue
    
    # Build SPY price series keyed by ET time
    spy_bars = {}
    spy_open_price = None
    for bar in spy_data["results"]:
        ts_ms = bar["t"]
        ts_utc = ts_ms / 1000
        if not is_rth(ts_utc, et_offset):
            continue
        et_ts = ts_utc + et_offset * 3600
        et_dt = datetime.utcfromtimestamp(et_ts)
        minute_floor = (et_dt.minute // 5) * 5
        bar_key_str = et_dt.replace(minute=minute_floor, second=0, microsecond=0).strftime("%H:%M")
        spy_bars[bar_key_str] = {
            "open": bar["o"],
            "high": bar["h"],
            "low": bar["l"],
            "close": bar["c"],
            "volume": bar["v"]
        }
        if spy_open_price is None:
            spy_open_price = bar["o"]
    
    if not spy_bars or spy_open_price is None:
        print(f"  No RTH SPY bars for {date_str}, skipping")
        continue
    
    # Build ordered list of 5-min bar times (9:30, 9:35, ..., 15:55)
    bar_times = []
    for h in range(9, 16):
        start_m = 30 if h == 9 else 0
        end_m = 55 if h < 15 else 55
        if h == 15:
            end_m = 55
        for m in range(start_m, 60, 5):
            if h == 16:
                break
            bar_times.append(f"{h:02d}:{m:02d}")
    
    # Compute cumulative flow and merge with price
    cumulative = 0.0
    prev_cumulative = 0.0
    day_bars = []
    
    for bt in bar_times:
        flow = bar_flow.get((date_str, bt), 0.0)
        cumulative += flow
        delta = flow  # 5-min change
        
        if bt not in spy_bars:
            prev_cumulative = cumulative
            continue
        
        spy = spy_bars[bt]
        price_from_open = (spy["close"] / spy_open_price - 1) * 100  # % return from open
        
        day_bars.append({
            "date": date_str,
            "time": bt,
            "flow_cumulative": cumulative,
            "flow_delta": delta,
            "spy_close": spy["close"],
            "spy_open_day": spy_open_price,
            "price_from_open_pct": price_from_open,
            "bar_index": len(day_bars)
        })
        prev_cumulative = cumulative
    
    # Compute forward returns
    for i, bar in enumerate(day_bars):
        for horizon_name, horizon_bars in [("fwd_30m", 6), ("fwd_1h", 12), ("fwd_2h", 24)]:
            j = i + horizon_bars
            if j < len(day_bars):
                fwd_ret = (day_bars[j]["spy_close"] / bar["spy_close"] - 1) * 10000  # bps
                bar[horizon_name] = fwd_ret
            else:
                bar[horizon_name] = None
    
    all_bars.extend(day_bars)
    print(f"  {len(day_bars)} RTH bars, cumulative range: [{min(b['flow_cumulative'] for b in day_bars):.0f}, {max(b['flow_cumulative'] for b in day_bars):.0f}]")

print(f"\n{'='*80}")
print(f"Total bars across all days: {len(all_bars)}")

# Filter to bars with forward returns
def has_fwd(bar, horizon):
    return bar[horizon] is not None

# ============================================================
# ANALYSIS 1: Cumulative Flow Quintiles
# ============================================================
print(f"\n{'='*80}")
print("ANALYSIS 1: Forward Returns by Cumulative Retail Flow Quintile")
print("="*80)

# Get all cumulative values for quintile boundaries
cumulatives = [b["flow_cumulative"] for b in all_bars]
quintile_breaks = [np.percentile(cumulatives, p) for p in [0, 20, 40, 60, 80, 100]]
print(f"Quintile breaks: {[f'${x/1e6:.1f}M' for x in quintile_breaks]}")

def get_quintile(val):
    for q in range(4):
        if val <= quintile_breaks[q+1]:
            return q + 1
    return 5

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    print(f"\n  {horizon}:")
    print(f"  {'Quintile':<10} {'N':>6} {'Avg(bps)':>10} {'Med(bps)':>10} {'WR%':>8}")
    print(f"  {'-'*46}")
    for q in range(1, 6):
        vals = [b[horizon] for b in all_bars if get_quintile(b["flow_cumulative"]) == q and b[horizon] is not None]
        if vals:
            avg = np.mean(vals)
            med = np.median(vals)
            wr = sum(1 for v in vals if v > 0) / len(vals) * 100
            print(f"  Q{q} (sell) " if q == 1 else f"  Q{q} (buy)  " if q == 5 else f"  Q{q}        ", end="")
            print(f"{len(vals):>6} {avg:>10.2f} {med:>10.2f} {wr:>7.1f}%")

# ============================================================
# ANALYSIS 2: Flow Direction + Price Direction (Fade Analysis)
# ============================================================
print(f"\n{'='*80}")
print("ANALYSIS 2: Retail Flow Direction + Price Direction (Fade Detection)")
print("="*80)

categories = {
    "SELL+UP (faded)": lambda b: b["flow_cumulative"] < 0 and b["price_from_open_pct"] > 0,
    "SELL+DOWN (right)": lambda b: b["flow_cumulative"] < 0 and b["price_from_open_pct"] <= 0,
    "BUY+DOWN (faded)": lambda b: b["flow_cumulative"] > 0 and b["price_from_open_pct"] <= 0,
    "BUY+UP (right)": lambda b: b["flow_cumulative"] > 0 and b["price_from_open_pct"] > 0,
}

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    print(f"\n  {horizon}:")
    print(f"  {'Category':<22} {'N':>6} {'Avg(bps)':>10} {'Med(bps)':>10} {'WR%':>8}")
    print(f"  {'-'*58}")
    for cat_name, cat_fn in categories.items():
        vals = [b[horizon] for b in all_bars if cat_fn(b) and b[horizon] is not None]
        if vals:
            avg = np.mean(vals)
            med = np.median(vals)
            wr = sum(1 for v in vals if v > 0) / len(vals) * 100
            print(f"  {cat_name:<22} {len(vals):>6} {avg:>10.2f} {med:>10.2f} {wr:>7.1f}%")

# ============================================================
# ANALYSIS 3: Flow Momentum (5-min delta) x Price Direction
# ============================================================
print(f"\n{'='*80}")
print("ANALYSIS 3: Flow Momentum (5-min Delta) x Price Direction")
print("="*80)

momentum_cats = {
    "FLOW_INC + PRICE_UP": lambda b: b["flow_delta"] > 0 and b["price_from_open_pct"] > 0,
    "FLOW_INC + PRICE_DOWN": lambda b: b["flow_delta"] > 0 and b["price_from_open_pct"] <= 0,
    "FLOW_DEC + PRICE_UP": lambda b: b["flow_delta"] < 0 and b["price_from_open_pct"] > 0,
    "FLOW_DEC + PRICE_DOWN": lambda b: b["flow_delta"] < 0 and b["price_from_open_pct"] <= 0,
}

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    print(f"\n  {horizon}:")
    print(f"  {'Category':<25} {'N':>6} {'Avg(bps)':>10} {'Med(bps)':>10} {'WR%':>8}")
    print(f"  {'-'*61}")
    for cat_name, cat_fn in momentum_cats.items():
        vals = [b[horizon] for b in all_bars if cat_fn(b) and b[horizon] is not None]
        if vals:
            avg = np.mean(vals)
            med = np.median(vals)
            wr = sum(1 for v in vals if v > 0) / len(vals) * 100
            print(f"  {cat_name:<25} {len(vals):>6} {avg:>10.2f} {med:>10.2f} {wr:>7.1f}%")

# ============================================================
# ANALYSIS 4: Conviction Level
# ============================================================
print(f"\n{'='*80}")
print("ANALYSIS 4: Forward Returns by Conviction Level")
print("="*80)

conviction_levels = {
    "HIGH (|cum|>$400M)": lambda b: abs(b["flow_cumulative"]) > 400e6,
    "MODERATE ($200-400M)": lambda b: 200e6 < abs(b["flow_cumulative"]) <= 400e6,
    "LOW ($50-200M)": lambda b: 50e6 < abs(b["flow_cumulative"]) <= 200e6,
    "MINIMAL (<$50M)": lambda b: abs(b["flow_cumulative"]) <= 50e6,
}

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    print(f"\n  {horizon}:")
    print(f"  {'Level':<25} {'N':>6} {'Avg(bps)':>10} {'Med(bps)':>10} {'WR%':>8}")
    print(f"  {'-'*61}")
    for level_name, level_fn in conviction_levels.items():
        vals = [b[horizon] for b in all_bars if level_fn(b) and b[horizon] is not None]
        if vals:
            avg = np.mean(vals)
            med = np.median(vals)
            wr = sum(1 for v in vals if v > 0) / len(vals) * 100
            print(f"  {level_name:<25} {len(vals):>6} {avg:>10.2f} {med:>10.2f} {wr:>7.1f}%")

# ============================================================
# ANALYSIS 4b: Conviction + Direction (signed)
# ============================================================
print(f"\n{'='*80}")
print("ANALYSIS 4b: Conviction + Direction (Signed Cumulative)")
print("="*80)

signed_conviction = {
    "HIGH SELL (cum<-$400M)": lambda b: b["flow_cumulative"] < -400e6,
    "MOD SELL (-$400M to -$200M)": lambda b: -400e6 <= b["flow_cumulative"] < -200e6,
    "LOW SELL (-$200M to -$50M)": lambda b: -200e6 <= b["flow_cumulative"] < -50e6,
    "NEUTRAL (-$50M to $50M)": lambda b: -50e6 <= b["flow_cumulative"] <= 50e6,
    "LOW BUY ($50M to $200M)": lambda b: 50e6 < b["flow_cumulative"] <= 200e6,
    "MOD BUY ($200M to $400M)": lambda b: 200e6 < b["flow_cumulative"] <= 400e6,
    "HIGH BUY (cum>$400M)": lambda b: b["flow_cumulative"] > 400e6,
}

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    print(f"\n  {horizon}:")
    print(f"  {'Level':<32} {'N':>6} {'Avg(bps)':>10} {'Med(bps)':>10} {'WR%':>8}")
    print(f"  {'-'*68}")
    for level_name, level_fn in signed_conviction.items():
        vals = [b[horizon] for b in all_bars if level_fn(b) and b[horizon] is not None]
        if vals:
            avg = np.mean(vals)
            med = np.median(vals)
            wr = sum(1 for v in vals if v > 0) / len(vals) * 100
            print(f"  {level_name:<32} {len(vals):>6} {avg:>10.2f} {med:>10.2f} {wr:>7.1f}%")

# ============================================================
# Information Coefficient (IC)
# ============================================================
print(f"\n{'='*80}")
print("INFORMATION COEFFICIENT: Cumulative Flow vs Forward Returns")
print("="*80)

from scipy import stats

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    flows = [b["flow_cumulative"] for b in all_bars if b[horizon] is not None]
    rets = [b[horizon] for b in all_bars if b[horizon] is not None]
    if len(flows) > 2:
        ic_pearson, p_pearson = stats.pearsonr(flows, rets)
        ic_spearman, p_spearman = stats.spearmanr(flows, rets)
        print(f"  {horizon}: Pearson IC={ic_pearson:.4f} (p={p_pearson:.4f}), Spearman IC={ic_spearman:.4f} (p={p_spearman:.4f})")

# ============================================================
# Save outputs
# ============================================================

# Save JSON results
results = {
    "metadata": {
        "dates": dates,
        "total_bars": len(all_bars),
        "description": "HIRO retail flow vs SPY price 5-min bar fade analysis",
        "quintile_breaks_dollars": quintile_breaks,
    },
    "quintile_analysis": {},
    "fade_analysis": {},
    "momentum_analysis": {},
    "conviction_analysis": {},
    "information_coefficients": {},
}

for horizon in ["fwd_30m", "fwd_1h", "fwd_2h"]:
    # Quintiles
    results["quintile_analysis"][horizon] = {}
    for q in range(1, 6):
        vals = [b[horizon] for b in all_bars if get_quintile(b["flow_cumulative"]) == q and b[horizon] is not None]
        if vals:
            results["quintile_analysis"][horizon][f"Q{q}"] = {
                "n": len(vals), "avg_bps": round(np.mean(vals), 2),
                "median_bps": round(np.median(vals), 2),
                "win_rate": round(sum(1 for v in vals if v > 0) / len(vals) * 100, 1)
            }
    
    # Fade
    results["fade_analysis"][horizon] = {}
    for cat_name, cat_fn in categories.items():
        vals = [b[horizon] for b in all_bars if cat_fn(b) and b[horizon] is not None]
        if vals:
            results["fade_analysis"][horizon][cat_name] = {
                "n": len(vals), "avg_bps": round(np.mean(vals), 2),
                "median_bps": round(np.median(vals), 2),
                "win_rate": round(sum(1 for v in vals if v > 0) / len(vals) * 100, 1)
            }
    
    # IC
    flows = [b["flow_cumulative"] for b in all_bars if b[horizon] is not None]
    rets = [b[horizon] for b in all_bars if b[horizon] is not None]
    if len(flows) > 2:
        ic_p, _ = stats.pearsonr(flows, rets)
        ic_s, _ = stats.spearmanr(flows, rets)
        results["information_coefficients"][horizon] = {
            "pearson": round(ic_p, 4), "spearman": round(ic_s, 4)
        }

with open(os.path.join(DATA_DIR, "hiro_5min_fade_backtest.json"), "w") as f:
    json.dump(results, f, indent=2)
print(f"\nSaved results to hiro_5min_fade_backtest.json")

# Save CSV
csv_path = os.path.join(DATA_DIR, "hiro_5min_bars_merged.csv")
fieldnames = ["date", "time", "flow_cumulative", "flow_delta", "spy_close", "spy_open_day",
              "price_from_open_pct", "fwd_30m", "fwd_1h", "fwd_2h"]
with open(csv_path, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    for bar in all_bars:
        row = {k: bar.get(k) for k in fieldnames}
        writer.writerow(row)
print(f"Saved {len(all_bars)} bars to hiro_5min_bars_merged.csv")
