浏览代码

基本逻辑

tangbengaoyuan 3 月之前
父节点
当前提交
4d124cd35c
共有 1 个文件被更改,包括 148 次插入0 次删除
  1. 148 0
      point_unit_grouping_api.py

+ 148 - 0
point_unit_grouping_api.py

@@ -0,0 +1,148 @@
+import psycopg2
+from flask import Flask, jsonify
+from shapely import wkb
+from collections import Counter
+
+app = Flask(__name__)
+
+# 定义 h_xtfx 值的映射关系(数值用于插值计算)
+h_xtfx_mapping = {
+    "优先保护类": 1,
+    "安全利用类": 2,
+    "严格管控类": 3
+}
+reverse_h_xtfx_mapping = {v: k for k, v in h_xtfx_mapping.items()}
+
+
+def connect_to_database():
+    """建立数据库连接"""
+    try:
+        connection = psycopg2.connect(
+            user="postgres",
+            password="root",
+            host="localhost",
+            port="5432",
+            database="testdb1"
+        )
+        return connection
+    except Exception as error:
+        print("数据库连接失败:", error)
+        return None
+
+
+def fetch_unit_and_point_data(connection):
+    """获取单元和点位数据"""
+    if not connection:
+        return [], []
+    cursor = connection.cursor()
+    try:
+        # 查询单元信息(GID 和 几何图形)
+        unit_query = "SELECT gid, geom FROM unit_ceil"
+        cursor.execute(unit_query)
+        unit_rows = cursor.fetchall()
+        
+        # 查询点位信息(ID、几何图形、h_xtfx 值)
+        point_query = "SELECT id, geom, h_xtfx FROM fifty_thousand_survey_data"
+        cursor.execute(point_query)
+        point_rows = cursor.fetchall()
+        
+        return unit_rows, point_rows
+    except Exception as error:
+        print("数据查询失败:", error)
+        return [], []
+    finally:
+        cursor.close()
+
+
+def idw_interpolation(points, target_point):
+    """反距离加权插值函数"""
+    total_weight = 0
+    weighted_sum = 0
+    power = 2  # 距离权重的幂
+    for point, value in points:
+        distance = point.distance(target_point)
+        if distance == 0:
+            return value  # 距离为0时直接返回该点值
+        weight = 1 / (distance ** power)
+        total_weight += weight
+        weighted_sum += weight * value
+    return weighted_sum / total_weight if total_weight != 0 else None
+
+
+def check_h_xtfx_values(unit_rows, point_rows):
+    """核心逻辑:判断单元的 h_xtfx 值"""
+    unit_point_mapping = {}
+    for unit_id, unit_geom_wkb in unit_rows:
+        try:
+            unit_geom = wkb.loads(unit_geom_wkb, hex=True)
+            unit_points = []
+            for point_id, point_geom_wkb, h_xtfx in point_rows:
+                point_geom = wkb.loads(point_geom_wkb, hex=True)
+                if unit_geom.contains(point_geom):
+                    unit_points.append((point_geom, h_xtfx))  # 存储几何坐标和 h_xtfx 值
+            unit_point_mapping[unit_id] = unit_points
+        except Exception as e:
+            print(f"处理单元 {unit_id} 失败:", e)
+
+    result = {}
+    for unit_id, points in unit_point_mapping.items():
+        if not points:
+            result[unit_id] = None
+            continue  # 无点位,直接设为 None
+
+        has_strict_control = any(h_xtfx == "严格管控类" for _, h_xtfx in points)
+        h_xtfx_list = [h_xtfx for _, h_xtfx in points]
+
+        if not has_strict_control:
+            # 无严格管控类:先判断比例是否 ≥80%
+            counter = Counter(h_xtfx_list)
+            most_common, count = counter.most_common(1)[0]
+            if count / len(points) >= 0.8:
+                result[unit_id] = most_common
+                continue  # 比例达标,直接输出
+
+            # 比例不达标:对优先保护类和安全利用类进行插值
+            valid_points = [(geom, h_xtfx_mapping[h_xtfx]) 
+                           for geom, h_xtfx in points 
+                           if h_xtfx in ["优先保护类", "安全利用类"]]
+            if len(valid_points) < 2:
+                # 有效点位不足,取最常见值(避免插值错误)
+                result[unit_id] = most_common
+            else:
+                unit_center = unit_geom.centroid
+                interpolated = idw_interpolation(valid_points, unit_center)
+                result[unit_id] = (
+                    "优先保护类" if interpolated <= 1.5 
+                    else "安全利用类" if interpolated <= 2.5 
+                    else "严格管控类"  # 理论上不会出现,因无严格管控类点位
+                )
+        else:
+            # 存在严格管控类:对所有点位(含严格管控类)进行插值
+            point_coords = [(geom, h_xtfx_mapping[h_xtfx]) for geom, h_xtfx in points]
+            unit_center = unit_geom.centroid
+            interpolated = idw_interpolation(point_coords, unit_center)
+            result[unit_id] = (
+                "优先保护类" if interpolated <= 1.5 
+                else "安全利用类" if interpolated <= 2.5 
+                else "严格管控类"
+            )
+
+    return result
+
+
+@app.route('/get_h_xtfx_result', methods=['GET'])
+def get_h_xtfx_result():
+    """接口:返回单元 h_xtfx                                                                结果"""
+    connection = connect_to_database()
+    if not connection:
+        return jsonify({"error": "数据库连接失败"})
+    
+    unit_rows, point_rows = fetch_unit_and_point_data(connection)
+    result = check_h_xtfx_values(unit_rows, point_rows)
+    
+    connection.close()
+    return jsonify(result)
+
+
+if __name__ == '__main__':
+    app.run(debug=True)