# backend/core/utils.py
import hashlib
import secrets
import string
import re
import uuid
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
import random
import luhn
from passlib.context import CryptContext
from jose import JWTError, jwt
import shortuuid
import pytz
import logging

from config.settings import settings

logger = logging.getLogger(__name__)

# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


class SecurityUtils:
    """Security related utilities"""

    @staticmethod
    def verify_password(plain_password: str, hashed_password: str) -> bool:
        """Verify a plain password against a hash"""
        return pwd_context.verify(plain_password, hashed_password)

    @staticmethod
    def get_password_hash(password: str) -> str:
        """Hash a password"""
        return pwd_context.hash(password)

    @staticmethod
    def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
        """Create JWT access token"""
        to_encode = data.copy()
        if expires_delta:
            expire = datetime.utcnow() + expires_delta
        else:
            expire = datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
        
        to_encode.update({"exp": expire, "type": "access"})
        encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
        return encoded_jwt

    @staticmethod
    def create_refresh_token(data: Dict[str, Any]) -> str:
        """Create JWT refresh token"""
        to_encode = data.copy()
        expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
        to_encode.update({"exp": expire, "type": "refresh"})
        encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
        return encoded_jwt

    @staticmethod
    def decode_token(token: str) -> Optional[Dict[str, Any]]:
        """Decode JWT token"""
        try:
            payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
            return payload
        except JWTError as e:
            logger.error(f"Token decode error: {e}")
            return None

    @staticmethod
    def generate_api_key() -> str:
        """Generate random API key"""
        return f"tk_{secrets.token_urlsafe(32)}"

    @staticmethod
    def generate_api_secret() -> str:
        """Generate random API secret"""
        return secrets.token_urlsafe(48)

    @staticmethod
    def generate_reserve_id(prefix: str = "BOT", number: int = None) -> str:
        """Generate reserve ID for admin"""
        if number is None:
            number = random.randint(1, 999)
        return f"{prefix.upper()}-{number:03d}"

    @staticmethod
    def validate_telegram_id(telegram_id: int) -> bool:
        """Validate Telegram ID"""
        return telegram_id > 0 and telegram_id < 999999999999999


class KeyGenerator:
    """Key generation utilities"""

    PATTERN_VARIABLES = {
        "{DATE}": lambda: datetime.now().strftime("%Y%m%d"),
        "{YYYYMMDD}": lambda: datetime.now().strftime("%Y%m%d"),
        "{DDMMYYYY}": lambda: datetime.now().strftime("%d%m%Y"),
        "{RANDOM4}": lambda: KeyGenerator._generate_random(4),
        "{RANDOM6}": lambda: KeyGenerator._generate_random(6),
        "{RANDOM8}": lambda: KeyGenerator._generate_random(8),
        "{RANDOM10}": lambda: KeyGenerator._generate_random(10),
        "{CHECKSUM}": lambda: "0",  # Will be replaced with actual checksum
    }

    @staticmethod
    def _generate_random(length: int, exclude_similar: bool = True) -> str:
        """Generate random string"""
        chars = string.ascii_uppercase + string.digits
        if exclude_similar:
            chars = chars.replace('0', '').replace('O', '').replace('1', '').replace('I', '')
        return ''.join(secrets.choice(chars) for _ in range(length))

    @staticmethod
    def generate_key(pattern: str, prefix: str = "", exclude_similar: bool = True, 
                     enable_checksum: bool = False) -> str:
        """Generate key based on pattern"""
        key = pattern
        
        # Replace variables
        for var, func in KeyGenerator.PATTERN_VARIABLES.items():
            if var in key:
                if var == "{RANDOM4}" or var == "{RANDOM6}" or var == "{RANDOM8}" or var == "{RANDOM10}":
                    key = key.replace(var, KeyGenerator._generate_random(int(var[7:-1]), exclude_similar))
                else:
                    key = key.replace(var, func())
        
        # Add prefix if not in pattern
        if prefix and not key.startswith(prefix):
            key = f"{prefix}-{key}"
        
        # Add checksum if enabled
        if enable_checksum:
            key_without_checksum = key.replace("-", "").replace(" ", "")
            checksum = luhn.generate(key_without_checksum)
            key = f"{key}{checksum}"
        
        return key

    @staticmethod
    def validate_key_format(key: str, pattern: str) -> bool:
        """Validate key against pattern"""
        # Convert pattern to regex
        regex_pattern = pattern.replace("{DATE}", r"\d{8}")
        regex_pattern = regex_pattern.replace("{YYYYMMDD}", r"\d{8}")
        regex_pattern = regex_pattern.replace("{DDMMYYYY}", r"\d{8}")
        regex_pattern = regex_pattern.replace("{RANDOM4}", r"[A-Z0-9]{4}")
        regex_pattern = regex_pattern.replace("{RANDOM6}", r"[A-Z0-9]{6}")
        regex_pattern = regex_pattern.replace("{RANDOM8}", r"[A-Z0-9]{8}")
        regex_pattern = regex_pattern.replace("{RANDOM10}", r"[A-Z0-9]{10}")
        regex_pattern = regex_pattern.replace("{PREFIX}", r"[A-Z]+")
        regex_pattern = regex_pattern.replace("{CHECKSUM}", r"\d")
        
        return bool(re.match(f"^{regex_pattern}$", key))

    @staticmethod
    def generate_batch(pattern: str, count: int, prefix: str = "", 
                       exclude_similar: bool = True, enable_checksum: bool = False) -> List[str]:
        """Generate batch of keys"""
        keys = set()
        while len(keys) < count:
            key = KeyGenerator.generate_key(pattern, prefix, exclude_similar, enable_checksum)
            keys.add(key)
        return list(keys)

    @staticmethod
    def calculate_checksum(key: str) -> int:
        """Calculate Luhn checksum for key"""
        # Remove non-alphanumeric characters
        clean_key = re.sub(r'[^A-Z0-9]', '', key.upper())
        return luhn.checksum(clean_key)

    @staticmethod
    def validate_checksum(key: str) -> bool:
        """Validate key using Luhn algorithm"""
        clean_key = re.sub(r'[^A-Z0-9]', '', key.upper())
        return luhn.verify(clean_key)


