93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
from fastapi import HTTPException, Depends, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
from sqlalchemy.orm import Session
|
|
from datetime import datetime, timedelta
|
|
import jwt
|
|
import secrets
|
|
import os
|
|
from typing import Optional
|
|
|
|
from ..models.database import get_db, APIKey
|
|
from ..config import settings
|
|
|
|
# JWT configuration
|
|
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
|
|
JWT_ALGORITHM = "HS256"
|
|
JWT_EXPIRATION_HOURS = 24 * 7 # 7 days expiration
|
|
|
|
security = HTTPBearer()
|
|
|
|
class AuthMiddleware:
|
|
def __init__(self):
|
|
self.secret_key = JWT_SECRET_KEY
|
|
|
|
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS)
|
|
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=JWT_ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
def verify_token(self, token: str):
|
|
try:
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[JWT_ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token has expired"
|
|
)
|
|
except jwt.JWTError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token"
|
|
)
|
|
|
|
def verify_api_key(self, api_key: str, db: Session):
|
|
"""Validate API key"""
|
|
db_key = db.query(APIKey).filter(APIKey.key == api_key).first()
|
|
return db_key is not None
|
|
|
|
def generate_api_key(self) -> str:
|
|
"""Generate a new API key"""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
# Create authentication middleware instance
|
|
auth_middleware = AuthMiddleware()
|
|
|
|
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Get current user (JWT authentication)"""
|
|
token = credentials.credentials
|
|
payload = auth_middleware.verify_token(token)
|
|
return payload
|
|
|
|
async def get_current_user_api_key(api_key: str = Depends(security), db: Session = Depends(get_db)):
|
|
"""Get current user (API key authentication)"""
|
|
if not auth_middleware.verify_api_key(api_key.credentials, db):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key"
|
|
)
|
|
return {"api_key": api_key.credentials}
|
|
|
|
def require_auth(use_api_key: bool = False):
|
|
"""Authentication dependency decorator"""
|
|
if use_api_key:
|
|
return get_current_user_api_key
|
|
else:
|
|
return get_current_user
|
|
|
|
def handle_auth_error(request, exc):
|
|
"""Handle authentication error"""
|
|
if request.headers.get("x-requested-with") == "XMLHttpRequest":
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error": "Invalid or expired token"}
|
|
)
|
|
else:
|
|
return RedirectResponse(url="/login", status_code=303) |