Skip to content

Commit

Permalink
🐛 fix batch response
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed May 23, 2024
1 parent faaede7 commit cb37924
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 10 deletions.
95 changes: 95 additions & 0 deletions examples/nodejs/main-api-batch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import "./ensure-fetch";

import fs from "fs";
import { A1111StableDiffusionApi } from "../../dist/main";

import cliProgress from "cli-progress";
import open from "open";

// progress bar
let timer: any = null;
const bar1 = new cliProgress.SingleBar({}, cliProgress.Presets.shades_classic);
bar1.start(100, 0);
const logProgress = (api: A1111StableDiffusionApi) => {
if (timer) {
clearInterval(timer);
}
timer = setInterval(async () => {
const { progress } = await api.Service.progress();
if (!progress) {
return;
}
bar1.update(Math.ceil(progress * 100));
}, 300);
};

// call generate api from client
export const generate1Girl_batch = async (api: A1111StableDiffusionApi) => {
const batch1 = await api.Service.txt2imgBatch(
{
prompt: `photorealistic, RAW photo, best quality, 1girl, fashion orange top, half body, pastel grey background, highly detailed face, cold light`,
negative_prompt: `fake, paintings, error, bad art, NG_DeepNegative_V1_75T,`,
sampler_name: "DPM++ SDE Karras",
width: 512,
height: 768,
steps: 20,
},
{
batchSize: 2,
numBatches: 5,
}
);
return batch1;
};

// save base64 image
async function saveBase64Image(
base64String: string,
outputFilename = `image_${Math.floor(Math.random() * 1000000)}`
) {
let extension = "png";
let data = base64String;
if (data.startsWith("data:image/png;base64,")) {
[, data] = data.split(",");
}
const fileName = `${outputFilename}.${extension}`;
const buffer = Buffer.from(data, "base64");
await fs.promises.writeFile(`./outputs/${fileName}`, buffer);
return fileName;
}

const main = async () => {
const api = new A1111StableDiffusionApi({
client: {
BASE: "http://127.0.0.1:7860",
},
});
logProgress(api);

console.time("batch image generate");
const batch1 = await generate1Girl_batch(api);

batch1
.waitForComplete()
.catch((err) => {
console.error(err);
return [];
})
.finally(() => {
bar1.update(100);
bar1.stop();
console.timeEnd("batch image generate");
process.exit();
});

batch1.on("batch_complete", async ({ images }) => {
for (const image of images) {
const fileName = await saveBase64Image(image);
console.log("Generated image done: ", fileName);
}
});
};

if (require.main === module) {
main();
}
79 changes: 79 additions & 0 deletions examples/nodejs/main-api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import "./ensure-fetch";

import fs from "fs";
import { A1111StableDiffusionApi } from "../../dist/main";

import cliProgress from "cli-progress";
import open from "open";

// progress bar
let timer: any = null;
const bar1 = new cliProgress.SingleBar({}, cliProgress.Presets.shades_classic);
bar1.start(100, 0);
const logProgress = (api: A1111StableDiffusionApi) => {
if (timer) {
clearInterval(timer);
}
timer = setInterval(async () => {
const { progress } = await api.Service.progress();
if (!progress) {
return;
}
bar1.update(Math.ceil(progress * 100));
}, 300);
};

// call generate api from client
export const generate1Girl = async (api: A1111StableDiffusionApi) => {
const { image } = await api.Service.txt2img({
prompt: `photorealistic, RAW photo, best quality, 1girl, fashion orange top, half body, pastel grey background, highly detailed face, cold light`,
negative_prompt: `fake, paintings, error, bad art, NG_DeepNegative_V1_75T,`,
sampler_name: "DPM++ SDE Karras",
width: 512,
height: 768,
steps: 20,
});
return image;
};

// save base64 image
async function saveBase64Image(
base64String: string,
outputFilename = `image_${Math.floor(Math.random() * 1000000)}`
) {
let extension = "png";
let data = base64String;
if (data.startsWith("data:image/png;base64,")) {
[, data] = data.split(",");
}
const fileName = `${outputFilename}.${extension}`;
const buffer = Buffer.from(data, "base64");
await fs.promises.writeFile(`./outputs/${fileName}`, buffer);
return fileName;
}

const main = async () => {
const api = new A1111StableDiffusionApi({
client: {
BASE: "http://127.0.0.1:7860",
},
});
logProgress(api);

console.time("image generate");
const image = await generate1Girl(api);
bar1.update(100);
bar1.stop();
console.timeEnd("image generate");
if (!image) {
throw new Error("Failed to generate image");
}
const fileName = await saveBase64Image(image);
console.log("Generated image done: ", fileName);
await open(fileName, { wait: true });
process.exit();
};

if (require.main === module) {
main();
}
2 changes: 1 addition & 1 deletion examples/nodejs/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async function saveBase64Image(
}
const fileName = `${outputFilename}.${extension}`;
const buffer = Buffer.from(data, "base64");
await fs.promises.writeFile(fileName, buffer);
await fs.promises.writeFile(`./outputs/${fileName}`, buffer);
return fileName;
}

