import json
import logging
import os
import pwd
from abc import abstractmethod
from contextlib import suppress
from textwrap import dedent
from typing import Mapping, Optional, Protocol

import sentry_sdk
import yaml

from defence360agent.utils import atomic_rewrite

logger = logging.getLogger(__name__)

# Don't read config if its file is larger than this.
_MAX_CONFIG_SIZE = 1 << 20  # 1MiB


class IConfigProvider(Protocol):
    @abstractmethod
    def read_config_file(
        self, force_read: bool = False, ignore_errors: bool = True
    ):
        raise NotImplementedError

    @abstractmethod
    def write_config_file(self, config: Mapping) -> None:
        raise NotImplementedError

    @abstractmethod
    def modified_since(self, timestamp: Optional[float]) -> bool:
        raise NotImplementedError


class ConfigError(Exception):
    pass


class JsonMessage:
    """Pretty-print given *obj* as JSON.

    To be used for logging. Example:

      logging.info("object: %s", JsonMessage(obj))

    """

    def __init__(self, obj):
        self._obj = obj

    def __str__(self):
        return json.dumps(self._obj, sort_keys=True)


def diff_section(prev_section: Optional[dict], section: Optional[dict]):
    """Return difference between config sections."""
    prev_section = prev_section or {}
    section = section or {}
    removed_settings = prev_section.keys() - section.keys()
    added_settings = section.keys() - prev_section.keys()
    return {
        "-": {v: prev_section[v] for v in removed_settings},
        "+": {v: section[v] for v in added_settings},
        # modified settings
        "?": {
            v: (prev_section[v], section[v])
            for v in (prev_section.keys() & section.keys())
            if prev_section[v] != section[v]
        },
    }


def diff_config(prev_conf: dict, conf: dict):
    """Compare *prev_conf* with the current *conf*."""
    removed_sections = prev_conf.keys() - conf.keys()
    yield {section: prev_conf[section] for section in removed_sections}
    added_sections = conf.keys() - prev_conf.keys()
    yield {section: conf[section] for section in added_sections}
    # changed sections
    yield {
        section: diff_section(prev_conf[section], conf[section])
        for section in (prev_conf.keys() & conf.keys())
        if prev_conf[section] != conf[section]
    }


def exclude_equals(*, main_conf: dict, base_conf: dict) -> dict:
    """
    Return dict derived from *main_conf* excluding parts
    that are equal in *base_conf*.
    For example,
    >>> base_conf = {
        "SECTION1": {"OPTION1": "default", "OPTION2": "default"},
        "SECTION2": {"OPTION1": "default"}
    }
    >>> main_conf = {
        "SECTION1": {"OPTION1": "value", "OPTION2": "default"},
        "SECTION2": {"OPTION1": "default"}
    }
    >>>
    >>> exclude_equals(main_conf=main_conf, base_conf=base_conf)
    {'SECTION1': {'OPTION1': 'value'}}
    >>>
    """
    _, added, changed = diff_config(base_conf, main_conf)
    result = {}
    for section, value in main_conf.items():
        if section in added.keys():
            result[section] = value
        if section in changed.keys():
            result.setdefault(section, {}).update(changed[section]["+"])
            result.setdefault(section, {}).update(
                {k: v[1] for k, v in changed[section]["?"].items()}
            )
    return result


