Browse Source

添加获取数据库数据接口

yes-yes-yes-k 6 days ago
parent
commit
5b8b8b66d1
3 changed files with 193 additions and 35 deletions
  1. 92 14
      api/app/__init__.py
  2. 94 6
      api/app/utils.py
  3. 7 15
      api/run.py

+ 92 - 14
api/app/__init__.py

@@ -1,33 +1,111 @@
+from flask import Flask, jsonify, request
 import os
 
 from flask import Flask
 from flask_cors import CORS
-from . import config
 from flask_sqlalchemy import SQLAlchemy
 from flask_migrate import Migrate
 import logging
+import pandas as pd
+from datetime import datetime
+import json
+from .utils import get_current_data, get_dataset_by_id, get_table_data ,get_table_metal_averages
 
-# 创建 SQLAlchemy 全局实例
+# 创建SQLAlchemy实例,确保其他模块可以导入
 db = SQLAlchemy()
 
-# 创建并配置 Flask 应用
+# 导入配置和工具函数(在db定义之后)
+from . import config
+from .utils import get_current_data, get_dataset_by_id
+
 def create_app():
     app = Flask(__name__)
-    CORS(app)
-    # 进行初始配置,加载配置文件等
+    
+    # 配置应用
+    app.config['SECRET_KEY'] = 'abcdef1234567890'
     app.config.from_object(config.Config)
     app.logger.setLevel(logging.DEBUG)
-    # 初始化 SQLAlchemy
+    
+    # 初始化CORS
+    CORS(app)
+    
+    # 初始化数据库
     db.init_app(app)
-
-    # 初始化 Flask-Migrate
+    
+    # 初始化迁移工具
     migrate = Migrate(app, db)
-
-    # 导入路由
-    from . import routes
-    from . import frontend
+    
+    # 注册蓝图
+    from . import routes, frontend
     app.register_blueprint(routes.bp)
     app.register_blueprint(frontend.bp)
-    app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'dev-secret-key-here')
-    
+
+    register_api_routes(app)
     return app
+
+def safe_data_conversion(df):
+    """安全转换DataFrame为可序列化的字典列表,处理特殊数据类型"""
+    # 处理日期时间类型
+    for col in df.columns:
+        if pd.api.types.is_datetime64_any_dtype(df[col]):
+            df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
+    
+    # 处理空值
+    df = df.fillna('')
+    
+    # 转换为字典列表
+    data_list = df.to_dict(orient='records')
+    
+    # 处理可能的嵌套结构
+    def serialize_item(item):
+        if isinstance(item, dict):
+            return {k: serialize_item(v) for k, v in item.items()}
+        elif isinstance(item, list):
+            return [serialize_item(v) for v in item]
+        elif isinstance(item, datetime):
+            return item.strftime('%Y-%m-%d %H:%M:%S')
+        else:
+            return item
+    
+    return [serialize_item(item) for item in data_list]
+
+def register_api_routes(app):
+    # 新增:通用表查询接口
+    @app.route('/api/table-data', methods=['GET'])
+    def api_table_data():
+        # 1. 获取前端传的表名
+        table_name = request.args.get('table_name')
+        if not table_name:
+            return jsonify({"error": "请传入 table_name 参数(如 ?table_name=dataset_35)"}), 400
+
+        try:
+            # 2. 调用工具函数查询数据
+            df = get_table_data(db.session, table_name)
+            # 3. 安全转换数据(复用之前的 safe_data_conversion 函数)
+            data = safe_data_conversion(df)
+            return jsonify({"success": True, "data": data})
+        except ValueError as e:  # 非法表名
+            return jsonify({"success": False, "error": str(e)}), 400
+        except Exception as e:  # 其他错误
+            app.logger.error(f"查询表 {table_name} 失败: {str(e)}", exc_info=True)
+            return jsonify({"success": False, "error": str(e)}), 500
+        
+        # 新增接口:查询金属平均值(/api/table-averages)
+    @app.route('/api/table-averages', methods=['GET'])
+    def api_table_averages():
+        # 1. 获取参数
+        table_name = request.args.get('table_name')
+        if not table_name:
+            return jsonify({"error": "缺少参数:table_name"}), 400
+
+        try:
+            # 2. 调用工具函数计算平均值(复用 db.session)
+            averages = get_table_metal_averages(db.session, table_name)
+            # 3. 返回结果
+            return jsonify({"success": True, "averages": averages})
+        except ValueError as e:  # 表名校验失败(预期错误)
+            return jsonify({"success": False, "error": str(e)}), 400
+        except Exception as e:  # 其他运行时错误(如数据库连接失败)
+            app.logger.error(f"计算 {table_name} 平均值失败: {str(e)}", exc_info=True)
+            return jsonify({"success": False, "error": str(e)}), 500
+   

+ 94 - 6
api/app/utils.py

@@ -7,9 +7,84 @@ import pandas as pd
 from .database_models import CurrentReduce, CurrentReflux
 from sqlalchemy.schema import MetaData, Table
 import geopandas as gpd
+import re
+import pandas as pd
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy import text
 
 Base = declarative_base()
 
