#!/usr/bin/env python3
"""Download historical options OI data from Theta Data for 10 pilot stocks."""

import json
import os
import time
import urllib.request
import urllib.error
from datetime import datetime, timedelta, date
from collections import defaultdict

BASE_URL = "http://localhost:25503/v3"
OUTPUT_DIR = "/Users/lutherbot/.openclaw/workspace/data/stock_gamma_history"
SYMBOLS = ["TSLA", "PLTR", "ARM", "MSTR", "COIN", "SOFI", "CRWD", "AMD", "SMCI", "AAPL"]
START_DATE = date(2025, 3, 1)
END_DATE = date(2026, 3, 17)
REQUEST_DELAY = 0.6  # seconds between requests

def api_get(endpoint, params):
    """Make a GET request to the Theta Data API."""
    param_str = "&".join(f"{k}={v}" for k, v in params.items())
    url = f"{BASE_URL}/{endpoint}?{param_str}&format=json"
    try:
        req = urllib.request.Request(url)
        req.add_header('Accept', 'application/json')
        with urllib.request.urlopen(req, timeout=120) as resp:
            return json.loads(resp.read().decode())
    except Exception as e:
        print(f"  ERROR: {url} -> {e}")
        return None

def get_expirations(symbol):
    """Get all option expirations for a symbol."""
    data = api_get("option/list/expirations", {"symbol": symbol})
    if not data or "response" not in data:
        return []
    exps = []
    for item in data["response"]:
        exp_str = item.get("expiration", "")
        try:
            exp_date = datetime.strptime(exp_str, "%Y-%m-%d").date()
            exps.append(exp_date)
        except:
            continue
    return sorted(exps)

def get_trading_dates(start, end):
    """Generate trading dates (weekdays) between start and end."""
    dates = []
    current = start
    while current <= end:
        if current.weekday() < 5:  # Mon-Fri
            dates.append(current)
        current += timedelta(days=1)
    return dates

def get_nearest_expirations(trading_date, all_expirations, n=3):
    """Get the n nearest future expirations for a given trading date."""
    future_exps = [e for e in all_expirations if e >= trading_date]
    return future_exps[:n]

def get_month_ranges():
    """Generate (month_start, month_end) tuples covering our date range."""
    ranges = []
    current = START_DATE.replace(day=1)
    while current <= END_DATE:
        month_start = max(current, START_DATE)
        # End of month
        if current.month == 12:
            next_month = current.replace(year=current.year + 1, month=1)
        else:
            next_month = current.replace(month=current.month + 1)
        month_end = min(next_month - timedelta(days=1), END_DATE)
        ranges.append((month_start, month_end))
        current = next_month
    return ranges

