numpy中比较两个数字的断言函数

比如在比较torch模型输出和onnxruntime输出,

import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

————————————————————————————————————

 

import numpy as  np
 

断言函数

 

单元测试,单元测试是对一部分代码进行测试,可以提高代码质量,可重复性测试等.单元测试通常使用断言函数,在计算时,通常要考虑浮点数比较问题,numpy.testing包中包含很多实用的工具函数.

assert_almost_equal断言精度近似相等

#指定精度为小数点后七位
a = 0.123456789
b = 0.123456780
print(np.testing.assert_almost_equal(a,b,decimal=8))
#None 表示没有异常

#指定精度为小数点后九位
# print(np.testing.assert_almost_equal(a,b,decimal=9))

# AssertionError: 
# Arrays are not almost equal to 9 decimals
#  ACTUAL: 0.123456789
#  DESIRED: 0.12345678

c = 0.122
d = 0.121
print(np.testing.assert_almost_equal(c,d,decimal=3))
 
None
None
 

需要注意的是,如果在指定为数上数值相差1则仍然不会报错,如c和d所示.同样的道理,若指定a=0.123456789和b=0.123456788则指定decimal=9是不会出现异常的.

assert_approx_equal断言有效位近似相等

#指定有效位为8
a = 0.123456789
b = 0.123456788
print(np.testing.assert_approx_equal(a,b,significant=8))

#指定有效位为9
# print(np.testing.assert_approx_equal(a,b,significant=9))

# AssertionError: 
# Items are not equal to 9 significant digits:
#  ACTUAL: 0.123456789
#  DESIRED: 0.12345678
 
None
None
 

与assert_almost_equal类似,如果在指定为数上数值相差1则仍然不会报错.上面两个函数是精度和有效位的差别,但在实际使用中并没有差别。

assert_array_almost_equal数组近似比较

assert_array_almost_equal数组会首先比较维度,然后再比较数值。

# 精度为8
a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_array_almost_equal(a,b,decimal=8))

# 精度为9
# print(np.testing.assert_array_almost_equal(a,b,decimal=9))

# Arrays are not almost equal to 9 decimals

# Mismatched elements: 1 / 2 (50%)
# Max absolute difference: 9.e-09
# Max relative difference: 7.29000059e-08
#  x: array([0.         , 0.123456789])
#  y: array([0.        , 0.12345678])

c = np.array([0,0.123456780,0]) #三维
# print(np.testing.assert_array_almost_equal(a,c,decimal=8))

# AssertionError: 
# Arrays are not almost equal to 8 decimals

# (shapes (2,), (3,) mismatch)
#  x: array([0.        , 0.12345679])
#  y: array([0.        , 0.12345678, 0.        ])
 
None
 

assert_array_equal比较数组相等

严格比较数组的维度与元素值

a = np.array([0,0.123456789])
b = np.array([0,0.123456789])
print(np.testing.assert_array_equal(a,b))
 
None
 

assert_allclose比较数组相等

与assert_array_equal不同的是,该函数有atol(绝对容差限)、rtol参数(相对容差限)。比如对于数组a,b,则将测试是否满足
∣ a − b ∣ ≤ ( a t o l + r t o l ∗ ∣ b ∣ ) |a-b| \leq (atol+rtol*|b|) ab(atol+rtolb)

a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_allclose(a,b,rtol=1e-7,atol=0))
 
None
 

assert_array_less比较数组大小

assert_array_less(a,b)严格比较数组a是否小于b

a = np.array([0,0.1])
b = np.array([0.1,0.2])
print(np.testing.assert_array_less(a,b))
 
None
 

assert_equal比较对象相等

这里的对象可以是数组、列表、元组以及字典

# print(np.testing.assert_equal((1,2),(1,3)))  #出现异常

print(np.testing.assert_equal((1,2),(1,2)))

print(np.testing.assert_equal([1,2],[1,2]))

print(np.testing.assert_equal({'1':1,'2':2},{'1':1,'2':2}))
None
None
None
 

