- 新增完整的 Python 实现,替代 Go 版本 - 添加 Web 登录界面和仪表板 - 实现 JWT 认证和 API 密钥管理 - 添加数据库存储功能 - 保持与 Go 版本一致的目录结构和启动脚本 - 包含完整的文档和测试脚本
398 lines
14 KiB
Python
398 lines
14 KiB
Python
"""
|
|
数据库服务
|
|
实现项目映射、分支模式匹配等功能
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Optional, List, Dict, Any
|
|
from datetime import datetime
|
|
import structlog
|
|
import re
|
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, ForeignKey
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, relationship
|
|
from sqlalchemy.sql import text
|
|
|
|
from app.config import get_settings
|
|
|
|
logger = structlog.get_logger()
|
|
Base = declarative_base()
|
|
|
|
|
|
# 数据库模型
|
|
class APIKey(Base):
|
|
"""API 密钥模型"""
|
|
__tablename__ = "api_keys"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
key = Column(String(255), unique=True, nullable=False)
|
|
description = Column(Text)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
|
|
class ProjectMapping(Base):
|
|
"""项目映射模型"""
|
|
__tablename__ = "project_mappings"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
repository_name = Column(String(255), unique=True, nullable=False)
|
|
default_job = Column(String(255))
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
# 关系
|
|
branch_jobs = relationship("BranchJob", back_populates="project", cascade="all, delete-orphan")
|
|
branch_patterns = relationship("BranchPattern", back_populates="project", cascade="all, delete-orphan")
|
|
|
|
|
|
class BranchJob(Base):
|
|
"""分支任务映射模型"""
|
|
__tablename__ = "branch_jobs"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
project_id = Column(Integer, ForeignKey("project_mappings.id"), nullable=False)
|
|
branch_name = Column(String(255), nullable=False)
|
|
job_name = Column(String(255), nullable=False)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
# 关系
|
|
project = relationship("ProjectMapping", back_populates="branch_jobs")
|
|
|
|
|
|
class BranchPattern(Base):
|
|
"""分支模式映射模型"""
|
|
__tablename__ = "branch_patterns"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
project_id = Column(Integer, ForeignKey("project_mappings.id"), nullable=False)
|
|
pattern = Column(String(255), nullable=False)
|
|
job_name = Column(String(255), nullable=False)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
# 关系
|
|
project = relationship("ProjectMapping", back_populates="branch_patterns")
|
|
|
|
|
|
class TriggerLog(Base):
|
|
"""触发日志模型"""
|
|
__tablename__ = "trigger_logs"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
repository_name = Column(String(255), nullable=False)
|
|
branch_name = Column(String(255), nullable=False)
|
|
commit_sha = Column(String(255), nullable=False)
|
|
job_name = Column(String(255), nullable=False)
|
|
status = Column(String(50), nullable=False)
|
|
error_message = Column(Text)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
|
|
class DatabaseService:
|
|
"""数据库服务"""
|
|
|
|
def __init__(self):
|
|
self.settings = get_settings()
|
|
self.engine = None
|
|
self.SessionLocal = None
|
|
self._init_database()
|
|
|
|
def _init_database(self):
|
|
"""初始化数据库"""
|
|
try:
|
|
self.engine = create_engine(
|
|
self.settings.database.url,
|
|
echo=self.settings.database.echo,
|
|
pool_size=self.settings.database.pool_size,
|
|
max_overflow=self.settings.database.max_overflow
|
|
)
|
|
|
|
# 创建表
|
|
Base.metadata.create_all(bind=self.engine)
|
|
|
|
# 创建会话工厂
|
|
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
|
|
|
logger.info("Database initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to initialize database", error=str(e))
|
|
raise
|
|
|
|
def get_session(self):
|
|
"""获取数据库会话"""
|
|
return self.SessionLocal()
|
|
|
|
async def get_project_mapping(self, repository_name: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取项目映射
|
|
|
|
Args:
|
|
repository_name: 仓库名
|
|
|
|
Returns:
|
|
Dict: 项目映射信息
|
|
"""
|
|
try:
|
|
def _get_mapping():
|
|
session = self.get_session()
|
|
try:
|
|
project = session.query(ProjectMapping).filter(
|
|
ProjectMapping.repository_name == repository_name
|
|
).first()
|
|
|
|
if not project:
|
|
return None
|
|
|
|
# 构建返回数据
|
|
result = {
|
|
"id": project.id,
|
|
"repository_name": project.repository_name,
|
|
"default_job": project.default_job,
|
|
"branch_jobs": [],
|
|
"branch_patterns": []
|
|
}
|
|
|
|
# 添加分支任务映射
|
|
for branch_job in project.branch_jobs:
|
|
result["branch_jobs"].append({
|
|
"id": branch_job.id,
|
|
"branch_name": branch_job.branch_name,
|
|
"job_name": branch_job.job_name
|
|
})
|
|
|
|
# 添加分支模式映射
|
|
for pattern in project.branch_patterns:
|
|
result["branch_patterns"].append({
|
|
"id": pattern.id,
|
|
"pattern": pattern.pattern,
|
|
"job_name": pattern.job_name
|
|
})
|
|
|
|
return result
|
|
|
|
finally:
|
|
session.close()
|
|
|
|
# 在线程池中执行数据库操作
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _get_mapping)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get project mapping",
|
|
repository_name=repository_name, error=str(e))
|
|
return None
|
|
|
|
async def determine_job_name(self, repository_name: str, branch_name: str) -> Optional[str]:
|
|
"""
|
|
根据分支名确定任务名
|
|
|
|
Args:
|
|
repository_name: 仓库名
|
|
branch_name: 分支名
|
|
|
|
Returns:
|
|
str: 任务名
|
|
"""
|
|
try:
|
|
project = await self.get_project_mapping(repository_name)
|
|
if not project:
|
|
return None
|
|
|
|
# 1. 检查精确分支匹配
|
|
for branch_job in project["branch_jobs"]:
|
|
if branch_job["branch_name"] == branch_name:
|
|
logger.debug("Found exact branch match",
|
|
branch=branch_name, job=branch_job["job_name"])
|
|
return branch_job["job_name"]
|
|
|
|
# 2. 检查模式匹配
|
|
for pattern in project["branch_patterns"]:
|
|
try:
|
|
if re.match(pattern["pattern"], branch_name):
|
|
logger.debug("Branch matched pattern",
|
|
branch=branch_name, pattern=pattern["pattern"],
|
|
job=pattern["job_name"])
|
|
return pattern["job_name"]
|
|
except re.error as e:
|
|
logger.error("Invalid regex pattern",
|
|
pattern=pattern["pattern"], error=str(e))
|
|
continue
|
|
|
|
# 3. 使用默认任务
|
|
if project["default_job"]:
|
|
logger.debug("Using default job",
|
|
branch=branch_name, job=project["default_job"])
|
|
return project["default_job"]
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to determine job name",
|
|
repository_name=repository_name, branch_name=branch_name,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def log_trigger(self, log_data: Dict[str, Any]) -> bool:
|
|
"""
|
|
记录触发日志
|
|
|
|
Args:
|
|
log_data: 日志数据
|
|
|
|
Returns:
|
|
bool: 是否成功
|
|
"""
|
|
try:
|
|
def _log_trigger():
|
|
session = self.get_session()
|
|
try:
|
|
log = TriggerLog(
|
|
repository_name=log_data["repository_name"],
|
|
branch_name=log_data["branch_name"],
|
|
commit_sha=log_data["commit_sha"],
|
|
job_name=log_data["job_name"],
|
|
status=log_data["status"],
|
|
error_message=log_data.get("error_message")
|
|
)
|
|
|
|
session.add(log)
|
|
session.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error("Failed to log trigger", error=str(e))
|
|
return False
|
|
finally:
|
|
session.close()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _log_trigger)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to log trigger", error=str(e))
|
|
return False
|
|
|
|
async def get_trigger_logs(self, repository_name: str = None,
|
|
branch_name: str = None, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
获取触发日志
|
|
|
|
Args:
|
|
repository_name: 仓库名(可选)
|
|
branch_name: 分支名(可选)
|
|
limit: 限制数量
|
|
|
|
Returns:
|
|
List: 日志列表
|
|
"""
|
|
try:
|
|
def _get_logs():
|
|
session = self.get_session()
|
|
try:
|
|
query = session.query(TriggerLog)
|
|
|
|
if repository_name:
|
|
query = query.filter(TriggerLog.repository_name == repository_name)
|
|
|
|
if branch_name:
|
|
query = query.filter(TriggerLog.branch_name == branch_name)
|
|
|
|
logs = query.order_by(TriggerLog.created_at.desc()).limit(limit).all()
|
|
|
|
return [
|
|
{
|
|
"id": log.id,
|
|
"repository_name": log.repository_name,
|
|
"branch_name": log.branch_name,
|
|
"commit_sha": log.commit_sha,
|
|
"job_name": log.job_name,
|
|
"status": log.status,
|
|
"error_message": log.error_message,
|
|
"created_at": log.created_at.isoformat()
|
|
}
|
|
for log in logs
|
|
]
|
|
|
|
finally:
|
|
session.close()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _get_logs)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get trigger logs", error=str(e))
|
|
return []
|
|
|
|
async def create_project_mapping(self, mapping_data: Dict[str, Any]) -> bool:
|
|
"""
|
|
创建项目映射
|
|
|
|
Args:
|
|
mapping_data: 映射数据
|
|
|
|
Returns:
|
|
bool: 是否成功
|
|
"""
|
|
try:
|
|
def _create_mapping():
|
|
session = self.get_session()
|
|
try:
|
|
# 创建项目映射
|
|
project = ProjectMapping(
|
|
repository_name=mapping_data["repository_name"],
|
|
default_job=mapping_data.get("default_job")
|
|
)
|
|
|
|
session.add(project)
|
|
session.flush() # 获取 ID
|
|
|
|
# 添加分支任务映射
|
|
for branch_job in mapping_data.get("branch_jobs", []):
|
|
job = BranchJob(
|
|
project_id=project.id,
|
|
branch_name=branch_job["branch_name"],
|
|
job_name=branch_job["job_name"]
|
|
)
|
|
session.add(job)
|
|
|
|
# 添加分支模式映射
|
|
for pattern in mapping_data.get("branch_patterns", []):
|
|
pattern_obj = BranchPattern(
|
|
project_id=project.id,
|
|
pattern=pattern["pattern"],
|
|
job_name=pattern["job_name"]
|
|
)
|
|
session.add(pattern_obj)
|
|
|
|
session.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error("Failed to create project mapping", error=str(e))
|
|
return False
|
|
finally:
|
|
session.close()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _create_mapping)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create project mapping", error=str(e))
|
|
return False
|
|
|
|
|
|
# 全局数据库服务实例
|
|
_database_service: Optional[DatabaseService] = None
|
|
|
|
|
|
def get_database_service() -> DatabaseService:
|
|
"""获取数据库服务实例"""
|
|
global _database_service
|
|
if _database_service is None:
|
|
_database_service = DatabaseService()
|
|
return _database_service |