#!/usr/bin/env python3
"""
Download historical options OI from Theta Data and compute daily net gamma.
All print statements flush immediately for real-time progress.
"""
import csv, json, math, os, sys, time, urllib.request
from datetime import datetime, timedelta

BASE_URL = "http://localhost:25503/v3"
SYMBOLS = ["TSLA", "PLTR", "ARM", "MSTR", "COIN", "SOFI", "CRWD", "AMD", "SMCI", "AAPL"]
SIGMA = 0.30
PI2 = math.sqrt(2 * math.pi)
OUTPUT_DIR = "/Users/lutherbot/.openclaw/workspace/data/stock_gamma_history"
POLYGON_DIR = "/Users/lutherbot/.openclaw/workspace/data/polygon_historical"
END_DATE = datetime(2026, 3, 17)
NUM_DAYS = 250

def log(msg):
    print(msg, flush=True)

def fetch_json(url, retries=1):
    for attempt in range(retries + 1):
        try:
            with urllib.request.urlopen(urllib.request.Request(url), timeout=30) as resp:
                return json.loads(resp.read().decode())
        except Exception as e:
            if attempt < retries:
                time.sleep(10)
            else:
                return None

def get_trading_days(end_date, num_days):
    days = []
    d = end_date
    while len(days) < num_days:
        if d.weekday() < 5:
            days.append(d)
        d -= timedelta(days=1)
    return days

def load_spot_prices(symbol):
    csv_path = os.path.join(POLYGON_DIR, f"{symbol}_5min.csv")
    if not os.path.exists(csv_path):
        return {}
    spots = {}
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            date_str = row['datetime'][:10]
            if date_str not in spots:
                try:
                    spots[date_str] = float(row['open'])
                except (ValueError, KeyError):
                    pass
    return spots

def get_expirations(symbol):
    data = fetch_json(f"{BASE_URL}/option/list/expirations?symbol={symbol}&format=json")
    if not data or 'response' not in data:
        return []
    return sorted([e['expiration'] for e in data['response']])

def find_nearest_friday_exp(expirations, trade_date):
    trade_str = trade_date.strftime("%Y-%m-%d")
    valid = [e for e in expirations if e >= trade_str]
    if not valid:
        return None
    for exp in valid:
        if datetime.strptime(exp, "%Y-%m-%d").weekday() == 4:
            return exp
    return valid[0]

def compute_gamma(spot, strike, dte_days):
    if dte_days <= 0 or spot <= 0 or strike <= 0:
        return 0.0
    t = dte_days / 365.0
    sqrt_t = math.sqrt(t)
    try:
        d1 = math.log(spot / strike) / (SIGMA * sqrt_t)
        return math.exp(-0.5 * d1 * d1) / (spot * SIGMA * sqrt_t * PI2)
    except (ValueError, ZeroDivisionError):
        return 0.0

def process_oi_response(response_data, spot, trade_date, exp_date_str):
    if not response_data or 'response' not in response_data:
        return None
    contracts = response_data['response']
    if not contracts:
        return None
    exp_dt = datetime.strptime(exp_date_str, "%Y-%m-%d")
    dte = (exp_dt - trade_date).days
    if dte < 0:
        return None
    
    call_gamma_total = 0.0
    put_gamma_total = 0.0
    total_oi = 0
    
    for item in contracts:
        contract = item.get('contract', {})
        data_list = item.get('data', [])
        strike = contract.get('strike', 0)
        right = contract.get('right', '').upper()
        if not data_list or strike <= 0:
            continue
        oi = data_list[0].get('open_interest', 0)
        if oi <= 0:
            continue
        total_oi += oi
        gamma = compute_gamma(spot, strike, dte)
        dollar_gamma = gamma * oi * 100 * spot
        if right == 'CALL':
            call_gamma_total += dollar_gamma
        elif right == 'PUT':
            put_gamma_total += dollar_gamma
    
    if total_oi == 0:
        return None
    net_gamma = put_gamma_total - call_gamma_total
    gamma_ratio = put_gamma_total / call_gamma_total if call_gamma_total > 0 else 0
    return {
        'net_gamma': net_gamma, 'call_gamma': call_gamma_total,
        'put_gamma': put_gamma_total, 'total_oi': total_oi,
        'gamma_ratio': round(gamma_ratio, 4),
    }

