CTranslate2 学习

通过CTranslate2 能够实现加速模型推理,降低transformer在内存中的使用。

CTranslate2 目前支持的模型:


Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper


GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, LLaMa

如何使用


pip install ctranslate2



translator = ctranslate2.Translator(translation_model_path)
translator.translate_batch(tokens)

generator = ctranslate2.Generator(generation_model_path)
generator.generate_batch(start_tokens)


BLOOM模型


ct2-transformers-converter --model bigscience/bloom-560m --output_dir bloom-560m



import ctranslate2
import transformers

generator = ctranslate2.Generator("bloom-560m")
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m")

text = "Hello, I am"
start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
results = generator.generate_batch([start_tokens], max_length=30, sampling_topk=10)
print(tokenizer.decode(results[0].sequences_ids[0]))