assert_string_equal 比较字符串相等

不仅比较字符,还比较大小写

print(np.testing.assert_string_equal('abc','abc'))

# print(np.testing.assert_string_equal('Abc','abc')) #出现异常
 

 

None

assert_array_almost_equal_nulp比较浮点数

机器精度(machine epsilon)是指浮点运算中的相对舍入误差上界。即机器允许在机器精度的范围下的误差。

#使用finfo函数确定机器精度
eps = np.finfo(float).eps
print(eps)

a = 1.0
b = a + eps #加上机器精度
c = a + 2*eps #加上2个机器精度 超出范围会出现异常
d = a + 1.4*eps
print(np.testing.assert_array_almost_equal_nulp(a,b))

# print(np.testing.assert_array_almost_equal_nulp(a,c))
# AssertionError: X and Y are not equal to 1 ULP (max is 2)

print(np.testing.assert_array_almost_equal_nulp(a,d))
 
2.220446049250313e-16
None
None
 

 

报错信息中的ULP(uint of Least Precision),指的是浮点数的最小精确度数。根据IEEE 754标准,四则运b算的标准必须保持在半个ULP内。在上面的c中,超过了的1个eps,实际测试中,不超过1.4倍的eps也是不会出现异常的。

assert_array_max_ulp多ULP浮点数比较

该函数可以通过maxulp参数设置多个ULP(默认为1)来增大浮点数比较的允许误差。如果两个浮点数的差距大于所设置或者默认的ULP,则函数assert_array_max_ulp会出现异常,若在误差范围内,则函数会返回两者所差的ULP个数(按第一个小数位四舍五入)。

a = 1.0
b = a + 2*eps
c = a + 1.499*eps
# print(np.testing.assert_array_max_ulp(a,b))
# AssertionError: Arrays are not almost equal up to 1 ULP

print(np.testing.assert_array_max_ulp(a,b,maxulp=2))

print(np.testing.assert_array_max_ulp(a,a))

print(np.testing.assert_array_max_ulp(a,c,maxulp=2))
 
2.0
0.0
1.0
 

单元测试

# python中的单元测试
#编写阶层函数
def function(n):
    if n==0:
        return 1
    elif n<0:
        raise ValueError("输入的值不合法")
    else:
        array = np.arange(1,n+1)
        return np.cumprod(array)[-1]
    

print(function(0))
print(function(9))
# function(-1)
# ValueError: 输入的值不合法
 
1
362880
#利用unittest模块进行单元测试
import unittest
import numpy as np
def function(n):
    if n==0:
        return 1
    elif n<0:
        raise ValueError("输入的值不合法")
    else:
        array = np.arange(1,n+1)
        return np.cumprod(array)[-1]
    

class FactoyiaTest(unittest.TestCase):
    """继承unittest.TestCase类"""
    def test_factorial(self):
        #计算3的阶层
        self.assertEqual(6,function(3))
        
    def test_zero(self):
        #计算0的阶层
        self.assertEqual(1,function(0))
    def test_negative(self):
        self.assertRaises(IndexError,function(-1))

if  __name__ =='__main__':
    unittest.main()
    
# .E.
# ======================================================================
# ERROR: test_negative (__main__.FactoyiaTest)
# ----------------------------------------------------------------------
# Traceback (most recent call last):
#   File "C:/Users/zhj/Desktop/untitled3.py", line 32, in test_negative
#     self.assertRaises(IndexError,function(-1))
#   File "C:/Users/zhj/Desktop/untitled3.py", line 16, in function
#     raise ValueError("输入的值不合法")
# ValueError: 输入的值不合法

# ----------------------------------------------------------------------
# Ran 3 tests in 0.010s

# FAILED (errors=1)
 

 

将上面的程序放在单独的.py文件,运行时提示错误提示,也就在使用function时出现了非法输入。

posted @ 2024-03-13 11:34  海_纳百川  阅读(64)  评论(0编辑  收藏  举报
本站总访问量