#!/usr/bin/env python3
"""Download OI history for remaining symbols using monthly expirations only.
TSLA already done. Uses ~13 API calls per symbol instead of ~55.
"""

import json
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timedelta, date
from collections import defaultdict
from pathlib import Path

BASE = "http://localhost:25503/v3"
OUTPUT_DIR = Path("/Users/lutherbot/.openclaw/workspace/data/stock_gamma_history")
SYMBOLS = ["PLTR", "ARM", "MSTR", "COIN", "SOFI", "CRWD", "AMD", "SMCI", "AAPL"]

def api_get(url, timeout=180):
    try:
        with urllib.request.urlopen(url, timeout=timeout) as resp:
            return json.loads(resp.read())
    except urllib.error.HTTPError as e:
        if e.code == 472:
            return None
        print(f"  HTTP {e.code}", flush=True)
        return None
    except Exception as e:
        print(f"  Error: {e}", flush=True)
        return None

def get_trading_dates():
    dates = []
    d = date(2025, 3, 3)
    end = date(2026, 3, 17)
    while d <= end:
        if d.weekday() < 5:
            dates.append(d)
        d += timedelta(days=1)
    return dates

def get_monthly_expirations(symbol):
    """Get expirations, filter to ~monthly (3rd Friday or closest available)."""
    url = f"{BASE}/option/list/expirations?symbol={symbol}&format=json"
    data = api_get(url)
    if not data:
        return []
    
    all_exps = []
    for e in data.get("response", []):
        exp_date = datetime.strptime(e["expiration"], "%Y-%m-%d").date()
        if date(2025, 3, 1) <= exp_date <= date(2026, 4, 30):
            all_exps.append(exp_date)
    all_exps.sort()
    
    # Pick monthly expirations (3rd Friday of each month)
    monthly = []
    for exp in all_exps:
        if exp.weekday() == 4 and 15 <= exp.day <= 21:  # 3rd Friday
            monthly.append(exp)
    
    # If we got fewer than 10 monthlies, supplement with bi-weekly picks
    if len(monthly) < 10:
        monthly = []
        last = None
        for exp in all_exps:
            if last is None or (exp - last).days >= 14:
                monthly.append(exp)
                last = exp
    
    print(f"  {len(all_exps)} total exps -> {len(monthly)} monthly", flush=True)
    return monthly

def find_nearest_exp(trade_date, expirations):
    """Find nearest expiration >= trade_date within 45 days."""
    best = None
    for exp in expirations:
        diff = (exp - trade_date).days
        if 0 <= diff <= 45:
            if best is None or diff < (best - trade_date).days:
                best = exp
    return best

def parse_oi_response(data):
    """Parse bulk OI response into per-date aggregates."""
    if not data or "response" not in data:
        return {}
    
    agg = defaultdict(lambda: {"call_oi": 0, "put_oi": 0, "strikes": set()})
    
    for item in data.get("response", []):
        contract = item.get("contract", {})
        right = contract.get("right", "").upper()
        strike = contract.get("strike")
        
        for d in item.get("data", []):
            oi = d.get("open_interest", 0)
            ds = d.get("timestamp", "")[:10]
            if right == "CALL":
                agg[ds]["call_oi"] += oi
            elif right == "PUT":
                agg[ds]["put_oi"] += oi
            if strike is not None:
                agg[ds]["strikes"].add(strike)
    
    result = {}
    for ds, info in agg.items():
        total = info["call_oi"] + info["put_oi"]
        if total > 0:
            result[ds] = {
                "call_oi": info["call_oi"],
                "put_oi": info["put_oi"],
                "total_oi": total,
                "put_call_ratio": round(info["put_oi"] / info["call_oi"], 4) if info["call_oi"] > 0 else 999.0,
                "n_strikes": len(info["strikes"])
            }
    return result

