实现简单的straight-through estimator(STE)(pytorch版)
摘要:import torch input = torch.randn(4,requires_grad = True) output = torch.sign(input) loss = output.mean() loss.backward() print(input) print(input.grad
阅读全文
posted @ 2020-09-09 17:35