import os
import re
import sys
import glob
import datetime as dt
from typing import List, Tuple, Optional

from dotenv import load_dotenv
# Load environment variables first
load_dotenv()

import psycopg2
from psycopg2.extras import execute_values
from bs4 import BeautifulSoup
from openai import OpenAI

# -----------------------------
# Configuration
# -----------------------------

# Root folder containing your exported HTML manuals
HTML_ROOT = os.getenv("HTML_MANUALS_PATH", "/Users/garyevans/PycharmProjects/BasicRAG/Manual")

# Database connection (adjust as needed)
PG_CONN_INFO = {
    "host": "localhost",
    "port": 5432,
    "dbname": "pgdevai",
    "user": "garyevans",
    "password": os.getenv("DB_PASSWORD")
}

PRODUCT_NAME = "Fujitsu Enterprise Postgres"

# Chunking target in characters.
CHUNK_CHAR_TARGET = 3000
CHUNK_CHAR_OVERLAP = 300

# OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
EMBED_MODEL = "text-embedding-3-small"
EMBED_BATCH_SIZE = 64  # batch embeddings to reduce latency/quota pressure


# Helpers:

def slugify(s: str) -> str:
    return re.sub(r"[^a-z0-9]+", "_", s.lower()).strip("_")

def detect_language(soup: BeautifulSoup) -> str:
    # Prefer <html lang="...">; fallback to 'en'
    html_tag = soup.find("html")
    if html_tag and html_tag.get("lang"):
        return html_tag.get("lang").split("-")[0].lower()
    return "en"

def extract_title(soup: BeautifulSoup) -> str:
    t = soup.find("title")
    return t.get_text(strip=True) if t else "Untitled"

def infer_version_from_title(title: str) -> Optional[str]:

    m = re.search(r"\b(?:Enterprise\s+Postgres|Postgres)\s+(1[0-9]|[7-9])\b", title, re.IGNORECASE)
    if m:
        return f"v{m.group(1)}"
    # also accept "v17" style
    m2 = re.search(r"\bv(1[0-9]|[7-9])\b", title.lower())
    if m2:
        return f"v{m2.group(1)}"
    return None

def infer_section_from_title(title: str) -> str:

    m = re.search(r"^(\d+(?:\.\d+)*)", title)
    if m:
        return m.group(1)
    return "0"

def infer_doc_type_from_title(title: str) -> Optional[str]:
    tl = title.lower()
    if "install" in tl or "installation" in tl or "setup" in tl:
        return "install"
    if "administration" in tl or "operations" in tl or "admin" in tl:
        return "admin"
    if "release notes" in tl or "release-notes" in tl:
        return "release-notes"
    if "application development" in tl or "development guide" in tl or "tutorial" in tl or "guide" in tl:
        return "tutorial"
    return None

COMPONENT_KEYWORDS = {
    "replication":              ["replication", "logical", "physical", "wal", "publisher", "subscriber", "slot"],
    "backup":                   ["backup", "restore", "recovery", "barman", "base backup"],
    "security":                 ["tls", "ssl", "audit", "row security", "rls", "encryption"],
    "knowledge management":     ["vector", "embedding", "semantic", "pgvector", "rag", "ai"],
    "monitoring":               ["monitoring", "metrics", "prometheus", "exporter"],
    "administration":           ["WebAdmin", "admin", "user", "role", "group", "permission", "authentication", "authorization"],
    "application development":   ["application", "development", "guide", "tutorial", "guideline", "tutorial", "guide", "sql", "function", "api", "programming"],
    "dev":                      ["application development", "literal", "syntax", "sql", "function", "api", "programming"],
}

def infer_component(text: str) -> Optional[str]:
    tl = text.lower()
    for comp, kws in COMPONENT_KEYWORDS.items():
        if any(k in tl for k in kws):
            return comp
    return None

