Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix .generate(input_ids=...) #485

Merged
merged 6 commits into from
Aug 30, 2023
Merged

Fix .generate(input_ids=...) #485

merged 6 commits into from
Aug 30, 2023

Conversation

borzunov
Copy link
Collaborator

@borzunov borzunov commented Aug 30, 2023

This PR fixes the following code (a popular way to run .generate()):

import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM

model_name = "Maykeye/TinyLlama-v0"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, add_bos_token=False)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)

inputs = tokenizer("A cat sat on", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=4)
tokenizer.decode(outputs[0])

@borzunov borzunov changed the title Fix .generate(input_ids=...) and .generate(inputs_embeds=...) Fix .generate(input_ids=...) Aug 30, 2023
@borzunov borzunov merged commit a26559f into main Aug 30, 2023
@borzunov borzunov deleted the input-ids branch August 30, 2023 02:59
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant