-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
astria api #115
base: master
Are you sure you want to change the base?
astria api #115
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,11 @@ | |
AsyncBaseResponseType, | ||
) | ||
|
||
import requests | ||
|
||
def load_image(file_path): | ||
with open(file_path, "rb") as f: | ||
return f.read() | ||
|
||
class AstriaApi(ProviderInterface, ImageInterface): | ||
provider_name = "astria" | ||
|
@@ -26,30 +31,81 @@ def __init__(self, api_keys: Dict = {}) -> None: | |
self.headers = {"authorization": f"Bearer {self.api_key}"} | ||
|
||
def image__generation_fine_tuning__create_project_async__launch_job( | ||
self, | ||
name: str, | ||
description: str, | ||
files: List[str], | ||
files_url: List[str] = [], | ||
base_project_id: Optional[int] = None, | ||
self, | ||
title: str, | ||
class_name: str, | ||
files: List[str] = [], | ||
files_url: List[str] = [], | ||
base_tune_id: Optional[int] = None, | ||
) -> AsyncLaunchJobResponseType: | ||
raise NotImplementedError | ||
data = { | ||
"tune[title]": title, | ||
"tune[name]": class_name, | ||
"tune[base_tune_id]": base_tune_id, | ||
# "tune[callback]": 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' | ||
} | ||
for image in files: | ||
image_data = load_image(image) # Assuming image is a file path | ||
files.append(("tune[images][]", image_data)) | ||
for image_url in files_url: | ||
files.append(("tune[image_urls][]", image_url)) | ||
|
||
response = requests.post(f"{self.url}tunes", data=data, files=files, headers=self.headers) | ||
response.raise_for_status() | ||
return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. provider_job_id is a string not an integer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
def image__generation_fine_tuning__create_project_async__get_job_result( | ||
self, provider_job_id: str | ||
) -> AsyncBaseResponseType[GenerationFineTuningCreateProjectAsyncDataClass]: | ||
raise NotImplementedError | ||
response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers) | ||
response.raise_for_status() | ||
data = response.json() | ||
return AsyncBaseResponseType( | ||
status="succeeded" if data['trained_at'] else "pending", | ||
provider_job_id=provider_job_id, | ||
original_response=data, | ||
standardized_response=GenerationFineTuningCreateProjectAsyncDataClass( | ||
project_id=data["id"], | ||
name=data["name"], | ||
description=data["title"], | ||
), | ||
) | ||
|
||
def image__generation_fine_tuning__generate_image_async__launch_job( | ||
self, | ||
project_id: str, | ||
prompt: str, | ||
negative_prompt: Optional[str] = "", | ||
num_images: Optional[int] = 1, | ||
self, | ||
project_id: str, | ||
prompt: str, | ||
negative_prompt: Optional[str] = "", | ||
num_images: Optional[int] = 1, | ||
input_image: Optional[str] = None, | ||
) -> AsyncLaunchJobResponseType: | ||
raise NotImplementedError | ||
data = { | ||
'prompt[text]': prompt, | ||
'prompt[negative_prompt]': negative_prompt, | ||
'prompt[num_images]': num_images, | ||
'prompt[face_swap]': True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I want to create something else other than human return an error (try to create a dog) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added args |
||
'prompt[inpaint_faces]': False, | ||
'prompt[super_resolution]': True, | ||
'prompt[face_correct]': False, | ||
# 'prompt[callback]': 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' | ||
} | ||
files = [] | ||
if input_image: | ||
files.append((f"tune[prompts_attributes][{i}][input_image]", load_image(input_image))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i is not defined There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
response = requests.post(f"{self.url}/tunes/{project_id}", headers=self.headers, data=data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this link doesn't work use this link instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
response.raise_for_status() | ||
return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) | ||
|
||
def image__generation_fine_tuning__generate_image_async__get_job_result( | ||
self, provider_job_id: str | ||
) -> AsyncBaseResponseType[GenerationFineTuningGenerateImageAsyncDataClass]: | ||
raise NotImplementedError | ||
response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers) | ||
response.raise_for_status() | ||
data = response.json() | ||
return AsyncBaseResponseType( | ||
status="succeeded" if data['trained_at'] else "pending", | ||
provider_job_id=provider_job_id, | ||
original_response=data, | ||
standardized_response=GenerationFineTuningGenerateImageAsyncDataClass(**data), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. **data can't work because difference between names There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Append to files which create a problem because files list already exist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, and that's a good thing :)