def infer_updated_date(soup: BeautifulSoup) -> dt.date:

    text = soup.get_text(" ", strip=True)
    m = re.search(r"Copyright\s+(\d{4})(?:-(\d{4}))?", text, re.IGNORECASE)
    if m:
        year = int(m.group(2) or m.group(1))
        return dt.date(year, 1, 1)
    return dt.date.today()

def build_doc_id(title: str, version: Optional[str], lang: str) -> str:
    """
    Create a stable doc_id per manual (not per page).
    Example: "Application Development Guide" + v17 + en -> "app_development_guide_v17_en"
    """
    # Strip the leading product/version words to focus on the guide name
    t = re.sub(r"^enterprise\s+postgres\s+\d+(\s+sp\d+)?\s*", "", title, flags=re.IGNORECASE)
    t = re.sub(r"^fujitsu\s+enterprise\s+postgres\s+\d+(\s+sp\d+)?\s*", "", t, flags=re.IGNORECASE)
    base = slugify(t)
    ver = version or "vxx"
    return f"{base}_{ver}_{lang}"

def extract_breadcrumbs(soup: BeautifulSoup) -> str:
    bc = soup.select_one(".breadcrumbslist")
    return bc.get_text(" ", strip=True) if bc else ""

def extract_section_blocks(soup: BeautifulSoup) -> List[Tuple[str, str]]:

    blocks: List[Tuple[str, str]] = []

    # Primary: H2, H3, H4 sections and their following sibling .body
    for h in soup.find_all(["h2", "h3", "h4"]):
        heading = h.get_text(" ", strip=True)
        # Accumulate sibling text until next heading of same/higher level
        body_parts = []
        for sib in h.find_all_next():
            if sib.name in ["h2", "h3", "h4"] and sib is not h:
                break
            # Only pick within a reasonable container before next header/footer
            if sib.name in ["div", "p", "pre", "ul", "ol", "table", "dl"]:
                # Skip nav & footer
                if sib.get("class") and any(c in ("top_header", "header_footer") for c in sib.get("class", [])):
                    continue
                # Avoid scripts/styles
                if sib.name in ["script", "style"]:
                    continue
                body_parts.append(sib.get_text(" ", strip=True))
        body = "\n".join([p for p in body_parts if p])
        if heading or body:
            blocks.append((heading, body))

    # Fallback: if no headings found, use entire body text
    if not blocks:
        body = soup.get_text(" ", strip=True)
        if body:
            blocks = [("Page", body)]
    return blocks

# Chunking

def chunk_text(heading: str, text: str, target_chars=CHUNK_CHAR_TARGET, overlap=CHUNK_CHAR_OVERLAP) -> List[str]:

    full = (heading.strip() + "\n" + text.strip()).strip()
    if not full:
        return []

    chunks = []
    start = 0
    while start < len(full):
        end = min(len(full), start + target_chars)
        chunk = full[start:end].strip()
        if start > 0 and heading:
            # Re-add a short heading context for downstream grounding
            chunk = f"{heading}\n{chunk}"
        if chunk:
            chunks.append(chunk)
        if end == len(full):
            break
        start = max(end - overlap, start + 1)
    return chunks

# Embeddings (batched)

def embed_texts(texts: List[str]) -> List[List[float]]:
    vectors: List[List[float]] = []
    for i in range(0, len(texts), EMBED_BATCH_SIZE):
        batch = texts[i:i+EMBED_BATCH_SIZE]
        resp = client.embeddings.create(
            model=EMBED_MODEL,
            input=batch
        )
        vectors.extend([d.embedding for d in resp.data])
    return vectors

def vector_literal(vec: List[float]) -> str:

    return "[" + ",".join(f"{x:.8f}" for x in vec) + "]"

# DB load

INSERT_SQL = """
INSERT INTO fep_manual_chunks
  (doc_id, section, doc_type, product, version, component, os, language, updated_at, content, embedding)
VALUES %s
"""

