"""
Fossati AI Bot — Backtesting Engine
Replay candle a candle do SMCAnalyzer + checklist Fossati.
Calcula: win rate, P&L, max drawdown, Sharpe, SQN, profit factor.

Uso programático:
    from backtest import Backtester
    bt = Backtester(symbols=["BTCUSDT"], days=90, interval="15m")
    results = bt.run()
    bt.save("backtest_results.json")
"""

from __future__ import annotations

import json
import logging
import math
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import List, Optional, Dict

import numpy as np
import pandas as pd
import ta
from binance.client import Client
from binance.exceptions import BinanceAPIException

import config
from smc_analyzer import SMCAnalyzer, Direction

log = logging.getLogger("fossati.backtest")


# ─── ESTRUTURAS ─────────────────────────────────────────────
@dataclass
class BacktestTrade:
    symbol: str
    direction: str            # "bullish" / "bearish"
    entry_time: str
    entry_price: float
    exit_time: str
    exit_price: float
    stop_loss: float
    take_profit: float
    leverage: int
    score: int
    pnl_pct: float            # P&L em % do saldo (já com alavancagem)
    pnl_usdt: float
    result: str               # "tp" / "sl" / "timeout"
    bars_held: int


@dataclass
class BacktestResult:
    period: str
    symbols: List[str]
    interval: str
    total_trades: int = 0
    winning_trades: int = 0
    losing_trades: int = 0
    win_rate: float = 0.0
    total_pnl_pct: float = 0.0
    total_pnl_usdt: float = 0.0
    max_drawdown_pct: float = 0.0
    sharpe_ratio: float = 0.0
    sqn: float = 0.0
    profit_factor: float = 0.0
    avg_win_pct: float = 0.0
    avg_loss_pct: float = 0.0
    largest_win_pct: float = 0.0
    largest_loss_pct: float = 0.0
    avg_bars_held: float = 0.0
    equity_curve: List[Dict] = field(default_factory=list)
    by_symbol: Dict[str, Dict] = field(default_factory=dict)
    trades: List[Dict] = field(default_factory=list)
    started_at: str = ""
    finished_at: str = ""
    initial_balance: float = 0.0


