# backend/api/tasks/backup_tasks.py
from celery import shared_task
from celery.utils.log import get_task_logger
from datetime import datetime, timedelta
import os
import subprocess
import gzip
import shutil
from pathlib import Path

from core.database import SessionLocal
from core.models import BackupHistory, Admin, TenantRegistry
from core.utils.datetime_utils import datetime_utils
from core.utils.id_generator import id_generator
from core.config import settings

logger = get_task_logger(__name__)


@shared_task
def perform_database_backup():
    """Perform automated database backup"""
    logger.info("Starting automated database backup...")
    
    backup_id = id_generator.generate_batch_id()
    timestamp = datetime_utils.now().strftime('%Y%m%d_%H%M%S')
    filename = f"backup_{timestamp}_{backup_id}.sql.gz"
    backup_path = os.path.join(settings.BACKUP_PATH, filename)
    
    # Create backup directory if it doesn't exist
    os.makedirs(settings.BACKUP_PATH, exist_ok=True)
    
    try:
        # Parse database URL
        db_url = settings.DATABASE_URL
        # postgresql://user:pass@host:port/dbname
        parts = db_url.replace("postgresql://", "").split("@")
        auth = parts[0].split(":")
        host_port_db = parts[1].split("/")
        host_port = host_port_db[0].split(":")
        
        db_user = auth[0]
        db_password = auth[1] if len(auth) > 1 else ""
        db_host = host_port[0]
        db_port = host_port[1] if len(host_port) > 1 else "5432"
        db_name = host_port_db[1]
        
        # Set environment
        env = os.environ.copy()
        if db_password:
            env["PGPASSWORD"] = db_password
        
        # Create temporary SQL file
        temp_sql = backup_path.replace('.gz', '')
        
        # Run pg_dump
        cmd = [
            "pg_dump",
            "-h", db_host,
            "-p", db_port,
            "-U", db_user,
            "-d", db_name,
            "--format=custom",
            "--file", temp_sql
        ]
        
        logger.info(f"Running backup command: {' '.join(cmd)}")
        result = subprocess.run(cmd, env=env, capture_output=True)
        
        if result.returncode == 0:
            # Compress the backup
            with open(temp_sql, 'rb') as f_in:
                with gzip.open(backup_path, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            
            # Remove temporary file
            os.remove(temp_sql)
            
            # Get file size
            file_size = os.path.getsize(backup_path)
            
            # Record backup in database
            db = SessionLocal()
            try:
                backup = BackupHistory(
                    backup_id=backup_id,
                    backup_type='automated',
                    file_path=backup_path,
                    file_size=file_size,
                    status='completed',
                    started_at=datetime_utils.now() - timedelta(seconds=10),
                    completed_at=datetime_utils.now()
                )
                db.add(backup)
                db.commit()
            finally:
                db.close()
            
            logger.info(f"Database backup completed: {filename} ({file_size / 1024 / 1024:.2f} MB)")
            
            # Clean up old backups
            cleanup_old_backups.delay()
            
            return {
                'backup_id': backup_id,
                'filename': filename,
                'size': file_size,
                'size_mb': file_size / 1024 / 1024,
                'status': 'completed',
                'timestamp': datetime_utils.now().isoformat()
            }
        else:
            error_msg = result.stderr.decode() if result.stderr else "Unknown error"
            logger.error(f"Backup failed: {error_msg}")
            
            # Record failure
            db = SessionLocal()
            try:
                backup = BackupHistory(
                    backup_id=backup_id,
                    backup_type='automated',
                    status='failed',
                    error_message=error_msg,
                    started_at=datetime_utils.now()
                )
                db.add(backup)
                db.commit()
            finally:
                db.close()
            
            return {
                'backup_id': backup_id,
                'status': 'failed',
                'error': error_msg
            }
            
    except Exception as e:
        logger.error(f"Backup failed: {e}")
        
        # Record failure
        db = SessionLocal()
        try:
            backup = BackupHistory(
                backup_id=backup_id,
                backup_type='automated',
                status='failed',
                error_message=str(e),
                started_at=datetime_utils.now()
            )
            db.add(backup)
            db.commit()
        finally:
            db.close()
        
        return {
            'backup_id': backup_id,
            'status': 'failed',
            'error': str(e)
        }


@shared_task
def cleanup_old_backups():
    """Remove backups older than retention period"""
    logger.info("Cleaning up old backups...")
    
    retention_days = settings.BACKUP_RETENTION_DAYS
    cutoff = datetime_utils.now() - timedelta(days=retention_days)
    removed = 0
    freed_space = 0
    
    db = SessionLocal()
    try:
        old_backups = db.query(BackupHistory).filter(
            BackupHistory.created_at < cutoff
        ).all()
        
        for backup in old_backups:
            # Delete file
            if backup.file_path and os.path.exists(backup.file_path):
                file_size = os.path.getsize(backup.file_path)
                os.remove(backup.file_path)
                freed_space += file_size
                removed += 1
                logger.info(f"Removed backup file: {backup.file_path}")
            
            # Delete record
            db.delete(backup)
        
        db.commit()
        
    finally:
        db.close()
    
    logger.info(f"Cleaned up {removed} old backups, freed {freed_space / 1024 / 1024:.2f} MB")
    return {
        'removed': removed,
        'freed_space_mb': freed_space / 1024 / 1024,
        'timestamp': datetime_utils.now().isoformat()
    }


@shared_task
def verify_backup_integrity(backup_id: str):
    """Verify backup file integrity"""
    logger.info(f"Verifying backup integrity: {backup_id}")
    
    db = SessionLocal()
    try:
        backup = db.query(BackupHistory).filter(
            BackupHistory.backup_id == backup_id
        ).first()
        
        if not backup or not backup.file_path:
            return {'backup_id': backup_id, 'integrity': 'not_found'}
        
        if not os.path.exists(backup.file_path):
            return {'backup_id': backup_id, 'integrity': 'file_missing'}
        
        # Check file size
        file_size = os.path.getsize(backup.file_path)
        if file_size == 0:
            return {'backup_id': backup_id, 'integrity': 'corrupted', 'reason': 'empty_file'}
        
        # Try to read the gzip file
        try:
            with gzip.open(backup.file_path, 'rb') as f:
                # Read first few bytes to verify
                f.read(1024)
        except Exception as e:
            return {'backup_id': backup_id, 'integrity': 'corrupted', 'reason': str(e)}
        
        # Update backup record
        backup.metadata = backup.metadata or {}
        backup.metadata['verified_at'] = datetime_utils.now().isoformat()
        backup.metadata['verified'] = True
        db.commit()
        
        logger.info(f"Backup {backup_id} integrity verified")
        
        return {
            'backup_id': backup_id,
            'integrity': 'verified',
            'size': file_size,
            'size_mb': file_size / 1024 / 1024,
            'verified_at': datetime_utils.now().isoformat()
        }
        
    finally:
        db.close()


@shared_task
def create_tenant_backup(admin_id: int, reserve_id: str):
    """Create backup for a specific tenant"""
    logger.info(f"Creating tenant backup for {reserve_id}")
    
    backup_id = id_generator.generate_batch_id()
    timestamp = datetime_utils.now().strftime('%Y%m%d_%H%M%S')
    filename = f"tenant_{reserve_id}_{timestamp}_{backup_id}.sql.gz"
    backup_path = os.path.join(settings.BACKUP_PATH, 'tenants', filename)
    
    os.makedirs(os.path.join(settings.BACKUP_PATH, 'tenants'), exist_ok=True)
    
    try:
        db = SessionLocal()
        try:
            tenant = db.query(TenantRegistry).filter(
                TenantRegistry.admin_id == admin_id
            ).first()
            
            if not tenant:
                return {'error': 'Tenant not found', 'reserve_id': reserve_id}
            
            # Build tenant database URL
            tenant_url = f"postgresql://{tenant.database_user}:{tenant.database_password}@{tenant.database_host}:{tenant.database_port}/{tenant.database_name}"
            
            # Parse tenant URL
            parts = tenant_url.replace("postgresql://", "").split("@")
            auth = parts[0].split(":")
            host_port_db = parts[1].split("/")
            host_port = host_port_db[0].split(":")
            
            db_user = auth[0]
            db_password = auth[1] if len(auth) > 1 else ""
            db_host = host_port[0]
            db_port = host_port[1] if len(host_port) > 1 else "5432"
            db_name = host_port_db[1]
            
            # Set environment
            env = os.environ.copy()
            if db_password:
                env["PGPASSWORD"] = db_password
            
            # Create temporary SQL file
            temp_sql = backup_path.replace('.gz', '')
            
            # Run pg_dump
            cmd = [
                "pg_dump",
                "-h", db_host,
                "-p", db_port,
                "-U", db_user,
                "-d", db_name,
                "--format=custom",
                "--file", temp_sql
            ]
            
            result = subprocess.run(cmd, env=env, capture_output=True)
            
            if result.returncode == 0:
                # Compress
                with open(temp_sql, 'rb') as f_in:
                    with gzip.open(backup_path, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)
                
                os.remove(temp_sql)
                file_size = os.path.getsize(backup_path)
                
                logger.info(f"Tenant backup completed for {reserve_id}: {file_size / 1024 / 1024:.2f} MB")
                
                return {
                    'backup_id': backup_id,
                    'reserve_id': reserve_id,
                    'filename': filename,
                    'size_mb': file_size / 1024 / 1024,
                    'status': 'completed'
                }
            else:
                return {
                    'backup_id': backup_id,
                    'reserve_id': reserve_id,
                    'status': 'failed',
                    'error': result.stderr.decode() if result.stderr else 'Unknown error'
                }
                
        finally:
            db.close()
            
    except Exception as e:
        logger.error(f"Tenant backup failed for {reserve_id}: {e}")
        return {
            'backup_id': backup_id,
            'reserve_id': reserve_id,
            'status': 'failed',
            'error': str(e)
        }


@shared_task
def restore_from_backup(backup_id: str):
    """Restore database from backup"""
    logger.info(f"Restoring from backup: {backup_id}")
    
    db = SessionLocal()
    try:
        backup = db.query(BackupHistory).filter(
            BackupHistory.backup_id == backup_id
        ).first()
        
        if not backup or not backup.file_path:
            return {'error': 'Backup not found', 'backup_id': backup_id}
        
        if not os.path.exists(backup.file_path):
            return {'error': 'Backup file missing', 'backup_id': backup_id}
        
        # Parse database URL
        db_url = settings.DATABASE_URL
        parts = db_url.replace("postgresql://", "").split("@")
        auth = parts[0].split(":")
        host_port_db = parts[1].split("/")
        host_port = host_port_db[0].split(":")
        
        db_user = auth[0]
        db_password = auth[1] if len(auth) > 1 else ""
        db_host = host_port[0]
        db_port = host_port[1] if len(host_port) > 1 else "5432"
        db_name = host_port_db[1]
        
        # Set environment
        env = os.environ.copy()
        if db_password:
            env["PGPASSWORD"] = db_password
        
        # Decompress backup
        temp_sql = backup.file_path.replace('.gz', '')
        with gzip.open(backup.file_path, 'rb') as f_in:
            with open(temp_sql, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        
        try:
            # Drop and recreate database
            drop_cmd = ["dropdb", "-h", db_host, "-p", db_port, "-U", db_user, "--if-exists", db_name]
            subprocess.run(drop_cmd, env=env, check=True, capture_output=True)
            
            create_cmd = ["createdb", "-h", db_host, "-p", db_port, "-U", db_user, db_name]
            subprocess.run(create_cmd, env=env, check=True, capture_output=True)
            
            # Restore
            restore_cmd = ["pg_restore", "-h", db_host, "-p", db_port, "-U", db_user, "-d", db_name, temp_sql]
            result = subprocess.run(restore_cmd, env=env, capture_output=True)
            
            if result.returncode == 0:
                logger.info(f"Restore completed successfully for backup {backup_id}")
                return {
                    'backup_id': backup_id,
                    'status': 'restored',
                    'timestamp': datetime_utils.now().isoformat()
                }
            else:
                error_msg = result.stderr.decode() if result.stderr else "Unknown error"
                return {
                    'backup_id': backup_id,
                    'status': 'failed',
                    'error': error_msg
                }
                
        finally:
            # Clean up temp file
            if os.path.exists(temp_sql):
                os.remove(temp_sql)
                
    finally:
        db.close()