What is one-hot?

libtorch 如何 OneHot ?
torch OneHot 源代码 ?


最新的 1.3 版本中已经添加了该函数

#include <torch/torch.h>
#include <c10/util/StringUtil.h>
torch::Tensor one_hot(const torch::Tensor &self, int64_t num_classes) {
	AT_CHECK(self.dtype() == torch::kLong, "one_hot is only applicable to index tensor.");
	auto shape = self.sizes().vec();

	// empty tensor could be converted to one hot representation,
	// but shape inference is not possible.
	if (self.numel() == 0) {
		if (num_classes <= 0) {
			AT_ERROR("Can not infer total number of classes from empty tensor.");
		else {
			return at::empty(shape, self.options());

	// non-empty tensor
	AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
	if (num_classes == -1) {
		num_classes = self.max().item().toLong() + 1;
	else {
		AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");

	torch::Tensor ret = at::zeros(shape, self.options());
	ret.scatter_(-1, self.unsqueeze(-1), 1);
	return ret;


	torch::TensorOptions options(torch::kLong);
	auto tensor = torch::tensor({ 0,1,2 }, options);
	std::cout << tensor << std::endl;

		auto one_hot = torch::one_hot(tensor,4);
		std::cout << one_hot << std::endl;
	catch (const c10::Error& watch)
		std::cout << watch.msg() << std::endl;
posted @ 2019-11-16 17:10  學海無涯  阅读(963)  评论(0编辑  收藏  举报