Эх сурвалжийг харах

添加省市模型及边界服务,支持通过名称获取边界GeoJSON,更新配置以允许环境变量覆盖边界文件路径。

drggboy 4 өдөр өмнө
parent
commit
465cc0f1fc

+ 2 - 1
Cd_Prediction_Integrated_System/config.py

@@ -59,7 +59,8 @@ DATA_PATHS = {
 # 分析模块配置
 ANALYSIS_CONFIG = {
     "template_tif": os.path.join(PROJECT_ROOT, "output", "raster", "meanTemp.tif"),
-    "boundary_shp": os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp"),
+    # 允许通过环境变量 CD_BOUNDARY_FILE 覆盖边界文件(支持 .geojson/.shp)
+    "boundary_shp": os.environ.get('CD_BOUNDARY_FILE', os.path.join(PROJECT_ROOT, "output", "raster", "lechang.shp")),
     "output_raster": os.path.join(PROJECT_ROOT, "output", "raster", "output.tif"),
     "temp_shapefile": os.path.join(PROJECT_ROOT, "output", "raster", "points666.shp")
 }

+ 10 - 1
app/api/vector.py

@@ -1,9 +1,10 @@
 # 矢量数据API
-from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks
+from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
 from fastapi.responses import FileResponse
 from sqlalchemy.orm import Session
 from ..database import get_db
 from ..services import vector_service
+from ..services.admin_boundary_service import get_boundary_geojson_by_name
 import os
 import shutil
 from typing import List
@@ -19,6 +20,14 @@ async def import_vector_data(file: UploadFile = File(...), db: Session = Depends
     
     return await vector_service.import_vector_data(file, db)
 
+@router.get("/boundary", summary="按名称获取边界", description="按县/市/省名称获取边界GeoJSON")
+def get_boundary(name: str = Query(...), level: str = Query("auto"), db: Session = Depends(get_db)):
+    """按名称和层级获取边界。level: county/city/province/auto"""
+    try:
+        return get_boundary_geojson_by_name(db, name, level)
+    except Exception as e:
+        raise HTTPException(status_code=404, detail=str(e))
+
 @router.get("/{vector_id}", summary="获取矢量数据", description="根据ID获取一条矢量数据记录")
 def get_vector_data(vector_id: int, db: Session = Depends(get_db)):
     """获取指定ID的矢量数据"""

+ 2 - 0
app/models/__init__.py

@@ -19,3 +19,5 @@ from app.models.atmo_company import *
 from app.models.water_sample import *
 from app.models.agricultural import *
 from app.models.cross_section import *
+from app.models.province import *
+from app.models.city import *

+ 92 - 0
app/models/city.py

@@ -0,0 +1,92 @@
+from sqlalchemy import Column, Integer, String, Text, Index, JSON
+from sqlalchemy.dialects.postgresql import JSONB
+from geoalchemy2 import Geometry
+from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
+import json
+
+class City(Base):
+    """市级行政区划地理数据表"""
+    __tablename__ = "cities"
+
+    # 字段映射字典
+    FIELD_MAPPING = {
+        '市名': 'name',
+        '市代码': 'code',
+        '省名': 'province_name',
+        '省代码': 'province_code'
+    }
+
+    # 反向映射字典
+    REVERSE_FIELD_MAPPING = {v: k for k, v in FIELD_MAPPING.items()}
+
+    id = Column(Integer, primary_key=True, index=True)
+    name = Column(String(100), nullable=False, comment="市名")
+    code = Column(Integer, nullable=False, unique=True, comment="市代码")
+    province_name = Column(String(100), nullable=False, comment="省名") 
+    province_code = Column(Integer, nullable=False, comment="省代码")
+    
+    # 使用PostGIS的几何类型来存储多边形数据
+    geometry = Column(Geometry('MULTIPOLYGON', srid=4326), nullable=False, comment="市级行政区划的几何数据")
+    
+    # 存储完整的GeoJSON数据
+    geojson = Column(JSON, nullable=False, comment="完整的GeoJSON数据") 
+
+    # 显式定义索引
+    __table_args__ = (
+        Index('idx_cities_name', 'name'),
+        Index('idx_cities_code', 'code'),
+        Index('idx_cities_province_code', 'province_code'),
+        #Index('idx_cities_geometry', 'geometry', postgresql_using='gist'),
+    ) 
+
+    @classmethod
+    def from_geojson_feature(cls, feature):
+        """从GeoJSON Feature创建City实例
+        
+        Args:
+            feature: GeoJSON Feature对象
+            
+        Returns:
+            City: 新创建的City实例
+        """
+        properties = feature['properties']
+        # 转换中文字段名为英文
+        mapped_properties = {
+            cls.FIELD_MAPPING.get(k, k): v 
+            for k, v in properties.items()
+        }
+        
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
+        geometry = feature['geometry']
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
+        
+        # 创建实例
+        city = cls(
+            **mapped_properties,
+            geometry=wkt,
+            geojson=feature
+        )
+        return city
+
+    def to_geojson_feature(self):
+        """将City实例转换为GeoJSON Feature
+        
+        Returns:
+            dict: GeoJSON Feature对象
+        """
+        properties = {}
+        # 转换英文字段名为中文
+        for eng_field, cn_field in self.REVERSE_FIELD_MAPPING.items():
+            if hasattr(self, eng_field):
+                properties[cn_field] = getattr(self, eng_field)
+
+        return {
+            'type': 'Feature',
+            'properties': properties,
+            'geometry': self.geometry
+        }

+ 7 - 2
app/models/county.py

@@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer, String, Text, Index, JSON
 from sqlalchemy.dialects.postgresql import JSONB
 from geoalchemy2 import Geometry
 from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
 import json
 
 class County(Base):
@@ -59,9 +60,13 @@ class County(Base):
             for k, v in properties.items()
         }
         
-        # 将GeoJSON geometry转换为WKT格式
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
         geometry = feature['geometry']
-        wkt = f"SRID=4326;{json.dumps(geometry)}"
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
         
         # 创建实例
         county = cls(

+ 87 - 0
app/models/province.py

@@ -0,0 +1,87 @@
+from sqlalchemy import Column, Integer, String, Text, Index, JSON
+from sqlalchemy.dialects.postgresql import JSONB
+from geoalchemy2 import Geometry
+from app.database import Base
+from shapely.geometry import shape, MultiPolygon, Polygon
+import json
+
+class Province(Base):
+    """省级行政区划地理数据表"""
+    __tablename__ = "provinces"
+
+    # 字段映射字典
+    FIELD_MAPPING = {
+        '省名': 'name',
+        '省代码': 'code'
+    }
+
+    # 反向映射字典
+    REVERSE_FIELD_MAPPING = {v: k for k, v in FIELD_MAPPING.items()}
+
+    id = Column(Integer, primary_key=True, index=True)
+    name = Column(String(100), nullable=False, comment="省名")
+    code = Column(Integer, nullable=False, unique=True, comment="省代码")
+    
+    # 使用PostGIS的几何类型来存储多边形数据
+    geometry = Column(Geometry('MULTIPOLYGON', srid=4326), nullable=False, comment="省级行政区划的几何数据")
+    
+    # 存储完整的GeoJSON数据
+    geojson = Column(JSON, nullable=False, comment="完整的GeoJSON数据") 
+
+    # 显式定义索引
+    __table_args__ = (
+        Index('idx_provinces_name', 'name'),
+        Index('idx_provinces_code', 'code'),
+        #Index('idx_provinces_geometry', 'geometry', postgresql_using='gist'),
+    ) 
+
+    @classmethod
+    def from_geojson_feature(cls, feature):
+        """从GeoJSON Feature创建Province实例
+        
+        Args:
+            feature: GeoJSON Feature对象
+            
+        Returns:
+            Province: 新创建的Province实例
+        """
+        properties = feature['properties']
+        # 转换中文字段名为英文
+        mapped_properties = {
+            cls.FIELD_MAPPING.get(k, k): v 
+            for k, v in properties.items()
+        }
+        
+        # 将GeoJSON geometry转换为EWKT格式(SRID=4326)
+        geometry = feature['geometry']
+        geom = shape(geometry)
+        # 统一为 MultiPolygon 以匹配列类型
+        if isinstance(geom, Polygon):
+            geom = MultiPolygon([geom])
+        wkt = f"SRID=4326;{geom.wkt}"
+        
+        # 创建实例
+        province = cls(
+            **mapped_properties,
+            geometry=wkt,
+            geojson=feature
+        )
+        return province
+
+    def to_geojson_feature(self):
+        """将Province实例转换为GeoJSON Feature
+        
+        Returns:
+            dict: GeoJSON Feature对象
+        """
+        properties = {}
+        # 转换英文字段名为中文
+        for eng_field, cn_field in self.REVERSE_FIELD_MAPPING.items():
+            if hasattr(self, eng_field):
+                properties[cn_field] = getattr(self, eng_field)
+
+        return {
+            'type': 'Feature',
+            'properties': properties,
+            'geometry': self.geometry
+        }

+ 114 - 0
app/services/admin_boundary_service.py

@@ -0,0 +1,114 @@
+from sqlalchemy.orm import Session
+from sqlalchemy.sql import text
+import json
+
+
+def get_boundary_geojson_by_name(db: Session, name: str, level: str = "auto") -> dict:
+    """根据名称获取边界GeoJSON Feature
+
+    Args:
+        db (Session): 数据库会话
+        name (str): 名称(县/市/省)
+        level (str): 层级,可选 "county"|"city"|"province"|"auto"
+
+    Returns:
+        dict: GeoJSON Feature 对象
+    """
+    # county 精确匹配
+    if level in ("county", "auto"):
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name, city_name, province_name
+            FROM counties WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "county",
+                    "name": r.name,
+                    "city": r.city_name,
+                    "province": r.province_name,
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    # city 直接查询或从counties聚合
+    if level in ("city", "auto"):
+        # 优先从cities表直接查询
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name, province_name
+            FROM cities WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "city",
+                    "name": r.name,
+                    "province": r.province_name
+                },
+                "geometry": json.loads(r.g)
+            }
+        
+        # 如果cities表没有,从counties表聚合
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(ST_UnaryUnion(ST_Union(geometry))) AS g,
+                   MIN(province_name) AS province_name
+            FROM counties WHERE city_name=:n
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "city",
+                    "name": name,
+                    "province": r.province_name
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    # province 直接查询或从counties聚合
+    if level in ("province", "auto"):
+        # 优先从provinces表直接查询
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(geometry) AS g, name
+            FROM provinces WHERE name=:n LIMIT 1
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "province",
+                    "name": r.name
+                },
+                "geometry": json.loads(r.g)
+            }
+        
+        # 如果provinces表没有,从counties表聚合
+        r = db.execute(text(
+            """
+            SELECT ST_AsGeoJSON(ST_UnaryUnion(ST_Union(geometry))) AS g
+            FROM counties WHERE province_name=:n
+            """
+        ), {"n": name}).fetchone()
+        if r and r.g:
+            return {
+                "type": "Feature",
+                "properties": {
+                    "level": "province",
+                    "name": name
+                },
+                "geometry": json.loads(r.g)
+            }
+
+    raise ValueError(f"未找到名称: {name}")
+
+

