import logging
import asyncio

from datetime import datetime
from collections import namedtuple
from peewee import OperationalError

from defence360agent.contracts.plugins import MessageSource
from defence360agent.subsys.persistent_state import register_lock_file, Scope
from defence360agent.model.analyst_cleanup import AnalystCleanupRequest
from defence360agent.utils import recurring_check
from defence360agent.utils.common import DAY
from defence360agent.utils.check_lock import check_lock
from defence360agent.api.server.analyst_cleanup import AnalystCleanupAPI
from defence360agent.utils.sshutil import remove_pub_key
from defence360agent.internals.iaid import IAIDTokenError


logger = logging.getLogger(__name__)
LOCK_FILE = register_lock_file("analyst-cleanup-update", Scope.IM360)

UpdateStatusRow = namedtuple(
    "UpdateStatusRow", ["zendesk_id", "new_status", "updated_at"]
)


class AnalystCleanupUpdate(MessageSource):
    async def create_source(self, loop, sink):
        self._loop = loop
        self._sink = sink
        self._task = loop.create_task(
            recurring_check(
                check_lock,
                check_period_first=True,
                check_lock_period=DAY / 2,
                lock_file=LOCK_FILE,
            )(self._update_task)()
        )

    async def shutdown(self):
        self._task.cancel()
        # CancelledError is handled by @recurring_check():
        await self._task

    @staticmethod
    async def _process(
        old_request, new_tickets_map, semaphore
    ) -> UpdateStatusRow | None:
        async with semaphore:
            zendesk_id = old_request.zendesk_id
            # Skip if the ticket wasn't found in the Zendesk response
            if zendesk_id not in new_tickets_map:
                logger.warning(
                    f"Ticket {zendesk_id} not found in Zendesk API response"
                )
                return

            ticket = new_tickets_map[zendesk_id]
            ticket_status = ticket["status"]
            updated_at = datetime.fromisoformat(
                ticket["updated_at"].replace("Z", "+00:00")
            )

            # Determine new local status based on Zendesk ticket status
            new_status = {
                "new": "pending",
                "solved": "completed",
                "closed": "completed",
            }.get(ticket_status, "in_progress")

            # Update local status if it has changed
            if new_status and new_status != old_request.status:
                logger.info(
                    f"Updating ticket {zendesk_id} status from"
                    f" '{old_request.status}' to '{new_status}'"
                )

                # If transitioning to completed, remove the SSH key
                if new_status == "completed":
                    logger.info(
                        f"Removing SSH key for user '{old_request.username}'"
                    )
                    await asyncio.to_thread(
                        remove_pub_key, old_request.username
                    )

                return UpdateStatusRow(zendesk_id, new_status, updated_at)

    @staticmethod
    def _update_db_statuses(rows: [UpdateStatusRow | None]):
        for ticket in rows:
            if not ticket:
                continue
            AnalystCleanupRequest.update_status(
                ticket.zendesk_id, ticket.new_status, ticket.updated_at
            )

    async def _update_task(self):
        """
        Gets all active and recently closed requests (for case if reopened).
        And asks all the requests status from zendesk API.
        Updates the state of the tickets in the database if changed.
        If any completed tickets, removes public key from relevant user.
        """
        try:
            current_requests = (
                AnalystCleanupRequest.get_all_relevant_requests()
            )

            # Skip if there are no requests to check
            if not current_requests:
                logger.info(
                    "No relevant analyst cleanup requests found to update"
                )
                return
        except OperationalError as e:
            if "no such table" in str(e):
                logger.info("Database hasn't been updated yet")
            else:
                logger.error(
                    f"Can't get data from analyst cleanup  table: {e}"
                )
            return

        # Extract Zendesk IDs from the requests
        zendesk_ids = [request.zendesk_id for request in current_requests]

        try:
            # Get ticket status updates from Zendesk API
            new_tickets = await AnalystCleanupAPI.get_tickets(zendesk_ids)
            if not new_tickets:
                logger.warning(
                    "Didn't get tickets info from imunifyAPI but expected"
                )
                return
            # Map from zendesk_id to ticket for easier lookup
            new_tickets_map = {
                str(ticket["id"]): ticket for ticket in new_tickets
            }
            # Process each request
            semaphore = asyncio.Semaphore(5)
            tasks = [
                self._process(old_request, new_tickets_map, semaphore)
                for old_request in current_requests
            ]

            results = await asyncio.gather(*tasks)
            # Update the ticket status in the database
            await asyncio.to_thread(self._update_db_statuses, results)

        except IAIDTokenError as e:
            logger.error(f"IAIDTokenError: {e}")
        except Exception as e:
            logger.error(f"Error updating analyst cleanup requests: {e}")
