yangtaodemon 2 днів тому
батько
коміт
e6bb1ef881
3 змінених файлів з 452 додано та 2 видалено
  1. 448 0
      app/api/admin.py
  2. 3 1
      app/main.py
  3. 1 1
      config.env

+ 448 - 0
app/api/admin.py

@@ -0,0 +1,448 @@
+# admin.py
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query
+from fastapi.responses import StreamingResponse
+from sqlalchemy.orm import Session
+from sqlalchemy import text, inspect
+from app.database import get_db, engine
+import bcrypt
+import pandas as pd
+import io
+import math  # 添加 math 模块导入
+from pydantic import BaseModel
+from typing import Optional
+from functools import wraps
+import logging
+import sys
+import os
+from datetime import datetime
+
+router = APIRouter()
+
+
+# ---------------- 日志设置 ----------------
+def setup_logger():
+    # 创建日志目录
+    log_dir = "logs"
+    if not os.path.exists(log_dir):
+        os.makedirs(log_dir)
+
+    logger = logging.getLogger("fastapi_app")
+    logger.setLevel(logging.DEBUG)
+
+    # 防止重复添加 handler
+    if not logger.handlers:
+        # 格式器
+        formatter = logging.Formatter(
+            fmt="%(asctime)s | %(levelname)-7s | %(funcName)-15s | %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S"
+        )
+
+        # 控制台输出(INFO 及以上)
+        console_handler = logging.StreamHandler(sys.stdout)
+        console_handler.setLevel(logging.INFO)
+        console_handler.setFormatter(formatter)
+
+        # 文件输出(DEBUG 及以上),按天分文件
+        log_file = os.path.join(log_dir, f"backend_{datetime.now().strftime('%Y%m%d')}.log")
+        file_handler = logging.FileHandler(log_file, encoding='utf-8')
+        file_handler.setLevel(logging.DEBUG)
+        file_handler.setFormatter(formatter)
+
+        # 添加处理器
+        logger.addHandler(console_handler)
+        logger.addHandler(file_handler)
+
+    return logger
+
+
+# 全局日志器实例
+logger = setup_logger()
+
+
+# ---------------- 数据模型 ----------------
+class RegisterRequest(BaseModel):
+    name: str
+    password: str
+    userType: str = "user"
+
+
+class LoginRequest(BaseModel):
+    name: str
+    password: str
+
+
+class UpdateUserRequest(BaseModel):
+    name: Optional[str] = None
+    password: Optional[str] = None
+    userType: Optional[str] = None
+
+
+# ---------------- 初始化用户表 ----------------
+def init_user_db():
+    with engine.begin() as conn:
+        conn.execute(text("""
+        CREATE TABLE IF NOT EXISTS users (
+            id SERIAL PRIMARY KEY,
+            name VARCHAR(50) UNIQUE NOT NULL,
+            password VARCHAR(255) NOT NULL,
+            usertype VARCHAR(10) NOT NULL,
+            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+        )
+        """))
+        admin_user = conn.execute(text("SELECT * FROM users WHERE usertype='admin'")).fetchone()
+        if not admin_user:
+            hashed_pwd = bcrypt.hashpw("admin".encode(), bcrypt.gensalt()).decode()
+            conn.execute(
+                text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,'admin')"),
+                {"name": "admin", "password": hashed_pwd}
+            )
+
+
+init_user_db()
+
+
+# ---------------- 装饰器 ----------------
+def db_operation(func):
+    """
+    通用数据库操作装饰器
+    - 自动记录请求参数
+    - 记录成功/失败日志
+    - 捕获异常并返回 HTTP 400
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        func_name = func.__name__
+        logger.info(f"▶️ 调用接口: {func_name}")
+        logger.debug(f"  参数: args={args}, kwargs={kwargs}")
+
+        try:
+            result = func(*args, **kwargs)
+            logger.info(f"✅ 接口成功: {func_name} → {result}")
+            return result
+        except Exception as e:
+            logger.error(f"❌ 接口异常: {func_name}, 错误: {str(e)}", exc_info=True)
+            raise HTTPException(status_code=400, detail=str(e))
+
+    return wrapper
+
+
+# ---------------- 用户管理接口 ----------------
+@router.post("/register")
+def register(req: RegisterRequest, db: Session = Depends(get_db)):
+    existing = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
+    if existing:
+        raise HTTPException(status_code=400, detail="用户名已存在")
+    if req.userType not in ("user", "admin"):
+        raise HTTPException(status_code=400, detail="用户类型不合法")
+    hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
+    db.execute(text("INSERT INTO users (name,password,usertype) VALUES (:name,:password,:usertype)"),
+               {"name": req.name, "password": hashed_pwd, "usertype": req.userType})
+    db.commit()
+    return {"message": "注册成功"}
+
+
+@router.post("/login")
+def login(req: LoginRequest, db: Session = Depends(get_db)):
+    user = db.execute(text("SELECT * FROM users WHERE name=:name"), {"name": req.name}).fetchone()
+    if not user:
+        raise HTTPException(status_code=400, detail="用户不存在")
+    if not bcrypt.checkpw(req.password.encode(), user.password.encode()):
+        raise HTTPException(status_code=400, detail="密码错误")
+    return {"message": "登录成功", "user": {"id": user.id, "name": user.name, "userType": user.usertype}}
+
+
+@router.get("/list_users")
+def list_users(db: Session = Depends(get_db)):
+    users = db.execute(text("SELECT id,name,usertype,created_at FROM users ORDER BY id")).fetchall()
+    return {"users": [{"id": u.id, "name": u.name, "userType": u.usertype, "created_at": u.created_at} for u in users]}
+
+
+# 编辑用户信息
+@router.put("/update_user/{user_id}")
+def update_user(user_id: int, req: UpdateUserRequest, db: Session = Depends(get_db)):
+    user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
+    if not user:
+        raise HTTPException(status_code=404, detail="用户不存在")
+
+    update_data = {}
+    if req.name is not None:
+        update_data["name"] = req.name
+    if req.password is not None:
+        hashed_pwd = bcrypt.hashpw(req.password.encode(), bcrypt.gensalt()).decode()
+        update_data["password"] = hashed_pwd
+    if req.userType is not None:
+        update_data["usertype"] = req.userType
+
+    if update_data:
+        set_str = ", ".join([f"{k}=:{k}" for k in update_data.keys()])
+        update_data["id"] = user_id
+        db.execute(text(f"UPDATE users SET {set_str} WHERE id=:id"), update_data)
+        db.commit()
+
+    return {"message": "更新成功"}
+
+
+# 删除用户
+@router.delete("/delete_user/{user_id}")
+def delete_user(user_id: int, db: Session = Depends(get_db)):
+    user = db.execute(text("SELECT * FROM users WHERE id=:id"), {"id": user_id}).fetchone()
+    if not user:
+        raise HTTPException(status_code=404, detail="用户不存在")
+    db.execute(text("DELETE FROM users WHERE id=:id"), {"id": user_id})
+    db.commit()
+    return {"message": "删除成功"}
+
+
+# ---------------- CRUD 接口 ----------------
+# 获取表所有数据
+@router.get("/table")
+@db_operation
+def get_table(
+        table: str = Query(...),
+        db: Session = Depends(get_db)
+):
+    """
+    查询指定表的所有数据
+    参数:
+        table (str): 要查询的表名(必填)
+    返回:
+        列表形式的查询结果,每行是一个字典
+    安全性:
+        使用 SQLAlchemy 的 identifier_preparer 对表名进行转义,防止 SQL 注入
+    """
+
+    logger.debug(f"[get_table] 查询表: {table}")
+    preparer = inspect(db.bind).dialect.identifier_preparer
+    query = text(f"SELECT * FROM {preparer.quote(table)}")
+    rows = db.execute(query).fetchall()
+
+    # 修复:使用正确的方法将行转换为字典
+    result = []
+    for row in rows:
+        # 方法1: 使用 row._asdict() (如果可用)
+        if hasattr(row, '_asdict'):
+            row_dict = row._asdict()
+        # 方法2: 使用 row._mapping (SQLAlchemy 1.4+)
+        elif hasattr(row, '_mapping'):
+            row_dict = dict(row._mapping)
+        # 方法3: 手动构建字典
+        else:
+            row_dict = {column: getattr(row, column) for column in row.keys()}
+
+        # 修复:将 NaN 值转换为 None
+        for key, value in row_dict.items():
+            if isinstance(value, float) and math.isnan(value):
+                row_dict[key] = None
+
+        result.append(row_dict)
+
+    logger.info(f"[get_table] 查询完成: {table} → {len(result)} 条记录")
+    return result
+
+
+# 新增数据
+@router.post("/add_item")
+@db_operation
+def add_item(
+        table: str = Query(...),
+        data: dict = None,
+        db: Session = Depends(get_db)
+):
+    """
+    向指定表插入一条新记录
+    参数:
+        table (str): 目标表名(必填)
+        data (dict): 要插入的数据(键为字段名,值为对应值)
+    逻辑:
+        1. 检查是否提供了数据
+        2. 检查是否存在重复数据(基于所有字段值完全匹配)
+        3. 执行插入操作
+    返回:
+        插入成功的数据内容
+    """
+
+    logger.info(f"[add_item] 接收到请求 - 表: {table}, 数据: {data}")
+
+    if not data:
+        logger.warning("[add_item] 缺少数据")
+        raise HTTPException(status_code=400, detail="缺少数据")
+
+    cols = list(data.keys())
+    where_clause = " AND ".join([f"{c} = :{c}" for c in cols])
+    logger.debug(f"[add_item] 检查重复数据: 表={table}, 条件={where_clause}, 参数={data}")
+
+    try:
+        if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
+            logger.warning(f"[add_item] 发现重复数据: {data}")
+            raise HTTPException(status_code=400, detail="重复数据")
+
+        query = text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})")
+        db.execute(query, data)
+        db.commit()
+        logger.info(f"[add_item] 成功插入: {data}")
+        return {"inserted": data}
+    except Exception as e:
+        logger.error(f"[add_item] 插入失败 - 表: {table}, 数据: {data}, 错误: {str(e)}", exc_info=True)
+        raise
+
+
+# 更新数据
+@router.put("/update_item")
+@db_operation
+def update_item(
+        table: str = Query(...),
+        id: int = Query(...),
+        data: dict = None,
+        db: Session = Depends(get_db)
+):
+    """
+    更新指定 ID 的记录
+    参数:
+        table (str): 表名(必填)
+        id (int): 要更新的记录 ID(必填)
+        data (dict): 要更新的字段和值
+    逻辑:
+        1. 检查是否提供了更新数据
+        2. 检查更新后的数据是否与其他记录冲突(避免重复)
+        3. 执行更新
+    返回:
+        更新成功的记录 ID
+    """
+    if not data:
+        raise HTTPException(status_code=400, detail="缺少更新数据")
+
+    columns = [f"{k} = :{k}" for k in data.keys()]
+    params = {**data, "id": id}
+
+    # 检查更新后是否与其他记录重复(排除自身)
+    where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
+    if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause} AND id != :id"), params).fetchone():
+        raise HTTPException(status_code=400, detail="重复数据")
+
+    result = db.execute(text(f"UPDATE {table} SET {', '.join(columns)} WHERE id=:id"), params)
+    if result.rowcount == 0:
+        raise HTTPException(status_code=404, detail="未找到目标记录")
+
+    db.commit()
+    return {"updated_id": id}
+
+
+# 删除数据
+@router.delete("/delete_item")
+@db_operation
+def delete_item(
+        table: str = Query(...),
+        id: int = Query(...),
+        db: Session = Depends(get_db)
+):
+    """
+    删除指定 ID 的记录
+    参数:
+        table (str): 表名(必填)
+        id (int): 要删除的记录 ID(必填)
+    返回:
+        被删除的记录 ID
+    注意:
+        如果记录不存在,返回 404 错误
+    """
+    result = db.execute(text(f"DELETE FROM {table} WHERE id=:id"), {"id": id})
+    if result.rowcount == 0:
+        raise HTTPException(status_code=404, detail="未找到目标记录")
+    db.commit()
+    return {"deleted_id": id}
+
+
+# ---------------- 文件导入导出 ----------------
+# 导入数据
+@router.post("/import_data")
+@db_operation
+def import_data(
+        table: str = Query(...),
+        file: UploadFile = File(...),
+        db: Session = Depends(get_db)
+):
+    """
+    从上传的 Excel 或 CSV 文件导入数据到指定表
+    支持格式: .xlsx, .csv
+    逻辑:
+        1. 读取文件内容
+        2. 逐行检查是否已存在(避免重复)
+        3. 插入新数据
+    返回:
+        统计信息:成功插入数量、跳过数量(因重复)
+    """
+
+    logger.info(f"[import_data] 开始导入文件 → 表: {table}, 文件: {file.filename}")
+
+    content = file.file.read()
+    try:
+        if file.filename.endswith(".xlsx"):
+            df = pd.read_excel(io.BytesIO(content))
+            logger.debug(f"[import_data] 已读取 Excel,共 {len(df)} 行")
+        elif file.filename.endswith(".csv"):
+            df = pd.read_csv(io.BytesIO(content))
+            logger.debug(f"[import_data] 已读取 CSV,共 {len(df)} 行")
+        else:
+            logger.warning(f"[import_data] 不支持的格式: {file.filename}")
+            raise HTTPException(status_code=400, detail="仅支持 .xlsx / .csv 文件")
+
+        inserted, skipped = 0, 0
+        for idx, row in df.iterrows():
+            data = row.to_dict()
+            where_clause = " AND ".join([f"{c} = :{c}" for c in data.keys()])
+            if db.execute(text(f"SELECT id FROM {table} WHERE {where_clause}"), data).fetchone():
+                skipped += 1
+                continue
+            cols = list(data.keys())
+            db.execute(
+                text(f"INSERT INTO {table} ({', '.join(cols)}) VALUES ({', '.join([':' + c for c in cols])})"),
+                data
+            )
+            inserted += 1
+
+        db.commit()
+        logger.info(f"[import_data] 导入完成: 插入 {inserted} 条, 跳过 {skipped} 条")
+        return {"inserted": inserted, "skipped": skipped}
+
+    except Exception as e:
+        logger.error(f"[import_data] 导入失败: {str(e)}", exc_info=True)
+        raise
+
+
+# 导出数据
+@router.get("/export_data")
+@db_operation
+def export_data(
+        table: str = Query(...),
+        fmt: str = Query("xlsx"),
+        db: Session = Depends(get_db)
+):
+    """
+    将指定表的数据导出为 Excel 或 CSV 文件
+    参数:
+        table (str): 要导出的表名(必填)
+        fmt (str): 导出格式,支持 'xlsx' 或 'csv'(默认 xlsx)
+    返回:
+        StreamingResponse:包含文件流的响应,触发浏览器下载
+    """
+    df = pd.read_sql(text(f"SELECT * FROM {table}"), db.bind)
+    buf = io.BytesIO()
+
+    if fmt == "xlsx":
+        df.to_excel(buf, index=False, engine="openpyxl")
+        filename = f"{table}.xlsx"
+        media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
+    elif fmt == "csv":
+        df.to_csv(buf, index=False)
+        filename = f"{table}.csv"
+        media_type = "text/csv"
+    else:
+        raise HTTPException(status_code=400, detail="仅支持 xlsx/csv")
+
+    buf.seek(0)
+    return StreamingResponse(
+        buf,
+        headers={"Content-Disposition": f"attachment; filename={filename}"},
+        media_type=media_type
+    )

+ 3 - 1
app/main.py

@@ -2,7 +2,8 @@ import time
 import traceback
 
 from fastapi import FastAPI
-from .api import vector, raster, cd_prediction, unit_grouping, water, agricultural_input, cd_flux_removal, cd_flux
+from .api import (vector, raster, cd_prediction, unit_grouping, water,
+                  agricultural_input, cd_flux_removal, cd_flux,admin)
 from .database import engine, Base
 from fastapi.middleware.cors import CORSMiddleware
 import logging
@@ -67,6 +68,7 @@ app.include_router(agricultural_input.router, prefix="/api/agricultural-input",
 app.include_router(cd_flux_removal.router, prefix="/api/cd-flux-removal", tags=["cd-flux-removal"])
 app.include_router(cd_flux.router, prefix="/api/cd-flux", tags=["cd-flux"])
 app.include_router(error.router, prefix="/api/errors", tags=["errors"])
+app.include_router(admin.router, prefix="/admin", tags=["admin"])
 
 @app.get("/")
 async def root():

+ 1 - 1
config.env

@@ -1,5 +1,5 @@
 DB_HOST=localhost
 DB_PORT=5432
-DB_NAME=soilgd
+DB_NAME=data_db
 DB_USER=postgres
 DB_PASSWORD=scau2025