utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. from pykrige import OrdinaryKriging
  2. from sqlalchemy.ext.declarative import declarative_base
  3. from sqlalchemy import Column, Integer, String, Float, DateTime, select, create_engine
  4. import uuid
  5. from datetime import datetime, timezone
  6. import pandas as pd
  7. from .database_models import CurrentReduce, CurrentReflux
  8. from sqlalchemy.schema import MetaData, Table
  9. import geopandas as gpd
  10. import re
  11. import pandas as pd
  12. from sqlalchemy.exc import SQLAlchemyError
  13. from sqlalchemy import text
  14. Base = declarative_base()
  15. def get_table_data(session, table_name):
  16. """
  17. 通用函数:根据表名查询数据(支持动态表)
  18. :param table_name: 数据库表名(如 current_reduce、dataset_35)
  19. :return: DataFrame 数据
  20. """
  21. # ===================== 安全校验:防止SQL注入 =====================
  22. # 允许的表名规则(根据你的数据库表结构定制):
  23. # - current_reduce / current_reflux
  24. # - dataset_数字(如 dataset_35、dataset_66)
  25. # - 其他固定表(Datasets、ModelParameters 等)
  26. allowed_pattern = re.compile(
  27. r'^(current_(reduce|reflux)|dataset_\d+|Datasets|ModelParameters|Models|software_intro|sqlite_sequence|users)$'
  28. )
  29. if not allowed_pattern.fullmatch(table_name):
  30. raise ValueError(f"非法表名:{table_name},仅允许以下格式:\n"
  31. "- current_reduce / current_reflux\n"
  32. "- dataset_数字(如 dataset_35)\n"
  33. "- 系统表(Datasets、ModelParameters 等)")
  34. # ===================== 执行数据库查询 =====================
  35. try:
  36. # 直接执行SQL(带字段名映射,确保转字典成功)
  37. sql = text(f"SELECT * FROM {table_name};")
  38. result = session.execute(sql).mappings().all() # 返回带字段名的行对象
  39. # 转换为DataFrame
  40. dataframe = pd.DataFrame([dict(row) for row in result])
  41. return dataframe
  42. except SQLAlchemyError as e:
  43. raise Exception(f"数据库查询失败(表 {table_name}):{str(e)}")
  44. except Exception as e:
  45. raise Exception(f"数据转换失败(表 {table_name}):{str(e)}")
  46. def get_table_metal_averages(session, table_name):
  47. """
  48. 计算表中所有数值型列(金属指标)的平均值
  49. :param table_name: 数据库表名
  50. :return: 字典,键为金属指标名,值为平均值
  51. """
  52. # 复用现有安全校验(防止SQL注入)
  53. allowed_pattern = re.compile(
  54. r'^(current_(reduce|reflux)|dataset_\d+|Datasets|ModelParameters|Models|software_intro|sqlite_sequence|users)$'
  55. )
  56. if not allowed_pattern.fullmatch(table_name):
  57. raise ValueError(f"非法表名:{table_name},仅允许特定格式表名")
  58. try:
  59. # 1. 获取原始数据(复用现有查询逻辑)
  60. df = get_table_data(session, table_name)
  61. # 2. 筛选数值型列(排除ID、字符串等非金属指标列)
  62. numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
  63. # 排除主键ID列(如果存在)
  64. numeric_cols = [col for col in numeric_cols if col.lower() != 'id']
  65. if not numeric_cols:
  66. raise ValueError(f"表 {table_name} 中未找到数值型金属指标列")
  67. # 3. 计算平均值(保留2位小数)
  68. averages = df[numeric_cols].mean().to_dict()
  69. return averages
  70. except SQLAlchemyError as e:
  71. raise Exception(f"数据库查询失败(表 {table_name}):{str(e)}")
  72. except Exception as e:
  73. raise Exception(f"平均值计算失败(表 {table_name}):{str(e)}")
  74. def create_dynamic_table(dataset_id, columns):
  75. """动态创建数据表"""
  76. # 动态构建列
  77. dynamic_columns = {
  78. 'id': Column(Integer, primary_key=True, autoincrement=True) # 为每个表添加一个主键
  79. }
  80. # 根据 columns 字典动态创建字段
  81. for col_name, col_type in columns.items():
  82. if col_type == 'str':
  83. dynamic_columns[col_name] = Column(String(255))
  84. elif col_type == 'int':
  85. dynamic_columns[col_name] = Column(Integer)
  86. elif col_type == 'float':
  87. dynamic_columns[col_name] = Column(Float)
  88. elif col_type == 'datetime':
  89. dynamic_columns[col_name] = Column(DateTime)
  90. # 动态生成模型类,表名使用 dataset_{dataset_id}
  91. table_name = f"dataset_{dataset_id}"
  92. # 在生成的类中添加 `__tablename__`
  93. dynamic_columns['__tablename__'] = table_name
  94. # 动态创建类
  95. dynamic_class = type(table_name, (Base,), dynamic_columns)
  96. # 打印调试信息
  97. print("table_name:", table_name)
  98. print("dynamic_columns:", dynamic_columns)
  99. # 创建数据库引擎
  100. engine = create_engine('sqlite:///SoilAcidification.db') # 这里需要替换为你的数据库引擎
  101. Base.metadata.create_all(engine) # 创建所有表格
  102. return dynamic_class
  103. # 判断文件类型是否允许
  104. def allowed_file(filename):
  105. ALLOWED_EXTENSIONS = {'xlsx', 'xls'}
  106. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  107. # 生成唯一文件名
  108. def generate_unique_filename(filename):
  109. # 获取文件的扩展名
  110. ext = filename.rsplit('.', 1)[1].lower()
  111. # 使用 UUID 和当前时间戳生成唯一文件名(使用 UTC 时区)
  112. unique_filename = f"{uuid.uuid4().hex}_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}.{ext}"
  113. return unique_filename
  114. def infer_column_types(df):
  115. type_map = {
  116. 'object': 'str',
  117. 'int64': 'int',
  118. 'float64': 'float',
  119. 'datetime64[ns]': 'datetime' # 适应Pandas datetime类型
  120. }
  121. # 提取列和其数据类型
  122. return {col: type_map.get(str(df[col].dtype), 'str') for col in df.columns}
  123. def clean_column_names(dataframe):
  124. # Strip whitespace and replace non-breaking spaces and other non-printable characters
  125. dataframe.columns = [col.strip().replace('\xa0', '') for col in dataframe.columns]
  126. return dataframe
  127. # 建立excel文件的列名和数据库模型字段之间的映射
  128. def rename_columns_for_model(dataframe, dataset_type):
  129. if dataset_type == 'reduce':
  130. rename_map = {
  131. '1/b': 'Q_over_b',
  132. 'pH': 'pH',
  133. 'OM': 'OM',
  134. 'CL': 'CL',
  135. 'H': 'H',
  136. 'Al': 'Al'
  137. }
  138. elif dataset_type == 'reflux':
  139. rename_map = {
  140. 'OM': 'OM',
  141. 'CL': 'CL',
  142. 'CEC': 'CEC',
  143. 'H+': 'H_plus',
  144. 'N': 'N',
  145. 'Al3+': 'Al3_plus',
  146. 'ΔpH': 'Delta_pH'
  147. }
  148. # 使用 rename() 方法更新列名
  149. dataframe = dataframe.rename(columns=rename_map)
  150. return dataframe
  151. # 建立前端参数和模型预测字段之间的映射
  152. def rename_columns_for_model_predict(dataframe, dataset_type):
  153. if dataset_type == 'reduce':
  154. rename_map = {
  155. 'init_pH': 'pH',
  156. 'OM': 'OM',
  157. 'CL': 'CL',
  158. 'H': 'H',
  159. 'Al': 'Al'
  160. }
  161. elif dataset_type == 'reflux':
  162. rename_map = {
  163. "OM": "OM",
  164. "CL": "CL",
  165. "CEC": "CEC",
  166. "H+": "H_plus",
  167. "N": "N",
  168. "Al3+": "Al3_plus"
  169. }
  170. # 使用 rename() 方法更新列名
  171. dataframe = dataframe.rename(columns=rename_map)
  172. return dataframe
  173. def insert_data_into_existing_table(session, dataframe, model_class):
  174. """Insert data from a DataFrame into an existing SQLAlchemy model table."""
  175. for index, row in dataframe.iterrows():
  176. record = model_class(**row.to_dict())
  177. session.add(record)
  178. def insert_data_into_dynamic_table(session, dataset_df, dynamic_table_class):
  179. for _, row in dataset_df.iterrows():
  180. record_data = row.to_dict()
  181. session.execute(dynamic_table_class.__table__.insert(), [record_data])
  182. def insert_data_by_type(session, dataset_df, dataset_type):
  183. if dataset_type == 'reduce':
  184. for _, row in dataset_df.iterrows():
  185. record = CurrentReduce(**row.to_dict())
  186. session.add(record)
  187. elif dataset_type == 'reflux':
  188. for _, row in dataset_df.iterrows():
  189. record = CurrentReflux(**row.to_dict())
  190. session.add(record)
  191. def get_current_data(session, data_type):
  192. # 根据数据类型选择相应的表模型
  193. if data_type == 'reduce':
  194. model = CurrentReduce
  195. elif data_type == 'reflux':
  196. model = CurrentReflux
  197. else:
  198. raise ValueError("Invalid data type provided. Choose 'reduce' or 'reflux'.")
  199. try:
  200. # 从数据库中查询所有记录并获取标量结果(模型实例)
  201. result = session.execute(select(model)).scalars().all()
  202. # 转换模型实例为字典列表(过滤SQLAlchemy内部属性)
  203. data_list = []
  204. for item in result:
  205. # 使用模型实例的__dict__属性,但排除内部属性
  206. item_dict = {key: value for key, value in item.__dict__.items()
  207. if not key.startswith('_sa_')}
  208. data_list.append(item_dict)
  209. # 将结果转换为DataFrame
  210. dataframe = pd.DataFrame(data_list)
  211. return dataframe
  212. except Exception as e:
  213. # 增加详细错误信息,方便调试
  214. raise Exception(f"获取{data_type}数据失败: {str(e)}") from e
  215. def get_dataset_by_id(session, dataset_id):
  216. # 动态获取表的元数据
  217. metadata = MetaData(bind=session.bind)
  218. dataset_table = Table(dataset_id, metadata, autoload=True, autoload_with=session.bind)
  219. # 从数据库中查询整个表的数据
  220. query = select(dataset_table)
  221. result = session.execute(query).fetchall()
  222. # 检查是否有数据返回
  223. if not result:
  224. raise ValueError(f"No data found for dataset {dataset_id}.")
  225. # 将结果转换为DataFrame
  226. dataframe = pd.DataFrame(result, columns=[column.name for column in dataset_table.columns])
  227. return dataframe
  228. def predict_to_Q(predictions, init_ph, target_ph):
  229. # 将预测结果转换为Q
  230. Q = predictions * (target_ph - init_ph)
  231. return Q
  232. # 说明:Q指生石灰投加量,单位是%,例如1%代表100g土壤中施加1g生石灰。
  233. # 其中,土壤是指表层20cm土壤。# 如果Q的单位换算为吨/公顷,即t/ha,则需要乘以25。
  234. # ΔpH=目标pH-初始pH
  235. def Q_to_t_ha(Q):
  236. return Q * 25
  237. def create_kriging(file_name, emission_column, points):
  238. # 从 Excel 读取数据
  239. df = pd.read_excel(file_name)
  240. print(df)
  241. # 转换为 GeoDataFrame
  242. gdf = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df['longitude'], df['latitude']))
  243. print(gdf)
  244. # 初始化并运行克里金插值
  245. OK = OrdinaryKriging(
  246. gdf.geometry.x,
  247. gdf.geometry.y,
  248. gdf[emission_column],
  249. variogram_model='spherical',
  250. verbose=True,
  251. enable_plotting=False
  252. )
  253. # 提取输入点的经度和纬度
  254. input_lons = [point[0] for point in points]
  255. input_lats = [point[1] for point in points]
  256. # 对输入的点进行插值
  257. z, ss = OK.execute('points', input_lons, input_lats)
  258. result = {
  259. "message": "Kriging interpolation for points completed successfully",
  260. "interpolated_concentrations": z.tolist()
  261. }
  262. return result