:)关于torch函数中dim的解释-读这篇就够了-|

关于torch函数中dim的解释-读这篇就够了

1 dim的取值范围

1)-1的作用

  0,1,2,-1.  其中-1 最后一维 即 2

  0,1,2,3,-1其中-1 最后一维 即3

2)维度

0,1,2,3表示 BCHW,常在CV任务中使用。

0,1,2 表示 CHW, 常在NLP任务中使用。

3)用图来说明

 

 2 NLP代码中实战dim

from torch import nn
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
check = "distilbert-base-uncased-finetuned-sst-2-english"
raw_eng = ["i like this video", "i hate the food", "I dont like the apple"]

tokenizers = AutoTokenizer.from_pretrained(check)
model = AutoModelForSequenceClassification.from_pretrained(check)
# 打印模型结构
print(model)
inputs = tokenizers(raw_eng,
# 是否pad
padding=True,
# 是否截断
truncation=True,
# 返回torch.tensor
return_tensors="pt")
print(inputs)
# 使用toknizers.decode来解tok id 为 英文
eng_content = tokenizers.decode([101, 1045, 2066, 2023, 2678, 102])
print(eng_content)

# 开始推理
out = model(**inputs)
print(out)
# 输出为[2,2] 前面2 为batchsize,后面2为2分类
print(out.logits.shape)

predictions = nn.functional.softmax(out.logits, dim=-1)
print(predictions)
label_dict = model.config.id2label
res_label = predictions.argmax(dim=-1).tolist()

for i in range(len(res_label)):
print(" the %s sample is " % i, label_dict[res_label[i]])
print("---end---")

  

out 输出为2,2 

 

需要对第一行 两个数据求softmax,概率值(置信度)

需要对第二行(样本2) 两个数据求softmax。

 

所以 softmax函数dim 应该取CHW中w, 也就是2, 为了统一方便,取-1最后一维。

 

posted on 2023-04-08 11:18  lexn  阅读(444)  评论(0编辑  收藏  举报

导航