-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #116 from AnFreTh/main
Release v0.2.2
- Loading branch information
Showing
10 changed files
with
399 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Version information.""" | ||
|
||
# The following line *must* be the last in the module, exactly as formatted: | ||
__version__ = "0.2.0" | ||
__version__ = "0.2.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .topic_poster import movie_poster | ||
from .topic_story import story_topic | ||
from .topic_summary import topic_summaries | ||
|
||
|
||
__all__ = [ | ||
"movie_poster", | ||
"story_topic", | ||
"topic_summaries", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from openai import OpenAI, OpenAIError | ||
from loguru import logger | ||
import os | ||
from IPython.display import Image, display | ||
|
||
# Define allowed models | ||
ALLOWED_MODELS = ["dall-e-3", "dall-e-2"] | ||
|
||
|
||
def movie_poster( | ||
topic, | ||
api_key, | ||
model="dall-e-3", | ||
quality="standard", | ||
prompt="Create a movie poster that depicts that topic best. Note that the words are ordered in decreasing order of their importance.", | ||
size="1024x1024", | ||
return_style="url", | ||
): | ||
""" | ||
Generate a movie-poster-style image based on a given topic using OpenAI's DALL-E model. | ||
Parameters: | ||
- topic: List of words/phrases or list of tuples (word, importance) representing the topic. | ||
- api_key: API key for OpenAI. | ||
- model: Model to use (e.g., 'dall-e'). | ||
- poster_style: Description of the style for the image, default is "Movie Poster". | ||
- content: Initial system content. | ||
- prompt: Prompt for the image generation. | ||
- size: Size of the generated image, default is "1024x1024". | ||
Returns: | ||
- image_url: URL of the generated image. | ||
""" | ||
|
||
# Load the API key from environment if not provided | ||
if api_key is None: | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
if api_key is None: | ||
raise ValueError("API key is missing. Please provide an API key.") | ||
|
||
assert return_style in [ | ||
"url", | ||
"plot", | ||
], "Invalid return style. Please choose 'url' or 'plot'" | ||
|
||
# Initialize the OpenAI client with your API key | ||
client = OpenAI(api_key=api_key) | ||
|
||
# Validate model | ||
if model not in ALLOWED_MODELS: | ||
raise ValueError( | ||
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}." | ||
) | ||
|
||
# Create the prompt for the movie poster | ||
if isinstance(topic[0], tuple): | ||
# If the topic is a list of tuples with importance | ||
topic_description = ", ".join( | ||
[f"{word} (importance: {importance})" for word, importance in topic] | ||
) | ||
else: | ||
# If the topic is a list of words in descending importance | ||
topic_description = topic | ||
|
||
image_prompt = f"Given the following topic: {topic_description}. {prompt}" | ||
|
||
# Logging the operation | ||
logger.info(f"--- Generating image with model: {model} ---") | ||
response = client.images.generate( | ||
model=model, | ||
prompt=image_prompt, | ||
size=size, | ||
quality=quality, | ||
n=1, | ||
) | ||
|
||
# Ensure the response is valid | ||
if response: | ||
image_url = response.data[0].url | ||
else: | ||
image_url = "No image generated. Please try again." | ||
|
||
if return_style == "url": | ||
return image_url | ||
|
||
elif return_style == "plot": | ||
display(Image(url=image_url)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from openai import OpenAI, OpenAIError | ||
from loguru import logger | ||
import os | ||
|
||
|
||
# Define allowed models | ||
ALLOWED_MODELS = [ | ||
"gpt-3.5-turbo", | ||
"gpt-3.5-turbo-16k", | ||
"gpt-4", | ||
"gpt-4o", | ||
"gpt-4o-mini", | ||
"gpt-4-turbo", | ||
] | ||
|
||
|
||
def story_topic( | ||
topic, | ||
api_key, | ||
model="gpt-3.5-turbo-16k", | ||
content="You are a creative writer.", | ||
prompt="Create a creative story that includes the following words:", | ||
max_tokens=250, | ||
temperature=0.8, | ||
top_p=1.0, | ||
frequency_penalty=0.0, | ||
presence_penalty=0.0, | ||
): | ||
""" | ||
Generate a creative story using OpenAI's GPT model. | ||
Parameters: | ||
- topic: List of words or phrases to include in the story. | ||
- api_key: API key for OpenAI. | ||
- model: Model to use (e.g., 'gpt-3.5-turbo-16k'). | ||
- content: Initial system content. | ||
- prompt: Prompt for the story generation. | ||
- max_tokens: Maximum tokens for the response. | ||
- temperature: Creativity level for the model. | ||
- top_p: Nucleus sampling parameter. | ||
- frequency_penalty: Penalty for word frequency. | ||
- presence_penalty: Penalty for word presence. | ||
Returns: | ||
- story: Generated story as a string. | ||
""" | ||
|
||
# Load the API key from environment if not provided | ||
if api_key is None: | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
# Initialize the OpenAI client with your API key | ||
client = OpenAI(api_key=api_key) | ||
|
||
# Validate model | ||
if model not in ALLOWED_MODELS: | ||
raise ValueError( | ||
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}." | ||
) | ||
|
||
# Create the prompt | ||
prompt = f"{prompt}: {', '.join(topic)}. Make it as short as {max_tokens} words." | ||
|
||
# Logging the operation | ||
logger.info(f"--- Generating story with model: {model} ---") | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{"role": "system", "content": content}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
top_p=top_p, | ||
frequency_penalty=frequency_penalty, | ||
presence_penalty=presence_penalty, | ||
) | ||
|
||
# Ensure the response is valid | ||
if response and len(response.choices) > 0: | ||
story = response.choices[0].message.content | ||
else: | ||
story = "No story generated. Please try again." | ||
|
||
return story | ||
|
||
except OpenAIError as e: | ||
return f"An error occurred: {str(e)}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from openai import OpenAI, OpenAIError | ||
from loguru import logger | ||
import os | ||
|
||
ALLOWED_MODELS = [ | ||
"gpt-3.5-turbo", | ||
"gpt-3.5-turbo-16k", | ||
"gpt-4", | ||
"gpt-4o", | ||
"gpt-4o-mini", | ||
"gpt-4-turbo", | ||
] | ||
|
||
|
||
def topic_summaries( | ||
topics, | ||
api_key, | ||
model="gpt-3.5-turbo-16k", | ||
content="You are a creative writer.", | ||
prompt="Provide a 1-2 sentence summary for the following topic:", | ||
max_tokens=60, | ||
temperature=0.7, | ||
top_p=1.0, | ||
frequency_penalty=0.0, | ||
presence_penalty=0.0, | ||
): | ||
""" | ||
Generate a 1-2 sentence summary for each topic using OpenAI's GPT model. | ||
Parameters: | ||
- topics: List of lists, where each sublist contains words/phrases representing a topic. | ||
- api_key: API key for OpenAI. | ||
- model: Model to use (e.g., 'gpt-3.5-turbo-16k'). | ||
- content: Initial system content. | ||
- prompt: Prompt for the summary generation. | ||
- max_tokens: Maximum tokens for each summary. | ||
- temperature: Creativity level for the model. | ||
- top_p: Nucleus sampling parameter. | ||
- frequency_penalty: Penalty for word frequency. | ||
- presence_penalty: Penalty for word presence. | ||
Returns: | ||
- summaries: List of summaries corresponding to each topic. | ||
""" | ||
|
||
# Load the API key from environment if not provided | ||
if api_key is None: | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
# Initialize the OpenAI client with your API key | ||
client = OpenAI(api_key=api_key) | ||
|
||
# Validate model | ||
if model not in ALLOWED_MODELS: | ||
raise ValueError( | ||
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}." | ||
) | ||
|
||
summaries = [] | ||
|
||
for idx, topic in enumerate(topics): | ||
# Create the prompt for each topic | ||
topic_prompt = f"{prompt} {', '.join(topic)}." | ||
|
||
# Logging the operation | ||
logger.info(f"--- Generating summary for topic {idx} with model: {model} ---") | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{"role": "system", "content": content}, | ||
{"role": "user", "content": topic_prompt}, | ||
], | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
top_p=top_p, | ||
frequency_penalty=frequency_penalty, | ||
presence_penalty=presence_penalty, | ||
) | ||
|
||
# Ensure the response is valid | ||
if response and len(response.choices) > 0: | ||
summary = response.choices[0].message.content | ||
else: | ||
summary = "No summary generated. Please try again." | ||
|
||
summaries.append(summary) | ||
|
||
except OpenAIError as e: | ||
summaries.append(f"An error occurred: {str(e)}") | ||
|
||
return summaries |
Oops, something went wrong.