# ============================================================================
# XERO ACCOUNTING API TO FABRIC LAKEHOUSE — PROLYTICS CONNECT
# Auto-generated notebook. Do not edit manually.
# Generated by: FabricXeroNotebookService
# ============================================================================

import requests
import json
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_timestamp
from datetime import datetime, timedelta
from delta.tables import DeltaTable
import time

# ============================================================================
# CONFIGURATION — injected by PHP FabricXeroNotebookService
# ============================================================================

XERO_CLIENT_ID     = "{{XERO_CLIENT_ID}}"
XERO_CLIENT_SECRET = "{{XERO_CLIENT_SECRET}}"
XERO_ACCESS_TOKEN  = "{{XERO_ACCESS_TOKEN}}"
XERO_REFRESH_TOKEN = "{{XERO_REFRESH_TOKEN}}"
XERO_TENANT_ID     = "{{XERO_TENANT_ID}}"
XERO_BASE_URL      = "https://api.xero.com/api.xro/2.0"

LOAD_TYPE                 = "{{LOAD_TYPE}}"                             # FULL or INCREMENTAL
ENTITIES_TO_LOAD          = {{ENTITIES_TO_LOAD}}                        # JSON array: ["accounts","invoices"]
INCREMENTAL_LOOKBACK_DAYS = {{INCREMENTAL_LOOKBACK_DAYS}}               # Integer: 30
TABLE_PREFIX              = "{{TABLE_PREFIX}}"                          # Notebook name: lowercase, no spaces

# ============================================================================
# ENTITY CONFIGURATION
# Table names are prefixed with TABLE_PREFIX so each notebook has its own
# isolated set of Delta tables in the Lakehouse.
# ============================================================================

