from __future__ import annotations

import pymysql
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from config import DB_HOST, DB_NAME, DB_PASS, DB_PORT, DB_USER


@contextmanager
def db():
    conn = pymysql.connect(
        host=DB_HOST,
        user=DB_USER,
        password=DB_PASS,
        database=DB_NAME,
        port=DB_PORT,
        charset="utf8mb4",
        cursorclass=pymysql.cursors.DictCursor,
        autocommit=False,
    )
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()


def fetchone(sql: str, params: tuple = ()) -> Optional[Dict[str, Any]]:
    with db() as conn:
        with conn.cursor() as cur:
            cur.execute(sql, params)
            return cur.fetchone()


def fetchall(sql: str, params: tuple = ()) -> List[Dict[str, Any]]:
    with db() as conn:
        with conn.cursor() as cur:
            cur.execute(sql, params)
            return list(cur.fetchall())


def execute(sql: str, params: tuple = ()) -> int:
    with db() as conn:
        with conn.cursor() as cur:
            cur.execute(sql, params)
            return cur.lastrowid


def get_setting(key: str, default: str = "") -> str:
    row = fetchone("SELECT value FROM settings WHERE `key`=%s", (key,))
    return str(row["value"]) if row and row.get("value") is not None else default


def set_setting(key: str, value: str | None) -> None:
    execute(
        "INSERT INTO settings (`key`, `value`) VALUES (%s, %s) ON DUPLICATE KEY UPDATE value=VALUES(value)",
        (key, value),
    )


def ensure_model(telegram_id: int, username: str | None, display_name: str | None) -> Dict[str, Any]:
    model = fetchone("SELECT * FROM models WHERE telegram_id=%s", (telegram_id,))
    clean_username = (username or "").lstrip("@") or None

    if model:
        execute(
            "UPDATE models SET username=%s, display_name=%s WHERE id=%s",
            (clean_username, display_name, model["id"]),
        )
        return fetchone("SELECT * FROM models WHERE id=%s", (model["id"],))

    model_id = execute(
        "INSERT INTO models (telegram_id, username, display_name, status) VALUES (%s, %s, %s, 'pending')",
        (telegram_id, clean_username, display_name),
    )
    return fetchone("SELECT * FROM models WHERE id=%s", (model_id,))


def get_active_packages() -> List[Dict[str, Any]]:
    return fetchall(
        "SELECT * FROM credit_packages WHERE active=1 ORDER BY sort_order ASC, id ASC"
    )


def create_payment(
    model_id: int,
    package_id: int,
    amount: float,
    credits: int,
    provider: str,
    external_reference: str,
    payment_url: str = "",
    qr_code: str = "",
    provider_payment_id: str = "",
    raw_response: str = "",
) -> int:
    return execute(
        """
        INSERT INTO payments
        (model_id, package_id, amount, credits, provider, external_reference, payment_url, qr_code, provider_payment_id, raw_response)
        VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
        """,
        (
            model_id,
            package_id,
            amount,
            credits,
            provider,
            external_reference,
            payment_url,
            qr_code,
            provider_payment_id,
            raw_response,
        ),
    )


def spend_one_credit(model_id: int, reason: str) -> bool:
    with db() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT credit_balance FROM models WHERE id=%s FOR UPDATE", (model_id,))
            row = cur.fetchone()

            if not row or int(row["credit_balance"]) <= 0:
                return False

            cur.execute("UPDATE models SET credit_balance = credit_balance - 1 WHERE id=%s", (model_id,))
            cur.execute(
                "INSERT INTO credit_transactions (model_id, type, amount, reason) VALUES (%s,'spend',1,%s)",
                (model_id, reason),
            )

            return True


def ads_today_count(model_id: int) -> int:
    row = fetchone(
        "SELECT COUNT(*) AS total FROM ads WHERE model_id=%s AND DATE(created_at)=CURDATE() AND status NOT IN ('rejected','cancelled')",
        (model_id,),
    )
    return int(row["total"]) if row else 0


def save_ad(
    model: Dict[str, Any],
    answers: Dict[str, str],
    media_file_id: str,
    media_type: str,
    final_text: str,
    status: str,
) -> int:
    with db() as conn:
        with conn.cursor() as cur:
            cur.execute(
                """
                INSERT INTO ads (model_id, telegram_user_id, username, final_text, media_file_id, media_type, status, credit_spent)
                VALUES (%s,%s,%s,%s,%s,%s,%s,1)
                """,
                (
                    model["id"],
                    model["telegram_id"],
                    model.get("username"),
                    final_text,
                    media_file_id,
                    media_type,
                    status,
                ),
            )

            ad_id = cur.lastrowid

            labels = {
                "video_call": "Faz vídeo chamada?",
                "sexting": "Faz sexting?",
                "packs": "Vende packs/prévias?",
                "available": "Está disponível agora?",
            }

            for key, answer in answers.items():
                cur.execute(
                    "INSERT INTO ad_answers (ad_id, question_key, question_label, answer) VALUES (%s,%s,%s,%s)",
                    (ad_id, key, labels.get(key, key), answer),
                )

            return int(ad_id)


def upsert_group(chat_id: int, title: str, chat_type: str, status: str = "pending") -> None:
    execute(
        """
        INSERT INTO target_groups (chat_id, title, type, status)
        VALUES (%s,%s,%s,%s)
        ON DUPLICATE KEY UPDATE title=VALUES(title), type=VALUES(type)
        """,
        (chat_id, title, chat_type, status),
    )



def update_model_cpf_cnpj(model_id: int, cpf_cnpj: str) -> None:
    execute(
        "UPDATE models SET cpf_cnpj=%s WHERE id=%s",
        (cpf_cnpj, model_id),
    )

