Skip to content

Commit

Permalink
Merge pull request #116 from AnFreTh/main
Browse files Browse the repository at this point in the history
Release v0.2.2
  • Loading branch information
AnFreTh authored Feb 17, 2025
2 parents bc3ddfb + 90dfedb commit 485e172
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 72 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ You can install STREAM directly from PyPI or from the GitHub repository:
pip install git+https://github.com/AnFreTh/STREAM.git
```

3. **Download necessary NLTK resources**:

To download all necessary NLTK resources required for some models, simply run:

```python
import nltk
def ensure_nltk_resources():
resources = [
"stopwords",
"wordnet",
"punkt_tab",
"brown",
"averaged_perceptron_tagger"
]
for resource in resources:
try:
nltk.data.find(resource)
except LookupError:
try:
print(f"Downloading NLTK resource: {resource}")
nltk.download(resource)
except Exception as e:
print(f"Failed to download {resource}: {e}")
ensure_nltk_resources()
```
3. **Install requirements for add-ons**:
To use STREAMS visualizations, simply run:
```bash
Expand Down
Binary file added assets/movie_poster_topic1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run(self):
"plotting": ["dash", "plotly", "matplotlib", "wordcloud"],
"bertopic": ["hdbscan"],
"dcte": ["pyarrow", "setfit"],
"experimental": ["openai"],
}


Expand Down
2 changes: 1 addition & 1 deletion stream_topic/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.2.0"
__version__ = "0.2.2"
10 changes: 10 additions & 0 deletions stream_topic/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .topic_poster import movie_poster
from .topic_story import story_topic
from .topic_summary import topic_summaries


__all__ = [
"movie_poster",
"story_topic",
"topic_summaries",
]
88 changes: 88 additions & 0 deletions stream_topic/experimental/topic_poster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from openai import OpenAI, OpenAIError
from loguru import logger
import os
from IPython.display import Image, display

# Define allowed models
ALLOWED_MODELS = ["dall-e-3", "dall-e-2"]


def movie_poster(
topic,
api_key,
model="dall-e-3",
quality="standard",
prompt="Create a movie poster that depicts that topic best. Note that the words are ordered in decreasing order of their importance.",
size="1024x1024",
return_style="url",
):
"""
Generate a movie-poster-style image based on a given topic using OpenAI's DALL-E model.
Parameters:
- topic: List of words/phrases or list of tuples (word, importance) representing the topic.
- api_key: API key for OpenAI.
- model: Model to use (e.g., 'dall-e').
- poster_style: Description of the style for the image, default is "Movie Poster".
- content: Initial system content.
- prompt: Prompt for the image generation.
- size: Size of the generated image, default is "1024x1024".
Returns:
- image_url: URL of the generated image.
"""

# Load the API key from environment if not provided
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")

if api_key is None:
raise ValueError("API key is missing. Please provide an API key.")

assert return_style in [
"url",
"plot",
], "Invalid return style. Please choose 'url' or 'plot'"

# Initialize the OpenAI client with your API key
client = OpenAI(api_key=api_key)

# Validate model
if model not in ALLOWED_MODELS:
raise ValueError(
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
)

# Create the prompt for the movie poster
if isinstance(topic[0], tuple):
# If the topic is a list of tuples with importance
topic_description = ", ".join(
[f"{word} (importance: {importance})" for word, importance in topic]
)
else:
# If the topic is a list of words in descending importance
topic_description = topic

image_prompt = f"Given the following topic: {topic_description}. {prompt}"

# Logging the operation
logger.info(f"--- Generating image with model: {model} ---")
response = client.images.generate(
model=model,
prompt=image_prompt,
size=size,
quality=quality,
n=1,
)

# Ensure the response is valid
if response:
image_url = response.data[0].url
else:
image_url = "No image generated. Please try again."

if return_style == "url":
return image_url

elif return_style == "plot":
display(Image(url=image_url))
90 changes: 90 additions & 0 deletions stream_topic/experimental/topic_story.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from openai import OpenAI, OpenAIError
from loguru import logger
import os


# Define allowed models
ALLOWED_MODELS = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
]


def story_topic(
topic,
api_key,
model="gpt-3.5-turbo-16k",
content="You are a creative writer.",
prompt="Create a creative story that includes the following words:",
max_tokens=250,
temperature=0.8,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
):
"""
Generate a creative story using OpenAI's GPT model.
Parameters:
- topic: List of words or phrases to include in the story.
- api_key: API key for OpenAI.
- model: Model to use (e.g., 'gpt-3.5-turbo-16k').
- content: Initial system content.
- prompt: Prompt for the story generation.
- max_tokens: Maximum tokens for the response.
- temperature: Creativity level for the model.
- top_p: Nucleus sampling parameter.
- frequency_penalty: Penalty for word frequency.
- presence_penalty: Penalty for word presence.
Returns:
- story: Generated story as a string.
"""

# Load the API key from environment if not provided
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")

# Initialize the OpenAI client with your API key
client = OpenAI(api_key=api_key)

# Validate model
if model not in ALLOWED_MODELS:
raise ValueError(
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
)

# Create the prompt
prompt = f"{prompt}: {', '.join(topic)}. Make it as short as {max_tokens} words."

# Logging the operation
logger.info(f"--- Generating story with model: {model} ---")

try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": content},
{"role": "user", "content": prompt},
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)

# Ensure the response is valid
if response and len(response.choices) > 0:
story = response.choices[0].message.content
else:
story = "No story generated. Please try again."

return story

except OpenAIError as e:
return f"An error occurred: {str(e)}"
93 changes: 93 additions & 0 deletions stream_topic/experimental/topic_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from openai import OpenAI, OpenAIError
from loguru import logger
import os

ALLOWED_MODELS = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
]


def topic_summaries(
topics,
api_key,
model="gpt-3.5-turbo-16k",
content="You are a creative writer.",
prompt="Provide a 1-2 sentence summary for the following topic:",
max_tokens=60,
temperature=0.7,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
):
"""
Generate a 1-2 sentence summary for each topic using OpenAI's GPT model.
Parameters:
- topics: List of lists, where each sublist contains words/phrases representing a topic.
- api_key: API key for OpenAI.
- model: Model to use (e.g., 'gpt-3.5-turbo-16k').
- content: Initial system content.
- prompt: Prompt for the summary generation.
- max_tokens: Maximum tokens for each summary.
- temperature: Creativity level for the model.
- top_p: Nucleus sampling parameter.
- frequency_penalty: Penalty for word frequency.
- presence_penalty: Penalty for word presence.
Returns:
- summaries: List of summaries corresponding to each topic.
"""

# Load the API key from environment if not provided
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")

# Initialize the OpenAI client with your API key
client = OpenAI(api_key=api_key)

# Validate model
if model not in ALLOWED_MODELS:
raise ValueError(
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
)

summaries = []

for idx, topic in enumerate(topics):
# Create the prompt for each topic
topic_prompt = f"{prompt} {', '.join(topic)}."

# Logging the operation
logger.info(f"--- Generating summary for topic {idx} with model: {model} ---")

try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": content},
{"role": "user", "content": topic_prompt},
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)

# Ensure the response is valid
if response and len(response.choices) > 0:
summary = response.choices[0].message.content
else:
summary = "No summary generated. Please try again."

summaries.append(summary)

except OpenAIError as e:
summaries.append(f"An error occurred: {str(e)}")

return summaries
Loading

0 comments on commit 485e172

Please # to comment.