ENTITY_CONFIG = {
    "accounts": {
        "endpoint": "Accounts",
        "table_name": f"{TABLE_PREFIX}_xero_accounts",
        "merge_key": "AccountID",
        "description": "Chart of Accounts",
        "pagination": "none",
        "data_key": "Accounts",
        "incremental": True,
    },
    "contacts": {
        "endpoint": "Contacts",
        "table_name": f"{TABLE_PREFIX}_xero_contacts",
        "merge_key": "ContactID",
        "description": "Customers and Suppliers",
        "pagination": "page",
        "data_key": "Contacts",
        "incremental": True,
    },
    "invoices": {
        "endpoint": "Invoices",
        "table_name": f"{TABLE_PREFIX}_xero_invoices",
        "merge_key": "InvoiceID",
        "description": "Sales Invoices",
        "pagination": "page",
        "data_key": "Invoices",
        "incremental": True,
    },
    "bills": {
        "endpoint": "Invoices",
        "table_name": f"{TABLE_PREFIX}_xero_bills",
        "merge_key": "InvoiceID",
        "description": "Purchase Bills (Accounts Payable)",
        "pagination": "page",
        "data_key": "Invoices",
        "incremental": True,
        "extra_params": {"where": 'Type=="ACCPAY"'},
    },
    "credit_notes": {
        "endpoint": "CreditNotes",
        "table_name": f"{TABLE_PREFIX}_xero_credit_notes",
        "merge_key": "CreditNoteID",
        "description": "Credit Notes",
        "pagination": "page",
        "data_key": "CreditNotes",
        "incremental": True,
    },
    "payments": {
        "endpoint": "Payments",
        "table_name": f"{TABLE_PREFIX}_xero_payments",
        "merge_key": "PaymentID",
        "description": "Payments",
        "pagination": "page",
        "data_key": "Payments",
        "incremental": True,
    },
    "bank_transactions": {
        "endpoint": "BankTransactions",
        "table_name": f"{TABLE_PREFIX}_xero_bank_transactions",
        "merge_key": "BankTransactionID",
        "description": "Bank Transactions",
        "pagination": "page",
        "data_key": "BankTransactions",
        "incremental": True,
    },
    "bank_transfers": {
        "endpoint": "BankTransfers",
        "table_name": f"{TABLE_PREFIX}_xero_bank_transfers",
        "merge_key": "BankTransferID",
        "description": "Bank Transfers",
        "pagination": "page",
        "data_key": "BankTransfers",
        "incremental": False,
    },
    "journals": {
        "endpoint": "Journals",
        "table_name": f"{TABLE_PREFIX}_xero_journals",
        "merge_key": "JournalID",
        "description": "System Journals (offset-based pagination)",
        "pagination": "offset",
        "data_key": "Journals",
        "incremental": True,
    },
    "manual_journals": {
        "endpoint": "ManualJournals",
        "table_name": f"{TABLE_PREFIX}_xero_manual_journals",
        "merge_key": "JournalID",
        "description": "Manual Journal Entries",
        "pagination": "page",
        "data_key": "ManualJournals",
        "incremental": True,
    },
    "items": {
        "endpoint": "Items",
        "table_name": f"{TABLE_PREFIX}_xero_items",
        "merge_key": "ItemID",
        "description": "Inventory Items",
        "pagination": "none",
        "data_key": "Items",
        "incremental": True,
    },
    "tracking_categories": {
        "endpoint": "TrackingCategories",
        "table_name": f"{TABLE_PREFIX}_xero_tracking_categories",
        "merge_key": "TrackingCategoryID",
        "description": "Tracking Categories",
        "pagination": "none",
        "data_key": "TrackingCategories",
        "incremental": False,
    },
    "tax_rates": {
        "endpoint": "TaxRates",
        "table_name": f"{TABLE_PREFIX}_xero_tax_rates",
        "merge_key": "TaxType",
        "description": "Tax Rates",
        "pagination": "none",
        "data_key": "TaxRates",
        "incremental": False,
    },
    "currencies": {
        "endpoint": "Currencies",
        "table_name": f"{TABLE_PREFIX}_xero_currencies",
        "merge_key": "Code",
        "description": "Currencies",
        "pagination": "none",
        "data_key": "Currencies",
        "incremental": False,
    },
    "organisations": {
        "endpoint": "Organisation",
        "table_name": f"{TABLE_PREFIX}_xero_organisation",
        "merge_key": "OrganisationID",
        "description": "Organisation Details",
        "pagination": "none",
        "data_key": "Organisations",
        "incremental": False,
    },
    "reports_trial_balance": {
        "endpoint": "Reports/TrialBalance",
        "table_name": f"{TABLE_PREFIX}_xero_trial_balance",
        "merge_key": None,
        "description": "Trial Balance Report",
        "pagination": "report",
        "data_key": "Reports",
        "incremental": False,
    },
    "reports_balance_sheet": {
        "endpoint": "Reports/BalanceSheet",
        "table_name": f"{TABLE_PREFIX}_xero_balance_sheet",
        "merge_key": None,
        "description": "Balance Sheet Report",
        "pagination": "report",
        "data_key": "Reports",
        "incremental": False,
    },
    "reports_profit_loss": {
        "endpoint": "Reports/ProfitAndLoss",
        "table_name": f"{TABLE_PREFIX}_xero_profit_and_loss",
        "merge_key": None,
        "description": "Profit & Loss Report",
        "pagination": "report",
        "data_key": "Reports",
        "incremental": False,
    },
}

# ============================================================================
# XERO API CLIENT
# ============================================================================

