cd_flux_service.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import os
  2. import csv
  3. import logging
  4. import sys
  5. import pandas as pd
  6. import geopandas as gpd
  7. from shapely.geometry import Point
  8. from pathlib import Path
  9. from typing import Dict, Any
  10. from sqlalchemy import func
  11. from sqlalchemy.orm import Session
  12. from app.database import SessionLocal
  13. from app.models import FarmlandData, FluxCdOutputData, FluxCdInputData
  14. from app.utils.mapping_utils import MappingUtils, csv_to_raster_workflow
  15. # 配置日志
  16. logger = logging.getLogger(__name__)
  17. logger.setLevel(logging.INFO)
  18. handler = logging.StreamHandler()
  19. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  20. handler.setFormatter(formatter)
  21. logger.addHandler(handler)
  22. def get_base_dir():
  23. """获取基础目录路径(与土地数据处理函数一致)"""
  24. if getattr(sys, 'frozen', False):
  25. # 打包后的可执行文件
  26. return os.path.dirname(sys.executable)
  27. else:
  28. # 脚本运行模式
  29. return os.path.dirname(os.path.abspath(__file__))
  30. def get_static_dir():
  31. """获取静态资源目录"""
  32. base_dir = get_base_dir()
  33. return os.path.join(base_dir, "..", "static", "cd_flux")
  34. def get_default_boundary_shp():
  35. """获取默认边界SHP文件路径"""
  36. static_dir = get_static_dir()
  37. # 尝试几个可能的边界文件路径
  38. possible_paths = [
  39. os.path.join(static_dir, "lechang.shp"),
  40. ]
  41. for path in possible_paths:
  42. if os.path.exists(path):
  43. return path
  44. return None
  45. class FluxCdVisualizationService:
  46. """
  47. 农田Cd通量可视化服务类
  48. @description: 提供基于农田位置和Cd通量数据的空间分布图和直方图生成功能
  49. @version: 1.1.2 (修复绘图问题)
  50. """
  51. def __init__(self, db: Session = None):
  52. """
  53. 初始化可视化服务
  54. @param db: 可选的数据库会话对象
  55. """
  56. self.logger = logger
  57. self.db = db
  58. self.mapper = MappingUtils()
  59. def generate_cd_input_flux_map(self, output_dir: str = None, boundary_shp: str = None) -> Dict[str, Any]:
  60. """
  61. 生成输入镉通量(In_Cd)的空间分布图和直方图
  62. @param output_dir: 输出文件目录
  63. @param boundary_shp: 边界Shapefile文件路径
  64. @return: 包含输出文件路径的字典
  65. """
  66. try:
  67. # 如果未提供数据库会话,则创建新的会话
  68. db = self.db if self.db else SessionLocal()
  69. should_close = self.db is None
  70. # 1. 从数据库查询数据
  71. data = self._fetch_fluxcd_data(db)
  72. if not data:
  73. if should_close:
  74. db.close()
  75. return {
  76. "success": False,
  77. "message": "数据库中未找到任何农田Cd通量数据",
  78. "data": None
  79. }
  80. # 2. 设置输出目录
  81. static_dir = get_static_dir()
  82. if output_dir is None:
  83. output_dir = static_dir
  84. os.makedirs(output_dir, exist_ok=True)
  85. # 3. 设置边界SHP(如果未提供则使用默认)
  86. if boundary_shp is None:
  87. boundary_shp = get_default_boundary_shp()
  88. if boundary_shp:
  89. self.logger.info(f"使用默认边界文件: {boundary_shp}")
  90. else:
  91. self.logger.warning("未找到默认边界文件,将不使用边界裁剪")
  92. # 4. 生成CSV文件
  93. csv_path = os.path.join(output_dir, "fluxcd_input.csv")
  94. self._generate_csv(data, csv_path)
  95. # 5. 设置模板TIFF路径(使用土地数据处理中的模板)
  96. template_tif = os.path.join(static_dir, "meanTemp.tif")
  97. if not os.path.exists(template_tif):
  98. # 尝试其他可能的模板路径
  99. template_tif = os.path.join(static_dir, "template.tif")
  100. if not os.path.exists(template_tif):
  101. raise FileNotFoundError(f"未找到模板TIFF文件")
  102. # 6. 使用csv_to_raster_workflow将CSV转换为栅格
  103. base_name = "fluxcd_input"
  104. raster_path = os.path.join(output_dir, f"{base_name}_raster.tif")
  105. # 关键修改:确保传递边界文件
  106. workflow_result = csv_to_raster_workflow(
  107. csv_file=csv_path,
  108. template_tif=template_tif,
  109. output_dir=output_dir,
  110. resolution_factor=4.0,
  111. interpolation_method='linear',
  112. field_name='In_Cd',
  113. lon_col=0, # CSV中经度列索引
  114. lat_col=1, # CSV中纬度列索引
  115. value_col=2, # CSV中数值列索引
  116. enable_interpolation=True
  117. )
  118. # 获取栅格路径和统计信息
  119. raster_path = workflow_result['raster']
  120. stats = workflow_result['statistics']
  121. # 7. 创建栅格地图 - 关键修改:使用边界文件裁剪
  122. map_path = os.path.join(output_dir, f"{base_name}_map")
  123. map_file = self.mapper.create_raster_map(
  124. shp_path=boundary_shp, # 边界文件
  125. tif_path=raster_path, # 栅格文件
  126. output_path=map_path,
  127. colormap='green_yellow_red_purple',
  128. title="input Cd flux map",
  129. output_size=12,
  130. dpi=300,
  131. enable_interpolation=False,
  132. interpolation_method='linear'
  133. )
  134. # 8. 创建直方图
  135. histogram_path = self.mapper.create_histogram(
  136. raster_path,
  137. save_path=os.path.join(output_dir, f"{base_name}_histogram.jpg"),
  138. xlabel='input Cd flux(g/ha/a)',
  139. ylabel='frequency',
  140. title='input Cd flux histogram',
  141. bins=100
  142. )
  143. result = {
  144. "success": True,
  145. "message": "成功生成Cd通量可视化结果",
  146. "data": {
  147. "csv": csv_path,
  148. "raster": raster_path,
  149. "map": map_file,
  150. "histogram": histogram_path,
  151. "statistics": stats,
  152. "boundary_used": boundary_shp if boundary_shp else "无"
  153. }
  154. }
  155. if should_close:
  156. db.close()
  157. return result
  158. except Exception as e:
  159. self.logger.error(f"生成Cd通量可视化结果时发生错误: {str(e)}", exc_info=True)
  160. return {
  161. "success": False,
  162. "message": f"生成失败: {str(e)}",
  163. "data": None
  164. }
  165. def update_from_csv(self, csv_file_path: str):
  166. """包装更新服务方法"""
  167. update_service = FluxCdUpdateService(db=self.db)
  168. return update_service.update_from_csv(csv_file_path)
  169. def _fetch_fluxcd_data(self, db: Session) -> list:
  170. """
  171. 从数据库查询需要的数据
  172. @param db: 数据库会话
  173. @returns: 查询结果列表
  174. """
  175. try:
  176. # 查询Farmland_data和FluxCd_output_data的联合数据
  177. query = db.query(
  178. FarmlandData.lon,
  179. FarmlandData.lan,
  180. FluxCdOutputData.in_cd
  181. ).join(
  182. FluxCdOutputData,
  183. (FarmlandData.farmland_id == FluxCdOutputData.farmland_id) &
  184. (FarmlandData.sample_id == FluxCdOutputData.sample_id)
  185. )
  186. return query.all()
  187. except Exception as e:
  188. self.logger.error(f"查询农田Cd通量数据时发生错误: {str(e)}")
  189. return []
  190. def _generate_csv(self, data: list, output_path: str):
  191. """
  192. 将查询结果生成CSV文件
  193. @param data: 查询结果列表
  194. @param output_path: 输出CSV文件路径
  195. """
  196. try:
  197. with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
  198. writer = csv.writer(csvfile)
  199. # 写入表头
  200. writer.writerow(['lon', 'lan', 'In_Cd'])
  201. # 写入数据
  202. for row in data:
  203. writer.writerow(row)
  204. self.logger.info(f"CSV文件已生成: {output_path}")
  205. except Exception as e:
  206. self.logger.error(f"生成CSV文件时发生错误: {str(e)}")
  207. raise
  208. class FluxCdUpdateService:
  209. """
  210. 农田Cd通量数据更新服务类
  211. 处理用户上传的CSV文件并更新数据库中的通量值
  212. """
  213. def __init__(self, db: Session = None):
  214. self.db = db if db else SessionLocal()
  215. self.logger = logging.getLogger(__name__)
  216. def update_from_csv(self, csv_file_path: str) -> Dict[str, Any]:
  217. """
  218. 从CSV文件更新Cd通量数据
  219. @param csv_file_path: CSV文件路径
  220. @return: 更新结果字典
  221. """
  222. try:
  223. # 读取CSV文件
  224. df = self._read_csv(csv_file_path)
  225. # 验证CSV格式
  226. self._validate_csv(df)
  227. # 更新数据库
  228. update_count = 0
  229. create_count = 0
  230. error_count = 0
  231. for _, row in df.iterrows():
  232. try:
  233. # 查找匹配的农田样点
  234. farmland = self._find_farmland(row['lon'], row['lan'])
  235. if farmland:
  236. # 更新现有记录
  237. input_updated = self._update_input_data(farmland, row)
  238. if input_updated:
  239. self._update_output_data(farmland)
  240. update_count += 1
  241. else:
  242. # 创建新记录
  243. if self._create_new_records(row):
  244. create_count += 1
  245. else:
  246. error_count += 1
  247. except Exception as e:
  248. self.logger.error(f"处理记录失败: {str(e)}")
  249. error_count += 1
  250. self.db.commit()
  251. return {
  252. "success": True,
  253. "message": f"成功更新 {update_count} 条记录,创建 {create_count} 条新记录,{error_count} 条失败",
  254. "updated_count": update_count,
  255. "created_count": create_count,
  256. "error_count": error_count
  257. }
  258. except Exception as e:
  259. self.db.rollback()
  260. self.logger.error(f"更新失败: {str(e)}", exc_info=True)
  261. return {
  262. "success": False,
  263. "message": f"更新失败: {str(e)}"
  264. }
  265. finally:
  266. self.db.close()
  267. def _create_new_records(self, row: pd.Series):
  268. """为新的经纬度创建完整的记录"""
  269. try:
  270. self.logger.info(f"创建新农田记录: 经度={row['lon']}, 纬度={row['lan']}")
  271. # 获取下一个可用的 Farmland_ID
  272. max_id = self.db.query(func.max(FarmlandData.farmland_id)).scalar()
  273. new_farmland_id = (max_id or 0) + 1
  274. # 创建新的农田样点记录
  275. farmland = FarmlandData(
  276. farmland_id=new_farmland_id,
  277. lon=row['lon'],
  278. lan=row['lan'],
  279. type=0.0, # 默认为旱地
  280. geom=f"POINT({row['lon']} {row['lan']})"
  281. )
  282. self.db.add(farmland)
  283. self.db.flush() # 确保ID被分配
  284. self.logger.info(f"农田记录创建成功: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  285. # 创建新的输入通量记录
  286. input_data = FluxCdInputData(
  287. farmland_id=farmland.farmland_id,
  288. sample_id=farmland.sample_id,
  289. atmospheric_deposition=row['DQCJ_Cd'],
  290. irrigation_input=row['GGS_Cd'],
  291. agro_chemicals_input=row['NCP_Cd'],
  292. # 设置合理的默认值
  293. initial_cd=0.0,
  294. groundwater_leaching=0.023,
  295. surface_runoff=0.368,
  296. grain_removal=0.0,
  297. straw_removal=0.0
  298. )
  299. self.db.add(input_data)
  300. self.logger.info("输入通量记录创建成功")
  301. # 计算输入总通量
  302. in_cd = row['DQCJ_Cd'] + row['GGS_Cd'] + row['NCP_Cd']
  303. # 计算输出总通量(假设默认值)
  304. out_cd = 0.023 + 0.368 + 0.0 + 0.0
  305. # 创建新的输出通量记录
  306. output_data = FluxCdOutputData(
  307. farmland_id=farmland.farmland_id,
  308. sample_id=farmland.sample_id,
  309. in_cd=in_cd,
  310. out_cd=out_cd,
  311. net_cd=in_cd - out_cd,
  312. end_cd=0.0 # 默认当年Cd浓度
  313. )
  314. self.db.add(output_data)
  315. self.logger.info("输出通量记录创建成功")
  316. self.logger.info(f"所有记录创建完成: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  317. return True
  318. except Exception as e:
  319. self.logger.error(f"创建新记录失败: {str(e)}")
  320. return False
  321. def _read_csv(self, file_path: str) -> pd.DataFrame:
  322. """读取CSV文件并验证基本格式"""
  323. try:
  324. df = pd.read_csv(file_path)
  325. self.logger.info(f"成功读取CSV文件: {file_path}")
  326. return df
  327. except Exception as e:
  328. raise ValueError(f"CSV文件读取失败: {str(e)}")
  329. def _validate_csv(self, df: pd.DataFrame):
  330. """验证CSV文件包含必要的列"""
  331. required_columns = {'lon', 'lan', 'DQCJ_Cd', 'GGS_Cd', 'NCP_Cd'}
  332. if not required_columns.issubset(df.columns):
  333. missing = required_columns - set(df.columns)
  334. raise ValueError(f"CSV缺少必要列: {', '.join(missing)}")
  335. def _find_farmland(self, lon: float, lan: float) -> FarmlandData:
  336. """根据经纬度查找农田样点"""
  337. # 使用容差匹配(0.001度≈100米)
  338. tol = 0.001
  339. return self.db.query(FarmlandData).filter(
  340. FarmlandData.lon.between(lon - tol, lon + tol),
  341. FarmlandData.lan.between(lan - tol, lan + tol)
  342. ).first()
  343. def _update_input_data(self, farmland: FarmlandData, row: pd.Series) -> bool:
  344. """更新输入通量数据,返回是否有更新"""
  345. input_data = self.db.query(FluxCdInputData).filter(
  346. FluxCdInputData.farmland_id == farmland.farmland_id,
  347. FluxCdInputData.sample_id == farmland.sample_id
  348. ).first()
  349. if not input_data:
  350. self.logger.warning(f"未找到输入数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  351. return False
  352. # 检查是否需要更新
  353. updated = False
  354. if input_data.atmospheric_deposition != row['DQCJ_Cd']:
  355. input_data.atmospheric_deposition = row['DQCJ_Cd']
  356. updated = True
  357. if input_data.irrigation_input != row['GGS_Cd']:
  358. input_data.irrigation_input = row['GGS_Cd']
  359. updated = True
  360. if input_data.agro_chemicals_input != row['NCP_Cd']:
  361. input_data.agro_chemicals_input = row['NCP_Cd']
  362. updated = True
  363. if updated:
  364. self.logger.info(f"更新输入通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  365. return updated
  366. def _update_output_data(self, farmland: FarmlandData):
  367. """更新输出通量数据"""
  368. output_data = self.db.query(FluxCdOutputData).filter(
  369. FluxCdOutputData.farmland_id == farmland.farmland_id,
  370. FluxCdOutputData.sample_id == farmland.sample_id
  371. ).first()
  372. if not output_data:
  373. self.logger.warning(f"未找到输出数据: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  374. return
  375. # 重新获取更新后的输入数据
  376. input_data = self.db.query(FluxCdInputData).filter(
  377. FluxCdInputData.farmland_id == farmland.farmland_id,
  378. FluxCdInputData.sample_id == farmland.sample_id
  379. ).first()
  380. # 重新计算并更新
  381. output_data.in_cd = input_data.input_flux()
  382. # 注意:输出总通量out_cd不会由用户上传的CSV更新,所以我们保持原值
  383. # 重新计算净通量
  384. output_data.net_cd = output_data.in_cd - output_data.out_cd
  385. self.logger.info(f"更新输出通量: Farmland_ID={farmland.farmland_id}, Sample_ID={farmland.sample_id}")
  386. # 测试主函数
  387. # 测试主函数
  388. if __name__ == "__main__":
  389. import tempfile
  390. import shutil
  391. import json
  392. # 创建数据库会话
  393. db = SessionLocal()
  394. try:
  395. # 初始化服务
  396. service = FluxCdVisualizationService(db=db)
  397. print("=" * 50)
  398. print("测试Cd通量可视化服务")
  399. print("=" * 50)
  400. # 测试生成可视化结果
  401. print("\n>>> 测试生成Cd输入通量可视化地图")
  402. # 测试更新服务
  403. print("\n>>> 测试从CSV更新Cd通量数据")
  404. # 创建临时目录和CSV文件
  405. temp_dir = tempfile.mkdtemp()
  406. test_csv_path = os.path.join(temp_dir, "test_update.csv")
  407. # 创建测试CSV文件
  408. test_data = [
  409. "lon,lan,DQCJ_Cd,GGS_Cd,NCP_Cd",
  410. "113.123,25.456,1.24,4.56,7.89",
  411. "113.125,25.457,2.35,5.67,8.90"
  412. ]
  413. with open(test_csv_path, 'w', encoding='utf-8') as f:
  414. f.write("\n".join(test_data))
  415. # 调用更新服务
  416. update_result = service.update_from_csv(test_csv_path)
  417. print(json.dumps(update_result, indent=2, ensure_ascii=False))
  418. # 清理临时文件
  419. shutil.rmtree(temp_dir)
  420. print(f"\n清理临时目录: {temp_dir}")
  421. print("\n测试完成!")
  422. except Exception as e:
  423. logger.error(f"测试过程中发生错误: {str(e)}")
  424. import traceback
  425. traceback.print_exc()
  426. finally:
  427. # 关闭数据库会话
  428. db.close()