Python小练习:裁减函数(Clip Function)
作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/
本文介绍两种数据裁剪方法,将原始数据裁剪到某一指定范围内。
1. clip_function_test.py
1 # -*- coding: utf-8 -*- 2 # Author:凯鲁嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Python小练习:裁减函数(Clip Function) 5 import torch 6 import numpy as np 7 import matplotlib.pyplot as plt 8 plt.rc('font',family='Times New Roman') 9 # 裁剪范围 10 LOG_STD_MAX = 2 11 LOG_STD_MIN = -10 12 def clip_function( 13 x: torch.Tensor, 14 bound_mode: str 15 ) -> torch.Tensor: 16 if bound_mode == "clamp": # 将x裁剪到[-10, 2] 17 # 大于2的统一设为2,小于-10的统一设为-10 18 x = torch.clamp(x, LOG_STD_MIN, LOG_STD_MAX) 19 elif bound_mode == "tanh": # 将x裁剪到[-10, 2] 20 scale = (LOG_STD_MAX-LOG_STD_MIN) / 2 # 6 21 x = (torch.tanh(x)+1) * scale + LOG_STD_MIN 22 # tanh:[-1, 1], torch.tanh()+1:[0, 2] 23 # (torch.tanh(x)+1) * scale:[0, 12] 24 # (torch.tanh(x)+1) * scale + LOG_STD_MIN:[-10, 2] 25 elif bound_mode == "no": 26 x = x 27 else: 28 raise NotImplementedError 29 return x 30 31 torch.manual_seed(0) 32 x = torch.randn(2, 3)*10 33 print('原始数据:\n', x) 34 35 str1 = 'clamp' 36 print('裁剪算子:', str1) 37 y = clip_function(x, str1) 38 print('裁剪后:\n', y) 39 40 str2 = 'tanh' 41 print('裁剪算子:', str2) 42 y = clip_function(x, str2) 43 print('裁剪后:\n', y) 44 45 # --------------------画图------------------------ 46 num = 1000 47 a = torch.randn(num)*10.0 48 a, _ = torch.sort(a) 49 b1 = clip_function(a, str1) 50 b2 = clip_function(a, str2) 51 # 手动设置横纵坐标范围 52 plt.xlim([0, num]) 53 plt.ylim([a.min(), a.max()]) 54 aa = np.arange(0, num) 55 plt.plot(aa, a, color = 'green', ls = '-', label = 'data') 56 plt.plot(aa, b1, color = 'red', ls = '-', label = str1) 57 plt.plot(aa, b2, color = 'blue', ls = '-', label = str2) 58 # 画2条不起眼的虚线 59 plt.plot([0, num], [LOG_STD_MIN, LOG_STD_MIN], color = 'gray', ls = '--', alpha = 0.3) 60 plt.plot([0, num], [LOG_STD_MAX, LOG_STD_MAX], color = 'gray', ls = '--', alpha = 0.3) 61 # 横纵坐标轴 62 plt.xlabel('x') 63 plt.ylabel('clip(x)') 64 plt.legend(loc = 2) 65 plt.tight_layout() 66 plt.savefig('Clip Function.png', bbox_inches='tight', dpi=500) 67 plt.show()
2. 结果
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Other/clip_function_test.py" 原始数据: tensor([[ 15.4100, -2.9343, -21.7879], [ 5.6843, -10.8452, -13.9860]]) 裁剪算子: clamp 裁剪后: tensor([[ 2.0000, -2.9343, -10.0000], [ 2.0000, -10.0000, -10.0000]]) 裁剪算子: tanh 裁剪后: tensor([[ 2.0000, -9.9662, -10.0000], [ 1.9999, -10.0000, -10.0000]]) Process finished with exit code 0