pytorch简单实现dropout
def dropout(X,drop_prob):
X=X.float()//将张量变成浮点数张量
assert 0<=drop_prob<=1//drop_prob不满足0-1则终止程序
keep_prob=1-drop_prob//对未丢弃的函数进行拉伸
if keep_prob==0:
return torch.zeros_like(X)//返回和X大小相同的全0矩阵
mask=(torch.randn(X.shape)<keep_prob).float()//如果该矩阵的元素小于keep_prob的值返回Fasle大于返回True用float让布尔值变为浮点数
return mask*x/keep_prob//让这俩个矩阵进行点对点乘积,再除以keep_prob。