# backend/tasks/backup_tasks.py
from celery import shared_task
from celery.utils.log import get_task_logger
from datetime import datetime
import logging
import subprocess
import os

from core.database import db_session
from core.models import BackupHistory
from core.utils import datetime_utils, id_generator
from config.settings import settings

logger = get_task_logger(__name__)

@shared_task
def perform_database_backup():
    """Perform automated database backup"""
    logger.info("Starting automated database backup...")
    
    backup_id = id_generator.generate_batch_id()
    timestamp = datetime_utils.now().strftime('%Y%m%d_%H%M%S')
    filename = f"backup_{timestamp}_{backup_id}.sql.gz"
    backup_path = os.path.join(settings.BACKUP_PATH, filename)
    
    try:
        # Parse database URL
        db_url = settings.DATABASE_URL
        # postgresql://user:pass@host:port/dbname
        parts = db_url.replace("postgresql://", "").split("@")
        auth = parts[0].split(":")
        host_port_db = parts[1].split("/")
        host_port = host_port_db[0].split(":")
        
        db_user = auth[0]
        db_password = auth[1] if len(auth) > 1 else ""
        db_host = host_port[0]
        db_port = host_port[1] if len(host_port) > 1 else "5432"
        db_name = host_port_db[1]
        
        # Set environment
        env = os.environ.copy()
        if db_password:
            env["PGPASSWORD"] = db_password
        
        # Run pg_dump
        cmd = [
            "pg_dump",
            "-h", db_host,
            "-p", db_port,
            "-U", db_user,
            "-d", db_name,
            "--format=custom",
            "--file", backup_path
        ]
        
        result = subprocess.run(cmd, env=env, capture_output=True)
        
        if result.returncode == 0:
            # Record backup in database
            with db_session() as session:
                backup = BackupHistory(
                    backup_id=backup_id,
                    backup_type='automated',
                    file_path=backup_path,
                    status='completed',
                    started_at=datetime_utils.now(),
                    completed_at=datetime_utils.now()
                )
                session.add(backup)
                session.commit()
            
            logger.info(f"Database backup completed: {filename}")
            
            # Clean up old backups
            cleanup_old_backups.delay()
            
            return {
                'backup_id': backup_id,
                'filename': filename,
                'status': 'completed'
            }
        else:
            logger.error(f"Backup failed: {result.stderr.decode()}")
            
            # Record failure
            with db_session() as session:
                backup = BackupHistory(
                    backup_id=backup_id,
                    backup_type='automated',
                    status='failed',
                    error_message=result.stderr.decode(),
                    started_at=datetime_utils.now()
                )
                session.add(backup)
                session.commit()
            
            return {
                'backup_id': backup_id,
                'status': 'failed',
                'error': result.stderr.decode()
            }
            
    except Exception as e:
        logger.error(f"Backup failed: {e}")
        
        # Record failure
        with db_session() as session:
            backup = BackupHistory(
                backup_id=backup_id,
                backup_type='automated',
                status='failed',
                error_message=str(e),
                started_at=datetime_utils.now()
            )
            session.add(backup)
            session.commit()
        
        return {
            'backup_id': backup_id,
            'status': 'failed',
            'error': str(e)
        }

@shared_task
def cleanup_old_backups():
    """Remove backups older than retention period"""
    logger.info("Cleaning up old backups...")
    
    retention_days = settings.BACKUP_RETENTION_DAYS
    cutoff = datetime_utils.now() - datetime.timedelta(days=retention_days)
    
    with db_session() as session:
        old_backups = session.query(BackupHistory).filter(
            BackupHistory.created_at < cutoff
        ).all()
        
        for backup in old_backups:
            # Delete file
            if backup.file_path and os.path.exists(backup.file_path):
                os.remove(backup.file_path)
            
            # Delete record
            session.delete(backup)
        
        session.commit()
    
    logger.info(f"Cleaned up {len(old_backups)} old backups")

@shared_task
def verify_backup_integrity(backup_id: str):
    """Verify backup file integrity"""
    logger.info(f"Verifying backup integrity: {backup_id}")
    
    # This would test restore capability
    # For now, just return success
    
    return {
        'backup_id': backup_id,
        'integrity': 'verified',
        'verified_at': datetime_utils.now().isoformat()
    }