123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 单元分类功能演示脚本
- 该脚本帮助用户快速启动和测试单元分类功能
- """
- import os
- import sys
- import subprocess
- import time
- import threading
- from pathlib import Path
- def print_banner():
- """打印横幅"""
- print("=" * 70)
- print("单元分类功能演示")
- print("=" * 70)
- print()
- def check_dependencies():
- """检查依赖项"""
- print("🔍 检查依赖项...")
-
- required_packages = [
- 'fastapi',
- 'uvicorn',
- 'sqlalchemy',
- 'shapely',
- 'psycopg2',
- 'geoalchemy2'
- ]
-
- missing_packages = []
-
- for package in required_packages:
- try:
- __import__(package)
- print(f"✓ {package}")
- except ImportError:
- missing_packages.append(package)
- print(f"✗ {package} (缺失)")
-
- if missing_packages:
- print("\n❌ 缺少依赖项,请先安装:")
- print("conda env create -f environment.yml")
- print("或手动安装: pip install " + " ".join(missing_packages))
- return False
-
- print("✓ 所有依赖项已安装")
- return True
- def check_database_config():
- """检查数据库配置"""
- print("\n🔍 检查数据库配置...")
-
- # 回到项目根目录查找配置文件
- project_root = Path(__file__).parent.parent.parent
- config_file = project_root / "config.env"
-
- if not config_file.exists():
- print("❌ 配置文件 config.env 不存在")
- return False
-
- # 检查基本配置项
- required_vars = ['DB_HOST', 'DB_PORT', 'DB_NAME', 'DB_USER', 'DB_PASSWORD']
-
- try:
- from dotenv import load_dotenv # type: ignore
- load_dotenv(str(config_file))
-
- for var in required_vars:
- if not os.getenv(var):
- print(f"❌ 环境变量 {var} 未设置")
- return False
-
- print("✓ 数据库配置正确")
- return True
-
- except Exception as e:
- print(f"❌ 检查数据库配置时出错: {e}")
- return False
- def start_server():
- """启动服务器"""
- print("\n🚀 启动FastAPI服务器...")
- print("服务器将在 http://localhost:8000 启动")
- print("API文档地址: http://localhost:8000/docs")
- print("按 Ctrl+C 停止服务器")
- print("-" * 50)
-
- try:
- # 切换到项目根目录
- project_root = Path(__file__).parent.parent.parent
- os.chdir(project_root)
-
- # 使用subprocess启动uvicorn
- process = subprocess.Popen([
- sys.executable, '-m', 'uvicorn',
- 'main:app',
- '--reload',
- '--host', '0.0.0.0',
- '--port', '8000'
- ])
-
- # 等待服务器启动
- time.sleep(3)
-
- return process
-
- except Exception as e:
- print(f"❌ 启动服务器失败: {e}")
- return None
- def run_tests():
- """运行测试"""
- print("\n🧪 运行集成测试...")
- print("-" * 50)
-
- try:
- # 等待服务器完全启动
- time.sleep(2)
-
- # 运行测试脚本
- project_root = Path(__file__).parent.parent.parent
- test_file = project_root / "tests" / "integration" / "test_unit_grouping.py"
-
- if test_file.exists():
- subprocess.run([sys.executable, str(test_file)])
- else:
- print(f"❌ 测试文件不存在: {test_file}")
-
- except Exception as e:
- print(f"❌ 运行测试失败: {e}")
- def show_usage_info():
- """显示使用信息"""
- print("\n📖 使用说明:")
- print("-" * 50)
- print("1. 主要API端点:")
- print(" GET /api/unit-grouping/h_xtfx - 获取所有单元的h_xtfx分类")
- print(" GET /api/unit-grouping/statistics - 获取统计信息")
- print(" GET /api/unit-grouping/unit/{unit_id} - 获取特定单元的h_xtfx值")
- print()
- print("2. 测试示例:")
- print(" curl http://localhost:8000/api/unit-grouping/h_xtfx")
- print(" curl http://localhost:8000/api/unit-grouping/statistics")
- print()
- print("3. 查看完整API文档:")
- print(" http://localhost:8000/docs")
- print()
- print("4. 重新运行测试:")
- print(" python tests/integration/test_unit_grouping.py")
- print()
- print("5. 查看功能文档:")
- print(" docs/features/unit-grouping/README.md")
- def main():
- """主函数"""
- print_banner()
-
- # 检查依赖项
- if not check_dependencies():
- return
-
- # 检查数据库配置
- if not check_database_config():
- return
-
- # 启动服务器
- server_process = start_server()
- if not server_process:
- return
-
- try:
- # 运行测试
- test_thread = threading.Thread(target=run_tests)
- test_thread.daemon = True
- test_thread.start()
-
- # 等待测试完成
- test_thread.join(timeout=30)
-
- # 显示使用信息
- show_usage_info()
-
- print("\n🎉 服务器正在运行中...")
- print("按 Ctrl+C 退出")
-
- # 等待用户中断
- server_process.wait()
-
- except KeyboardInterrupt:
- print("\n\n👋 正在关闭服务器...")
- server_process.terminate()
- server_process.wait()
- print("服务器已关闭")
-
- except Exception as e:
- print(f"\n❌ 运行时出错: {e}")
- if server_process:
- server_process.terminate()
- if __name__ == "__main__":
- main()
|