class ConfigReader:
    """
    ConfigFile file for settings page.
    Location config file is PATH
    """

    def __init__(self, path, disclaimer="", permissions=None):
        self.path = path
        self.disclaimer = disclaimer
        self.permissions = permissions

    def __repr__(self):
        return "<{classname}({path})>".format(
            classname=self.__class__.__qualname__, path=self.path
        )

    def __str__(self):
        return f"ConfigReader at {self.path}"

    def read_config_file(
        self, force_read: bool = False, ignore_errors: bool = True
    ) -> dict:
        """Read config file into memory.

        Raises ConfigError.
        """
        try:
            if os.path.getsize(self.path) > _MAX_CONFIG_SIZE:
                raise ConfigError("Config file is too large")
            filename = self.path
            with open(filename, "r") as config_file:
                logger.info("Reading config file %s", filename)
                text = config_file.read()
        except UnicodeDecodeError as e:
            raise ConfigError("Unable to decode config file") from e
        except FileNotFoundError:
            return {}
        try:
            return self.load_config_body(text)
        except ConfigError as e:
            logger.error(e)
            if ignore_errors:
                return {}
            raise e

    def load_config_body(self, text: str) -> dict:
        try:
            config = yaml.safe_load(text)
        except yaml.YAMLError as e:
            raise ConfigError(
                f"Imunify360 config is not valid YAML document ({e})"
            ) from e

        if config is None:
            return {}

        if not isinstance(config, dict):
            raise ConfigError(
                "Imunify360 config is invalid or empty"
                ": path={!r}, text={!r}".format(self.path, text)
            )

        return config

    def _pre_write(self):
        pass

    def _post_write(self):
        pass

    def write_config_file(self, config) -> str:
        self._pre_write()
        config_text = ""
        if self.disclaimer:
            config_text += dedent(self.disclaimer)
            config_text += "\n"
        config_text += yaml.dump(config, default_flow_style=False)
        atomic_rewrite(
            self.path, config_text, backup=False, permissions=self.permissions
        )
        self._post_write()
        return config_text

    def modified_since(self, timestamp: Optional[float]) -> bool:
        return True


class CachedConfigReader(ConfigReader):
    def __init__(self, path, disclaimer="", permissions=None):
        super().__init__(path, disclaimer)
        self.mtime: Optional[float] = None
        self.size: Optional[float] = None
        self._config = {}
        self.permissions = permissions

    def __str__(self):
        return (
            "{classname} <'{path}', modified at {mtime}, {size} bytes>".format(
                classname=self.__class__.__qualname__,
                path=self.path,
                mtime=self.mtime,
                size=self.size,
            )
        )

    def read_config_file(
        self, force_read: bool = False, ignore_errors: bool = True
    ):
        """Update config if config file is modified"""
        if self.modified_since(self.mtime) or force_read:
            prev_config = self._config
            try:
                self._config = super().read_config_file(
                    ignore_errors=ignore_errors
                )
            except ConfigError as error:
                sentry_sdk.capture_exception(error)

                logger.warning(
                    "%s is invalid, using previous settings: %s",
                    self,
                    JsonMessage(self._config),
                )
                if not ignore_errors:
                    raise error
            else:
                if self.mtime is not None:  # don't log on startup
                    diffs = list(diff_config(prev_config, self._config))
                    if any(diffs):
                        # content has changed, log it
                        logger.info(
                            "%s modified: removed=%s, added=%s, changed=%s",
                            self,
                            *map(JsonMessage, diffs),
                        )

            try:
                stat = os.stat(self.path)
                self.mtime = stat.st_mtime
                self.size = stat.st_size
            except FileNotFoundError:
                self.mtime = 0.0
                self.size = 0.0

        return self._config

    def modified_since(self, timestamp: Optional[float]) -> bool:
        """Whether the config has updated since *timestamp*.

        (as defined by its last modification time and size)
        :param timestamp: None means that the file has never been read before
        """
        # On startup consider timestamp to be None
        if timestamp is None:
            timestamp = 0.0
        try:
            stat = os.stat(self.path)
        except FileNotFoundError:
            st_mtime, st_size = 0.0, 0.0
        else:
            st_mtime, st_size = stat.st_mtime, stat.st_size
        return st_mtime > timestamp or st_size != self.size


class WriteOnlyConfigReader(CachedConfigReader):
    def read_config_file(self, *_, **__):
        return self._config

    def write_config_file(self, config):
        config_text = super().write_config_file(config)
        self._config = self.load_config_body(config_text)
        return config_text


class UserConfigReader(CachedConfigReader):
    def __init__(self, path, username):
        super().__init__(path)
        self.username = username

    def __str__(self):
        return f"Config of user {self.username}"

    def _pre_write(self):
        confdir = os.path.dirname(self.path)
        with suppress(FileExistsError):
            os.mkdir(os.path.dirname(self.path))
        os.chown(confdir, 0, pwd.getpwnam(self.username).pw_gid)
        os.chmod(confdir, 0o750)

    def _post_write(self):
        os.chown(self.path, 0, pwd.getpwnam(self.username).pw_gid)
        os.chmod(self.path, 0o640)
