#!/usr/bin/env python3 -u
"""Download historical options OI - v3 monthly expirations only for speed.

Uses only monthly (3rd Friday) expirations to minimize API calls.
Each symbol needs ~13 API calls instead of ~55-70 with weeklies.
"""

import json
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timedelta, date
from collections import defaultdict
from pathlib import Path

BASE = "http://localhost:25503/v3"
OUTPUT_DIR = Path("/Users/lutherbot/.openclaw/workspace/data/stock_gamma_history")

# Skip TSLA (already done)
SYMBOLS = ["PLTR", "ARM", "MSTR", "COIN", "SOFI", "CRWD", "AMD", "SMCI", "AAPL"]

def api_get(url, timeout=120):
    try:
        with urllib.request.urlopen(url, timeout=timeout) as resp:
            return json.loads(resp.read())
    except urllib.error.HTTPError as e:
        if e.code == 472:
            return None
        print(f"  HTTP {e.code}: {url[:120]}", flush=True)
        return None
    except Exception as e:
        print(f"  Error: {e}", flush=True)
        return None

def get_trading_dates():
    dates = []
    d = date(2025, 3, 3)
    end = date(2026, 3, 17)
    while d <= end:
        if d.weekday() < 5:
            dates.append(d)
        d += timedelta(days=1)
    return dates

def get_expirations(symbol):
    url = f"{BASE}/option/list/expirations?symbol={symbol}&format=json"
    data = api_get(url)
    if not data:
        return []
    exps = []
    for e in data.get("response", []):
        exp_date = datetime.strptime(e["expiration"], "%Y-%m-%d").date()
        if date(2025, 3, 1) <= exp_date <= date(2026, 4, 30):
            exps.append(exp_date)
    return sorted(exps)

def is_monthly_expiration(d):
    """Check if date is a 3rd Friday (monthly options expiration)."""
    if d.weekday() != 4:  # Not Friday
        return False
    # 3rd Friday: day 15-21
    return 15 <= d.day <= 21

def find_nearest_expiration(trade_date, expirations):
    """Find nearest expiration >= trade_date within 45 days."""
    best = None
    for exp in expirations:
        diff = (exp - trade_date).days
        if 0 <= diff <= 45:
            if best is None or diff < (best - trade_date).days:
                best = exp
    return best

def process_oi_response(data):
    if not data or "response" not in data:
        return {}
    
    date_data = defaultdict(lambda: {"call_oi": 0, "put_oi": 0, "strikes": set()})
    
    for item in data.get("response", []):
        contract = item.get("contract", {})
        right = contract.get("right", "").upper()
        strike = contract.get("strike")
        
        for d in item.get("data", []):
            oi = d.get("open_interest", 0)
            ts = d.get("timestamp", "")
            date_str = ts[:10]
            
            if right == "CALL":
                date_data[date_str]["call_oi"] += oi
            elif right == "PUT":
                date_data[date_str]["put_oi"] += oi
            if strike is not None:
                date_data[date_str]["strikes"].add(strike)
    
    result = {}
    for date_str, info in date_data.items():
        total = info["call_oi"] + info["put_oi"]
        if total > 0:
            result[date_str] = {
                "call_oi": info["call_oi"],
                "put_oi": info["put_oi"],
                "total_oi": total,
                "put_call_ratio": round(info["put_oi"] / info["call_oi"], 4) if info["call_oi"] > 0 else 999.0,
                "n_strikes": len(info["strikes"])
            }
    return result

def process_symbol(symbol, trading_dates, expirations):
    print(f"\n{'='*50}", flush=True)
    print(f"Processing {symbol}...", flush=True)
    
    # Filter to monthly expirations only
    monthly_exps = [e for e in expirations if is_monthly_expiration(e)]
    
    # If very few monthlies, also include some larger weeklies
    if len(monthly_exps) < 10:
        # Fallback: pick one expiration per ~3 weeks
        all_exps_sorted = sorted(expirations)
        monthly_exps = []
        last_added = None
        for e in all_exps_sorted:
            if last_added is None or (e - last_added).days >= 14:
                monthly_exps.append(e)
                last_added = e
    
    print(f"  Using {len(monthly_exps)} expirations (from {len(expirations)} total)", flush=True)
    
    # Group trading dates by nearest monthly expiration
    exp_to_dates = defaultdict(list)
    no_exp = 0
    for td in trading_dates:
        exp = find_nearest_expiration(td, monthly_exps)
        if exp:
            exp_to_dates[exp].append(td)
        else:
            no_exp += 1
    
    print(f"  {len(exp_to_dates)} expirations to query, {no_exp} dates unmapped", flush=True)
    
    all_daily = {}
    
    for i, (exp, dates) in enumerate(sorted(exp_to_dates.items())):
        start_d = min(dates)
        end_d = max(dates)
        
        url = (f"{BASE}/option/history/open_interest?"
               f"symbol={symbol}&expiration={exp.strftime('%Y-%m-%d')}"
               f"&start_date={start_d.strftime('%Y-%m-%d')}"
               f"&end_date={end_d.strftime('%Y-%m-%d')}&format=json")
        
        data = api_get(url)
        time.sleep(0.5)
        
        if data is None:
            print(f"  Skip exp {exp} (no data)", flush=True)
            continue
        
        parsed = process_oi_response(data)
        
        for td in dates:
            ds = td.strftime("%Y-%m-%d")
            if ds in parsed:
                all_daily[ds] = {
                    "date": ds,
                    "expiration": exp.strftime("%Y-%m-%d"),
                    **parsed[ds]
                }
        
        print(f"  {symbol}: {i+1}/{len(exp_to_dates)} exps done, {len(all_daily)} data points", flush=True)
    
    daily_data = [all_daily[k] for k in sorted(all_daily.keys())]
    print(f"  {symbol} COMPLETE: {len(daily_data)} data points", flush=True)
    return daily_data

def main():
    start_time = time.time()
    print("=" * 50, flush=True)
    print("Theta Data OI Download v3 - Monthly Expirations", flush=True)
    print("=" * 50, flush=True)
    
    trading_dates = get_trading_dates()
    print(f"Trading dates: {len(trading_dates)} ({trading_dates[0]} to {trading_dates[-1]})", flush=True)
    
    results = {}
    
    for symbol in SYMBOLS:
        elapsed = time.time() - start_time
        if elapsed > 1600:
            print(f"\nTime limit ({elapsed/60:.1f} min), stopping", flush=True)
            break
        
        try:
            expirations = get_expirations(symbol)
            if not expirations:
                print(f"  No expirations for {symbol}, skipping", flush=True)
                continue
            
            daily_data = process_symbol(symbol, trading_dates, expirations)
            
            if daily_data:
                output = {
                    "symbol": symbol,
                    "download_date": "2026-03-18",
                    "daily_data": daily_data
                }
                outpath = OUTPUT_DIR / f"{symbol}_oi_history.json"
                with open(outpath, "w") as f:
                    json.dump(output, f, indent=2)
                print(f"  Saved {outpath.name}", flush=True)
                results[symbol] = len(daily_data)
                
        except Exception as e:
            print(f"  FATAL {symbol}: {e}", flush=True)
            import traceback; traceback.print_exc()
            continue
    
    elapsed = time.time() - start_time
    print(f"\n{'='*50}", flush=True)
    print(f"DONE in {elapsed/60:.1f} minutes", flush=True)
    for sym, count in results.items():
        print(f"  {sym}: {count} data points", flush=True)

if __name__ == "__main__":
    main()
