from fastapi import HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import JSONResponse, RedirectResponse from sqlalchemy.orm import Session from datetime import datetime, timedelta import jwt import secrets import os from typing import Optional from ..models.database import get_db, APIKey from ..config import settings # JWT 配置 JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production") JWT_ALGORITHM = "HS256" JWT_EXPIRATION_HOURS = 24 * 7 # 7 天有效期 security = HTTPBearer() class AuthMiddleware: def __init__(self): self.secret_key = JWT_SECRET_KEY def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=JWT_ALGORITHM) return encoded_jwt def verify_token(self, token: str): try: payload = jwt.decode(token, self.secret_key, algorithms=[JWT_ALGORITHM]) return payload except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired" ) except jwt.JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" ) def verify_api_key(self, api_key: str, db: Session): """验证 API 密钥""" db_key = db.query(APIKey).filter(APIKey.key == api_key).first() return db_key is not None def generate_api_key(self) -> str: """生成新的 API 密钥""" return secrets.token_urlsafe(32) # 创建认证中间件实例 auth_middleware = AuthMiddleware() async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): """获取当前用户(JWT 认证)""" token = credentials.credentials payload = auth_middleware.verify_token(token) return payload async def get_current_user_api_key(api_key: str = Depends(security), db: Session = Depends(get_db)): """获取当前用户(API 密钥认证)""" if not auth_middleware.verify_api_key(api_key.credentials, db): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" ) return {"api_key": api_key.credentials} def require_auth(use_api_key: bool = False): """认证依赖装饰器""" if use_api_key: return get_current_user_api_key else: return get_current_user def handle_auth_error(request, exc): """处理认证错误""" if request.headers.get("x-requested-with") == "XMLHttpRequest": return JSONResponse( status_code=401, content={"error": "Invalid or expired token"} ) else: return RedirectResponse(url="/login", status_code=303)