# ============================================================================
# DEEL HR API TO FABRIC LAKEHOUSE — PROLYTICS CONNECT
# Auto-generated notebook. Do not edit manually.
# Generated by: FabricDeelNotebookService
# ============================================================================

import requests
import json
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_timestamp, lit, col
from pyspark.sql.types import NullType, StringType
from datetime import datetime, timedelta
from delta.tables import DeltaTable
import time
from typing import List, Dict, Optional

# ============================================================================
# CONFIGURATION — injected by PHP FabricDeelNotebookService
# ============================================================================

DEEL_API_TOKEN = "{{DEEL_API_TOKEN}}"
DEEL_BASE_URL = "https://api.letsdeel.com/rest/v2"

LOAD_TYPE = "{{LOAD_TYPE}}"                             # FULL or INCREMENTAL
ENTITIES_TO_LOAD = {{ENTITIES_TO_LOAD}}                 # JSON array: ["contracts","timesheets"]
INCREMENTAL_LOOKBACK_DAYS = {{INCREMENTAL_LOOKBACK_DAYS}}  # Integer: 30
TABLE_PREFIX = "{{TABLE_PREFIX}}"                       # Notebook name: lowercase, no spaces

# ============================================================================
# ENTITY CONFIGURATION
# Table names are prefixed with TABLE_PREFIX so each notebook has its own
# isolated set of Delta tables in the Lakehouse.
# e.g. TABLE_PREFIX = "deelhrprod" → "deelhrprod_deel_contracts"
# ============================================================================

ENTITY_CONFIG = {
    "contracts": {
        "endpoint": "contracts",
        "table_name": f"{TABLE_PREFIX}_deel_contracts",
        "merge_key": "id",
        "description": "Employment contracts and agreements",
        "page_size": 99,
    },
    "timesheets": {
        "endpoint": "timesheets",
        "table_name": f"{TABLE_PREFIX}_deel_timesheets",
        "merge_key": "id",
        "description": "Time tracking and timesheets",
        "page_size": 99,
    },
    "people": {
        "endpoint": "people",
        "table_name": f"{TABLE_PREFIX}_deel_people",
        "merge_key": "id",
        "description": "Workers, contractors, and employees",
        "page_size": 99,
    },
    "time_off": {
        "endpoint": "time-off",
        "table_name": f"{TABLE_PREFIX}_deel_time_off",
        "merge_key": "id",
        "description": "Time off requests and balances",
        "page_size": 99,
    },
    "invoices": {
        "endpoint": "accounting/invoices",
        "table_name": f"{TABLE_PREFIX}_deel_invoices",
        "merge_key": "id",
        "description": "Billing invoices",
        "page_size": 99,
    },
    "payments": {
        "endpoint": "accounting/payments",
        "table_name": f"{TABLE_PREFIX}_deel_payments",
        "merge_key": "id",
        "description": "Payment records",
        "page_size": 99,
    },
    "adjustments": {
        "endpoint": "adjustments",
        "table_name": f"{TABLE_PREFIX}_deel_adjustments",
        "merge_key": "id",
        "description": "Payment adjustments",
        "page_size": 99,
    },
    "organizations": {
        "endpoint": "organizations",
        "table_name": f"{TABLE_PREFIX}_deel_organizations",
        "merge_key": "id",
        "description": "Organization details",
        "page_size": 99,
    },
}

# ============================================================================
# SPARK SESSION
# ============================================================================

spark = SparkSession.builder \
    .appName("DeelHRIntegration") \
    .getOrCreate()

print("Spark session initialized successfully")

# ============================================================================
# DELTA LAKE DATA CLEANING
# CRITICAL: Delta Lake does not support arrays with only null values or NullType.
# ============================================================================

def clean_dataframe_for_delta(df_pandas):
    """
    Prepare a pandas DataFrame for Delta Lake / Microsoft Fabric.

    Fabric rejects two Spark-inferred types:
      - STRUCT  — produced when a column contains dict or list values
      - VOID    — produced when every value in a column is null (NullType)

    Fix:
      1. Serialize dict/list columns to JSON strings → Spark sees StringType
      2. Cast all-null columns to object dtype    → Spark sees StringType, not NullType
    """
    for column in df_pandas.columns:
        non_null = df_pandas[column].dropna()
        if len(non_null) > 0 and isinstance(non_null.iloc[0], (dict, list)):
            # Nested object → JSON string (Fabric does not support STRUCT columns)
            df_pandas[column] = df_pandas[column].apply(
                lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x
            )
        elif df_pandas[column].isna().all():
            # All-null column → object dtype so Spark infers StringType, not NullType/VOID
            df_pandas[column] = df_pandas[column].astype(object)

    return df_pandas


def _cast_null_columns(df_spark):
    """Cast any remaining NullType columns to StringType after createDataFrame."""
    for field in df_spark.schema.fields:
        if isinstance(field.dataType, NullType):
            df_spark = df_spark.withColumn(field.name, col(field.name).cast(StringType()))
    return df_spark


# ============================================================================
# DEEL API CLIENT
# ============================================================================

