Skip to content

Update openai api #24

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def generate(config, **kwargs):
elif model_type in ('openai', 'openai-custom', 'gooseai', 'openai-chat'):
# TODO OpenAI errors
response, error = openAI_generate(model_type, **kwargs)
response = response.dict()
#save_response_json(response, 'examples/openAI_response.json')
formatted_response = format_openAI_response(response, kwargs['prompt'], echo=True)
#save_response_json(formatted_response, 'examples/openAI_formatted_response.json')
Expand Down Expand Up @@ -227,7 +228,7 @@ def format_openAI_response(response, prompt, echo=True):


@retry(n_tries=3, delay=1, backoff=2, on_failure=lambda *args, **kwargs: ("", None))
def openAI_generate(model_type, prompt, length=150, num_continuations=1, logprobs=10, temperature=0.8, top_p=1, stop=None,
def openAI_generate(model_type, prompt, length=150, num_continuations=1, logprobs=5, temperature=0.8, top_p=1, stop=None,
model='davinci', logit_bias=None, **kwargs):
if not logit_bias:
logit_bias = {}
Expand All @@ -240,29 +241,27 @@ def openAI_generate(model_type, prompt, length=150, num_continuations=1, logprob
'logit_bias': logit_bias,
'n': num_continuations,
'stop': stop,
'prompt': prompt,
#**kwargs
}
if model_type == 'openai-custom':
params['model'] = model
else:
params['engine'] = model


# hardcode to a non-chat model
if model_type == 'openai-chat':
params['messages'] = [{ 'role': "assistant", 'content': prompt }]
response = openai.ChatCompletion.create(
**params
)
else:
params['prompt'] = prompt
response = openai.Completion.create(
**params
)

params['model'] = 'davinci'

# always use completions instead of chatcompletions for logprobs
response = openai.completions.create(
**params
)
breakpoint()

return response, None


def search(query, documents, engine="curie"):
# https://help.openai.com/en/articles/6272952-search-transition-guide
# TODO use embeddings instead
# this function is never used anyway
return openai.Engine(engine).search(
documents=documents,
query=query
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ idna==2.10
jsonlines==2.0.0
kombu==5.0.2
multiprocess==0.70.11.1
openai>=0.27.2
openai>=1.1.1
pandas==1.3.3
pathos==0.2.7
pillow>=9.4.0
Expand Down
8 changes: 4 additions & 4 deletions util/gpt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def total_logprob(response):


def tokenize_ada(prompt):
response = openai.Completion.create(
response = openai.completions.create(
engine='ada',
prompt=prompt,
max_tokens=0,
Expand All @@ -41,7 +41,7 @@ def tokenize_ada(prompt):


def prompt_probs(prompt, engine='ada'):
response = openai.Completion.create(
response = openai.completions.create(
engine=engine,
prompt=prompt,
max_tokens=0,
Expand All @@ -57,7 +57,7 @@ def prompt_probs(prompt, engine='ada'):
# evaluates logL(prompt+target | prompt)
def conditional_logprob(prompt, target, engine='ada'):
combined = prompt + target
response = openai.Completion.create(
response = openai.completions.create(
engine=engine,
prompt=combined,
max_tokens=0,
Expand Down Expand Up @@ -135,7 +135,7 @@ def substring_probs(preprompt, content, target, engine='ada', quiet=0):
# returns a list of substrings of content
# logL(substring+target | substring) for each substring
def token_conditional_logprob(content, target, engine='ada'):
response = openai.Completion.create(
response = openai.completions.create(
engine=engine,
prompt=content,
max_tokens=0,
Expand Down
3 changes: 2 additions & 1 deletion util/multiverse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def generate(prompt, engine, goose=False):
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
#print('calling engine', engine, 'at endpoint', openai.api_base)
#print('prompt:', prompt)
response = openai.Completion.create(prompt=prompt,
response = openai.completions.create(prompt=prompt,
max_tokens=1,
n=1,
temperature=0,
logprobs=100,
model=engine)
response = response.dict()
return response

# TODO multiple "ground truth" trajectories
Expand Down