|
@@ -0,0 +1,448 @@
|
|
|
+# admin.py
|
|
|
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from sqlalchemy import text, inspect
|
|
|
+from app.database import get_db, engine
|
|
|
+import bcrypt
|
|
|
+import pandas as pd
|
|
|
+import io
|
|
|
+import math # 添加 math 模块导入
|
|
|
+from pydantic import BaseModel
|
|
|
+from typing import Optional
|
|
|
+from functools import wraps
|
|
|
+import logging
|
|
|
+import sys
|
|
|
+import os
|
|
|
+from datetime import datetime
|
|
|
+
|
|
|
+router = APIRouter()
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 日志设置 ----------------
|
|
|
+def setup_logger():
|
|
|
+ # 创建日志目录
|
|
|
+ log_dir = "logs"
|
|
|
+ if not os.path.exists(log_dir):
|
|
|
+ os.makedirs(log_dir)
|
|
|
+
|
|
|
+ logger = logging.getLogger("fastapi_app")
|
|
|
+ logger.setLevel(logging.DEBUG)
|
|
|
+
|
|
|
+ # 防止重复添加 handler
|
|
|
+ if not logger.handlers:
|
|
|
+ # 格式器
|
|
|
+ formatter = logging.Formatter(
|
|
|
+ fmt="%(asctime)s | %(levelname)-7s | %(funcName)-15s | %(message)s",
|
|
|
+ datefmt="%Y-%m-%d %H:%M:%S"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 控制台输出(INFO 及以上)
|
|
|
+ console_handler = logging.StreamHandler(sys.stdout)
|
|
|
+ console_handler.setLevel(logging.INFO)
|
|
|
+ console_handler.setFormatter(formatter)
|
|
|
+
|
|
|
+ # 文件输出(DEBUG 及以上),按天分文件
|
|
|
+ log_file = os.path.join(log_dir, f"backend_{datetime.now().strftime('%Y%m%d')}.log")
|
|
|
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
|
|
+ file_handler.setLevel(logging.DEBUG)
|
|
|
+ file_handler.setFormatter(formatter)
|
|
|
+
|
|
|
+ # 添加处理器
|
|
|
+ logger.addHandler(console_handler)
|
|
|
+ logger.addHandler(file_handler)
|
|
|
+
|
|
|
+ return logger
|
|
|
+
|
|
|
+
|
|
|
+# 全局日志器实例
|
|
|
+logger = setup_logger()
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 数据模型 ----------------
|
|
|
+class RegisterRequest(BaseModel):
|
|
|
+ name: str
|
|
|
+ password: str
|
|
|
+ userType: str = "user"
|
|
|
+
|
|
|
+
|
|
|
+class LoginRequest(BaseModel):
|
|
|
+ name: str
|
|
|
+ password: str
|
|
|
+
|
|
|
+
|
|
|
+class UpdateUserRequest(BaseModel):
|
|
|
+ name: Optional[str] = None
|
|
|
+ password: Optional[str] = None
|
|
|
+ userType: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 初始化用户表 ----------------
|
|
|
+def init_user_db():
|
|
|
+ with engine.begin() as conn:
|
|
|
+ conn.execute(text("""
|
|
|
+ CREATE TABLE IF NOT EXISTS users (
|
|
|
+ id SERIAL PRIMARY KEY,
|
|
|
+ name VARCHAR(50) UNIQUE NOT NULL,
|
|
|
+ password VARCHAR(255) NOT NULL,
|
|
|
+ usertype VARCHAR(10) NOT NULL,
|
|
|
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
|
+ )
|
|
|
+ """))
|
|
|
+ admin_user = conn.execute(text("SELECT * FROM users WHERE usertype='admin'")).fetchone()
|
|
|
+ if not admin_user:
|
|
|
+ hashed_pwd = bcrypt.hashpw("admin".encode(), bcrypt.gensalt()).decode()
|
|
|
+ conn.execute(
|
|
|
+ text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,'admin')"),
|
|
|
+ {"name": "admin", "password": hashed_pwd}
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+init_user_db()
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 装饰器 ----------------
|
|
|
+def db_operation(func):
|
|
|
+ """
|
|
|
+ 通用数据库操作装饰器
|
|
|
+ - 自动记录请求参数
|
|
|
+ - 记录成功/失败日志
|
|
|
+ - 捕获异常并返回 HTTP 400
|
|
|
+ """
|
|
|
+
|
|
|
+ @wraps(func)
|
|
|
+ def wrapper(*args, **kwargs):
|
|
|
+ func_name = func.__name__
|
|
|
+ logger.info(f"▶️ 调用接口: {func_name}")
|
|
|
+ logger.debug(f" 参数: args={args}, kwargs={kwargs}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = func(*args, **kwargs)
|
|
|
+ logger.info(f"✅ 接口成功: {func_name} → {result}")
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 接口异常: {func_name}, 错误: {str(e)}", exc_info=True)
|
|
|
+ raise HTTPException(status_code=400, detail=str(e))
|
|
|
+
|
|
|
+ return wrapper
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 用户管理接口 ----------------
|
|
|
+@router.post("/register")
|
|
|
+def register(req: RegisterRequest, db: Session = Depends(get_db)):
|
|
|
+ existing = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
|
|
|
+ if existing:
|
|
|
+ raise HTTPException(status_code=400, detail="用户名已存在")
|
|
|
+ if req.userType not in ("user", "admin"):
|
|
|
+ raise HTTPException(status_code=400, detail="用户类型不合法")
|
|
|
+ hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
|
|
|
+ db.execute(text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,:usertype)"),
|
|
|
+ {"name": req.name, "password": hashed_pwd, "usertype": req.userType})
|
|
|
+ db.commit()
|
|
|
+ return {"message": "注册成功"}
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/login")
|
|
|
+def login(req: LoginRequest, db: Session = Depends(get_db)):
|
|
|
+ user = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
|
|
|
+ if not user:
|
|
|
+ raise HTTPException(status_code=400, detail="用户不存在")
|
|
|
+ if not bcrypt.checkpw(req.password.encode(), user.password.encode()):
|
|
|
+ raise HTTPException(status_code=400, detail="密码错误")
|
|
|
+ return {"message": "登录成功", "user": {"id": user.id, "name": user.name, "userType": user.usertype}}
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/list_users")
|
|
|
+def list_users(db: Session = Depends(get_db)):
|
|
|
+ users = db.execute(text("SELECT id,name,usertype,created_at FROM users ORDER BY id")).fetchall()
|
|
|
+ return {"users": [{"id": u.id, "name": u.name, "userType": u.usertype, "created_at": u.created_at} for u in users]}
|
|
|
+
|
|
|
+
|
|
|
+# 编辑用户信息
|
|
|
+@router.put("/update_user/{user_id}")
|
|
|
+def update_user(user_id: int, req: UpdateUserRequest, db: Session = Depends(get_db)):
|
|
|
+ user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
|
|
|
+ if not user:
|
|
|
+ raise HTTPException(status_code=404, detail="用户不存在")
|
|
|
+
|
|
|
+ update_data = {}
|
|
|
+ if req.name is not None:
|
|
|
+ update_data["name"] = req.name
|
|
|
+ if req.password is not None:
|
|
|
+ hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
|
|
|
+ update_data["password"] = hashed_pwd
|
|
|
+ if req.userType is not None:
|
|
|
+ update_data["usertype"] = req.userType
|
|
|
+
|
|
|
+ if update_data:
|
|
|
+ set_str = ", ".join([f"{k}=:{k}" for k in update_data.keys()])
|
|
|
+ update_data["id"] = user_id
|
|
|
+ db.execute(text(f"UPDATE users SET {set_str} WHERE id=:id"), update_data)
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ return {"message": "更新成功"}
|
|
|
+
|
|
|
+
|
|
|
+# 删除用户
|
|
|
+@router.delete("/delete_user/{user_id}")
|
|
|
+def delete_user(user_id: int, db: Session = Depends(get_db)):
|
|
|
+ user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
|
|
|
+ if not user:
|
|
|
+ raise HTTPException(status_code=404, detail="用户不存在")
|
|
|
+ db.execute(text("DELETE FROM users WHERE id=:id"), {"id": user_id})
|
|
|
+ db.commit()
|
|
|
+ return {"message": "删除成功"}
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- CRUD 接口 ----------------
|
|
|
+# 获取表所有数据
|
|
|
+@router.get("/table")
|
|
|
+@db_operation
|
|
|
+def get_table(
|
|
|
+ table: str = Query(...),
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 查询指定表的所有数据
|
|
|
+ 参数:
|
|
|
+ table (str): 要查询的表名(必填)
|
|
|
+ 返回:
|
|
|
+ 列表形式的查询结果,每行是一个字典
|
|
|
+ 安全性:
|
|
|
+ 使用 SQLAlchemy 的 identifier_preparer 对表名进行转义,防止 SQL 注入
|
|
|
+ """
|
|
|
+
|
|
|
+ logger.debug(f"[get_table] 查询表: {table}")
|
|
|
+ preparer = inspect(db.bind).dialect.identifier_preparer
|
|
|
+ query = text(f"SELECT * FROM {preparer.quote(table)}")
|
|
|
+ rows = db.execute(query).fetchall()
|
|
|
+
|
|
|
+ # 修复:使用正确的方法将行转换为字典
|
|
|
+ result = []
|
|
|
+ for row in rows:
|
|
|
+ # 方法1: 使用 row._asdict() (如果可用)
|
|
|
+ if hasattr(row, '_asdict'):
|
|
|
+ row_dict = row._asdict()
|
|
|
+ # 方法2: 使用 row._mapping (SQLAlchemy 1.4+)
|
|
|
+ elif hasattr(row, '_mapping'):
|
|
|
+ row_dict = dict(row._mapping)
|
|
|
+ # 方法3: 手动构建字典
|
|
|
+ else:
|
|
|
+ row_dict = {column: getattr(row, column) for column in row.keys()}
|
|
|
+
|
|
|
+ # 修复:将 NaN 值转换为 None
|
|
|
+ for key, value in row_dict.items():
|
|
|
+ if isinstance(value, float) and math.isnan(value):
|
|
|
+ row_dict[key] = None
|
|
|
+
|
|
|
+ result.append(row_dict)
|
|
|
+
|
|
|
+ logger.info(f"[get_table] 查询完成: {table} → {len(result)} 条记录")
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# 新增数据
|
|
|
+@router.post("/add_item")
|
|
|
+@db_operation
|
|
|
+def add_item(
|
|
|
+ table: str = Query(...),
|
|
|
+ data: dict = None,
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 向指定表插入一条新记录
|
|
|
+ 参数:
|
|
|
+ table (str): 目标表名(必填)
|
|
|
+ data (dict): 要插入的数据(键为字段名,值为对应值)
|
|
|
+ 逻辑:
|
|
|
+ 1. 检查是否提供了数据
|
|
|
+ 2. 检查是否存在重复数据(基于所有字段值完全匹配)
|
|
|
+ 3. 执行插入操作
|
|
|
+ 返回:
|
|
|
+ 插入成功的数据内容
|
|
|
+ """
|
|
|
+
|
|
|
+ logger.info(f"[add_item] 接收到请求 - 表: {table}, 数据: {data}")
|
|
|
+
|
|
|
+ if not data:
|
|
|
+ logger.warning("[add_item] 缺少数据")
|
|
|
+ raise HTTPException(status_code=400, detail="缺少数据")
|
|
|
+
|
|
|
+ cols = list(data.keys())
|
|
|
+ where_clause = " AND ".join([f"{c} = :{c}" for c in cols])
|
|
|
+ logger.debug(f"[add_item] 检查重复数据: 表={table}, 条件={where_clause}, 参数={data}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
|
|
|
+ logger.warning(f"[add_item] 发现重复数据: {data}")
|
|
|
+ raise HTTPException(status_code=400, detail="重复数据")
|
|
|
+
|
|
|
+ query = text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})")
|
|
|
+ db.execute(query, data)
|
|
|
+ db.commit()
|
|
|
+ logger.info(f"[add_item] 成功插入: {data}")
|
|
|
+ return {"inserted": data}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[add_item] 插入失败 - 表: {table}, 数据: {data}, 错误: {str(e)}", exc_info=True)
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+# 更新数据
|
|
|
+@router.put("/update_item")
|
|
|
+@db_operation
|
|
|
+def update_item(
|
|
|
+ table: str = Query(...),
|
|
|
+ id: int = Query(...),
|
|
|
+ data: dict = None,
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 更新指定 ID 的记录
|
|
|
+ 参数:
|
|
|
+ table (str): 表名(必填)
|
|
|
+ id (int): 要更新的记录 ID(必填)
|
|
|
+ data (dict): 要更新的字段和值
|
|
|
+ 逻辑:
|
|
|
+ 1. 检查是否提供了更新数据
|
|
|
+ 2. 检查更新后的数据是否与其他记录冲突(避免重复)
|
|
|
+ 3. 执行更新
|
|
|
+ 返回:
|
|
|
+ 更新成功的记录 ID
|
|
|
+ """
|
|
|
+ if not data:
|
|
|
+ raise HTTPException(status_code=400, detail="缺少更新数据")
|
|
|
+
|
|
|
+ columns = [f"{k} = :{k}" for k in data.keys()]
|
|
|
+ params = {**data, "id": id}
|
|
|
+
|
|
|
+ # 检查更新后是否与其他记录重复(排除自身)
|
|
|
+ where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
|
|
|
+ if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause} AND id != :id"), params).fetchone():
|
|
|
+ raise HTTPException(status_code=400, detail="重复数据")
|
|
|
+
|
|
|
+ result = db.execute(text(f"UPDATE {table} SET {', '.join(columns)} WHERE id=:id"), params)
|
|
|
+ if result.rowcount == 0:
|
|
|
+ raise HTTPException(status_code=404, detail="未找到目标记录")
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+ return {"updated_id": id}
|
|
|
+
|
|
|
+
|
|
|
+# 删除数据
|
|
|
+@router.delete("/delete_item")
|
|
|
+@db_operation
|
|
|
+def delete_item(
|
|
|
+ table: str = Query(...),
|
|
|
+ id: int = Query(...),
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 删除指定 ID 的记录
|
|
|
+ 参数:
|
|
|
+ table (str): 表名(必填)
|
|
|
+ id (int): 要删除的记录 ID(必填)
|
|
|
+ 返回:
|
|
|
+ 被删除的记录 ID
|
|
|
+ 注意:
|
|
|
+ 如果记录不存在,返回 404 错误
|
|
|
+ """
|
|
|
+ result = db.execute(text(f"DELETE FROM {table} WHERE id=:id"), {"id": id})
|
|
|
+ if result.rowcount == 0:
|
|
|
+ raise HTTPException(status_code=404, detail="未找到目标记录")
|
|
|
+ db.commit()
|
|
|
+ return {"deleted_id": id}
|
|
|
+
|
|
|
+
|
|
|
+# ---------------- 文件导入导出 ----------------
|
|
|
+# 导入数据
|
|
|
+@router.post("/import_data")
|
|
|
+@db_operation
|
|
|
+def import_data(
|
|
|
+ table: str = Query(...),
|
|
|
+ file: UploadFile = File(...),
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 从上传的 Excel 或 CSV 文件导入数据到指定表
|
|
|
+ 支持格式: .xlsx, .csv
|
|
|
+ 逻辑:
|
|
|
+ 1. 读取文件内容
|
|
|
+ 2. 逐行检查是否已存在(避免重复)
|
|
|
+ 3. 插入新数据
|
|
|
+ 返回:
|
|
|
+ 统计信息:成功插入数量、跳过数量(因重复)
|
|
|
+ """
|
|
|
+
|
|
|
+ logger.info(f"[import_data] 开始导入文件 → 表: {table}, 文件: {file.filename}")
|
|
|
+
|
|
|
+ content = file.file.read()
|
|
|
+ try:
|
|
|
+ if file.filename.endswith(".xlsx"):
|
|
|
+ df = pd.read_excel(io.BytesIO(content))
|
|
|
+ logger.debug(f"[import_data] 已读取 Excel,共 {len(df)} 行")
|
|
|
+ elif file.filename.endswith(".csv"):
|
|
|
+ df = pd.read_csv(io.BytesIO(content))
|
|
|
+ logger.debug(f"[import_data] 已读取 CSV,共 {len(df)} 行")
|
|
|
+ else:
|
|
|
+ logger.warning(f"[import_data] 不支持的格式: {file.filename}")
|
|
|
+ raise HTTPException(status_code=400, detail="仅支持 .xlsx / .csv 文件")
|
|
|
+
|
|
|
+ inserted, skipped = 0, 0
|
|
|
+ for idx, row in df.iterrows():
|
|
|
+ data = row.to_dict()
|
|
|
+ where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
|
|
|
+ if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
|
|
|
+ skipped += 1
|
|
|
+ continue
|
|
|
+ cols = list(data.keys())
|
|
|
+ db.execute(
|
|
|
+ text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})"),
|
|
|
+ data
|
|
|
+ )
|
|
|
+ inserted += 1
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+ logger.info(f"[import_data] 导入完成: 插入 {inserted} 条, 跳过 {skipped} 条")
|
|
|
+ return {"inserted": inserted, "skipped": skipped}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[import_data] 导入失败: {str(e)}", exc_info=True)
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+# 导出数据
|
|
|
+@router.get("/export_data")
|
|
|
+@db_operation
|
|
|
+def export_data(
|
|
|
+ table: str = Query(...),
|
|
|
+ fmt: str = Query("xlsx"),
|
|
|
+ db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 将指定表的数据导出为 Excel 或 CSV 文件
|
|
|
+ 参数:
|
|
|
+ table (str): 要导出的表名(必填)
|
|
|
+ fmt (str): 导出格式,支持 'xlsx' 或 'csv'(默认 xlsx)
|
|
|
+ 返回:
|
|
|
+ StreamingResponse:包含文件流的响应,触发浏览器下载
|
|
|
+ """
|
|
|
+ df = pd.read_sql(text(f"SELECT * FROM {table}"), db.bind)
|
|
|
+ buf = io.BytesIO()
|
|
|
+
|
|
|
+ if fmt == "xlsx":
|
|
|
+ df.to_excel(buf, index=False, engine="openpyxl")
|
|
|
+ filename = f"{table}.xlsx"
|
|
|
+ media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
|
|
+ elif fmt == "csv":
|
|
|
+ df.to_csv(buf, index=False)
|
|
|
+ filename = f"{table}.csv"
|
|
|
+ media_type = "text/csv"
|
|
|
+ else:
|
|
|
+ raise HTTPException(status_code=400, detail="仅支持 xlsx/csv")
|
|
|
+
|
|
|
+ buf.seek(0)
|
|
|
+ return StreamingResponse(
|
|
|
+ buf,
|
|
|
+ headers={"Content-Disposition": f"attachment; filename={filename}"},
|
|
|
+ media_type=media_type
|
|
|
+ )
|