class DeelAPIClient:
    """Universal Deel API client with retry logic and pagination"""

    def __init__(self, token: str, base_url: str = DEEL_BASE_URL):
        self.token = token
        self.base_url = base_url
        self.headers = {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json",
        }

    def _make_request(self, endpoint: str, params: Dict = None, max_retries: int = 3) -> Optional[Dict]:
        """Make API request with exponential backoff retry logic"""
        url = f"{self.base_url}/{endpoint}"

        for attempt in range(max_retries):
            try:
                response = requests.get(
                    url, headers=self.headers, params=params, timeout=30
                )

                if response.status_code == 200:
                    return response.json()
                elif response.status_code == 429:  # Rate limit
                    wait_time = 2 ** attempt
                    print(f"  Rate limited. Waiting {wait_time}s before retry...")
                    time.sleep(wait_time)
                    continue
                elif response.status_code == 403:
                    error_msg = (
                        response.json().get("errors", [{}])[0].get("message", "Forbidden")
                        if response.text
                        else "Forbidden"
                    )
                    print(f"  Permission denied for {endpoint}: {error_msg}")
                    return None  # Continue with other entities
                elif response.status_code == 404:
                    print(f"  Endpoint not found: {endpoint}")
                    return None
                else:
                    print(f"  Error {response.status_code}: {response.text[:200]}")
                    return None

            except requests.exceptions.RequestException as e:
                print(f"  Request failed (attempt {attempt + 1}/{max_retries}): {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                else:
                    raise

        return None

    def get_data(
        self,
        endpoint: str,
        params: Dict = None,
        incremental: bool = False,
        last_modified: Optional[str] = None,
        page_size: int = 99,
    ) -> List[Dict]:
        """
        Fetch data with automatic cursor-based pagination.

        Deel API v2 uses cursor pagination:
          - Request:  limit + (after_cursor on subsequent pages)
          - Response: {"data": [...], "page": {"cursor": "...", "total_rows": N}}
        Continue fetching while response["page"]["cursor"] is present and non-empty.
        """
        all_data = []
        page_num = 1

        if params is None:
            params = {}

        if incremental and last_modified:
            params["modified_after"] = last_modified
            print(f"  Filtering records modified after: {last_modified}")

        params["limit"] = page_size
        # Do NOT set 'after_cursor' on the first request

        while True:
            response = self._make_request(endpoint, params)

            if not response:
                break

            data = response.get("data", [])
            if not data:
                break

            all_data.extend(data)
            print(f"  Page {page_num}: Fetched {len(data)} records (Total: {len(all_data)})")

            # Cursor-based: continue only if a next-page cursor is returned
            next_cursor = response.get("page", {}).get("cursor")
            if not next_cursor:
                break

            params["after_cursor"] = next_cursor
            page_num += 1

        return all_data


# ============================================================================
# DATA LOADER
# ============================================================================

class DeelDataLoader:
    """Handles loading data to Lakehouse with full or incremental modes"""

    def __init__(self, client: DeelAPIClient, load_type: str = "INCREMENTAL"):
        self.client = client
        self.load_type = load_type.upper()

        if self.load_type not in ["FULL", "INCREMENTAL"]:
            raise ValueError("LOAD_TYPE must be 'FULL' or 'INCREMENTAL'")

    def get_last_load_timestamp(self, table_name: str) -> Optional[str]:
        """Get last load timestamp from existing Delta table"""
        try:
            if not spark.catalog.tableExists(table_name):
                return None

            result = spark.sql(
                f"SELECT MAX(load_timestamp) as last_load FROM {table_name}"
            ).collect()

            last_load = result[0]["last_load"]

            if last_load:
                return last_load.strftime("%Y-%m-%dT%H:%M:%SZ")
            else:
                lookback_date = datetime.now() - timedelta(days=INCREMENTAL_LOOKBACK_DAYS)
                return lookback_date.strftime("%Y-%m-%dT%H:%M:%SZ")
        except Exception as e:
            print(f"  Could not get last load timestamp: {str(e)}")
            return None

    def load_to_lakehouse(self, data: List[Dict], table_name: str, mode: str = "overwrite"):
        """Load data to Lakehouse as Delta table with NullType handling"""
        if not data:
            print(f"No data to load for {table_name}")
            return

        # Convert to pandas DataFrame
        df_pandas = pd.DataFrame(data)

        # CRITICAL: Clean for Delta Lake compatibility
        df_pandas = clean_dataframe_for_delta(df_pandas)

        # Convert to Spark DataFrame
        df_spark = spark.createDataFrame(df_pandas)
        df_spark = _cast_null_columns(df_spark)

        # Add metadata columns
        df_spark = (
            df_spark
            .withColumn("load_timestamp", current_timestamp())
            .withColumn("source_system", lit("Deel"))
            .withColumn("load_type", lit(self.load_type))
        )

        # Write to Delta table
        df_spark.write \
            .format("delta") \
            .mode(mode) \
            .option("mergeSchema", "true") \
            .saveAsTable(table_name)

        print(f"Successfully loaded {len(data)} records to {table_name}")

    def load_entity(self, entity_key: str, config: Dict) -> Dict:
        """Load a single entity based on configuration"""
        endpoint = config["endpoint"]
        table_name = config["table_name"]
        merge_key = config["merge_key"]
        description = config["description"]

        print(f"\n{'='*60}")
        print(f"ENTITY: {entity_key.upper()} - {description}")
        print(f"{'='*60}")
        print(f"  Endpoint : {endpoint}")
        print(f"  Table    : {table_name}")
        print(f"  Load Type: {self.load_type}")

        try:
            last_modified = None
            is_incremental = False

            if self.load_type == "INCREMENTAL":
                last_modified = self.get_last_load_timestamp(table_name)
                is_incremental = True

                if not last_modified:
                    print("  No previous load found. Performing initial full load.")
                    is_incremental = False

            print(f"\n  Extracting data from Deel API...")
            data = self.client.get_data(
                endpoint=endpoint,
                incremental=is_incremental,
                last_modified=last_modified,
                page_size=config.get("page_size", 99),
            )

            if not data:
                print("  No data returned from API")
                return {"status": "success", "records": 0, "mode": "no_data"}

            print(f"\n  Loading {len(data)} records to Lakehouse...")

            if self.load_type == "FULL" or not spark.catalog.tableExists(table_name):
                self.load_to_lakehouse(data, table_name, mode="overwrite")
                mode_used = "OVERWRITE"
                print(f"  Full load completed: {len(data)} records written")
            else:
                # Incremental merge (upsert)
                df_pandas = pd.DataFrame(data)
                df_pandas = clean_dataframe_for_delta(df_pandas)
                df_spark = spark.createDataFrame(df_pandas)
                df_spark = _cast_null_columns(df_spark)
                df_spark = (
                    df_spark
                    .withColumn("load_timestamp", current_timestamp())
                    .withColumn("source_system", lit("Deel"))
                    .withColumn("load_type", lit(self.load_type))
                )

                delta_table = DeltaTable.forName(spark, table_name)
                (
                    delta_table.alias("target")
                    .merge(df_spark.alias("source"), f"target.{merge_key} = source.{merge_key}")
                    .whenMatchedUpdateAll()
                    .whenNotMatchedInsertAll()
                    .execute()
                )
                mode_used = "MERGE"
                print(f"  Incremental load completed: {len(data)} records merged")

            final_count = spark.table(table_name).count()
            print(f"  Total records in table: {final_count:,}")

            return {
                "status": "success",
                "records": len(data),
                "mode": mode_used,
                "total_records": final_count,
            }

        except Exception as e:
            print(f"  Error: {str(e)}")
            return {"status": "failed", "error": str(e), "records": 0}


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def run_deel_integration():
    """Main execution function"""
    print("\n" + "=" * 60)
    print("DEEL HR DATA INTEGRATION — PROLYTICS CONNECT")
    print("=" * 60)
    print(f"Load Type  : {LOAD_TYPE}")
    print(f"Start Time : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Base URL   : {DEEL_BASE_URL}")
    print("=" * 60)

    client = DeelAPIClient(DEEL_API_TOKEN)
    loader = DeelDataLoader(client, LOAD_TYPE)

    # Determine which entities to process
    if ENTITIES_TO_LOAD:
        entities_to_process = {
            k: v for k, v in ENTITY_CONFIG.items() if k in ENTITIES_TO_LOAD
        }
        print(f"\nProcessing {len(entities_to_process)} selected entities: {', '.join(entities_to_process.keys())}")
    else:
        entities_to_process = ENTITY_CONFIG
        print(f"\nProcessing all {len(entities_to_process)} entities")

    results = {}
    for entity_key, config in entities_to_process.items():
        result = loader.load_entity(entity_key, config)
        results[entity_key] = result

    # Print summary
    print("\n" + "=" * 60)
    print("EXECUTION SUMMARY")
    print("=" * 60)

    success_count = sum(1 for r in results.values() if r["status"] == "success")
    failed_count = len(results) - success_count
    total_records = sum(r.get("records", 0) for r in results.values())

    print(f"Successful : {success_count}")
    print(f"Failed     : {failed_count}")
    print(f"Total Rows : {total_records:,}")
    print("-" * 60)

    for entity, result in results.items():
        if result["status"] == "success":
            mode = result.get("mode", "N/A")
            records = result.get("records", 0)
            total = result.get("total_records", records)
            print(f"OK  {entity:20s}: {records:>6,} records [{mode}] -> Total: {total:,}")
        else:
            error = result.get("error", "Unknown error")[:60]
            print(f"ERR {entity:20s}: FAILED - {error}")

    print("\n" + "=" * 60)
    print(f"End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("=" * 60)

    # Structured output for FabricDeelNotebookService to parse
    print("\n" + "=" * 80)
    print("###DEEL_RESULTS_START###")
    print(json.dumps(results, indent=2, default=str))
    print("###DEEL_RESULTS_END###")
    print("=" * 80)

    return results


# ============================================================================
# EXECUTE
# ============================================================================

results = run_deel_integration()