def process_symbol(symbol, trading_dates):
    """Download OI history for one symbol."""
    print(f"\n{'='*50}", flush=True)
    print(f"  {symbol}: getting expirations...", flush=True)
    
    monthly_exps = get_monthly_expirations(symbol)
    if not monthly_exps:
        print(f"  {symbol}: NO EXPIRATIONS, skipping", flush=True)
        return None
    
    # Group trading dates by nearest monthly expiration
    exp_groups = defaultdict(list)
    unmapped = 0
    for td in trading_dates:
        exp = find_nearest_exp(td, monthly_exps)
        if exp:
            exp_groups[exp].append(td)
        else:
            unmapped += 1
    
    n_queries = len(exp_groups)
    print(f"  {symbol}: {n_queries} API calls needed, {unmapped} dates unmapped", flush=True)
    
    all_daily = {}
    
    for i, (exp, dates) in enumerate(sorted(exp_groups.items())):
        start_d = min(dates).strftime("%Y-%m-%d")
        end_d = max(dates).strftime("%Y-%m-%d")
        exp_s = exp.strftime("%Y-%m-%d")
        
        url = (f"{BASE}/option/history/open_interest?"
               f"symbol={symbol}&expiration={exp_s}"
               f"&start_date={start_d}&end_date={end_d}&format=json")
        
        t0 = time.time()
        data = api_get(url)
        elapsed = time.time() - t0
        time.sleep(0.5)
        
        if data is None:
            print(f"  {symbol}: exp {exp_s} -> no data ({elapsed:.0f}s)", flush=True)
            continue
        
        parsed = parse_oi_response(data)
        matched = 0
        for td in dates:
            ds = td.strftime("%Y-%m-%d")
            if ds in parsed:
                all_daily[ds] = {"date": ds, "expiration": exp_s, **parsed[ds]}
                matched += 1
        
        print(f"  {symbol}: [{i+1}/{n_queries}] exp={exp_s} -> {matched} days ({elapsed:.0f}s) | total={len(all_daily)}", flush=True)
    
    daily_data = [all_daily[k] for k in sorted(all_daily.keys())]
    print(f"  {symbol} DONE: {len(daily_data)} data points", flush=True)
    return daily_data

def main():
    start = time.time()
    print("=" * 50, flush=True)
    print("OI History Download - Monthly Expirations", flush=True)
    print(f"Symbols: {', '.join(SYMBOLS)}", flush=True)
    print("=" * 50, flush=True)
    
    trading_dates = get_trading_dates()
    print(f"Trading dates: {len(trading_dates)} ({trading_dates[0]} to {trading_dates[-1]})", flush=True)
    
    # Check which symbols already have good data
    done = {}
    for sym in list(SYMBOLS):
        path = OUTPUT_DIR / f"{sym}_oi_history.json"
        if path.exists():
            try:
                with open(path) as f:
                    d = json.load(f)
                n = len(d.get("daily_data", []))
                if n >= 200:
                    print(f"  {sym}: already has {n} data points, SKIPPING", flush=True)
                    done[sym] = n
                    SYMBOLS.remove(sym)
            except:
                pass
    
    results = dict(done)
    
    for symbol in SYMBOLS:
        elapsed = time.time() - start
        if elapsed > 1700:
            print(f"\nTime limit ({elapsed/60:.1f} min)", flush=True)
            break
        
        try:
            daily_data = process_symbol(symbol, trading_dates)
            if daily_data:
                output = {
                    "symbol": symbol,
                    "download_date": "2026-03-18",
                    "daily_data": daily_data
                }
                outpath = OUTPUT_DIR / f"{symbol}_oi_history.json"
                with open(outpath, "w") as f:
                    json.dump(output, f, indent=2)
                print(f"  Saved {outpath.name}", flush=True)
                results[symbol] = len(daily_data)
        except Exception as e:
            print(f"  FATAL {symbol}: {e}", flush=True)
            import traceback; traceback.print_exc()
    
    elapsed = time.time() - start
    print(f"\n{'='*50}", flush=True)
    print(f"COMPLETE in {elapsed/60:.1f} minutes", flush=True)
    for sym in ["TSLA"] + [s for s in results if s != "TSLA"]:
        if sym in results:
            print(f"  {sym}: {results[sym]} data points", flush=True)

if __name__ == "__main__":
    main()
