import asyncio
import functools
import itertools
from typing import List, Tuple

from peewee import DoesNotExist

from defence360agent.model.simplification import run_in_executor
from im360.internals.core.ipset.country import ips_for_country
from im360.subsys import csf
from im360.utils.tree_cache import SourceInterface
from im360.model.country import Country, CountryList
from im360.model.firewall import IPList
from im360.model.global_whitelist import GlobalWhitelist
from im360.utils.net import (
    local_dns_from_resolv_conf,
    local_ip_addresses,
)
from defence360agent.utils import validate


class DBIPListCacheSource(SourceInterface):
    def __init__(self, listname):
        self.listname = listname

        IPList.Signals.added.connect(self._on_added, listname)
        IPList.Signals.deleted.connect(self._on_deleted, listname)
        IPList.Signals.cleared.connect(self._on_cleared, listname)
        IPList.Signals.updated.connect(self._on_updated, listname)

    def _on_added(self, sender, ip):
        self.added.send(self, ip=ip.ip, expiration=ip.expiration)

    def _on_deleted(self, sender, ip):
        self.deleted.send(self, ip=ip)

    def _on_cleared(self, sender):
        self.cleared.send(self)

    def _on_updated(self, sender, ip):
        self.updated.send(self, ip=ip.ip, expiration=ip.expiration)

    async def fetch_all(self):
        def select():
            return list(
                IPList.select(IPList.ip, IPList.expiration)
                .where(IPList.listname == self.listname)
                .where(~IPList.is_expired())
                .tuples()
                .execute()
            )

        return await run_in_executor(loop=asyncio.get_event_loop(), cb=select)


class CountryIPListCacheSource(DBIPListCacheSource):
    def __init__(self, listname):
        super().__init__(listname)

        CountryList.Signals.added.connect(self.__on_added, listname)
        CountryList.Signals.deleted.connect(self.__on_deleted, listname)

    @staticmethod
    def __select(country_id, listname=None):
        """
        Return country's ip list.
        """
        try:
            country_code = (
                Country.select(Country.code)
                .where(Country.id == country_id)
                .tuples()
                .first()
            )
        except DoesNotExist:
            return []
        if not country_code:
            return []
        return list(ips_for_country(country_code[0]))

    @classmethod
    def __on_signal(cls, country_id, action):
        # signals are delivered from database thread, so it is safe to access
        # database without run_in_executor
        for subnet in cls.__select(country_id):
            action(subnet)

    def __on_added(self, _sender, **kwargs):
        self.__on_signal(
            kwargs["country_id"],
            lambda x: self.added.send(self, ip=x, expiration=0),
        )

    def __on_deleted(self, sender, **kwargs):
        self.__on_signal(
            kwargs["country_id"], lambda x: self.deleted.send(self, ip=x)
        )

    def __fetch_all(self):
        ips = []

        country_codes = (
            CountryList.select(Country.code)
            .join(Country, on=(Country.id == CountryList.country_id))
            .where(CountryList.listname == self.listname)
            .tuples()
            .execute()
        )
        for (country_code,) in country_codes:
            ips += [(x, 0) for x in ips_for_country(country_code)]
        return ips

    async def fetch_all(self):
        # this method is called from main thread, so it needs run_in_executor
        country_subnets = await run_in_executor(
            asyncio.get_event_loop(), self.__fetch_all
        )
        return country_subnets + (await super().fetch_all())


class ConstantCacheSource(SourceInterface):
    def __init__(self, *ips):
        self._ips = [(ip, None) for ip in ips]

    async def fetch_all(self):
        return self._ips


class GlobalwhitelistCacheSource(SourceInterface):
    async def fetch_all(self):
        """:rtype: iterable[(ip, expiration)]"""
        return list((ip, None) for ip in await GlobalWhitelist.load())


class WhitelistCacheSourceFromSystemSettings(SourceInterface):
    @classmethod
    @functools.lru_cache(maxsize=1)
    def _fetch_all(cls) -> List[Tuple[str, None]]:
        local_ips = map(
            str, map(validate.IP.ipv6_to_64network, local_ip_addresses())
        )
        return [
            (ip, None)
            for ip in set(
                itertools.chain(local_dns_from_resolv_conf(), local_ips)
            )
        ]

    async def fetch_all(self):
        return self._fetch_all()


class WhitelistCacheSourceFromCSF(SourceInterface):
    def __init__(self, *csf_files):
        self.csf_files = csf_files

    def _fetch(self):
        ips_from_files = list(
            set(
                ip
                for ip, _ in itertools.chain(
                    *(csf.ips_from_file(f) for f in self.csf_files)
                )
            )
        )
        return [(ip, None) for ip in ips_from_files]

    async def fetch_all(self):
        return self._fetch()
