CTranslate2 学习
通过CTranslate2 能够实现加速模型推理,降低transformer在内存中的使用。
CTranslate2 目前支持的模型:
Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, LLaMa
如何使用
pip install ctranslate2
translator = ctranslate2.Translator(translation_model_path)
translator.translate_batch(tokens)
generator = ctranslate2.Generator(generation_model_path)
generator.generate_batch(start_tokens)
BLOOM模型
ct2-transformers-converter --model bigscience/bloom-560m --output_dir bloom-560m
import ctranslate2
import transformers
generator = ctranslate2.Generator("bloom-560m")
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m")
text = "Hello, I am"
start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
results = generator.generate_batch([start_tokens], max_length=30, sampling_topk=10)
print(tokenizer.decode(results[0].sequences_ids[0]))