:)关于torch函数中dim的解释-读这篇就够了-|
关于torch函数中dim的解释-读这篇就够了
1 dim的取值范围
1)-1的作用
0,1,2,-1. 其中-1 最后一维 即 2
0,1,2,3,-1其中-1 最后一维 即3
2)维度
0,1,2,3表示 BCHW,常在CV任务中使用。
0,1,2 表示 CHW, 常在NLP任务中使用。
3)用图来说明
2 NLP代码中实战dim
from torch import nn
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
check = "distilbert-base-uncased-finetuned-sst-2-english"
raw_eng = ["i like this video", "i hate the food", "I dont like the apple"]
tokenizers = AutoTokenizer.from_pretrained(check)
model = AutoModelForSequenceClassification.from_pretrained(check)
# 打印模型结构
print(model)
inputs = tokenizers(raw_eng,
# 是否pad
padding=True,
# 是否截断
truncation=True,
# 返回torch.tensor
return_tensors="pt")
print(inputs)
# 使用toknizers.decode来解tok id 为 英文
eng_content = tokenizers.decode([101, 1045, 2066, 2023, 2678, 102])
print(eng_content)
# 开始推理
out = model(**inputs)
print(out)
# 输出为[2,2] 前面2 为batchsize,后面2为2分类
print(out.logits.shape)
predictions = nn.functional.softmax(out.logits, dim=-1)
print(predictions)
label_dict = model.config.id2label
res_label = predictions.argmax(dim=-1).tolist()
for i in range(len(res_label)):
print(" the %s sample is " % i, label_dict[res_label[i]])
print("---end---")
out 输出为2,2
需要对第一行 两个数据求softmax,概率值(置信度)
需要对第二行(样本2) 两个数据求softmax。
所以 softmax函数dim 应该取CHW中w, 也就是2, 为了统一方便,取-1最后一维。