def download_symbol(symbol):
    """Download all OI data for a single symbol."""
    print(f"\n{'='*60}")
    print(f"DOWNLOADING: {symbol}")
    print(f"{'='*60}")
    
    # Get all expirations
    all_exps = get_expirations(symbol)
    if not all_exps:
        print(f"  No expirations found for {symbol}, skipping")
        return
    
    # Filter to relevant expirations (>= START_DATE)
    relevant_exps = [e for e in all_exps if e >= START_DATE]
    print(f"  Found {len(all_exps)} total expirations, {len(relevant_exps)} relevant")
    
    # Get all trading dates
    trading_dates = get_trading_dates(START_DATE, END_DATE)
    print(f"  Trading dates to cover: {len(trading_dates)}")
    
    # Figure out which expirations we need to query
    # For each trading date, we need the nearest 3 expirations
    # Collect unique (expiration, month) pairs we need to query
    needed_exps = set()  # set of expiration dates
    date_to_exps = {}  # trading_date -> list of expirations
    
    for td in trading_dates:
        nearest = get_nearest_expirations(td, relevant_exps, n=3)
        date_to_exps[td] = nearest
        for exp in nearest:
            needed_exps.add(exp)
    
    print(f"  Unique expirations to query: {len(needed_exps)}")
    
    # For each needed expiration, figure out the date range we need
    # (from first trading date that uses it to last trading date that uses it)
    exp_ranges = {}
    for exp in sorted(needed_exps):
        dates_using = [td for td, exps in date_to_exps.items() if exp in exps]
        if dates_using:
            exp_ranges[exp] = (min(dates_using), max(dates_using))
    
    # Query by expiration, chunked by month to keep responses manageable
    # Store raw: exp -> {contract_key -> {date -> oi}}
    raw_data = {}  # exp_str -> list of response items
    
    month_ranges = get_month_ranges()
    total_queries = 0
    
    for exp in sorted(exp_ranges.keys()):
        exp_start, exp_end = exp_ranges[exp]
        exp_str = exp.strftime("%Y-%m-%d")
        
        for m_start, m_end in month_ranges:
            # Clip to the range we need for this expiration
            q_start = max(m_start, exp_start)
            q_end = min(m_end, exp_end)
            if q_start > q_end:
                continue
            
            start_str = q_start.strftime("%Y-%m-%d")
            end_str = q_end.strftime("%Y-%m-%d")
            
            time.sleep(REQUEST_DELAY)
            result = api_get("option/history/open_interest", {
                "symbol": symbol,
                "expiration": exp_str,
                "start_date": start_str,
                "end_date": end_str
            })
            total_queries += 1
            
            if result and "response" in result:
                items = result["response"]
                if exp_str not in raw_data:
                    raw_data[exp_str] = []
                raw_data[exp_str].extend(items)
                if total_queries % 10 == 0:
                    print(f"  ... {total_queries} queries done, exp={exp_str}, {start_str} to {end_str}, got {len(items)} contracts")
            else:
                if total_queries % 10 == 0:
                    print(f"  ... {total_queries} queries done, exp={exp_str}, {start_str} to {end_str}, NO DATA")
    
    print(f"  Total API queries for {symbol}: {total_queries}")
    
    # Reorganize by date
    output_data = {}
    
    for td in trading_dates:
        td_str = td.strftime("%Y-%m-%d")
        exps_for_date = date_to_exps[td]
        exp_strs = [e.strftime("%Y-%m-%d") for e in exps_for_date]
        
        contracts = []
        for exp in exps_for_date:
            exp_str = exp.strftime("%Y-%m-%d")
            if exp_str not in raw_data:
                continue
            
            for item in raw_data[exp_str]:
                contract = item.get("contract", {})
                data_points = item.get("data", [])
                
                for dp in data_points:
                    # Parse timestamp to date
                    ts = dp.get("timestamp", "")
                    try:
                        dp_date = ts[:10]  # "2025-03-03"
                        if dp_date == td_str:
                            oi = dp.get("open_interest", 0)
                            if oi > 0:
                                contracts.append({
                                    "strike": contract.get("strike", 0),
                                    "right": contract.get("right", ""),
                                    "expiration": contract.get("expiration", ""),
                                    "open_interest": oi
                                })
                    except:
                        continue
        
        if contracts:
            output_data[td_str] = {
                "expirations_used": exp_strs,
                "contracts": contracts
            }
    
    # Save
    output = {
        "symbol": symbol,
        "download_date": "2026-03-18",
        "dates_covered": len([d for d in output_data if output_data[d]["contracts"]]),
        "data": output_data
    }
    
    output_path = os.path.join(OUTPUT_DIR, f"{symbol}_oi_history.json")
    with open(output_path, 'w') as f:
        json.dump(output, f)
    
    file_size = os.path.getsize(output_path) / (1024 * 1024)
    print(f"  ✓ Saved {symbol}: {output['dates_covered']} dates with data, {file_size:.1f} MB")
    return output['dates_covered']

def main():
    print("=" * 60)
    print("Theta Data Options OI Historical Download")
    print(f"Symbols: {', '.join(SYMBOLS)}")
    print(f"Date range: {START_DATE} to {END_DATE}")
    print("=" * 60)
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    results = {}
    for symbol in SYMBOLS:
        try:
            count = download_symbol(symbol)
            results[symbol] = count
        except Exception as e:
            print(f"  FAILED: {symbol} -> {e}")
            import traceback
            traceback.print_exc()
            results[symbol] = 0
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    for sym, count in results.items():
        print(f"  {sym}: {count} trading days with data")
    print("Done!")

if __name__ == "__main__":
    main()
