Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

astria api #115

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions edenai_apis/apis/astria/astria_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
AsyncBaseResponseType,
)

import requests

def load_image(file_path):
with open(file_path, "rb") as f:
return f.read()

class AstriaApi(ProviderInterface, ImageInterface):
provider_name = "astria"
Expand All @@ -26,30 +31,81 @@ def __init__(self, api_keys: Dict = {}) -> None:
self.headers = {"authorization": f"Bearer {self.api_key}"}

def image__generation_fine_tuning__create_project_async__launch_job(
self,
name: str,
description: str,
files: List[str],
files_url: List[str] = [],
base_project_id: Optional[int] = None,
self,
title: str,
class_name: str,
files: List[str] = [],
files_url: List[str] = [],
base_tune_id: Optional[int] = None,
) -> AsyncLaunchJobResponseType:
raise NotImplementedError
data = {
"tune[title]": title,
"tune[name]": class_name,
"tune[base_tune_id]": base_tune_id,
# "tune[callback]": 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1'
}
for image in files:
image_data = load_image(image) # Assuming image is a file path
files.append(("tune[images][]", image_data))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Append to files which create a problem because files list already exist

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, and that's a good thing :)

for image_url in files_url:
files.append(("tune[image_urls][]", image_url))

response = requests.post(f"{self.url}tunes", data=data, files=files, headers=self.headers)
response.raise_for_status()
return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provider_job_id is a string not an integer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


def image__generation_fine_tuning__create_project_async__get_job_result(
self, provider_job_id: str
) -> AsyncBaseResponseType[GenerationFineTuningCreateProjectAsyncDataClass]:
raise NotImplementedError
response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers)
response.raise_for_status()
data = response.json()
return AsyncBaseResponseType(
status="succeeded" if data['trained_at'] else "pending",
provider_job_id=provider_job_id,
original_response=data,
standardized_response=GenerationFineTuningCreateProjectAsyncDataClass(
project_id=data["id"],
name=data["name"],
description=data["title"],
),
)

def image__generation_fine_tuning__generate_image_async__launch_job(
self,
project_id: str,
prompt: str,
negative_prompt: Optional[str] = "",
num_images: Optional[int] = 1,
self,
project_id: str,
prompt: str,
negative_prompt: Optional[str] = "",
num_images: Optional[int] = 1,
input_image: Optional[str] = None,
) -> AsyncLaunchJobResponseType:
raise NotImplementedError
data = {
'prompt[text]': prompt,
'prompt[negative_prompt]': negative_prompt,
'prompt[num_images]': num_images,
'prompt[face_swap]': True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I want to create something else other than human return an error (try to create a dog)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added args

'prompt[inpaint_faces]': False,
'prompt[super_resolution]': True,
'prompt[face_correct]': False,
# 'prompt[callback]': 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1'
}
files = []
if input_image:
files.append((f"tune[prompts_attributes][{i}][input_image]", load_image(input_image)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i is not defined

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


response = requests.post(f"{self.url}/tunes/{project_id}", headers=self.headers, data=data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this link doesn't work use this link instead
https://api.astria.ai/tunes/project_id/prompts

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

response.raise_for_status()
return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"])

def image__generation_fine_tuning__generate_image_async__get_job_result(
self, provider_job_id: str
) -> AsyncBaseResponseType[GenerationFineTuningGenerateImageAsyncDataClass]:
raise NotImplementedError
response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers)
response.raise_for_status()
data = response.json()
return AsyncBaseResponseType(
status="succeeded" if data['trained_at'] else "pending",
provider_job_id=provider_job_id,
original_response=data,
standardized_response=GenerationFineTuningGenerateImageAsyncDataClass(**data),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**data can't work because difference between names

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)