point_unit_grouping_api.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import psycopg2
  2. from flask import Flask, jsonify
  3. from shapely import wkb
  4. from collections import Counter
  5. app = Flask(__name__)
  6. # 定义 h_xtfx 值的映射关系(数值用于插值计算)
  7. h_xtfx_mapping = {
  8. "优先保护类": 1,
  9. "安全利用类": 2,
  10. "严格管控类": 3
  11. }
  12. reverse_h_xtfx_mapping = {v: k for k, v in h_xtfx_mapping.items()}
  13. def connect_to_database():
  14. """建立数据库连接"""
  15. try:
  16. connection = psycopg2.connect(
  17. user="postgres",
  18. password="root",
  19. host="localhost",
  20. port="5432",
  21. database="testdb1"
  22. )
  23. return connection
  24. except Exception as error:
  25. print("数据库连接失败:", error)
  26. return None
  27. def fetch_unit_and_point_data(connection):
  28. """获取单元和点位数据"""
  29. if not connection:
  30. return [], []
  31. cursor = connection.cursor()
  32. try:
  33. # 查询单元信息(GID 和 几何图形)
  34. unit_query = "SELECT gid, geom FROM unit_ceil"
  35. cursor.execute(unit_query)
  36. unit_rows = cursor.fetchall()
  37. # 查询点位信息(ID、几何图形、h_xtfx 值)
  38. point_query = "SELECT id, geom, h_xtfx FROM fifty_thousand_survey_data"
  39. cursor.execute(point_query)
  40. point_rows = cursor.fetchall()
  41. return unit_rows, point_rows
  42. except Exception as error:
  43. print("数据查询失败:", error)
  44. return [], []
  45. finally:
  46. cursor.close()
  47. def idw_interpolation(points, target_point):
  48. """反距离加权插值函数"""
  49. total_weight = 0
  50. weighted_sum = 0
  51. power = 2 # 距离权重的幂
  52. for point, value in points:
  53. distance = point.distance(target_point)
  54. if distance == 0:
  55. return value # 距离为0时直接返回该点值
  56. weight = 1 / (distance ** power)
  57. total_weight += weight
  58. weighted_sum += weight * value
  59. return weighted_sum / total_weight if total_weight != 0 else None
  60. def check_h_xtfx_values(unit_rows, point_rows):
  61. """核心逻辑:判断单元的 h_xtfx 值"""
  62. unit_point_mapping = {}
  63. for unit_id, unit_geom_wkb in unit_rows:
  64. try:
  65. unit_geom = wkb.loads(unit_geom_wkb, hex=True)
  66. unit_points = []
  67. for point_id, point_geom_wkb, h_xtfx in point_rows:
  68. point_geom = wkb.loads(point_geom_wkb, hex=True)
  69. if unit_geom.contains(point_geom):
  70. unit_points.append((point_geom, h_xtfx)) # 存储几何坐标和 h_xtfx 值
  71. unit_point_mapping[unit_id] = unit_points
  72. except Exception as e:
  73. print(f"处理单元 {unit_id} 失败:", e)
  74. result = {}
  75. for unit_id, points in unit_point_mapping.items():
  76. if not points:
  77. result[unit_id] = None
  78. continue # 无点位,直接设为 None
  79. has_strict_control = any(h_xtfx == "严格管控类" for _, h_xtfx in points)
  80. h_xtfx_list = [h_xtfx for _, h_xtfx in points]
  81. if not has_strict_control:
  82. # 无严格管控类:先判断比例是否 ≥80%
  83. counter = Counter(h_xtfx_list)
  84. most_common, count = counter.most_common(1)[0]
  85. if count / len(points) >= 0.8:
  86. result[unit_id] = most_common
  87. continue # 比例达标,直接输出
  88. # 比例不达标:对优先保护类和安全利用类进行插值
  89. valid_points = [(geom, h_xtfx_mapping[h_xtfx])
  90. for geom, h_xtfx in points
  91. if h_xtfx in ["优先保护类", "安全利用类"]]
  92. if len(valid_points) < 2:
  93. # 有效点位不足,取最常见值(避免插值错误)
  94. result[unit_id] = most_common
  95. else:
  96. unit_center = unit_geom.centroid
  97. interpolated = idw_interpolation(valid_points, unit_center)
  98. result[unit_id] = (
  99. "优先保护类" if interpolated <= 1.5
  100. else "安全利用类" if interpolated <= 2.5
  101. else "严格管控类" # 理论上不会出现,因无严格管控类点位
  102. )
  103. else:
  104. # 存在严格管控类:对所有点位(含严格管控类)进行插值
  105. point_coords = [(geom, h_xtfx_mapping[h_xtfx]) for geom, h_xtfx in points]
  106. unit_center = unit_geom.centroid
  107. interpolated = idw_interpolation(point_coords, unit_center)
  108. result[unit_id] = (
  109. "优先保护类" if interpolated <= 1.5
  110. else "安全利用类" if interpolated <= 2.5
  111. else "严格管控类"
  112. )
  113. return result
  114. @app.route('/get_h_xtfx_result', methods=['GET'])
  115. def get_h_xtfx_result():
  116. """接口:返回单元 h_xtfx 结果"""
  117. connection = connect_to_database()
  118. if not connection:
  119. return jsonify({"error": "数据库连接失败"})
  120. unit_rows, point_rows = fetch_unit_and_point_data(connection)
  121. result = check_h_xtfx_values(unit_rows, point_rows)
  122. connection.close()
  123. return jsonify(result)
  124. if __name__ == '__main__':
  125. app.run(debug=True)