pytorch简单实现dropout

def dropout(X,drop_prob):
X=X.float()//将张量变成浮点数张量

assert 0<=drop_prob<=1//drop_prob不满足0-1则终止程序

keep_prob=1-drop_prob//对未丢弃的函数进行拉伸

if keep_prob==0:

  return torch.zeros_like(X)//返回和X大小相同的全0矩阵

mask=(torch.randn(X.shape)<keep_prob).float()//如果该矩阵的元素小于keep_prob的值返回Fasle大于返回True用float让布尔值变为浮点数

return mask*x/keep_prob//让这俩个矩阵进行点对点乘积,再除以keep_prob。

 

posted @ 2021-07-26 15:45  祥瑞哈哈哈  阅读(674)  评论(0编辑  收藏  举报