# backend/shared/middleware/rate_limit.py
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
import time
import logging
from typing import Dict, Tuple
from collections import defaultdict

from config.settings import settings

logger = logging.getLogger(__name__)

class RateLimitMiddleware(BaseHTTPMiddleware):
    """Rate limiting middleware"""
    
    def __init__(self, app: ASGIApp):
        super().__init__(app)
        self.rate_limits: Dict[str, list] = defaultdict(list)
        self.limits = {
            'default': (settings.RATE_LIMIT_DEFAULT, 60),
            'auth': (10, 60),  # 10 requests per minute
            'api': (100, 60),  # 100 requests per minute
        }
    
    async def dispatch(self, request: Request, call_next):
        # Skip rate limiting for certain paths
        if request.url.path.startswith('/health'):
            return await call_next(request)
        
        # Get client IP
        client_ip = request.client.host
        
        # Determine rate limit based on path
        if request.url.path.startswith('/api/v1/auth'):
            limit_key = 'auth'
        elif request.url.path.startswith('/api/v1'):
            limit_key = 'api'
        else:
            limit_key = 'default'
        
        limit, period = self.limits.get(limit_key, self.limits['default'])
        
        # Check rate limit
        now = time.time()
        self.rate_limits[client_ip] = [
            t for t in self.rate_limits[client_ip] 
            if t > now - period
        ]
        
        if len(self.rate_limits[client_ip]) >= limit:
            logger.warning(f"Rate limit exceeded for {client_ip} on {request.url.path}")
            raise HTTPException(status_code=429, detail="Rate limit exceeded")
        
        self.rate_limits[client_ip].append(now)
        
        return await call_next(request)