import logging
import peewee as pw

from defence360agent.utils.validate import IP, NumericIPVersion
from im360.internals import geo

NETMASK_32 = 4294967295

logger = logging.getLogger(__name__)


def migrate(migrator, database, fake=False, **kwargs):
    """Fix broken records in ignored_by_port_proto table:

    1. Add CIDR notation (/32) to IPv4 addresses that lack it
    2. Fix incorrect data in country_id field

    Examples of broken records:
    - IPv4 without CIDR: {"ip": "1.2.3.5"} -> {"ip": "1.2.3.5/32"}
    - Country field with comments: {"country": "alman Wire"} -> {"comment": "alman Wire", "country": AU}
    """
    if fake:
        # Nothing to do in fake mode
        return

    try:
        IgnoredByPort = migrator.orm["ignored_by_port_proto"]

        with database.atomic():
            # Step 1: Fix IPv4 addresses without CIDR notation
            ipv4_records = (
                IgnoredByPort.select()
                .where(
                    (IgnoredByPort.version == NumericIPVersion[IP.V4])
                    & ~IgnoredByPort.ip.contains("/")
                    & (IgnoredByPort.netmask == NETMASK_32)
                )
                .dicts()
            )

            for record in ipv4_records:
                try:
                    cidr_ip = f"{record['ip']}/32"

                    # Check if there's a duplicate with the CIDR notation
                    duplicate_exists = (
                        IgnoredByPort.select()
                        .where(
                            (IgnoredByPort.port_proto == record["port_proto"])
                            & (IgnoredByPort.ip == cidr_ip)
                        )
                        .exists()
                    )

                    if duplicate_exists:
                        logger.info(
                            "Deleting duplicate record for IP %s (ID: %s) as"
                            " CIDR version exists",
                            record["ip"],
                            record["id"],
                        )
                        IgnoredByPort.delete().where(
                            IgnoredByPort.id == record["id"]
                        ).execute()
                    else:
                        logger.info(
                            "Updating IP %s to %s (ID: %s)",
                            record["ip"],
                            cidr_ip,
                            record["id"],
                        )
                        IgnoredByPort.update(ip=cidr_ip).where(
                            IgnoredByPort.id == record["id"]
                        ).execute()

                except pw.IntegrityError as e:
                    logger.warning("Error updating IP %s: %s", record["ip"], e)

            # Step 2: Fix country field containing comments
            country_records = (
                IgnoredByPort.select()
                .where(
                    (IgnoredByPort.country.is_null(False))
                    & (pw.fn.LENGTH(IgnoredByPort.country) > 2)
                )
                .dicts()
            )
            with geo.reader() as geo_reader:
                for record in country_records:
                    try:
                        country_value = record.get("country")
                        comment_value = record.get("comment")

                        if (
                            comment_value and comment_value == country_value
                        ) or not country_value:
                            new_country = geo_reader.get_id(record["ip"])
                            logger.info(
                                "Updating country value '%s' to'%s'"
                                " for record ID: %s",
                                country_value,
                                new_country,
                                record["id"],
                            )
                            IgnoredByPort.update(country=new_country).where(
                                IgnoredByPort.id == record["id"]
                            ).execute()

                    except pw.IntegrityError as e:
                        logger.warning(
                            "Error fixing country field for record ID %s: %s",
                            record["id"],
                            e,
                        )

    except Exception as e:
        logger.error(
            "Error in migration fix_ignored_by_port_proto: %r",
            e,
        )


def rollback(migrator, database, fake=False, **kwargs):
    """Write your rollback migrations here."""
    pass
