- 新增完整的 Python 实现,替代 Go 版本 - 添加 Web 登录界面和仪表板 - 实现 JWT 认证和 API 密钥管理 - 添加数据库存储功能 - 保持与 Go 版本一致的目录结构和启动脚本 - 包含完整的文档和测试脚本
93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
from fastapi import HTTPException, Depends, status
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from fastapi.responses import JSONResponse, RedirectResponse
|
||
from sqlalchemy.orm import Session
|
||
from datetime import datetime, timedelta
|
||
import jwt
|
||
import secrets
|
||
import os
|
||
from typing import Optional
|
||
|
||
from ..models.database import get_db, APIKey
|
||
from ..config import settings
|
||
|
||
# JWT 配置
|
||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
|
||
JWT_ALGORITHM = "HS256"
|
||
JWT_EXPIRATION_HOURS = 24 * 7 # 7 天有效期
|
||
|
||
security = HTTPBearer()
|
||
|
||
class AuthMiddleware:
|
||
def __init__(self):
|
||
self.secret_key = JWT_SECRET_KEY
|
||
|
||
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None):
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS)
|
||
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=JWT_ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
def verify_token(self, token: str):
|
||
try:
|
||
payload = jwt.decode(token, self.secret_key, algorithms=[JWT_ALGORITHM])
|
||
return payload
|
||
except jwt.ExpiredSignatureError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Token has expired"
|
||
)
|
||
except jwt.JWTError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid token"
|
||
)
|
||
|
||
def verify_api_key(self, api_key: str, db: Session):
|
||
"""验证 API 密钥"""
|
||
db_key = db.query(APIKey).filter(APIKey.key == api_key).first()
|
||
return db_key is not None
|
||
|
||
def generate_api_key(self) -> str:
|
||
"""生成新的 API 密钥"""
|
||
return secrets.token_urlsafe(32)
|
||
|
||
# 创建认证中间件实例
|
||
auth_middleware = AuthMiddleware()
|
||
|
||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||
"""获取当前用户(JWT 认证)"""
|
||
token = credentials.credentials
|
||
payload = auth_middleware.verify_token(token)
|
||
return payload
|
||
|
||
async def get_current_user_api_key(api_key: str = Depends(security), db: Session = Depends(get_db)):
|
||
"""获取当前用户(API 密钥认证)"""
|
||
if not auth_middleware.verify_api_key(api_key.credentials, db):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid API key"
|
||
)
|
||
return {"api_key": api_key.credentials}
|
||
|
||
def require_auth(use_api_key: bool = False):
|
||
"""认证依赖装饰器"""
|
||
if use_api_key:
|
||
return get_current_user_api_key
|
||
else:
|
||
return get_current_user
|
||
|
||
def handle_auth_error(request, exc):
|
||
"""处理认证错误"""
|
||
if request.headers.get("x-requested-with") == "XMLHttpRequest":
|
||
return JSONResponse(
|
||
status_code=401,
|
||
content={"error": "Invalid or expired token"}
|
||
)
|
||
else:
|
||
return RedirectResponse(url="/login", status_code=303) |