
#!/usr/bin/env python3
"""Worker: process one date for Study 3."""
import gc
import json
import os
import sys
import time
from collections import deque
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd

DATE_STR = sys.argv[1]
OUTPUT_FILE = sys.argv[2]
SPX_FILE = sys.argv[3]
CBBO_DIR = Path(sys.argv[4])

SUSTAINED_THRESHOLD = 1.5
MIN_DURATION_SEC = 120
BASELINE_WINDOW = 60

UNDEF = 9223372036854775807


def load_spx_1min():
    df = pd.read_csv(SPX_FILE)
    df["datetime"] = pd.to_datetime(df["datetime"], utc=True)
    df = df.sort_values("datetime").reset_index(drop=True)
    df["ts"] = (df["datetime"] - pd.Timestamp("1970-01-01", tz="UTC")).dt.total_seconds().astype(np.int64)
    return df


def spx_return_after(spx_df, ts_epoch, minutes):
    p0_idx = np.searchsorted(spx_df["ts"].values, ts_epoch, side="right") - 1
    p1_idx = np.searchsorted(spx_df["ts"].values, ts_epoch + minutes * 60, side="right") - 1
    if 0 <= p0_idx < len(spx_df) and 0 <= p1_idx < len(spx_df):
        return float(spx_df.iloc[p1_idx]["close"] - spx_df.iloc[p0_idx]["close"])
    return None


def find_cbbo_file(date_str):
    direct = CBBO_DIR / f"spxw_cbbo_1s_{date_str}.dbn.zst"
    if direct.exists():
        return direct
    for opra_dir in CBBO_DIR.glob("OPRA-*"):
        f = opra_dir / f"opra-pillar-{date_str}.cbbo-1s.dbn.zst"
        if f.exists():
            return f
    return None


def make_event(date_str, side, state, end_sec, spx_df):
    start_ts = state["start"]
    return {
        "date": date_str,
        "side": side,
        "start_et": datetime.utcfromtimestamp(start_ts).strftime("%H:%M:%S"),
        "end_et": datetime.utcfromtimestamp(end_sec).strftime("%H:%M:%S"),
        "duration_sec": state["consecutive"],
        "max_ratio": round(state.get("max_ratio", 0), 3),
        "start_ts": int(start_ts),
        "end_ts": int(end_sec),
        "ret_from_start_10m": spx_return_after(spx_df, start_ts, 10),
        "ret_from_start_30m": spx_return_after(spx_df, start_ts, 30),
        "ret_from_start_60m": spx_return_after(spx_df, start_ts, 60),
        "ret_from_end_10m": spx_return_after(spx_df, end_sec, 10),
        "ret_from_end_30m": spx_return_after(spx_df, end_sec, 30),
        "ret_from_end_60m": spx_return_after(spx_df, end_sec, 60),
    }