class IDGenerator:
    """ID generation utilities"""

    @staticmethod
    def generate_transaction_id() -> str:
        """Generate unique transaction ID"""
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        random_part = secrets.token_hex(4).upper()
        return f"TXN{timestamp}{random_part}"

    @staticmethod
    def generate_broadcast_id() -> str:
        """Generate unique broadcast ID"""
        return f"BCST{shortuuid.uuid()[:8].upper()}"

    @staticmethod
    def generate_ticket_id() -> str:
        """Generate unique support ticket ID"""
        return f"TCKT{shortuuid.uuid()[:8].upper()}"

    @staticmethod
    def generate_batch_id() -> str:
        """Generate unique batch ID for key uploads"""
        timestamp = datetime.now().strftime("%Y%m%d")
        random_part = secrets.token_hex(3).upper()
        return f"BATCH{timestamp}{random_part}"

    @staticmethod
    def generate_referral_code(user_id: int) -> str:
        """Generate referral code for user"""
        return shortuuid.encode(user_id)[:8].upper()


class DateTimeUtils:
    """Date and time utilities"""

    @staticmethod
    def now() -> datetime:
        """Get current UTC datetime"""
        return datetime.utcnow()

    @staticmethod
    def now_tz(timezone: str = "UTC") -> datetime:
        """Get current datetime in specified timezone"""
        tz = pytz.timezone(timezone)
        return datetime.now(tz)

    @staticmethod
    def format_datetime(dt: datetime, format: str = "%Y-%m-%d %H:%M:%S") -> str:
        """Format datetime to string"""
        return dt.strftime(format)

    @staticmethod
    def parse_datetime(date_str: str, format: str = "%Y-%m-%d %H:%M:%S") -> Optional[datetime]:
        """Parse datetime from string"""
        try:
            return datetime.strptime(date_str, format)
        except ValueError:
            return None

    @staticmethod
    def get_expiry_date(days: int) -> datetime:
        """Get expiry date after specified days"""
        return datetime.utcnow() + timedelta(days=days)

    @staticmethod
    def days_until(expiry_date: datetime) -> int:
        """Get number of days until expiry date"""
        if not expiry_date:
            return 0
        delta = expiry_date - datetime.utcnow()
        return max(0, delta.days)

    @staticmethod
    def is_expired(expiry_date: datetime) -> bool:
        """Check if date is expired"""
        if not expiry_date:
            return False
        return datetime.utcnow() > expiry_date


class Validators:
    """Input validation utilities"""

    @staticmethod
    def validate_email(email: str) -> bool:
        """Validate email format"""
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        return bool(re.match(pattern, email))

    @staticmethod
    def validate_phone(phone: str) -> bool:
        """Validate phone number format"""
        pattern = r'^\+?[1-9]\d{1,14}$'
        return bool(re.match(pattern, phone))

    @staticmethod
    def validate_ip(ip: str) -> bool:
        """Validate IP address"""
        pattern = r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'
        return bool(re.match(pattern, ip))

    @staticmethod
    def validate_url(url: str) -> bool:
        """Validate URL"""
        pattern = r'^(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?$'
        return bool(re.match(pattern, url))

    @staticmethod
    def validate_amount(amount: float, min_amount: float = 0, max_amount: float = 1000000) -> bool:
        """Validate amount"""
        return min_amount <= amount <= max_amount


class FileUtils:
    """File handling utilities"""

    @staticmethod
    def get_file_extension(filename: str) -> str:
        """Get file extension"""
        return filename.split('.')[-1].lower() if '.' in filename else ''

    @staticmethod
    def generate_filename(prefix: str, extension: str) -> str:
        """Generate unique filename"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        random_part = secrets.token_hex(4)
        return f"{prefix}_{timestamp}_{random_part}.{extension}"

    @staticmethod
    def validate_file_size(size: int, max_size: int = settings.MAX_UPLOAD_SIZE) -> bool:
        """Validate file size"""
        return size <= max_size

    @staticmethod
    def validate_file_extension(filename: str, allowed_extensions: List[str] = None) -> bool:
        """Validate file extension"""
        if allowed_extensions is None:
            allowed_extensions = settings.ALLOWED_EXTENSIONS
        ext = FileUtils.get_file_extension(filename)
        return f".{ext}" in allowed_extensions


class CacheKeys:
    """Cache key definitions"""

    @staticmethod
    def user_key(user_id: int) -> str:
        return f"user:{user_id}"

    @staticmethod
    def admin_key(admin_id: int) -> str:
        return f"admin:{admin_id}"

    @staticmethod
    def key_key(key_value: str) -> str:
        return f"key:{key_value}"

    @staticmethod
    def plan_key(duration: str) -> str:
        return f"plan:{duration}"

    @staticmethod
    def token_key(token: str) -> str:
        return f"token:{token}"

    @staticmethod
    def rate_limit_key(identifier: str, endpoint: str) -> str:
        return f"ratelimit:{identifier}:{endpoint}"

    @staticmethod
    def session_key(session_id: str) -> str:
        return f"session:{session_id}"


# Initialize utilities
security = SecurityUtils()
key_generator = KeyGenerator()
id_generator = IDGenerator()
datetime_utils = DateTimeUtils()
validators = Validators()
file_utils = FileUtils()
cache_keys = CacheKeys()