+ 63 - 2
app/services/cd_prediction_service.py

@@ -19,6 +19,9 @@ import io
 
 from ..config.cd_prediction_config import cd_config
 from ..utils.cd_prediction_wrapper import CdPredictionWrapper
+from ..database import SessionLocal
+from .admin_boundary_service import get_boundary_geojson_by_name
+import json
 
 class CdPredictionService:
     """
@@ -550,6 +553,25 @@ class CdPredictionService:
         @returns {Dict[str, Any]} 预测结果信息
         """
         try:
+            # 用数据库边界覆盖环境变量给集成系统
+            tmp_geojson = None
+            try:
+                db = SessionLocal()
+                feature = get_boundary_geojson_by_name(db, county_name, level="auto")
+                fc = {"type": "FeatureCollection", "features": [feature]}
+                tmp_dir = tempfile.mkdtemp()
+                tmp_geojson = os.path.join(tmp_dir, "boundary.geojson")
+                with open(tmp_geojson, 'w', encoding='utf-8') as f:
+                    json.dump(fc, f, ensure_ascii=False)
+                os.environ['CD_BOUNDARY_FILE'] = tmp_geojson
+            except Exception as _e:
+                self.logger.warning(f"从数据库获取边界失败,回退到默认配置: {str(_e)}")
+            finally:
+                try:
+                    db.close()
+                except Exception:
+                    pass
+
             # 运行作物Cd预测
             self.logger.info(f"为{county_name}执行作物Cd预测")
             prediction_result = self.wrapper.run_prediction_script("crop", raster_config_override)
