记一则 np.nan (np.average, np.argmin) 导致的死循环

设计算法的时候发现有时候算法无法结束,算法采用随机数据

a = np.arange(6).reshape(2, 3)
fail_to_chosen = np.average(a[np.zeros(2, dtype=bool)], axis=0)

fail_to_chosen = array([nan, nan, nan]) 且没有报错(连 RuntimeError/RuntimeWarning 都没有),但是如果不使用参数 axis 则有报错。

接下来对用 fail_to_chosena 进行计算,当然结果都是 nan,接下来我又使用了 np.argmin, np.argmax 这两个函数。

np.argmax([np.nan, 0.])  # 0
np.argmin([np.nan, 0.])  # 0

虽然函数含义相反但是结果都是 np.nan 的位置,于是整个算法进入死循环。

  1. 预分类 a,得到正类和负类,意外得到空的负类 []
  2. [] 使用带参数 axisnp.average 得到 avg = [nan, ...]
  3. (a - avg) ** 2 = [nan] 使用 np.argmin
  4. 更新 a 的分类,显然全为负类;
  5. 得到空的正类,情况和空的负类相同。

于是数据的类别在正类和负类中震荡,算法死循环了。

posted @ 2023-01-22 23:11  Violeshnv  阅读(28)  评论(0编辑  收藏  举报