admin.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # admin.py
  2. from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query
  3. from fastapi.responses import StreamingResponse
  4. from sqlalchemy.orm import Session
  5. from sqlalchemy import text, inspect
  6. from app.database import get_db, engine
  7. import bcrypt
  8. import pandas as pd
  9. import io
  10. import math # 添加 math 模块导入
  11. from pydantic import BaseModel
  12. from typing import Optional
  13. from functools import wraps
  14. import logging
  15. import sys
  16. import os
  17. from datetime import datetime
  18. router = APIRouter()
  19. # ---------------- 日志设置 ----------------
  20. def setup_logger():
  21. # 创建日志目录
  22. log_dir = "logs"
  23. if not os.path.exists(log_dir):
  24. os.makedirs(log_dir)
  25. logger = logging.getLogger("fastapi_app")
  26. logger.setLevel(logging.DEBUG)
  27. # 防止重复添加 handler
  28. if not logger.handlers:
  29. # 格式器
  30. formatter = logging.Formatter(
  31. fmt="%(asctime)s | %(levelname)-7s | %(funcName)-15s | %(message)s",
  32. datefmt="%Y-%m-%d %H:%M:%S"
  33. )
  34. # 控制台输出(INFO 及以上)
  35. console_handler = logging.StreamHandler(sys.stdout)
  36. console_handler.setLevel(logging.INFO)
  37. console_handler.setFormatter(formatter)
  38. # 文件输出(DEBUG 及以上),按天分文件
  39. log_file = os.path.join(log_dir, f"backend_{datetime.now().strftime('%Y%m%d')}.log")
  40. file_handler = logging.FileHandler(log_file, encoding='utf-8')
  41. file_handler.setLevel(logging.DEBUG)
  42. file_handler.setFormatter(formatter)
  43. # 添加处理器
  44. logger.addHandler(console_handler)
  45. logger.addHandler(file_handler)
  46. return logger
  47. # 全局日志器实例
  48. logger = setup_logger()
  49. # ---------------- 数据模型 ----------------
  50. class RegisterRequest(BaseModel):
  51. name: str
  52. password: str
  53. userType: str = "user"
  54. class LoginRequest(BaseModel):
  55. name: str
  56. password: str
  57. class UpdateUserRequest(BaseModel):
  58. name: Optional[str] = None
  59. password: Optional[str] = None
  60. userType: Optional[str] = None
  61. # ---------------- 初始化用户表 ----------------
  62. def init_user_db():
  63. with engine.begin() as conn:
  64. conn.execute(text("""
  65. CREATE TABLE IF NOT EXISTS users (
  66. id SERIAL PRIMARY KEY,
  67. name VARCHAR(50) UNIQUE NOT NULL,
  68. password VARCHAR(255) NOT NULL,
  69. usertype VARCHAR(10) NOT NULL,
  70. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  71. )
  72. """))
  73. admin_user = conn.execute(text("SELECT * FROM users WHERE usertype='admin'")).fetchone()
  74. if not admin_user:
  75. hashed_pwd = bcrypt.hashpw("admin".encode(), bcrypt.gensalt()).decode()
  76. conn.execute(
  77. text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,'admin')"),
  78. {"name": "admin", "password": hashed_pwd}
  79. )
  80. init_user_db()
  81. # ---------------- 装饰器 ----------------
  82. def db_operation(func):
  83. """
  84. 通用数据库操作装饰器
  85. - 自动记录请求参数
  86. - 记录成功/失败日志
  87. - 捕获异常并返回 HTTP 400
  88. """
  89. @wraps(func)
  90. def wrapper(*args, **kwargs):
  91. func_name = func.__name__
  92. logger.info(f"▶️ 调用接口: {func_name}")
  93. logger.debug(f" 参数: args={args}, kwargs={kwargs}")
  94. try:
  95. result = func(*args, **kwargs)
  96. logger.info(f"✅ 接口成功: {func_name} → {result}")
  97. return result
  98. except Exception as e:
  99. logger.error(f"❌ 接口异常: {func_name}, 错误: {str(e)}", exc_info=True)
  100. raise HTTPException(status_code=400, detail=str(e))
  101. return wrapper
  102. # ---------------- 用户管理接口 ----------------
  103. @router.post("/register")
  104. def register(req: RegisterRequest, db: Session = Depends(get_db)):
  105. existing = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
  106. if existing:
  107. raise HTTPException(status_code=400, detail="用户名已存在")
  108. if req.userType not in ("user", "admin"):
  109. raise HTTPException(status_code=400, detail="用户类型不合法")
  110. hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
  111. db.execute(text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,:usertype)"),
  112. {"name": req.name, "password": hashed_pwd, "usertype": req.userType})
  113. db.commit()
  114. return {"message": "注册成功"}
  115. @router.post("/login")
  116. def login(req: LoginRequest, db: Session = Depends(get_db)):
  117. user = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
  118. if not user:
  119. raise HTTPException(status_code=400, detail="用户不存在")
  120. if not bcrypt.checkpw(req.password.encode(), user.password.encode()):
  121. raise HTTPException(status_code=400, detail="密码错误")
  122. return {"message": "登录成功", "user": {"id": user.id, "name": user.name, "userType": user.usertype}}
  123. @router.get("/list_users")
  124. def list_users(db: Session = Depends(get_db)):
  125. users = db.execute(text("SELECT id,name,usertype,created_at FROM users ORDER BY id")).fetchall()
  126. return {"users": [{"id": u.id, "name": u.name, "userType": u.usertype, "created_at": u.created_at} for u in users]}
  127. # 编辑用户信息
  128. @router.put("/update_user/{user_id}")
  129. def update_user(user_id: int, req: UpdateUserRequest, db: Session = Depends(get_db)):
  130. user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
  131. if not user:
  132. raise HTTPException(status_code=404, detail="用户不存在")
  133. update_data = {}
  134. if req.name is not None:
  135. update_data["name"] = req.name
  136. if req.password is not None:
  137. hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
  138. update_data["password"] = hashed_pwd
  139. if req.userType is not None:
  140. update_data["usertype"] = req.userType
  141. if update_data:
  142. set_str = ", ".join([f"{k}=:{k}" for k in update_data.keys()])
  143. update_data["id"] = user_id
  144. db.execute(text(f"UPDATE users SET {set_str} WHERE id=:id"), update_data)
  145. db.commit()
  146. return {"message": "更新成功"}
  147. # 删除用户
  148. @router.delete("/delete_user/{user_id}")
  149. def delete_user(user_id: int, db: Session = Depends(get_db)):
  150. user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
  151. if not user:
  152. raise HTTPException(status_code=404, detail="用户不存在")
  153. db.execute(text("DELETE FROM users WHERE id=:id"), {"id": user_id})
  154. db.commit()
  155. return {"message": "删除成功"}
  156. # ---------------- CRUD 接口 ----------------
  157. # 获取表所有数据
  158. @router.get("/table")
  159. @db_operation
  160. def get_table(
  161. table: str = Query(...),
  162. db: Session = Depends(get_db)
  163. ):
  164. """
  165. 查询指定表的所有数据
  166. 参数:
  167. table (str): 要查询的表名(必填)
  168. 返回:
  169. 列表形式的查询结果,每行是一个字典
  170. 安全性:
  171. 使用 SQLAlchemy 的 identifier_preparer 对表名进行转义,防止 SQL 注入
  172. """
  173. logger.debug(f"[get_table] 查询表: {table}")
  174. preparer = inspect(db.bind).dialect.identifier_preparer
  175. query = text(f"SELECT * FROM {preparer.quote(table)}")
  176. rows = db.execute(query).fetchall()
  177. # 修复:使用正确的方法将行转换为字典
  178. result = []
  179. for row in rows:
  180. # 方法1: 使用 row._asdict() (如果可用)
  181. if hasattr(row, '_asdict'):
  182. row_dict = row._asdict()
  183. # 方法2: 使用 row._mapping (SQLAlchemy 1.4+)
  184. elif hasattr(row, '_mapping'):
  185. row_dict = dict(row._mapping)
  186. # 方法3: 手动构建字典
  187. else:
  188. row_dict = {column: getattr(row, column) for column in row.keys()}
  189. # 修复:将 NaN 值转换为 None
  190. for key, value in row_dict.items():
  191. if isinstance(value, float) and math.isnan(value):
  192. row_dict[key] = None
  193. result.append(row_dict)
  194. logger.info(f"[get_table] 查询完成: {table} → {len(result)} 条记录")
  195. return result
  196. # 新增数据
  197. @router.post("/add_item")
  198. @db_operation
  199. def add_item(
  200. table: str = Query(...),
  201. data: dict = None,
  202. db: Session = Depends(get_db)
  203. ):
  204. """
  205. 向指定表插入一条新记录
  206. 参数:
  207. table (str): 目标表名(必填)
  208. data (dict): 要插入的数据(键为字段名,值为对应值)
  209. 逻辑:
  210. 1. 检查是否提供了数据
  211. 2. 检查是否存在重复数据(基于所有字段值完全匹配)
  212. 3. 执行插入操作
  213. 返回:
  214. 插入成功的数据内容
  215. """
  216. logger.info(f"[add_item] 接收到请求 - 表: {table}, 数据: {data}")
  217. if not data:
  218. logger.warning("[add_item] 缺少数据")
  219. raise HTTPException(status_code=400, detail="缺少数据")
  220. cols = list(data.keys())
  221. where_clause = " AND ".join([f"{c} = :{c}" for c in cols])
  222. logger.debug(f"[add_item] 检查重复数据: 表={table}, 条件={where_clause}, 参数={data}")
  223. try:
  224. if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
  225. logger.warning(f"[add_item] 发现重复数据: {data}")
  226. raise HTTPException(status_code=400, detail="重复数据")
  227. query = text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})")
  228. db.execute(query, data)
  229. db.commit()
  230. logger.info(f"[add_item] 成功插入: {data}")
  231. return {"inserted": data}
  232. except Exception as e:
  233. logger.error(f"[add_item] 插入失败 - 表: {table}, 数据: {data}, 错误: {str(e)}", exc_info=True)
  234. raise
  235. # 更新数据
  236. @router.put("/update_item")
  237. @db_operation
  238. def update_item(
  239. table: str = Query(...),
  240. id: int = Query(...),
  241. data: dict = None,
  242. db: Session = Depends(get_db)
  243. ):
  244. """
  245. 更新指定 ID 的记录
  246. 参数:
  247. table (str): 表名(必填)
  248. id (int): 要更新的记录 ID(必填)
  249. data (dict): 要更新的字段和值
  250. 逻辑:
  251. 1. 检查是否提供了更新数据
  252. 2. 检查更新后的数据是否与其他记录冲突(避免重复)
  253. 3. 执行更新
  254. 返回:
  255. 更新成功的记录 ID
  256. """
  257. if not data:
  258. raise HTTPException(status_code=400, detail="缺少更新数据")
  259. columns = [f"{k} = :{k}" for k in data.keys()]
  260. params = {**data, "id": id}
  261. # 检查更新后是否与其他记录重复(排除自身)
  262. where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
  263. if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause} AND id != :id"), params).fetchone():
  264. raise HTTPException(status_code=400, detail="重复数据")
  265. result = db.execute(text(f"UPDATE {table} SET {', '.join(columns)} WHERE id=:id"), params)
  266. if result.rowcount == 0:
  267. raise HTTPException(status_code=404, detail="未找到目标记录")
  268. db.commit()
  269. return {"updated_id": id}
  270. # 删除数据
  271. @router.delete("/delete_item")
  272. @db_operation
  273. def delete_item(
  274. table: str = Query(...),
  275. id: int = Query(...),
  276. db: Session = Depends(get_db)
  277. ):
  278. """
  279. 删除指定 ID 的记录
  280. 参数:
  281. table (str): 表名(必填)
  282. id (int): 要删除的记录 ID(必填)
  283. 返回:
  284. 被删除的记录 ID
  285. 注意:
  286. 如果记录不存在,返回 404 错误
  287. """
  288. result = db.execute(text(f"DELETE FROM {table} WHERE id=:id"), {"id": id})
  289. if result.rowcount == 0:
  290. raise HTTPException(status_code=404, detail="未找到目标记录")
  291. db.commit()
  292. return {"deleted_id": id}
  293. # ---------------- 文件导入导出 ----------------
  294. # 导入数据
  295. @router.post("/import_data")
  296. @db_operation
  297. def import_data(
  298. table: str = Query(...),
  299. file: UploadFile = File(...),
  300. db: Session = Depends(get_db)
  301. ):
  302. """
  303. 从上传的 Excel 或 CSV 文件导入数据到指定表
  304. 支持格式: .xlsx, .csv
  305. 逻辑:
  306. 1. 读取文件内容
  307. 2. 逐行检查是否已存在(避免重复)
  308. 3. 插入新数据
  309. 返回:
  310. 统计信息:成功插入数量、跳过数量(因重复)
  311. """
  312. logger.info(f"[import_data] 开始导入文件 → 表: {table}, 文件: {file.filename}")
  313. content = file.file.read()
  314. try:
  315. if file.filename.endswith(".xlsx"):
  316. df = pd.read_excel(io.BytesIO(content))
  317. logger.debug(f"[import_data] 已读取 Excel,共 {len(df)} 行")
  318. elif file.filename.endswith(".csv"):
  319. df = pd.read_csv(io.BytesIO(content))
  320. logger.debug(f"[import_data] 已读取 CSV,共 {len(df)} 行")
  321. else:
  322. logger.warning(f"[import_data] 不支持的格式: {file.filename}")
  323. raise HTTPException(status_code=400, detail="仅支持 .xlsx / .csv 文件")
  324. inserted, skipped = 0, 0
  325. for idx, row in df.iterrows():
  326. data = row.to_dict()
  327. where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
  328. if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
  329. skipped += 1
  330. continue
  331. cols = list(data.keys())
  332. db.execute(
  333. text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})"),
  334. data
  335. )
  336. inserted += 1
  337. db.commit()
  338. logger.info(f"[import_data] 导入完成: 插入 {inserted} 条, 跳过 {skipped} 条")
  339. return {"inserted": inserted, "skipped": skipped}
  340. except Exception as e:
  341. logger.error(f"[import_data] 导入失败: {str(e)}", exc_info=True)
  342. raise
  343. # 导出数据
  344. @router.get("/export_data")
  345. @db_operation
  346. def export_data(
  347. table: str = Query(...),
  348. fmt: str = Query("xlsx"),
  349. db: Session = Depends(get_db)
  350. ):
  351. """
  352. 将指定表的数据导出为 Excel 或 CSV 文件
  353. 参数:
  354. table (str): 要导出的表名(必填)
  355. fmt (str): 导出格式,支持 'xlsx' 或 'csv'(默认 xlsx)
  356. 返回:
  357. StreamingResponse:包含文件流的响应,触发浏览器下载
  358. """
  359. df = pd.read_sql(text(f"SELECT * FROM {table}"), db.bind)
  360. buf = io.BytesIO()
  361. if fmt == "xlsx":
  362. df.to_excel(buf, index=False, engine="openpyxl")
  363. filename = f"{table}.xlsx"
  364. media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  365. elif fmt == "csv":
  366. df.to_csv(buf, index=False)
  367. filename = f"{table}.csv"
  368. media_type = "text/csv"
  369. else:
  370. raise HTTPException(status_code=400, detail="仅支持 xlsx/csv")
  371. buf.seek(0)
  372. return StreamingResponse(
  373. buf,
  374. headers={"Content-Disposition": f"attachment; filename={filename}"},
  375. media_type=media_type
  376. )