open_llama_7b_v2 本地运行尝鲜

open_llama_7b_v2 https://github.com/openlm-research/open_llama

auto 多卡时,显存共 906M+3870M+3870M+762M == 9408 M ,大概率是 tf cuda 驱动的问题,倒腾下就正常了 3296M+3822M+3822M+3296M == 14236M
cuda:2 单卡时,显存共 13266M

毕竟显存占用大约是参数量的两倍

Python 3.9.16
torch 2.0.1
transformers 4.39.1

import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

## v2 models
model_path = './'
device_map = 'cuda:2'   # 'auto' 'cuda:2'
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.float16, device_map=device_map,
)

prompt = 'Q: What is the largest animal?\nA:'
if device_map == 'auto':
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
else:
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device_map)
generation_output = model.generate(
    input_ids=input_ids, max_new_tokens=32
)
print(tokenizer.decode(generation_output[0]))

# import time   # for check memory usage
# time.sleep(10)
posted @ 2024-03-28 11:09  沙滩炒花蛤  阅读(15)  评论(0编辑  收藏  举报