
#!/usr/bin/env python3
"""Worker: extract SPXW 0DTE trades for spike events on one date."""
import json
import sys
import time
from datetime import datetime, date as date_cls
from pathlib import Path

import numpy as np

DATE_STR = sys.argv[1]
EVENTS_FILE = sys.argv[2]
OUTPUT_FILE = sys.argv[3]
SPXW_DIR = Path(sys.argv[4])

UNDEF = 9223372036854775807
FLOW_WINDOW_SEC = 60  # 60 seconds after spike


def main():
    import databento as db
    
    t0 = time.time()
    
    # Load spike events for this date
    with open(EVENTS_FILE) as f:
        all_events = json.load(f)
    events = [e for e in all_events if e["date"] == DATE_STR]
    
    if not events:
        print(f"  {DATE_STR}: no events", file=sys.stderr)
        with open(OUTPUT_FILE, "w") as f:
            json.dump([], f)
        return
    
    # Find SPXW trade file
    trade_file = SPXW_DIR / f"spxw_trades_{DATE_STR}.dbn.zst"
    if not trade_file.exists():
        print(f"  {DATE_STR}: no trade file", file=sys.stderr)
        with open(OUTPUT_FILE, "w") as f:
            json.dump([], f)
        return
    
    print(f"  {DATE_STR}: loading {len(events)} events, trade file...", file=sys.stderr)
    store = db.DBNStore.from_file(str(trade_file))
    
    # Build instrument_id -> put/call mapping
    inst_is_put = {}
    md = store.metadata
    if md.mappings:
        for raw_sym, mapping_list in md.mappings.items():
            raw_sym_str = str(raw_sym).strip()
            parts = raw_sym_str.replace("  ", " ").split()
            is_put = False
            if len(parts) >= 2:
                code = parts[-1]
                if len(code) >= 7:
                    pc_char = code[6]
                    is_put = pc_char == "P"
            for m in mapping_list:
                inst_id = int(m["symbol"])
                inst_is_put[inst_id] = is_put
    
    n_puts = sum(v for v in inst_is_put.values())
    n_calls = sum(1 for v in inst_is_put.values() if not v)
    print(f"  {DATE_STR}: {len(inst_is_put)} instruments ({n_puts} puts, {n_calls} calls)", file=sys.stderr)
    
    # Build event windows: for each event, track flow in [ts, ts + 60s]
    event_windows = []
    for e in events:
        ts = e["ts"]
        event_windows.append({
            "ts": ts,
            "ts_end": ts + FLOW_WINDOW_SEC,
            "call_buy": 0,
            "call_sell": 0,
            "put_buy": 0,
            "put_sell": 0,
            "total_trades": 0,
            "total_size": 0,
        })
    
    # Sort by ts for efficient matching
    event_windows.sort(key=lambda x: x["ts"])
    min_ts = event_windows[0]["ts"]
    max_ts_end = max(e["ts_end"] for e in event_windows)
    
    # Stream trades
    n_records = 0
    n_matched = 0
    
    print(f"  {DATE_STR}: streaming trades...", file=sys.stderr)
    
    for rec in store:
        n_records += 1
        ts_ns = rec.ts_event
        if ts_ns >= UNDEF:
            continue
        
        ts_sec = ts_ns // 1_000_000_000
        
        # Skip records outside any event window
        if ts_sec < min_ts or ts_sec > max_ts_end:
            continue
        
        inst_id = rec.instrument_id
        is_put = inst_is_put.get(inst_id)
        if is_put is None:
            continue
        
        # Determine aggressor side from rec.side
        # 'A' = seller aggressor (sell), 'B' = buyer aggressor (buy) — Databento convention
        side_char = getattr(rec, 'side', None)
        if side_char is None:
            continue
        
        # side is typically a string 'A' or 'B', or could be bytes
        if isinstance(side_char, bytes):
            side_char = side_char.decode('ascii', errors='ignore')
        side_char = str(side_char).strip().upper()
        
        if side_char not in ('A', 'B'):
            continue
        
        is_buy = side_char == 'B'
        size = getattr(rec, 'size', 1)
        
        # Match to event windows
        for ew in event_windows:
            if ew["ts"] <= ts_sec <= ew["ts_end"]:
                n_matched += 1
                ew["total_trades"] += 1
                ew["total_size"] += size
                if is_put:
                    if is_buy:
                        ew["put_buy"] += size
                    else:
                        ew["put_sell"] += size
                else:
                    if is_buy:
                        ew["call_buy"] += size
                    else:
                        ew["call_sell"] += size
        
        if n_records % 5_000_000 == 0:
            print(f"    {DATE_STR}: {n_records/1e6:.0f}M records, {n_matched} matched", file=sys.stderr)
    
    elapsed = time.time() - t0
    print(f"  {DATE_STR}: DONE — {n_records/1e6:.1f}M records, {n_matched} matched trades, {elapsed:.0f}s", file=sys.stderr)
    
    # Enrich events with flow features
    results = []
    for i, e in enumerate(events):
        ew = event_windows[i]
        net_call = ew["call_buy"] - ew["call_sell"]
        net_put = ew["put_buy"] - ew["put_sell"]
        total = ew["total_size"]
        
        results.append({
            "date": DATE_STR,
            "ts": e["ts"],
            "spread_side": e.get("spread_side", ""),
            "call_buy": ew["call_buy"],
            "call_sell": ew["call_sell"],
            "put_buy": ew["put_buy"],
            "put_sell": ew["put_sell"],
            "net_call_flow": net_call,
            "net_put_flow": net_put,
            "net_bullish_flow": net_call - net_put,  # positive = bullish
            "total_size": total,
            "total_trades": ew["total_trades"],
            "flow_direction": 1 if (net_call - net_put) > 0 else (-1 if (net_call - net_put) < 0 else 0),
        })
    
    with open(OUTPUT_FILE, "w") as f:
        json.dump(results, f, indent=2, default=str)
    
    print(f"  {DATE_STR}: wrote {len(results)} enriched events", file=sys.stderr)


if __name__ == "__main__":
    main()