def process_date(date_str, spx_df):
    import databento as db

    cbbo_file = find_cbbo_file(date_str)
    if not cbbo_file:
        print(f"  {date_str}: no CBBO file", file=sys.stderr)
        return []

    print(f"  {date_str}: loading store...", file=sys.stderr)
    t0 = time.time()
    store = db.DBNStore.from_file(str(cbbo_file))

    # Build instrument_id -> put/call mapping from metadata
    inst_is_put = {}
    md = store.metadata
    if md.mappings:
        for raw_sym, mapping_list in md.mappings.items():
            raw_sym_str = str(raw_sym).strip()
            parts = raw_sym_str.replace("  ", " ").split()
            is_put = False
            if len(parts) >= 2:
                code = parts[-1]
                if len(code) >= 7:
                    pc_char = code[6]
                    is_put = pc_char == "P"
                else:
                    is_put = "P" in code
            for m in mapping_list:
                inst_id = int(m["symbol"])
                inst_is_put[inst_id] = is_put

    n_puts = sum(inst_is_put.values())
    n_calls = sum(1 for v in inst_is_put.values() if not v)
    print(f"  {date_str}: {len(inst_is_put)} instruments ({n_puts} puts, {n_calls} calls)", file=sys.stderr)

    # RTH bounds
    rth_start_h, rth_start_m = 9, 30
    rth_end_h, rth_end_m = 16, 0

    put_spread_buffer = deque(maxlen=BASELINE_WINDOW + 10)
    call_spread_buffer = deque(maxlen=BASELINE_WINDOW + 10)

    current_sec = None
    sec_put_spreads = []
    sec_call_spreads = []

    states = {
        "PUT_SUSTAINED": {"in_event": False, "start": None, "consecutive": 0, "max_ratio": 0},
        "CALL_SUSTAINED": {"in_event": False, "start": None, "consecutive": 0, "max_ratio": 0},
        "BOTH_SUSTAINED": {"in_event": False, "start": None, "consecutive": 0, "max_put": 0, "max_call": 0},
    }

    events = []
    n_records = 0
    n_valid_secs = 0

    print(f"  {date_str}: streaming records...", file=sys.stderr)

    for rec in store:
        n_records += 1
        ts_ns = rec.ts_event
        if ts_ns >= UNDEF:
            continue

        ts_sec = ts_ns // 1_000_000_000

        # RTH filter (rough UTC to ET: -5 in winter, -4 in summer)
        # March dates are EST (-5), April dates are EDT (-4)
        dt = datetime.utcfromtimestamp(ts_sec)
        # Use -4 for dates >= March 9 2026 (DST starts), -5 before
        is_dst = date_str >= "2026-03-09"
        et_offset = 4 if is_dst else 5
        hour_et = dt.hour - et_offset
        if hour_et < 0:
            hour_et += 24
        min_et = dt.minute

        if hour_et < rth_start_h or (hour_et == rth_start_h and min_et < rth_start_m):
            continue
        if hour_et > rth_end_h or (hour_et == rth_end_h and min_et >= rth_end_m):
            continue

        # Get spread from first level
        l = rec.levels[0]
        if l.bid_px >= UNDEF or l.ask_px >= UNDEF or l.bid_px <= 0 or l.ask_px <= 0:
            continue

        spread = (l.ask_px - l.bid_px) / 1e9
        if spread < 0:
            continue

        inst_id = rec.instrument_id
        is_put = inst_is_put.get(inst_id)
        if is_put is None:
            continue

        if current_sec is None:
            current_sec = ts_sec

        if ts_sec != current_sec:
            # Process the completed second
            if sec_put_spreads and sec_call_spreads:
                put_med = float(np.median(sec_put_spreads))
                call_med = float(np.median(sec_call_spreads))

                put_spread_buffer.append((current_sec, put_med))
                call_spread_buffer.append((current_sec, call_med))
                n_valid_secs += 1

                if len(put_spread_buffer) >= 30 and len(call_spread_buffer) >= 30:
                    put_baseline_vals = [v for _, v in list(put_spread_buffer)[:-1]][-BASELINE_WINDOW:]
                    call_baseline_vals = [v for _, v in list(call_spread_buffer)[:-1]][-BASELINE_WINDOW:]

                    put_baseline = float(np.median(put_baseline_vals)) if put_baseline_vals else put_med
                    call_baseline = float(np.median(call_baseline_vals)) if call_baseline_vals else call_med

                    put_ratio = put_med / (put_baseline + 1e-9)
                    call_ratio = call_med / (call_baseline + 1e-9)

                    # --- PUT_SUSTAINED ---
                    st = states["PUT_SUSTAINED"]
                    if put_ratio >= SUSTAINED_THRESHOLD and call_ratio < SUSTAINED_THRESHOLD:
                        if not st["in_event"]:
                            st["in_event"] = True
                            st["start"] = current_sec
                            st["consecutive"] = 1
                            st["max_ratio"] = put_ratio
                        else:
                            st["consecutive"] += 1
                            st["max_ratio"] = max(st["max_ratio"], put_ratio)
                    else:
                        if st["in_event"] and st["consecutive"] >= MIN_DURATION_SEC:
                            events.append(make_event(date_str, "PUT_SUSTAINED", st, current_sec, spx_df))
                        st["in_event"] = False
                        st["consecutive"] = 0

                    # --- CALL_SUSTAINED ---
                    st = states["CALL_SUSTAINED"]
                    if call_ratio >= SUSTAINED_THRESHOLD and put_ratio < SUSTAINED_THRESHOLD:
                        if not st["in_event"]:
                            st["in_event"] = True
                            st["start"] = current_sec
                            st["consecutive"] = 1
                            st["max_ratio"] = call_ratio
                        else:
                            st["consecutive"] += 1
                            st["max_ratio"] = max(st["max_ratio"], call_ratio)
                    else:
                        if st["in_event"] and st["consecutive"] >= MIN_DURATION_SEC:
                            events.append(make_event(date_str, "CALL_SUSTAINED", st, current_sec, spx_df))
                        st["in_event"] = False
                        st["consecutive"] = 0

                    # --- BOTH_SUSTAINED ---
                    st = states["BOTH_SUSTAINED"]
                    if put_ratio >= SUSTAINED_THRESHOLD and call_ratio >= SUSTAINED_THRESHOLD:
                        if not st["in_event"]:
                            st["in_event"] = True
                            st["start"] = current_sec
                            st["consecutive"] = 1
                            st["max_put"] = put_ratio
                            st["max_call"] = call_ratio
                        else:
                            st["consecutive"] += 1
                            st["max_put"] = max(st["max_put"], put_ratio)
                            st["max_call"] = max(st["max_call"], call_ratio)
                    else:
                        if st["in_event"] and st["consecutive"] >= MIN_DURATION_SEC:
                            ev = make_event(date_str, "BOTH_SUSTAINED", st, current_sec, spx_df)
                            ev["max_ratio"] = round(max(st["max_put"], st["max_call"]), 3)
                            events.append(ev)
                        st["in_event"] = False
                        st["consecutive"] = 0

            current_sec = ts_sec
            sec_put_spreads = []
            sec_call_spreads = []

        # Accumulate
        if is_put:
            sec_put_spreads.append(spread)
        else:
            sec_call_spreads.append(spread)

        if n_records % 10_000_000 == 0:
            elapsed = time.time() - t0
            print(f"    {date_str}: {n_records/1e6:.0f}M records, {n_valid_secs} secs, {len(events)} events, {elapsed:.0f}s", file=sys.stderr)

    # Close any open events at end of day
    for side_name, st in states.items():
        if st["in_event"] and st["consecutive"] >= MIN_DURATION_SEC:
            ev = make_event(date_str, side_name, st, current_sec, spx_df)
            if side_name == "BOTH_SUSTAINED":
                ev["max_ratio"] = round(max(st.get("max_put", 0), st.get("max_call", 0)), 3)
            events.append(ev)

    elapsed = time.time() - t0
    print(f"  {date_str}: DONE — {n_records/1e6:.1f}M records, {n_valid_secs} secs, {len(events)} events, {elapsed:.0f}s", file=sys.stderr)

    return events


# ── Main ──
spx_df = load_spx_1min()
events = process_date(DATE_STR, spx_df)

# Write results as JSON to output file
with open(OUTPUT_FILE, "w") as f:
    json.dump(events, f, indent=2, default=str)

print(f"  {DATE_STR}: wrote {len(events)} events to {OUTPUT_FILE}", file=sys.stderr)