+def get_table_data(session, table_name):
+    """
+    通用函数:根据表名查询数据(支持动态表)
+    :param table_name: 数据库表名(如 current_reduce、dataset_35)
+    :return: DataFrame 数据
+    """
+    # ===================== 安全校验:防止SQL注入 =====================
+    # 允许的表名规则(根据你的数据库表结构定制):
+    # - current_reduce / current_reflux
+    # - dataset_数字(如 dataset_35、dataset_66)
+    # - 其他固定表(Datasets、ModelParameters 等)
+    allowed_pattern = re.compile(
+        r'^(current_(reduce|reflux)|dataset_\d+|Datasets|ModelParameters|Models|software_intro|sqlite_sequence|users)$'
+    )
+    if not allowed_pattern.fullmatch(table_name):
+        raise ValueError(f"非法表名:{table_name},仅允许以下格式:\n"
+                         "- current_reduce / current_reflux\n"
+                         "- dataset_数字(如 dataset_35)\n"
+                         "- 系统表(Datasets、ModelParameters 等)")
+
+    # ===================== 执行数据库查询 =====================
+    try:
+        # 直接执行SQL(带字段名映射,确保转字典成功)
+        sql = text(f"SELECT * FROM {table_name};")
+        result = session.execute(sql).mappings().all()  # 返回带字段名的行对象
+        # 转换为DataFrame
+        dataframe = pd.DataFrame([dict(row) for row in result])
+        return dataframe
+    except SQLAlchemyError as e:
+        raise Exception(f"数据库查询失败(表 {table_name}):{str(e)}")
+    except Exception as e:
+        raise Exception(f"数据转换失败(表 {table_name}):{str(e)}")
+
+
+def get_table_metal_averages(session, table_name):
+    """
+    计算表中所有数值型列(金属指标)的平均值
+    :param table_name: 数据库表名
+    :return: 字典,键为金属指标名,值为平均值
+    """
+    # 复用现有安全校验(防止SQL注入)
+    allowed_pattern = re.compile(
+        r'^(current_(reduce|reflux)|dataset_\d+|Datasets|ModelParameters|Models|software_intro|sqlite_sequence|users)$'
+    )
+    if not allowed_pattern.fullmatch(table_name):
+        raise ValueError(f"非法表名:{table_name},仅允许特定格式表名")
+
+    try:
+        # 1. 获取原始数据(复用现有查询逻辑)
+        df = get_table_data(session, table_name)
+        
+        # 2. 筛选数值型列(排除ID、字符串等非金属指标列)
+        numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
+        # 排除主键ID列(如果存在)
+        numeric_cols = [col for col in numeric_cols if col.lower() != 'id']
+        
+        if not numeric_cols:
+            raise ValueError(f"表 {table_name} 中未找到数值型金属指标列")
+        
+        # 3. 计算平均值(保留2位小数)
+        averages = df[numeric_cols].mean().to_dict()
+        return averages
+        
+    except SQLAlchemyError as e:
+        raise Exception(f"数据库查询失败(表 {table_name}):{str(e)}")
+    except Exception as e:
+        raise Exception(f"平均值计算失败(表 {table_name}):{str(e)}")
+
+
+
+
 def create_dynamic_table(dataset_id, columns):
     """动态创建数据表"""
     # 动态构建列
@@ -158,12 +233,25 @@ def get_current_data(session, data_type):
     else:
         raise ValueError("Invalid data type provided. Choose 'reduce' or 'reflux'.")
 
-    # 从数据库中查询所有记录
-    result = session.execute(select(model))
-
-    # 将结果转换为DataFrame
-    dataframe = pd.DataFrame([dict(row) for row in result])
-    return dataframe
+    try:
+        # 从数据库中查询所有记录并获取标量结果(模型实例)
+        result = session.execute(select(model)).scalars().all()
+        
+        # 转换模型实例为字典列表(过滤SQLAlchemy内部属性)
+        data_list = []
+        for item in result:
+            # 使用模型实例的__dict__属性,但排除内部属性
+            item_dict = {key: value for key, value in item.__dict__.items() 
+                         if not key.startswith('_sa_')}
+            data_list.append(item_dict)
+        
+        # 将结果转换为DataFrame
+        dataframe = pd.DataFrame(data_list)
+        return dataframe
+    
+    except Exception as e:
+        # 增加详细错误信息,方便调试
+        raise Exception(f"获取{data_type}数据失败: {str(e)}") from e
 
 def get_dataset_by_id(session, dataset_id):
     # 动态获取表的元数据

+ 7 - 15
api/run.py

@@ -1,23 +1,15 @@
-from flask import request
+# run.py(项目根目录)
+from flask import request, jsonify  
+from app import create_app  # 从app包导入工厂函数  
+from app.utils import get_table_metal_averages  # 从app包的utils模块导入函数  
 from flask_cors import CORS  # 导入CORS
-
-from app import create_app
 import os
 
-# 创建 Flask 应用
-app = create_app()
-
-# 配置CORS
-CORS(app, resources={r"/*": {"origins": "*"}})  # 允许所有域名,生产环境应限制为前端域名
-
-# 使用 HTTPS
-context = ('ssl/cert.crt', 'ssl/cert.key')
-
-
-
+# 创建Flask应用  
+app = create_app()  
 
 
 # 启动服务器
 if __name__ == '__main__':
     app.run(host="0.0.0.0", port=5000, debug=True)  # 注意:这里添加了ssl_context参数来启用HTTPS
-    # app.run(debug=True)
+    # app.run(debug=True)