import os
from functools import lru_cache

import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


MODEL_NAME = os.getenv("NLLB_MODEL", "facebook/nllb-200-distilled-600M")
MAX_LENGTH = int(os.getenv("NLLB_MAX_LENGTH", "384"))

LANGUAGE_ALIASES = {
    "es": "spa_Latn",
    "es-es": "spa_Latn",
    "en": "eng_Latn",
    "en-us": "eng_Latn",
    "fr": "fra_Latn",
    "fr-fr": "fra_Latn",
    "ca": "cat_Latn",
    "ca-es": "cat_Latn",
    "de": "deu_Latn",
    "it": "ita_Latn",
    "pt": "por_Latn",
}


class TranslateRequest(BaseModel):
    text: str | None = Field(default=None, min_length=1)
    source_lang: str | None = None
    target_lang: str | None = None
    sourceLanguage: str | None = None
    targetLanguage: str | None = None
    source: str | None = None
    target: str | None = None
    q: str | None = None


class TranslateResponse(BaseModel):
    translatedText: str
    source_lang: str
    target_lang: str
    model: str


def normalize_language_code(value: str | None) -> str:
    if not value:
        raise HTTPException(status_code=400, detail="Missing language code")
    key = value.strip().replace("_", "-").lower()
    return LANGUAGE_ALIASES.get(key, value.strip())


@lru_cache(maxsize=1)
def get_translator():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    model.eval()
    return tokenizer, model


def translate_text(text: str, source_lang: str, target_lang: str) -> str:
    tokenizer, model = get_translator()
    tokenizer.src_lang = source_lang
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH)
    forced_bos_token_id = tokenizer.convert_tokens_to_ids(target_lang)
    with torch.no_grad():
        generated = model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id,
            max_length=MAX_LENGTH,
        )
    return tokenizer.batch_decode(generated, skip_special_tokens=True)[0].strip()


app = FastAPI(title="ZGZ Local NLLB Service", version="1.0.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event("startup")
def startup():
    get_translator()


@app.api_route("/health", methods=["GET", "HEAD"])
def health():
    return {"status": "ok", "model": MODEL_NAME}


@app.post("/translate", response_model=TranslateResponse)
def translate(request: TranslateRequest):
    text = (request.text or request.q or "").strip()
    if not text:
        raise HTTPException(status_code=400, detail="Missing text")

    source_lang = normalize_language_code(request.source_lang or request.sourceLanguage or request.source)
    target_lang = normalize_language_code(request.target_lang or request.targetLanguage or request.target)
    translated = translate_text(text, source_lang, target_lang)

    return TranslateResponse(
        translatedText=translated,
        source_lang=source_lang,
        target_lang=target_lang,
        model=MODEL_NAME,
    )
