# admin.py from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from sqlalchemy import text, inspect from app.database import get_db, engine import bcrypt import pandas as pd import io import math # 添加 math 模块导入 from pydantic import BaseModel from typing import Optional from functools import wraps import logging import sys import os from datetime import datetime router = APIRouter() # ---------------- 日志设置 ---------------- def setup_logger(): # 创建日志目录 log_dir = "logs" if not os.path.exists(log_dir): os.makedirs(log_dir) logger = logging.getLogger("fastapi_app") logger.setLevel(logging.DEBUG) # 防止重复添加 handler if not logger.handlers: # 格式器 formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)-7s | %(funcName)-15s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) # 控制台输出(INFO 及以上) console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) # 文件输出(DEBUG 及以上),按天分文件 log_file = os.path.join(log_dir, f"backend_{datetime.now().strftime('%Y%m%d')}.log") file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) # 添加处理器 logger.addHandler(console_handler) logger.addHandler(file_handler) return logger # 全局日志器实例 logger = setup_logger() # ---------------- 数据模型 ---------------- class RegisterRequest(BaseModel): name: str password: str userType: str = "user" class LoginRequest(BaseModel): name: str password: str class UpdateUserRequest(BaseModel): name: Optional[str] = None password: Optional[str] = None userType: Optional[str] = None # ---------------- 初始化用户表 ---------------- def init_user_db(): with engine.begin() as conn: conn.execute(text(""" CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, name VARCHAR(50) UNIQUE NOT NULL, password VARCHAR(255) NOT NULL, usertype VARCHAR(10) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """)) admin_user = conn.execute(text("SELECT * FROM users WHERE usertype='admin'")).fetchone() if not admin_user: hashed_pwd = bcrypt.hashpw("admin".encode(), bcrypt.gensalt()).decode() conn.execute( text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,'admin')"), {"name": "admin", "password": hashed_pwd} ) init_user_db() # ---------------- 装饰器 ---------------- def db_operation(func): """ 通用数据库操作装饰器 - 自动记录请求参数 - 记录成功/失败日志 - 捕获异常并返回 HTTP 400 """ @wraps(func) def wrapper(*args, **kwargs): func_name = func.__name__ logger.info(f"▶️ 调用接口: {func_name}") logger.debug(f" 参数: args={args}, kwargs={kwargs}") try: result = func(*args, **kwargs) logger.info(f"✅ 接口成功: {func_name} → {result}") return result except Exception as e: logger.error(f"❌ 接口异常: {func_name}, 错误: {str(e)}", exc_info=True) raise HTTPException(status_code=400, detail=str(e)) return wrapper # ---------------- 用户管理接口 ---------------- @router.post("/register") def register(req: RegisterRequest, db: Session = Depends(get_db)): existing = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone() if existing: raise HTTPException(status_code=400, detail="用户名已存在") if req.userType not in ("user", "admin"): raise HTTPException(status_code=400, detail="用户类型不合法") hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode() db.execute(text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,:usertype)"), {"name": req.name, "password": hashed_pwd, "usertype": req.userType}) db.commit() return {"message": "注册成功"} @router.post("/login") def login(req: LoginRequest, db: Session = Depends(get_db)): user = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone() if not user: raise HTTPException(status_code=400, detail="用户不存在") if not bcrypt.checkpw(req.password.encode(), user.password.encode()): raise HTTPException(status_code=400, detail="密码错误") return {"message": "登录成功", "user": {"id": user.id, "name": user.name, "userType": user.usertype}} @router.get("/list_users") def list_users(db: Session = Depends(get_db)): users = db.execute(text("SELECT id,name,usertype,created_at FROM users ORDER BY id")).fetchall() return {"users": [{"id": u.id, "name": u.name, "userType": u.usertype, "created_at": u.created_at} for u in users]} # 编辑用户信息 @router.put("/update_user/{user_id}") def update_user(user_id: int, req: UpdateUserRequest, db: Session = Depends(get_db)): user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone() if not user: raise HTTPException(status_code=404, detail="用户不存在") update_data = {} if req.name is not None: update_data["name"] = req.name if req.password is not None: hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode() update_data["password"] = hashed_pwd if req.userType is not None: update_data["usertype"] = req.userType if update_data: set_str = ", ".join([f"{k}=:{k}" for k in update_data.keys()]) update_data["id"] = user_id db.execute(text(f"UPDATE users SET {set_str} WHERE id=:id"), update_data) db.commit() return {"message": "更新成功"} # 删除用户 @router.delete("/delete_user/{user_id}") def delete_user(user_id: int, db: Session = Depends(get_db)): user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone() if not user: raise HTTPException(status_code=404, detail="用户不存在") db.execute(text("DELETE FROM users WHERE id=:id"), {"id": user_id}) db.commit() return {"message": "删除成功"} # ---------------- CRUD 接口 ---------------- # 获取表所有数据 @router.get("/table") @db_operation def get_table( table: str = Query(...), db: Session = Depends(get_db) ): """ 查询指定表的所有数据 参数: table (str): 要查询的表名(必填) 返回: 列表形式的查询结果,每行是一个字典 安全性: 使用 SQLAlchemy 的 identifier_preparer 对表名进行转义,防止 SQL 注入 """ logger.debug(f"[get_table] 查询表: {table}") preparer = inspect(db.bind).dialect.identifier_preparer query = text(f"SELECT * FROM {preparer.quote(table)}") rows = db.execute(query).fetchall() # 修复:使用正确的方法将行转换为字典 result = [] for row in rows: # 方法1: 使用 row._asdict() (如果可用) if hasattr(row, '_asdict'): row_dict = row._asdict() # 方法2: 使用 row._mapping (SQLAlchemy 1.4+) elif hasattr(row, '_mapping'): row_dict = dict(row._mapping) # 方法3: 手动构建字典 else: row_dict = {column: getattr(row, column) for column in row.keys()} # 修复:将 NaN 值转换为 None for key, value in row_dict.items(): if isinstance(value, float) and math.isnan(value): row_dict[key] = None result.append(row_dict) logger.info(f"[get_table] 查询完成: {table} → {len(result)} 条记录") return result # 新增数据 @router.post("/add_item") @db_operation def add_item( table: str = Query(...), data: dict = None, db: Session = Depends(get_db) ): """ 向指定表插入一条新记录 参数: table (str): 目标表名(必填) data (dict): 要插入的数据(键为字段名,值为对应值) 逻辑: 1. 检查是否提供了数据 2. 检查是否存在重复数据(基于所有字段值完全匹配) 3. 执行插入操作 返回: 插入成功的数据内容 """ logger.info(f"[add_item] 接收到请求 - 表: {table}, 数据: {data}") if not data: logger.warning("[add_item] 缺少数据") raise HTTPException(status_code=400, detail="缺少数据") cols = list(data.keys()) where_clause = " AND ".join([f"{c} = :{c}" for c in cols]) logger.debug(f"[add_item] 检查重复数据: 表={table}, 条件={where_clause}, 参数={data}") try: if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone(): logger.warning(f"[add_item] 发现重复数据: {data}") raise HTTPException(status_code=400, detail="重复数据") query = text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})") db.execute(query, data) db.commit() logger.info(f"[add_item] 成功插入: {data}") return {"inserted": data} except Exception as e: logger.error(f"[add_item] 插入失败 - 表: {table}, 数据: {data}, 错误: {str(e)}", exc_info=True) raise # 更新数据 @router.put("/update_item") @db_operation def update_item( table: str = Query(...), id: int = Query(...), data: dict = None, db: Session = Depends(get_db) ): """ 更新指定 ID 的记录 参数: table (str): 表名(必填) id (int): 要更新的记录 ID(必填) data (dict): 要更新的字段和值 逻辑: 1. 检查是否提供了更新数据 2. 检查更新后的数据是否与其他记录冲突(避免重复) 3. 执行更新 返回: 更新成功的记录 ID """ if not data: raise HTTPException(status_code=400, detail="缺少更新数据") columns = [f"{k} = :{k}" for k in data.keys()] params = {**data, "id": id} # 检查更新后是否与其他记录重复(排除自身) where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()]) if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause} AND id != :id"), params).fetchone(): raise HTTPException(status_code=400, detail="重复数据") result = db.execute(text(f"UPDATE {table} SET {', '.join(columns)} WHERE id=:id"), params) if result.rowcount == 0: raise HTTPException(status_code=404, detail="未找到目标记录") db.commit() return {"updated_id": id} # 删除数据 @router.delete("/delete_item") @db_operation def delete_item( table: str = Query(...), id: int = Query(...), db: Session = Depends(get_db) ): """ 删除指定 ID 的记录 参数: table (str): 表名(必填) id (int): 要删除的记录 ID(必填) 返回: 被删除的记录 ID 注意: 如果记录不存在,返回 404 错误 """ result = db.execute(text(f"DELETE FROM {table} WHERE id=:id"), {"id": id}) if result.rowcount == 0: raise HTTPException(status_code=404, detail="未找到目标记录") db.commit() return {"deleted_id": id} # ---------------- 文件导入导出 ---------------- # 导入数据 @router.post("/import_data") @db_operation def import_data( table: str = Query(...), file: UploadFile = File(...), db: Session = Depends(get_db) ): """ 从上传的 Excel 或 CSV 文件导入数据到指定表 支持格式: .xlsx, .csv 逻辑: 1. 读取文件内容 2. 逐行检查是否已存在(避免重复) 3. 插入新数据 返回: 统计信息:成功插入数量、跳过数量(因重复) """ logger.info(f"[import_data] 开始导入文件 → 表: {table}, 文件: {file.filename}") content = file.file.read() try: if file.filename.endswith(".xlsx"): df = pd.read_excel(io.BytesIO(content)) logger.debug(f"[import_data] 已读取 Excel,共 {len(df)} 行") elif file.filename.endswith(".csv"): df = pd.read_csv(io.BytesIO(content)) logger.debug(f"[import_data] 已读取 CSV,共 {len(df)} 行") else: logger.warning(f"[import_data] 不支持的格式: {file.filename}") raise HTTPException(status_code=400, detail="仅支持 .xlsx / .csv 文件") inserted, skipped = 0, 0 for idx, row in df.iterrows(): data = row.to_dict() where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()]) if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone(): skipped += 1 continue cols = list(data.keys()) db.execute( text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})"), data ) inserted += 1 db.commit() logger.info(f"[import_data] 导入完成: 插入 {inserted} 条, 跳过 {skipped} 条") return {"inserted": inserted, "skipped": skipped} except Exception as e: logger.error(f"[import_data] 导入失败: {str(e)}", exc_info=True) raise # 导出数据 @router.get("/export_data") @db_operation def export_data( table: str = Query(...), fmt: str = Query("xlsx"), db: Session = Depends(get_db) ): """ 将指定表的数据导出为 Excel 或 CSV 文件 参数: table (str): 要导出的表名(必填) fmt (str): 导出格式,支持 'xlsx' 或 'csv'(默认 xlsx) 返回: StreamingResponse:包含文件流的响应,触发浏览器下载 """ df = pd.read_sql(text(f"SELECT * FROM {table}"), db.bind) buf = io.BytesIO() if fmt == "xlsx": df.to_excel(buf, index=False, engine="openpyxl") filename = f"{table}.xlsx" media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" elif fmt == "csv": df.to_csv(buf, index=False) filename = f"{table}.csv" media_type = "text/csv" else: raise HTTPException(status_code=400, detail="仅支持 xlsx/csv") buf.seek(0) return StreamingResponse( buf, headers={"Content-Disposition": f"attachment; filename={filename}"}, media_type=media_type )