@@ -565,7 +587,7 @@ class CdPredictionService:
             # 清理旧文件
             self._cleanup_old_files(model_type)
             
-            return {
+            result_obj = {
                 'map_path': copied_files.get('map_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'raster_path': copied_files.get('raster_path'),
@@ -574,6 +596,16 @@ class CdPredictionService:
                 'timestamp': timestamp,
                 'stats': self._get_file_stats(copied_files.get('map_path'))
             }
+            # 清理临时边界
+            try:
+                if tmp_geojson and os.path.exists(tmp_geojson):
+                    import shutil
+                    shutil.rmtree(os.path.dirname(tmp_geojson), ignore_errors=True)
+                if 'CD_BOUNDARY_FILE' in os.environ:
+                    del os.environ['CD_BOUNDARY_FILE']
+            except Exception:
+                pass
+            return result_obj
             
         except Exception as e:
             self.logger.error(f"为{county_name}执行作物Cd预测失败: {str(e)}")
@@ -590,6 +622,25 @@ class CdPredictionService:
         @returns {Dict[str, Any]} 预测结果信息
         """
         try:
+            # 用数据库边界覆盖环境变量给集成系统
+            tmp_geojson = None
+            try:
+                db = SessionLocal()
+                feature = get_boundary_geojson_by_name(db, county_name, level="auto")
+                fc = {"type": "FeatureCollection", "features": [feature]}
+                tmp_dir = tempfile.mkdtemp()
+                tmp_geojson = os.path.join(tmp_dir, "boundary.geojson")
+                with open(tmp_geojson, 'w', encoding='utf-8') as f:
+                    json.dump(fc, f, ensure_ascii=False)
+                os.environ['CD_BOUNDARY_FILE'] = tmp_geojson
+            except Exception as _e:
+                self.logger.warning(f"从数据库获取边界失败,回退到默认配置: {str(_e)}")
+            finally:
+                try:
+                    db.close()
+                except Exception:
+                    pass
+
             # 运行有效态Cd预测
             self.logger.info(f"为{county_name}执行有效态Cd预测")
             prediction_result = self.wrapper.run_prediction_script("effective", raster_config_override)
@@ -605,7 +656,7 @@ class CdPredictionService:
             # 清理旧文件
             self._cleanup_old_files(model_type)
             
-            return {
+            result_obj = {
                 'map_path': copied_files.get('map_path'),
                 'histogram_path': copied_files.get('histogram_path'),
                 'raster_path': copied_files.get('raster_path'),
@@ -614,6 +665,16 @@ class CdPredictionService:
                 'timestamp': timestamp,
                 'stats': self._get_file_stats(copied_files.get('map_path'))
             }
+            # 清理临时边界
+            try:
+                if tmp_geojson and os.path.exists(tmp_geojson):
+                    import shutil
+                    shutil.rmtree(os.path.dirname(tmp_geojson), ignore_errors=True)
+                if 'CD_BOUNDARY_FILE' in os.environ:
+                    del os.environ['CD_BOUNDARY_FILE']
+            except Exception:
+                pass
+            return result_obj
             
         except Exception as e:
             self.logger.error(f"为{county_name}执行有效态Cd预测失败: {str(e)}")

+ 66 - 0
migrations/versions/d89bb52ae481_add_provinces_and_cities_tables.py

@@ -0,0 +1,66 @@
+"""add_provinces_and_cities_tables
+
+Revision ID: d89bb52ae481
+Revises: e52e802fe61a
+Create Date: 2025-08-09 19:40:59.929438
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import geoalchemy2
+
+# revision identifiers, used by Alembic.
+revision = 'd89bb52ae481'
+down_revision = 'e52e802fe61a'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    """升级数据库到当前版本"""
+    # 创建provinces表
+    op.create_table('provinces',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(length=100), nullable=False, comment='省名'),
+    sa.Column('code', sa.Integer(), nullable=False, comment='省代码'),
+    sa.Column('geometry', geoalchemy2.types.Geometry(geometry_type='MULTIPOLYGON', srid=4326, from_text='ST_GeomFromEWKT', name='geometry', nullable=False), nullable=False, comment='省级行政区划的几何数据'),
+    sa.Column('geojson', sa.JSON(), nullable=False, comment='完整的GeoJSON数据'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('code')
+    )
+    op.create_index('idx_provinces_code', 'provinces', ['code'], unique=False)
+    op.create_index('idx_provinces_name', 'provinces', ['name'], unique=False)
+    op.create_index(op.f('ix_provinces_id'), 'provinces', ['id'], unique=False)
+    
+    # 创建cities表
+    op.create_table('cities',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('name', sa.String(length=100), nullable=False, comment='市名'),
+    sa.Column('code', sa.Integer(), nullable=False, comment='市代码'),
+    sa.Column('province_name', sa.String(length=100), nullable=False, comment='省名'),
+    sa.Column('province_code', sa.Integer(), nullable=False, comment='省代码'),
+    sa.Column('geometry', geoalchemy2.types.Geometry(geometry_type='MULTIPOLYGON', srid=4326, from_text='ST_GeomFromEWKT', name='geometry', nullable=False), nullable=False, comment='市级行政区划的几何数据'),
+    sa.Column('geojson', sa.JSON(), nullable=False, comment='完整的GeoJSON数据'),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('code')
+    )
+    op.create_index('idx_cities_code', 'cities', ['code'], unique=False)
+    op.create_index('idx_cities_name', 'cities', ['name'], unique=False)
+    op.create_index('idx_cities_province_code', 'cities', ['province_code'], unique=False)
+    op.create_index(op.f('ix_cities_id'), 'cities', ['id'], unique=False)
+
+
+def downgrade():
+    """降级数据库到上一个版本"""
+    # 删除cities表
+    op.drop_index(op.f('ix_cities_id'), table_name='cities')
+    op.drop_index('idx_cities_province_code', table_name='cities')
+    op.drop_index('idx_cities_name', table_name='cities')
+    op.drop_index('idx_cities_code', table_name='cities')
+    op.drop_table('cities')
+
+    # 删除provinces表
+    op.drop_index(op.f('ix_provinces_id'), table_name='provinces')
+    op.drop_index('idx_provinces_name', table_name='provinces')
+    op.drop_index('idx_provinces_code', table_name='provinces')
+    op.drop_table('provinces')