Expand Down
Empty file.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@stable-canvas/sd-webui-a1111-client",
"version": "1.1.2",
"version": "1.1.3",
"description": "API client for AUTOMATIC1111/stable-diffusion-webui for Node.js and Browser.",
"source": "src/main.ts",
"main": "dist/main.umd.js",
Expand Down
10 changes: 6 additions & 4 deletions src/api/cnet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { GenerationResponseInfo } from "./service.types";

class Img2imgBatchGeneration extends BatchGeneration<
Img2imgProcessParams,
{ image: string; info: GenerationResponseInfo }
{ images: string[]; info: GenerationResponseInfo }
> {
constructor(
readonly api: ControlNetApi,
Expand All @@ -44,7 +44,7 @@ class Img2imgBatchGeneration extends BatchGeneration<

class Txt2imgBatchGeneration extends BatchGeneration<
Txt2imgProcessParams,
{ image: string; info: GenerationResponseInfo }
{ images: string[]; info: GenerationResponseInfo }
> {
constructor(
readonly api: ControlNetApi,
Expand Down Expand Up @@ -199,7 +199,7 @@ export class ControlNetApi extends CachedApi {
* @param {Object} options - The options for the text to image processing.
* @param {Txt2imgProcessParams} options.params - The parameters for the text to image processing.
* @param {ControlNetUnitRequest[]} options.units - The control net units for the text to image processing.
* @return {Promise<{ image: string, info: GenerationResponseInfo }>} The processed image and information.
* @return {Promise<{ image: string, images: string[], info: GenerationResponseInfo }>} The processed image and information.
* @throws {Error} If no image is returned from the server.
*/
async txt2img({
Expand Down Expand Up @@ -228,6 +228,7 @@ export class ControlNetApi extends CachedApi {
}
return {
image,
images: resp.images || [],
info: JSON.parse(resp.info) as GenerationResponseInfo,
};
}
Expand All @@ -238,7 +239,7 @@ export class ControlNetApi extends CachedApi {
* @param {Object} options - The options for the image to image processing.
* @param {Img2imgProcessParams} options.params - The parameters for the image to image processing.
* @param {ControlNetUnitRequest[]} options.units - The control net units for the image to image processing.
* @return {Promise<{ image: string, info: GenerationResponseInfo }>} The processed image and information.
* @return {Promise<{ image: string, images: string[], info: GenerationResponseInfo }>} The processed image and information.
* @throws {Error} If no image is returned from the server.
*/
async img2img({
Expand Down Expand Up @@ -267,6 +268,7 @@ export class ControlNetApi extends CachedApi {
}
return {
image,
images: resp.images || [],
info: JSON.parse(resp.info) as GenerationResponseInfo,
};
}
Expand Down
10 changes: 6 additions & 4 deletions src/api/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import EventEmitter from "eventemitter3";

class Img2imgBatchGeneration extends BatchGeneration<
StableDiffusionProcessingImg2Img,
{ image: string; info: GenerationResponseInfo }
{ images: string[]; info: GenerationResponseInfo }
> {
constructor(
readonly serviceApi: ServiceApi,
Expand All @@ -32,7 +32,7 @@ class Img2imgBatchGeneration extends BatchGeneration<

class Txt2imgBatchGeneration extends BatchGeneration<
StableDiffusionProcessingTxt2Img,
{ image: string; info: GenerationResponseInfo }
{ images: string[]; info: GenerationResponseInfo }
> {
constructor(
readonly serviceApi: ServiceApi,
Expand Down Expand Up @@ -193,7 +193,7 @@ export class ServiceApi extends CachedApi {
* Asynchronously sends an image to the server for processing and returns the processed image and information.
*
* @param {StableDiffusionProcessingImg2Img} requestBody - The image to be processed.
* @return {Promise<{ image: string, info: GenerationResponseInfo }>} The processed image and information.
* @return {Promise<{ image: string, images: string[], info: GenerationResponseInfo }>} The processed image and information.
*/
async img2img(requestBody: StableDiffusionProcessingImg2Img) {
const resp = await this.client.default.img2ImgapiSdapiV1Img2ImgPost({
Expand All @@ -205,6 +205,7 @@ export class ServiceApi extends CachedApi {
}
return {
image,
images: resp.images || [],
info: JSON.parse(resp.info) as GenerationResponseInfo,
};
}
Expand All @@ -213,7 +214,7 @@ export class ServiceApi extends CachedApi {
* Asynchronously sends a text to the server for processing and returns the processed image and information.
*
* @param {StableDiffusionProcessingTxt2Img} requestBody - The text to be processed.
* @return {Promise<{ image: string, info: GenerationResponseInfo }>} The processed image and information.
* @return {Promise<{ image: string, images: string[], info: GenerationResponseInfo }>} The processed image and information.
*/
async txt2img(requestBody: StableDiffusionProcessingTxt2Img) {
const resp = await this.client.default.text2ImgapiSdapiV1Txt2ImgPost({
Expand All @@ -225,6 +226,7 @@ export class ServiceApi extends CachedApi {
}
return {
image,
images: resp.images || [],
info: JSON.parse(resp.info) as GenerationResponseInfo,
};
}
Expand Down

0 comments on commit cb37924

Please # to comment.