楚彦

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

  利用sklearn执行SVM分类时速度很慢,采用了多进程机制。

  一般多进程用于独立文件操作,各进程之间最好不通信。但此处,单幅影像SVM分类就很慢,只能添加多进程,由于不同进程之间不能共用一个变量(即使共用一个变量,还需要添加变量锁),故将单幅影像分为小幅,每小幅对应一个进程,每个进程对该小幅数据分类完成后,将处理结果输出到临时路径的临时文件中,最好再将临时文件按照顺序合成一个完整的分类结果。

 

def sub_process(in_data, fill_value, classify_model, temp_dir, filename, y):
"""

:param in_data:
:param fill_value:
:param classify_model:
:param temp_dir:
:param filename:
:param y:
:return:
"""
try:
nb_, nl_ ,ns_ = in_data.shape
# fill_value = np.nan

if np.isnan(fill_value):
unvalid_index = np.where(in_data != in_data)
else:
unvalid_index = np.where(in_data == fill_value)
nb_index = unvalid_index[0]
nl_index = unvalid_index[1]
ns_index = unvalid_index[2]
for i,j in zip(nl_index, ns_index):
in_data[:, i, j] = fill_value

temp_data = in_data[0,:,:]
if np.isnan(fill_value):
valid_index = np.where(temp_data == temp_data)
else:
valid_index = np.where(temp_data != fill_value)

# 获取有效特征数据
valid_in_data = []
for i in range(nb_):
in_data_ = in_data[i, :, :]
valid_in_data.append(in_data_[valid_index])
del in_data_

valid_in_data = np.array(valid_in_data)
print(y, "Start predicting")
prediction = classify_model.predict(np.transpose(valid_in_data))
print(y, "Finish predicting")

arr = np.zeros([nl_, ns_], dtype=np.byte)
arr[valid_index] = prediction.astype("float").astype("int8")
print(y, np.min(arr), np.max(arr))

# out_band.WriteArray(arr, 0, y)
# 多进程间不能共享被修改的变量(即使实现共享,还需要添加变量锁,降低效率)
# class_arr[y:(y+nl_), :] = arr

outfile = os.path.join(temp_dir, filename+"_"+str(y)+".tif")
driver = gdal.GetDriverByName("GTiff")
outds = driver.Create(outfile, ns_, nl_, 1, gdal.GDT_Byte, options=["COMPRESS=LZW"])
outband = outds.GetRasterBand(1)
outband.WriteArray(arr)
del arr, outband, outds

del prediction, valid_in_data, valid_index, unvalid_index, in_data

except Exception as error_msg:
print(str(error_msg))

def classify():
  ......
  pools = Pool(self.num_process)
  for y in range(0, nl, block_ysize):
   if y + block_ysize < nl:
  rows = block_ysize
  else:
   rows = nl - y

  in_data = in_ds.ReadAsArray(0, y, block_xsize, rows)

  pools.apply_async(sub_process, args=(in_data, self.fill_value, classify_model, self.temp_dir, output_filename, y))
  pools.close()
  pools.join()

  # 整合各个进程的处理结果
  tempfiles = glob(os.path.join(self.temp_dir, output_filename+"*.tif"))
  for tempfile_ in tempfiles:
  temp_filename = os.path.splitext(os.path.basename(tempfile_))[0]
  strs = temp_filename.split("_")
  y = int(strs[-1])
  temp_ds = gdal.Open(tempfile_)
  temp_data = temp_ds.ReadAsArray()
  out_band.WriteArray(temp_data, 0, y)
  del temp_ds, temp_data

  del out_band, out_ds, in_ds

 

 

posted on 2021-11-03 21:53  楚彦  阅读(381)  评论(0编辑  收藏  举报