def insert_rows(rows: List[Tuple[str, int, Optional[str], str, Optional[str], Optional[str], Optional[str], str, dt.date, str, str]]):

    # We need to annotate the embedding param as ::vector. execute_values doesn't let us place casts per-row easily,
    # so we embed the cast in the template.
    with psycopg2.connect(**PG_CONN_INFO) as conn:
        with conn.cursor() as cur:
            execute_values(
                cur,
                INSERT_SQL,
                rows,
                template="(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s::vector)"
            )
        conn.commit()

# Per-file processing

def process_html_file(path: str, version: str = "v17", os_facet: str = "Linux") -> List[Tuple]:

    with open(path, "rb") as file:
        soup = BeautifulSoup(file, "html.parser")

    title = extract_title(soup)
    lang = detect_language(soup)
    section_number = infer_section_from_title(title)
    doc_type = infer_doc_type_from_title(title)
    breadcrumbs = extract_breadcrumbs(soup)

    # Use breadcrumbs + title to infer component
    component = infer_component(breadcrumbs + " " + title)

    # Use provided OS facet instead of inferring it
    os_name = os_facet

    updated_at = infer_updated_date(soup)

    # Build a stable doc_id per manual/guide
    doc_id = build_doc_id(title, version, lang)

    # Extract logical section blocks in order
    blocks = extract_section_blocks(soup)

    # Chunk each block (keeps reading order); collect all chunks to embed in batch
    all_chunk_texts: List[str] = []
    for heading, body in blocks:
        # Add section number to heading if available (not "0") and not already in the heading
        if section_number != "0" and heading and not heading.startswith(section_number):
            heading_with_section = f"{section_number} {heading}"
        else:
            heading_with_section = heading

        chs = chunk_text(heading_with_section, body)
        if chs:
            all_chunk_texts.extend(chs)

    if not all_chunk_texts:
        return []

    # Embed all chunks in batches
    embeds = embed_texts(all_chunk_texts)
    assert len(embeds) == len(all_chunk_texts)


    rows: List[Tuple] = []
    for chunk_content, vec in zip(all_chunk_texts, embeds):
        rows.append((
            doc_id,                      # doc_id
            section_number,              # section
            doc_type,                    # doc_type
            PRODUCT_NAME,                # product (passed as parameter)
            version,                     # version (passed as parameter)
            component,                   # component
            os_name,                     # os
            lang,                        # language
            updated_at,                  # updated_at
            chunk_content,               # content
            vector_literal(vec),         # embedding (as pgvector literal to cast in SQL)
        ))

    return rows

# Main: walk directory, process, load


def main():
    # Get version from command line argument, default to "v17"
    # This enforces the rule of processing one manual version at a time
    version = "v17"
    os_facet = "Linux"

    if len(sys.argv) > 1:
        version = sys.argv[1]
        print(f"Using version: {version}")

    if len(sys.argv) > 2:
        os_facet = sys.argv[2]
        print(f"Using OS: {os_facet}")

    html_files = sorted(
        glob.glob(os.path.join(HTML_ROOT, "**", "*.html"), recursive=True)
    )
    if not html_files:
        print(f"No HTML files found under {HTML_ROOT}", file=sys.stderr)
        sys.exit(1)

    total_rows = 0
    batch_rows: List[Tuple] = []

    for idx, path in enumerate(html_files, 1):
        try:
            rows = process_html_file(path, version, os_facet)
            batch_rows.extend(rows)
            # Insert in DB-sized batches to avoid large transactions
            if len(batch_rows) >= 1000:
                insert_rows(batch_rows)
                total_rows += len(batch_rows)
                print(f"[{idx}/{len(html_files)}] Inserted {len(batch_rows)} rows (cumulative {total_rows})")
                batch_rows.clear()
        except Exception as e:
            print(f"Error processing {path}: {e}", file=sys.stderr)

    # Flush tail
    if batch_rows:
        insert_rows(batch_rows)
        total_rows += len(batch_rows)
        print(f"Inserted final {len(batch_rows)} rows. Total rows: {total_rows}")

if __name__ == "__main__":
    main()
