setup_data.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """
  2. 数据迁移脚本
  3. Data Migration Script
  4. 用于将现有的三个项目文件夹中的数据和模型文件复制到新的集成项目结构中
  5. """
  6. import os
  7. import shutil
  8. import logging
  9. from pathlib import Path
  10. def setup_logging():
  11. """设置日志"""
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format='%(asctime)s - %(levelname)s - %(message)s',
  15. handlers=[
  16. logging.FileHandler('setup_data.log', encoding='utf-8'),
  17. logging.StreamHandler()
  18. ]
  19. )
  20. def copy_file_safe(src, dst):
  21. """
  22. 安全复制文件
  23. @param src: 源文件路径
  24. @param dst: 目标文件路径
  25. """
  26. try:
  27. # 确保目标目录存在
  28. os.makedirs(os.path.dirname(dst), exist_ok=True)
  29. if os.path.exists(src):
  30. shutil.copy2(src, dst)
  31. logging.info(f"复制成功: {src} -> {dst}")
  32. return True
  33. else:
  34. logging.warning(f"源文件不存在: {src}")
  35. return False
  36. except Exception as e:
  37. logging.error(f"复制失败: {src} -> {dst}, 错误: {str(e)}")
  38. return False
  39. def setup_crop_cd_model():
  40. """设置作物Cd模型文件"""
  41. logging.info("=" * 50)
  42. logging.info("设置作物Cd模型文件...")
  43. # 源目录
  44. src_dir = "../作物Cd模型文件与数据/作物Cd模型文件与数据"
  45. # 目标目录
  46. model_files_dir = "models/crop_cd_model/model_files"
  47. data_dir = "models/crop_cd_model/data"
  48. # 复制模型文件
  49. model_files = [
  50. "cropCdNN.pth",
  51. "cropCd_mean.npy",
  52. "cropCd_scale.npy",
  53. "constrained_nn6.py"
  54. ]
  55. for file in model_files:
  56. src = os.path.join(src_dir, file)
  57. dst = os.path.join(model_files_dir, file)
  58. copy_file_safe(src, dst)
  59. # 复制数据文件
  60. data_files = [
  61. "areatest.csv"
  62. ]
  63. for file in data_files:
  64. src = os.path.join(src_dir, file)
  65. dst = os.path.join(data_dir, file)
  66. copy_file_safe(src, dst)
  67. # 复制坐标文件到共享数据目录
  68. coord_src = os.path.join(src_dir, "坐标.csv")
  69. coord_dst = "data/coordinates/坐标.csv"
  70. copy_file_safe(coord_src, coord_dst)
  71. def setup_effective_cd_model():
  72. """设置有效态Cd模型文件"""
  73. logging.info("=" * 50)
  74. logging.info("设置有效态Cd模型文件...")
  75. # 源目录
  76. src_dir = "../有效态Cd模型文件与数据/有效态Cd模型文件与数据"
  77. # 目标目录
  78. model_files_dir = "models/effective_cd_model/model_files"
  79. data_dir = "models/effective_cd_model/data"
  80. # 复制模型文件
  81. model_files = [
  82. "EffCdNN6C.pth",
  83. "EffCd_mean.npy",
  84. "EffCd_scale.npy",
  85. "constrained_nn6C.py"
  86. ]
  87. for file in model_files:
  88. src = os.path.join(src_dir, file)
  89. dst = os.path.join(model_files_dir, file)
  90. copy_file_safe(src, dst)
  91. # 复制数据文件
  92. data_files = [
  93. "areatest.csv"
  94. ]
  95. for file in data_files:
  96. src = os.path.join(src_dir, file)
  97. dst = os.path.join(data_dir, file)
  98. copy_file_safe(src, dst)
  99. def setup_irrigation_water_files():
  100. """设置灌溉水项目文件"""
  101. logging.info("=" * 50)
  102. logging.info("设置灌溉水项目文件...")
  103. # 源目录
  104. src_dir = "../Irrigation_Water/Irrigation_Water"
  105. # 复制栅格文件
  106. raster_files = [
  107. "Raster/meanTemp.tif",
  108. "Raster/lechang.shp",
  109. "Raster/lechang.shx",
  110. "Raster/lechang.dbf",
  111. "Raster/lechang.prj"
  112. ]
  113. for file in raster_files:
  114. src = os.path.join(src_dir, file)
  115. dst = os.path.join("output/raster", os.path.basename(file))
  116. copy_file_safe(src, dst)
  117. # 复制示例数据文件
  118. data_files = [
  119. "Data/Final_predictions.csv"
  120. ]
  121. for file in data_files:
  122. src = os.path.join(src_dir, file)
  123. dst = os.path.join("data/final", os.path.basename(file))
  124. copy_file_safe(src, dst)
  125. def create_directory_structure():
  126. """创建目录结构"""
  127. logging.info("=" * 50)
  128. logging.info("创建目录结构...")
  129. directories = [
  130. "models/crop_cd_model/model_files",
  131. "models/crop_cd_model/data",
  132. "models/effective_cd_model/model_files",
  133. "models/effective_cd_model/data",
  134. "data/coordinates",
  135. "data/predictions",
  136. "data/final",
  137. "output/raster",
  138. "output/figures",
  139. "output/reports",
  140. "analysis",
  141. "utils"
  142. ]
  143. for directory in directories:
  144. os.makedirs(directory, exist_ok=True)
  145. logging.info(f"创建目录: {directory}")
  146. def create_requirements_txt():
  147. """创建requirements.txt文件"""
  148. logging.info("=" * 50)
  149. logging.info("创建requirements.txt文件...")
  150. requirements = """# Cd预测集成系统依赖包
  151. numpy>=1.21.0
  152. pandas>=1.3.0
  153. torch>=1.9.0
  154. scikit-learn>=1.0.0
  155. geopandas>=0.10.0
  156. rasterio>=1.2.0
  157. matplotlib>=3.4.0
  158. seaborn>=0.11.0
  159. shapely>=1.7.0
  160. """
  161. with open("requirements.txt", "w", encoding="utf-8") as f:
  162. f.write(requirements)
  163. logging.info("requirements.txt 创建完成")
  164. def main():
  165. """主函数"""
  166. setup_logging()
  167. logging.info("开始设置Cd预测集成系统...")
  168. logging.info("=" * 60)
  169. try:
  170. # 创建目录结构
  171. create_directory_structure()
  172. # 设置各个模型的文件
  173. setup_crop_cd_model()
  174. setup_effective_cd_model()
  175. setup_irrigation_water_files()
  176. # 创建requirements.txt
  177. create_requirements_txt()
  178. logging.info("=" * 60)
  179. logging.info("🎉 Cd预测集成系统设置完成!")
  180. logging.info("=" * 60)
  181. print("\n" + "=" * 60)
  182. print("🎉 数据迁移完成!")
  183. print("=" * 60)
  184. print("下一步操作:")
  185. print("1. 安装依赖包:pip install -r requirements.txt")
  186. print("2. 运行主程序:python main.py")
  187. print("3. 查看输出结果在 output/ 目录中")
  188. print("=" * 60)
  189. except Exception as e:
  190. logging.error(f"设置过程中发生错误: {str(e)}")
  191. raise
  192. if __name__ == "__main__":
  193. main()