class XeroAPIClient:
    """Xero Accounting API client with OAuth2 self-refresh, rate limit handling, and pagination."""

    def __init__(self, access_token, refresh_token, client_id, client_secret, tenant_id):
        self.access_token  = access_token
        self.refresh_token = refresh_token
        self.client_id     = client_id
        self.client_secret = client_secret
        self.tenant_id     = tenant_id
        # Always refresh at startup — token may have aged since notebook was created
        self._refresh_token()

    def _get_headers(self):
        return {
            "Authorization":  f"Bearer {self.access_token}",
            "Xero-Tenant-Id": self.tenant_id,
            "Accept":         "application/json",
        }

    def _refresh_token(self):
        """Refresh OAuth2 token. Called at startup and on 401 responses."""
        try:
            response = requests.post(
                "https://identity.xero.com/connect/token",
                data={
                    "grant_type":    "refresh_token",
                    "refresh_token": self.refresh_token,
                    "client_id":     self.client_id,
                    "client_secret": self.client_secret,
                },
                headers={"Content-Type": "application/x-www-form-urlencoded"},
                timeout=30,
            )
            if response.status_code == 200:
                tokens = response.json()
                self.access_token  = tokens["access_token"]
                self.refresh_token = tokens.get("refresh_token", self.refresh_token)
                print("[TOKEN] Refreshed successfully")
            else:
                print(f"[TOKEN] Refresh failed: {response.status_code} {response.text[:200]}")
        except Exception as e:
            print(f"[TOKEN] Refresh error: {str(e)}")

    def _request(self, endpoint, params=None, extra_headers=None, max_retries=3):
        """Make a single API request with retry and rate limit handling."""
        url     = f"{XERO_BASE_URL}/{endpoint}"
        headers = self._get_headers()
        if extra_headers:
            headers.update(extra_headers)

        for attempt in range(max_retries):
            try:
                response = requests.get(url, headers=headers, params=params, timeout=30)
                if response.status_code == 200:
                    return response.json()
                elif response.status_code == 401:
                    print("  [401] Token expired, refreshing...")
                    self._refresh_token()
                    headers = self._get_headers()
                elif response.status_code == 429:
                    wait = int(response.headers.get("Retry-After", 60))
                    print(f"  [429] Rate limited. Waiting {wait}s...")
                    time.sleep(wait)
                elif response.status_code in (403, 404):
                    print(f"  [{response.status_code}] {endpoint}: access denied or not found")
                    return None
                else:
                    print(f"  [ERROR {response.status_code}] {response.text[:200]}")
                    if attempt == max_retries - 1:
                        return None
                    time.sleep(2 ** attempt)
            except Exception as e:
                print(f"  [EXCEPTION] attempt {attempt + 1}: {str(e)}")
                if attempt == max_retries - 1:
                    raise
                time.sleep(2 ** attempt)
        return None

    def fetch_paged(self, endpoint, data_key, extra_params=None, if_modified_since=None):
        """Fetch page-based paginated data (100 records per page)."""
        all_data = []
        page     = 1
        params   = dict(extra_params or {})
        h_extra  = {"If-Modified-Since": if_modified_since} if if_modified_since else {}

        while True:
            params["page"] = page
            result = self._request(endpoint, params, h_extra)
            if not result:
                break
            data = result.get(data_key, [])
            if not data:
                break
            all_data.extend(data)
            print(f"  Page {page}: +{len(data)} (total: {len(all_data)})")
            if len(data) < 100:
                break
            page += 1
        return all_data

    def fetch_offset(self, endpoint, data_key):
        """Fetch offset-based paginated data (Journals endpoint — sequential JournalNumber)."""
        all_data = []
        offset   = 0

        while True:
            result = self._request(endpoint, {"offset": offset})
            if not result:
                break
            data = result.get(data_key, [])
            if not data:
                break
            all_data.extend(data)
            offset = max(int(j.get("JournalNumber", 0)) for j in data)
            print(f"  Offset {offset}: +{len(data)} (total: {len(all_data)})")
            if len(data) < 100:
                break
        return all_data

    def fetch_single(self, endpoint, data_key, if_modified_since=None):
        """Fetch non-paginated endpoint (single request)."""
        h_extra = {"If-Modified-Since": if_modified_since} if if_modified_since else {}
        result  = self._request(endpoint, extra_headers=h_extra)
        return result.get(data_key, []) if result else []

    def fetch_report(self, endpoint):
        """Fetch a Xero financial report and flatten it into rows."""
        result = self._request(endpoint)
        if not result:
            return []
        reports = result.get("Reports", [])
        if not reports:
            return []

        report  = reports[0]
        rows    = []
        headers = []

        for section in report.get("Rows", []):
            if section.get("RowType") == "Header":
                headers = [c.get("Value", f"Col_{i}") for i, c in enumerate(section.get("Cells", []))]
                break

        for section in report.get("Rows", []):
            section_title = section.get("Title", "")
            for row in section.get("Rows", []):
                if row.get("RowType") == "Row":
                    cells    = row.get("Cells", [])
                    row_data = {
                        "_Section":    section_title,
                        "_ReportName": report.get("ReportName", ""),
                        "_ReportDate": report.get("ReportDate", ""),
                    }
                    for i, cell in enumerate(cells):
                        col = headers[i] if i < len(headers) else f"Column_{i}"
                        row_data[col] = cell.get("Value", "")
                    rows.append(row_data)
        return rows

