#!/usr/bin/env python3 -u
"""Download historical options OI data from Theta Data - optimized version.

Strategy: For each symbol, group trading dates by their nearest expiration,
then make ONE bulk API call per expiration covering all relevant dates.
This minimizes the number of API calls (one per unique expiration per symbol).
"""

import json
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timedelta, date
from collections import defaultdict
from pathlib import Path

BASE = "http://localhost:25503/v3"
OUTPUT_DIR = Path("/Users/lutherbot/.openclaw/workspace/data/stock_gamma_history")
SYMBOLS = ["TSLA", "PLTR", "ARM", "MSTR", "COIN", "SOFI", "CRWD", "AMD", "SMCI", "AAPL"]

def api_get(url, timeout=120):
    """Make API call, return parsed JSON or None."""
    try:
        req = urllib.request.Request(url)
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return json.loads(resp.read())
    except urllib.error.HTTPError as e:
        if e.code == 472:
            return None
        print(f"  HTTP {e.code}: {url[:100]}", flush=True)
        return None
    except Exception as e:
        print(f"  Error: {e} for {url[:100]}", flush=True)
        return None

def get_trading_dates():
    """Generate trading days from 2025-03-03 to 2026-03-17."""
    dates = []
    d = date(2025, 3, 3)
    end = date(2026, 3, 17)
    while d <= end:
        if d.weekday() < 5:
            dates.append(d)
        d += timedelta(days=1)
    return dates

def get_expirations(symbol):
    """Get expirations for symbol, filtered to our date range."""
    url = f"{BASE}/option/list/expirations?symbol={symbol}&format=json"
    data = api_get(url)
    if not data:
        return []
    exps = []
    for e in data.get("response", []):
        exp_date = datetime.strptime(e["expiration"], "%Y-%m-%d").date()
        # Only keep expirations from 2025-03-01 to 2026-04-30
        if date(2025, 3, 1) <= exp_date <= date(2026, 4, 30):
            exps.append(exp_date)
    return sorted(exps)

def find_nearest_expiration(trade_date, expirations):
    """Find nearest expiration >= trade_date within 10 days (nearest-expiry weekly/monthly)."""
    best = None
    for exp in expirations:
        diff = (exp - trade_date).days
        if 0 <= diff <= 10:
            if best is None or diff < (best - trade_date).days:
                best = exp
    return best

def process_oi_response(data):
    """Parse OI response into per-date aggregated data.
    Returns dict: date_str -> {call_oi, put_oi, n_strikes}
    """
    if not data or "response" not in data:
        return {}
    
    # Aggregate by date
    date_data = defaultdict(lambda: {"call_oi": 0, "put_oi": 0, "strikes": set()})
    
    for item in data.get("response", []):
        contract = item.get("contract", {})
        right = contract.get("right", "").upper()
        strike = contract.get("strike")
        
        for d in item.get("data", []):
            oi = d.get("open_interest", 0)
            ts = d.get("timestamp", "")
            date_str = ts[:10]
            
            if right == "CALL":
                date_data[date_str]["call_oi"] += oi
            elif right == "PUT":
                date_data[date_str]["put_oi"] += oi
            
            if strike is not None:
                date_data[date_str]["strikes"].add(strike)
    
    # Convert to final format
    result = {}
    for date_str, info in date_data.items():
        total = info["call_oi"] + info["put_oi"]
        if total > 0:
            result[date_str] = {
                "call_oi": info["call_oi"],
                "put_oi": info["put_oi"],
                "total_oi": total,
                "put_call_ratio": round(info["put_oi"] / info["call_oi"], 4) if info["call_oi"] > 0 else 999.0,
                "n_strikes": len(info["strikes"])
            }
    return result

def process_symbol(symbol, trading_dates, expirations):
    """Process one symbol. Returns list of daily data entries."""
    print(f"\n{'='*50}", flush=True)
    print(f"Processing {symbol}...", flush=True)
    
    # Group trading dates by their nearest expiration
    exp_to_dates = defaultdict(list)
    no_exp_count = 0
    for td in trading_dates:
        exp = find_nearest_expiration(td, expirations)
        if exp:
            exp_to_dates[exp].append(td)
        else:
            no_exp_count += 1
    
    print(f"  {len(exp_to_dates)} unique expirations to query, {no_exp_count} dates with no matching exp", flush=True)
    
    all_daily = {}  # date_str -> {date, expiration, call_oi, ...}
    
    for i, (exp, dates) in enumerate(sorted(exp_to_dates.items())):
        start_d = min(dates)
        end_d = max(dates)
        
        url = (f"{BASE}/option/history/open_interest?"
               f"symbol={symbol}&expiration={exp.strftime('%Y-%m-%d')}"
               f"&start_date={start_d.strftime('%Y-%m-%d')}"
               f"&end_date={end_d.strftime('%Y-%m-%d')}&format=json")
        
        data = api_get(url)
        time.sleep(0.5)
        
        if data is None:
            continue
        
        parsed = process_oi_response(data)
        
        for td in dates:
            ds = td.strftime("%Y-%m-%d")
            if ds in parsed:
                all_daily[ds] = {
                    "date": ds,
                    "expiration": exp.strftime("%Y-%m-%d"),
                    **parsed[ds]
                }
        
        if (i + 1) % 10 == 0 or i == len(exp_to_dates) - 1:
            print(f"  {symbol}: {i+1}/{len(exp_to_dates)} expirations done, {len(all_daily)} data points so far", flush=True)
    
    # Sort by date
    daily_data = [all_daily[k] for k in sorted(all_daily.keys())]
    print(f"  {symbol} COMPLETE: {len(daily_data)} data points", flush=True)
    return daily_data

def main():
    start_time = time.time()
    print("=" * 50, flush=True)
    print("Theta Data OI Download v2 - Optimized", flush=True)
    print("=" * 50, flush=True)
    
    trading_dates = get_trading_dates()
    print(f"Trading dates: {len(trading_dates)} ({trading_dates[0]} to {trading_dates[-1]})", flush=True)
    
    results_summary = {}
    
    for symbol in SYMBOLS:
        elapsed = time.time() - start_time
        if elapsed > 1600:
            print(f"\nTime limit approaching ({elapsed/60:.1f} min), stopping", flush=True)
            break
        
        try:
            expirations = get_expirations(symbol)
            if not expirations:
                print(f"  No expirations for {symbol}, skipping", flush=True)
                continue
            print(f"  {symbol}: {len(expirations)} expirations in range", flush=True)
            
            daily_data = process_symbol(symbol, trading_dates, expirations)
            
            if daily_data:
                output = {
                    "symbol": symbol,
                    "download_date": "2026-03-18",
                    "daily_data": daily_data
                }
                outpath = OUTPUT_DIR / f"{symbol}_oi_history.json"
                with open(outpath, "w") as f:
                    json.dump(output, f, indent=2)
                print(f"  Saved {outpath.name} ({len(daily_data)} days)", flush=True)
                results_summary[symbol] = len(daily_data)
            else:
                print(f"  No data for {symbol}", flush=True)
                
        except Exception as e:
            print(f"  FATAL error for {symbol}: {e}", flush=True)
            import traceback
            traceback.print_exc()
            continue
    
    elapsed = time.time() - start_time
    print(f"\n{'='*50}", flush=True)
    print(f"ALL DONE in {elapsed/60:.1f} minutes", flush=True)
    for sym, count in results_summary.items():
        print(f"  {sym}: {count} data points", flush=True)

if __name__ == "__main__":
    main()
