torch:针对mask掉的位置不进行softmax
错误方式
希望在进行softmax之前,如果对被mask掉的位置加上一个特别小的数字,那么softmax之后就会变成0。
pad_mask = (1 - doc_token_mask) * (-1999999) # 把原本0的位置变成一个特别小的数字
qk = qk + pad_mask # 加到原来的上面去
qk_softmax = torch.softmax(qk, dim=-1)
但是这样有两个问题:
- 在fp16的情况下,如果自己随便写的数字特别小,会发生inf
- 在计算梯度的时候,如果是加法,会影响梯度计算。
正确方式
qk = qk.masked_fill_(1-doc_token_mask, -float('inf')) # 把原本0的位置直接变成一个特别小的数字,而且 -float('inf')和精度无关
测试:
a=torch.tensor([1,2,3,4]).float()
mask=torch.tensor([1,1,0,0])
b=a.masked_fill_(1-mask, -float('inf')) # tensor([1., 2., -inf, -inf])
torch.softmax(b, dim=0) # tensor([0.2689, 0.7311, 0.0000, 0.0000])
ps: softmax与inf
>>> import torch
>>> import torch.nn.functional as F
>>> F.softmax(torch.Tensor([0, float('-inf')]), -1)
tensor([ 1.0000, 0.0000])
>>> F.softmax(torch.Tensor([0, float('inf')]), -1) # should give [0.0, 1.0]
tensor([ nan, nan])
>>> F.log_softmax(torch.Tensor([0, float('-inf')]), -1)
tensor([ 0.0000, -inf])
>>> F.log_softmax(torch.Tensor([0, float('inf')]), -1)
tensor([ nan, nan])
>>> F.softmax(torch.Tensor([float('-inf'), 0, float('-inf')]), -1)
tensor([ 0.0000, 1.0000, 0.0000])
>>> F.softmax(torch.Tensor([0, float('inf'), 0]), -1) # should give [0.0, 1.0, 0.0]
tensor([ nan, nan, nan])
>>> F.softmax(torch.Tensor([float('-inf'), 0, float('inf')]), -1) # should give [0.0, 0.0, 1.0]
tensor([ nan, nan, nan])