# ============================================================================
# DATA LOADING (Delta table merge — same pattern as Deel template)
# ============================================================================

def clean_dataframe_for_delta(df):
    """Serialize nested dicts/lists to JSON strings for Delta compatibility."""
    for col in df.columns:
        if df[col].dtype == object:
            df[col] = df[col].apply(
                lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x
            )
    return df


def load_to_lakehouse(spark, data, table_name, merge_key):
    """Load extracted data to a Delta table via merge (upsert) or overwrite."""
    if not data:
        print(f"  No data for {table_name}")
        return 0

    df_pandas = pd.DataFrame(data)
    df_pandas = clean_dataframe_for_delta(df_pandas)
    df_spark  = spark.createDataFrame(df_pandas)
    df_spark  = df_spark.withColumn("_loaded_at", current_timestamp())

    if merge_key and DeltaTable.isDeltaTable(spark, f"Tables/{table_name}"):
        delta = DeltaTable.forPath(spark, f"Tables/{table_name}")
        delta.alias("target").merge(
            df_spark.alias("source"),
            f"target.{merge_key} = source.{merge_key}"
        ).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
        print(f"  Merged {len(data)} rows into {table_name}")
    else:
        df_spark.write.format("delta").mode("overwrite") \
            .option("mergeSchema", "true").saveAsTable(table_name)
        print(f"  Wrote {len(data)} rows to {table_name} (overwrite)")

    return len(data)

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def run_xero_integration():
    spark  = SparkSession.builder.getOrCreate()
    client = XeroAPIClient(
        XERO_ACCESS_TOKEN, XERO_REFRESH_TOKEN,
        XERO_CLIENT_ID, XERO_CLIENT_SECRET,
        XERO_TENANT_ID
    )

    results       = {}
    total_records = 0
    if_modified_since = None

    if LOAD_TYPE == "INCREMENTAL":
        cutoff = datetime.utcnow() - timedelta(days=INCREMENTAL_LOOKBACK_DAYS)
        if_modified_since = cutoff.strftime("%Y-%m-%dT%H:%M:%S")

    for entity_key in ENTITIES_TO_LOAD:
        cfg = ENTITY_CONFIG.get(entity_key)
        if not cfg:
            print(f"\n[SKIP] Unknown entity: {entity_key}")
            continue

        print(f"\n[START] {entity_key} → {cfg['table_name']}")
        try:
            ims = if_modified_since if (LOAD_TYPE == "INCREMENTAL" and cfg.get("incremental")) else None

            if cfg["pagination"] == "report":
                data = client.fetch_report(cfg["endpoint"])
            elif cfg["pagination"] == "page":
                data = client.fetch_paged(
                    cfg["endpoint"], cfg["data_key"],
                    extra_params=cfg.get("extra_params"),
                    if_modified_since=ims,
                )
            elif cfg["pagination"] == "offset":
                data = client.fetch_offset(cfg["endpoint"], cfg["data_key"])
            else:  # none — single request
                data = client.fetch_single(cfg["endpoint"], cfg["data_key"], if_modified_since=ims)

            count = load_to_lakehouse(spark, data, cfg["table_name"], cfg["merge_key"])
            results[entity_key] = {"status": "success", "records": count}
            total_records += count
            print(f"[OK]   {entity_key}: {count} records")

        except Exception as e:
            results[entity_key] = {"status": "failed", "error": str(e)}
            print(f"[FAIL] {entity_key}: {str(e)}")

    # -------------------------------------------------------------------------
    # Structured output — parsed by FabricXeroNotebookService for run history
    # -------------------------------------------------------------------------
    print("\n###XERO_RESULTS_START###")
    print(json.dumps({
        "load_type":     LOAD_TYPE,
        "entities":      results,
        "total_records": total_records,
        "completed_at":  datetime.utcnow().isoformat(),
    }))
    print("###XERO_RESULTS_END###")

    # Output refreshed tokens so PHP can update xero_connections
    print("\n###XERO_TOKEN_UPDATE_START###")
    print(json.dumps({
        "access_token":  client.access_token,
        "refresh_token": client.refresh_token,
    }))
    print("###XERO_TOKEN_UPDATE_END###")


run_xero_integration()