def process_symbol(symbol, trading_days):
    log(f"\n{'='*60}")
    log(f"Processing {symbol}")
    
    spots = load_spot_prices(symbol)
    if not spots:
        log(f"  No spot data for {symbol}, skipping")
        return {'status': 'error', 'reason': 'no spot data', 'days_processed': 0}
    
    expirations = get_expirations(symbol)
    if not expirations:
        log(f"  No expirations for {symbol}, skipping")
        return {'status': 'error', 'reason': 'no expirations', 'days_processed': 0}
    log(f"  {len(expirations)} expirations, {len(spots)} spot dates")
    
    results = []
    errors = 0
    skipped = 0
    
    for i, trade_date in enumerate(trading_days):
        date_str = trade_date.strftime("%Y-%m-%d")
        
        if (i + 1) % 25 == 0:
            log(f"  {symbol}: {i+1}/{len(trading_days)} days ({len(results)} results, {errors} err, {skipped} skip)")
        
        spot = spots.get(date_str)
        if not spot:
            skipped += 1
            continue
        
        exp = find_nearest_friday_exp(expirations, trade_date)
        if not exp:
            skipped += 1
            continue
        
        url = (f"{BASE_URL}/option/history/open_interest"
               f"?symbol={symbol}&expiration={exp}"
               f"&start_date={date_str}&end_date={date_str}&format=json")
        
        data = fetch_json(url)
        if not data:
            errors += 1
            continue
        
        gamma_result = process_oi_response(data, spot, trade_date, exp)
        if gamma_result:
            results.append({
                'date': date_str,
                'spot': round(spot, 2),
                'net_gamma': round(gamma_result['net_gamma'], 2),
                'call_gamma': round(gamma_result['call_gamma'], 2),
                'put_gamma': round(gamma_result['put_gamma'], 2),
                'total_oi': gamma_result['total_oi'],
                'gamma_ratio': gamma_result['gamma_ratio'],
                'nearest_exp': exp,
            })
        else:
            skipped += 1
        
        time.sleep(0.05)
    
    if results:
        csv_path = os.path.join(OUTPUT_DIR, f"{symbol}_daily_gamma.csv")
        with open(csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['date','spot','net_gamma','call_gamma','put_gamma','total_oi','gamma_ratio','nearest_exp'])
            writer.writeheader()
            writer.writerows(sorted(results, key=lambda x: x['date']))
        log(f"  Saved {len(results)} rows to {csv_path}")
    
    log(f"  {symbol} DONE: {len(results)} results, {errors} errors, {skipped} skipped")
    return {
        'status': 'ok', 'days_processed': len(results),
        'errors': errors, 'skipped': skipped,
        'date_range': f"{results[-1]['date']} to {results[0]['date']}" if results else "N/A"
    }

def main():
    log(f"Gamma Download — {END_DATE.strftime('%Y-%m-%d')} — {NUM_DAYS} days — {len(SYMBOLS)} symbols")
    trading_days = get_trading_days(END_DATE, NUM_DAYS)
    log(f"Range: {trading_days[-1].strftime('%Y-%m-%d')} to {trading_days[0].strftime('%Y-%m-%d')}")
    
    summary = {}
    start_time = time.time()
    
    for sym in SYMBOLS:
        # Skip already completed
        csv_path = os.path.join(OUTPUT_DIR, f"{sym}_daily_gamma.csv")
        if os.path.exists(csv_path):
            row_count = sum(1 for _ in open(csv_path)) - 1
            if row_count >= 200:
                log(f"  SKIP {sym} — {row_count} rows already done")
                summary[sym] = {'status': 'ok', 'days_processed': row_count, 'errors': 0, 'skipped': 0, 'elapsed_sec': 0, 'note': 'cached'}
                continue
        
        sym_start = time.time()
        result = process_symbol(sym, trading_days)
        result['elapsed_sec'] = round(time.time() - sym_start, 1)
        summary[sym] = result
        time.sleep(0.5)
    
    total_elapsed = round(time.time() - start_time, 1)
    
    summary_data = {
        'generated_at': datetime.now().isoformat(),
        'end_date': END_DATE.strftime('%Y-%m-%d'),
        'trading_days': NUM_DAYS,
        'total_elapsed_sec': total_elapsed,
        'symbols': summary
    }
    with open(os.path.join(OUTPUT_DIR, 'download_summary.json'), 'w') as f:
        json.dump(summary_data, f, indent=2)
    
    log(f"\n{'='*60}")
    log(f"ALL DONE in {total_elapsed}s")
    for sym, info in summary.items():
        log(f"  {sym}: {info.get('days_processed', 0)} days, {info.get('errors', 0)} errors")

if __name__ == '__main__':
    main()