# ─── ENGINE ─────────────────────────────────────────────────
class Backtester:
    """
    Engine de backtest que reusa o `SMCAnalyzer` e o checklist Fossati.
    Mantém-se fiel à lógica de `main.scan_symbol`, mas sem chamadas de IA
    nem execução real — usa os preços do próprio dataset histórico.
    """

    INTERVAL_MINUTES = {
        "1m": 1, "3m": 3, "5m": 5, "15m": 15, "30m": 30,
        "1h": 60, "2h": 120, "4h": 240, "6h": 360, "12h": 720, "1d": 1440,
    }

    def __init__(
        self,
        symbols: Optional[List[str]] = None,
        days: int = 90,
        interval: str = "15m",
        warmup: int = 200,
        initial_balance: Optional[float] = None,
        max_bars_in_trade: int = 96,
    ):
        self.symbols          = symbols or list(config.SYMBOLS)
        self.days             = days
        self.interval         = interval
        self.warmup           = warmup
        self.max_bars_in_trade = max_bars_in_trade
        self.initial_balance  = float(initial_balance if initial_balance is not None else config.INITIAL_BALANCE)
        self.client           = Client(config.BINANCE_API_KEY, config.BINANCE_API_SECRET)
        self.analyzer         = SMCAnalyzer(config)
        self._trades: List[BacktestTrade] = []
        self._equity: List[Dict] = []

    # ── Coleta de dados ─────────────────────────────────────
    def _fetch_klines(self, symbol: str) -> pd.DataFrame:
        """Busca em páginas até cobrir `self.days`."""
        end_ms   = int(datetime.now(tz=timezone.utc).timestamp() * 1000)
        start_ms = end_ms - self.days * 24 * 60 * 60 * 1000
        all_rows: list = []
        cursor = start_ms
        try:
            while cursor < end_ms:
                batch = self.client.futures_klines(
                    symbol=symbol,
                    interval=self.interval,
                    startTime=cursor,
                    endTime=end_ms,
                    limit=1500,
                )
                if not batch:
                    break
                all_rows.extend(batch)
                last_open = batch[-1][0]
                next_cursor = last_open + self.INTERVAL_MINUTES[self.interval] * 60 * 1000
                if next_cursor <= cursor:
                    break
                cursor = next_cursor
                if len(batch) < 1500:
                    break
        except BinanceAPIException as e:
            log.error(f"Binance erro ao buscar {symbol}: {e.message}")
            return pd.DataFrame()
        except Exception as e:
            log.error(f"Erro ao buscar {symbol}: {e}")
            return pd.DataFrame()

        if not all_rows:
            return pd.DataFrame()

        df = pd.DataFrame(all_rows, columns=[
            "timestamp", "open", "high", "low", "close", "volume",
            "close_time", "quote_volume", "trades",
            "taker_buy_base", "taker_buy_quote", "ignore",
        ])
        for c in ["open", "high", "low", "close", "volume"]:
            df[c] = df[c].astype(float)
        df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True)
        return df.drop_duplicates(subset="timestamp").reset_index(drop=True)

    # ── Indicadores ─────────────────────────────────────────
    @staticmethod
    def _rsi(close: pd.Series, period: int = 14) -> float:
        if len(close) < period + 1:
            return 50.0
        try:
            return float(ta.momentum.RSIIndicator(close, window=period).rsi().iloc[-1])
        except Exception:
            return 50.0

    @staticmethod
    def _vol_ratio(volume: pd.Series, period: int = 20) -> float:
        if len(volume) < period + 1:
            return 1.0
        avg = volume.iloc[-period-1:-1].mean()
        cur = volume.iloc[-1]
        return float(cur / avg) if avg > 0 else 1.0

    # ── Sizing / leverage (reflete config.LEVERAGE_MAP) ─────
    @staticmethod
    def _leverage_for_score(score: int) -> int:
        lev_map = getattr(config, "LEVERAGE_MAP", {})
        if score in lev_map:
            return int(lev_map[score])
        # fallback: default
        return int(getattr(config, "DEFAULT_LEVERAGE", 10))

    # ── Lógica de sinal (mesmo critério de scan_symbol) ─────
    def _evaluate_window(self, df_15m: pd.DataFrame) -> Optional[dict]:
        analysis = self.analyzer.analyze(df_15m, "BT", self.interval)
        if not analysis:
            return None

        rsi       = self._rsi(df_15m["close"])
        vol_ratio = self._vol_ratio(df_15m["volume"])
        trend     = analysis.structure.trend

        signal_direction: Optional[Direction] = None
        if (trend == Direction.BULLISH and analysis.nearest_bullish_ob and
                analysis.in_discount and rsi < 60):
            signal_direction = Direction.BULLISH
        elif (trend == Direction.BEARISH and analysis.nearest_bearish_ob and
                analysis.in_premium and rsi > 40):
            signal_direction = Direction.BEARISH

        if not signal_direction:
            return None

        score, _ = self.analyzer.score_setup(
            analysis, signal_direction,
            volume_ratio=vol_ratio, rsi=rsi, htf_trend=trend,
        )
        if score < int(getattr(config, "MIN_CONFLUENCE_SCORE", 9)):
            return None

        ob = (analysis.nearest_bullish_ob if signal_direction == Direction.BULLISH
              else analysis.nearest_bearish_ob)
        entry = analysis.current_price
        if signal_direction == Direction.BULLISH:
            sl = float(ob.low) * 0.999
            risk = entry - sl
            if risk <= 0:
                return None
            tp = entry + risk * float(getattr(config, "MIN_RR_RATIO", 2.0))
        else:
            sl = float(ob.high) * 1.001
            risk = sl - entry
            if risk <= 0:
                return None
            tp = entry - risk * float(getattr(config, "MIN_RR_RATIO", 2.0))

        return {
            "direction": signal_direction,
            "entry":     entry,
            "stop_loss": sl,
            "take_profit": tp,
            "leverage":  self._leverage_for_score(score),
            "score":     score,
        }

    # ── Simulação de saída ──────────────────────────────────
    def _simulate_exit(
        self,
        df_full: pd.DataFrame,
        start_idx: int,
        signal: dict,
    ) -> Optional[dict]:
        """Percorre candles seguintes até bater SL, TP ou timeout."""
        end_idx = min(start_idx + self.max_bars_in_trade, len(df_full) - 1)
        is_long = signal["direction"] == Direction.BULLISH
        entry   = signal["entry"]
        sl      = signal["stop_loss"]
        tp      = signal["take_profit"]

        for i in range(start_idx + 1, end_idx + 1):
            high = float(df_full["high"].iloc[i])
            low  = float(df_full["low"].iloc[i])

            if is_long:
                if low <= sl:
                    return {"exit_price": sl, "exit_idx": i, "result": "sl"}
                if high >= tp:
                    return {"exit_price": tp, "exit_idx": i, "result": "tp"}
            else:
                if high >= sl:
                    return {"exit_price": sl, "exit_idx": i, "result": "sl"}
                if low <= tp:
                    return {"exit_price": tp, "exit_idx": i, "result": "tp"}

        # timeout — fecha no close do último candle visto
        last_close = float(df_full["close"].iloc[end_idx])
        return {"exit_price": last_close, "exit_idx": end_idx, "result": "timeout"}

    # ── Loop por símbolo ────────────────────────────────────
    def _run_symbol(self, symbol: str, balance_ref: List[float]) -> List[BacktestTrade]:
        log.info(f"[{symbol}] baixando {self.days}d de {self.interval}…")
        df = self._fetch_klines(symbol)
        if df.empty or len(df) < self.warmup + 10:
            log.warning(f"[{symbol}] dados insuficientes ({len(df)} candles)")
            return []
        log.info(f"[{symbol}] {len(df)} candles carregados, simulando…")

        trades: List[BacktestTrade] = []
        in_trade_until_idx = -1
        risk_pct = float(getattr(config, "MAX_RISK_PER_TRADE", 0.10))

        for i in range(self.warmup, len(df) - 1):
            if i <= in_trade_until_idx:
                continue

            window = df.iloc[i - self.warmup:i + 1].reset_index(drop=True)
            sig = self._evaluate_window(window)
            if not sig:
                continue

            exit_info = self._simulate_exit(df, i, sig)
            if not exit_info:
                continue

            entry = sig["entry"]
            exit_price = exit_info["exit_price"]
            is_long = sig["direction"] == Direction.BULLISH

            # P&L em % do preço (sem alavancagem)
            raw_pct = ((exit_price - entry) / entry) if is_long else ((entry - exit_price) / entry)
            # Sizing pelo risco fixo: posição calibrada para perder `risk_pct` do saldo no SL
            if is_long:
                stop_pct = (entry - sig["stop_loss"]) / entry
            else:
                stop_pct = (sig["stop_loss"] - entry) / entry
            if stop_pct <= 0:
                continue
            position_fraction = risk_pct / stop_pct  # fração do saldo alocada ao notional
            balance_pnl_pct  = raw_pct * position_fraction  # já reflete o risco real

            balance = balance_ref[0]
            pnl_usdt = balance * balance_pnl_pct
            balance_ref[0] = balance + pnl_usdt

            entry_ts = df["timestamp"].iloc[i]
            exit_ts  = df["timestamp"].iloc[exit_info["exit_idx"]]

            trade = BacktestTrade(
                symbol=symbol,
                direction=sig["direction"].value,
                entry_time=entry_ts.isoformat(),
                entry_price=round(entry, 6),
                exit_time=exit_ts.isoformat(),
                exit_price=round(exit_price, 6),
                stop_loss=round(sig["stop_loss"], 6),
                take_profit=round(sig["take_profit"], 6),
                leverage=int(sig["leverage"]),
                score=int(sig["score"]),
                pnl_pct=round(balance_pnl_pct * 100, 4),
                pnl_usdt=round(pnl_usdt, 4),
                result=exit_info["result"],
                bars_held=int(exit_info["exit_idx"] - i),
            )
            trades.append(trade)

            self._equity.append({
                "time":    exit_ts.isoformat(),
                "balance": round(balance_ref[0], 4),
                "symbol":  symbol,
            })

            in_trade_until_idx = exit_info["exit_idx"]

        log.info(f"[{symbol}] {len(trades)} trades simulados.")
        return trades

    # ── Métricas ────────────────────────────────────────────
    def _compute_metrics(self) -> BacktestResult:
        symbols = self.symbols
        period_label = f"últimos {self.days} dias ({self.interval})"
        result = BacktestResult(
            period=period_label,
            symbols=symbols,
            interval=self.interval,
            initial_balance=self.initial_balance,
        )
        result.equity_curve = self._equity

        if not self._trades:
            return result

        pnls_pct = np.array([t.pnl_pct for t in self._trades], dtype=float)
        wins     = pnls_pct[pnls_pct > 0]
        losses   = pnls_pct[pnls_pct < 0]

        result.total_trades   = len(self._trades)
        result.winning_trades = int(len(wins))
        result.losing_trades  = int(len(losses))
        result.win_rate       = round(len(wins) / len(pnls_pct) * 100, 2) if len(pnls_pct) else 0.0
        result.total_pnl_pct  = round(float(pnls_pct.sum()), 2)
        result.total_pnl_usdt = round(float(sum(t.pnl_usdt for t in self._trades)), 2)
        result.avg_win_pct    = round(float(wins.mean()), 3) if len(wins) else 0.0
        result.avg_loss_pct   = round(float(losses.mean()), 3) if len(losses) else 0.0
        result.largest_win_pct  = round(float(wins.max()), 3) if len(wins) else 0.0
        result.largest_loss_pct = round(float(losses.min()), 3) if len(losses) else 0.0
        result.avg_bars_held    = round(float(np.mean([t.bars_held for t in self._trades])), 2)

        # Profit factor
        gross_win  = float(wins.sum()) if len(wins) else 0.0
        gross_loss = abs(float(losses.sum())) if len(losses) else 0.0
        result.profit_factor = round(gross_win / gross_loss, 3) if gross_loss > 0 else float("inf") if gross_win > 0 else 0.0

        # Max drawdown sobre a curva de equity
        if self._equity:
            balances = np.array([e["balance"] for e in self._equity], dtype=float)
            peaks    = np.maximum.accumulate(balances)
            dd       = (peaks - balances) / peaks
            result.max_drawdown_pct = round(float(dd.max()) * 100, 2)

        # Sharpe (anualizado, considerando 365 dias)
        if len(pnls_pct) > 1:
            mean_r = pnls_pct.mean()
            std_r  = pnls_pct.std(ddof=1)
            if std_r > 0:
                trades_per_day = len(pnls_pct) / max(self.days, 1)
                result.sharpe_ratio = round(float((mean_r / std_r) * math.sqrt(trades_per_day * 365)), 3)

        # SQN = sqrt(N) * mean / std
        if len(pnls_pct) > 1 and pnls_pct.std(ddof=1) > 0:
            result.sqn = round(float(math.sqrt(len(pnls_pct)) * pnls_pct.mean() / pnls_pct.std(ddof=1)), 3)

        # Por símbolo
        by: Dict[str, Dict] = {}
        for s in symbols:
            ts = [t for t in self._trades if t.symbol == s]
            if not ts:
                continue
            wpct = [t.pnl_pct for t in ts if t.pnl_pct > 0]
            by[s] = {
                "trades":    len(ts),
                "wins":      len(wpct),
                "win_rate":  round(len(wpct) / len(ts) * 100, 2),
                "pnl_pct":   round(sum(t.pnl_pct for t in ts), 2),
                "pnl_usdt":  round(sum(t.pnl_usdt for t in ts), 2),
            }
        result.by_symbol = by
        result.trades    = [asdict(t) for t in self._trades]
        return result

    # ── API pública ─────────────────────────────────────────
    def run(self) -> BacktestResult:
        started = datetime.now(tz=timezone.utc)
        log.info(f"Iniciando backtest: {len(self.symbols)} símbolos · {self.days}d · {self.interval}")
        balance_ref = [self.initial_balance]

        for sym in self.symbols:
            try:
                trades = self._run_symbol(sym, balance_ref)
                self._trades.extend(trades)
            except Exception as e:
                log.exception(f"[{sym}] erro durante backtest: {e}")

        result = self._compute_metrics()
        result.started_at  = started.isoformat()
        result.finished_at = datetime.now(tz=timezone.utc).isoformat()

        log.info(
            f"Backtest concluído: {result.total_trades} trades, "
            f"WR {result.win_rate}%, PnL {result.total_pnl_pct}%, "
            f"DD {result.max_drawdown_pct}%, SQN {result.sqn}"
        )
        return result

    def save(self, path: str = "backtest_results.json") -> Path:
        result = self._compute_metrics()
        # popula started/finished mesmo se save() for chamado isoladamente
        if not result.started_at:
            result.started_at = datetime.now(tz=timezone.utc).isoformat()
            result.finished_at = result.started_at
        out = Path(path)
        out.write_text(json.dumps(asdict(result), indent=2, default=str), encoding="utf-8")
        log.info(f"Resultados salvos em {out.resolve()}")
        return out


def run_and_save(
    symbols: Optional[List[str]] = None,
    days: int = 90,
    interval: str = "15m",
    output: str = "backtest_results.json",
) -> dict:
    """Helper para uso pela rota Flask."""
    bt = Backtester(symbols=symbols, days=days, interval=interval)
    bt.run()
    bt.save(output)
    return json.loads(Path(output).read_text(encoding="utf-8"))
