Skip to content

Commit 103b758

Browse files
committed
General DiffusionPipeline class, readme updated
1 parent 666e670 commit 103b758

File tree

8 files changed

+41
-12
lines changed

8 files changed

+41
-12
lines changed

README.md

+8-7
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ npm i @aislamov/diffusers.js
1212

1313
Browser (see examples/react)
1414
```js
15-
const pipe = StableDiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx')
15+
import { DiffusionPipeline } from '@aislamov/diffusers.js'
16+
17+
const pipe = DiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx')
1618
const images = pipe.run({
1719
prompt: "an astronaut running a horse",
1820
numInferenceSteps: 30,
19-
width: 512,
20-
height: 512,
2121
})
2222

2323
const canvas = document.getElementById('canvas')
@@ -27,12 +27,13 @@ canvas.getContext('2d').putImageData(data, 0, 0);
2727

2828
Node.js (see examples/node)
2929
```js
30-
const pipe = StableDiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx')
30+
import { DiffusionPipeline } from '@aislamov/diffusers.js'
31+
import { PNG } from 'pngjs'
32+
33+
const pipe = DiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx')
3134
const images = pipe.run({
3235
prompt: "an astronaut running a horse",
3336
numInferenceSteps: 30,
34-
width: 512,
35-
height: 512,
3637
})
3738

3839
const data = await images[0].mul(255).round().clipByValue(0, 255).transpose(0, 2, 3, 1)
@@ -46,7 +47,7 @@ p.pack().pipe(fs.createWriteStream('output.png')).on('finish', () => {
4647

4748
'aislamov/stable-diffusion-2-1-base-onnx' model is optimized for GPU and will fail to load without CUDA/DML/WebGPU support. Please use 'cpu' revision on a machine without GPU.
4849
```js
49-
const pipe = StableDiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx', 'cpu')
50+
const pipe = DiffusionPipeline.fromPretrained('aislamov/stable-diffusion-2-1-base-onnx', { revision: 'cpu' })
5051
```
5152

5253
## Running examples

examples/node/package-lock.json

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/node/src/txt2img.mjs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import minimist from 'minimist';
2-
import { StableDiffusionPipeline } from '@aislamov/diffusers.js'
2+
import { DiffusionPipeline } from '@aislamov/diffusers.js'
33
import fs from 'fs'
44
import { PNG } from 'pngjs'
55
import progress from 'cli-progress'
@@ -19,7 +19,7 @@ function parseCommandLineArgs() {
1919

2020
async function main() {
2121
const args = parseCommandLineArgs();
22-
const pipe = await StableDiffusionPipeline.fromPretrained(
22+
const pipe = await DiffusionPipeline.fromPretrained(
2323
args.m,
2424
{
2525
revision: args.rev,

examples/react/src/App.tsx

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import React, { useEffect, useRef, useState } from 'react'
22
import './App.css';
33
import {
4+
DiffusionPipeline,
45
ProgressCallback,
56
ProgressCallbackPayload,
67
setModelCacheDir,
@@ -69,7 +70,7 @@ function App() {
6970
}
7071
setModelState('loading')
7172
try {
72-
pipeline.current = await StableDiffusionPipeline.fromPretrained(
73+
pipeline.current = await DiffusionPipeline.fromPretrained(
7374
'aislamov/stable-diffusion-2-1-base-onnx',
7475
{
7576
progressCallback

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@aislamov/diffusers.js",
3-
"version": "0.9.1",
3+
"version": "0.9.2",
44
"license": "MIT",
55
"author": {
66
"name": "Arthur Islamov",

src/index-node.ts

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import * as ORT from '@aislamov/onnxruntime-web64'
66

77
export * from './pipelines/StableDiffusionPipeline'
88
export * from './pipelines/StableDiffusionXLPipeline'
9+
export * from './pipelines/DiffusionPipeline'
910
export * from './pipelines/common'
1011
export * from './hub'
1112
export { setModelCacheDir } from '@/hub/browser'

src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { setCacheImpl } from '@/hub'
44

55
export * from './pipelines/StableDiffusionPipeline'
66
export * from './pipelines/StableDiffusionXLPipeline'
7+
export * from './pipelines/DiffusionPipeline'
78
export * from './pipelines/common'
89
export * from './hub'
910
export { setModelCacheDir } from '@/hub/browser'

src/pipelines/DiffusionPipeline.ts

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { PretrainedOptions } from '@/pipelines/common'
2+
import { GetModelFileOptions } from '@/hub/common'
3+
import { getModelJSON } from '@/hub'
4+
import { StableDiffusionPipeline } from '@/pipelines/StableDiffusionPipeline'
5+
import { StableDiffusionXLPipeline } from '@/pipelines/StableDiffusionXLPipeline'
6+
7+
export class DiffusionPipeline {
8+
static async fromPretrained (modelRepoOrPath: string, options?: PretrainedOptions) {
9+
const opts: GetModelFileOptions = {
10+
...options,
11+
}
12+
13+
const index = await getModelJSON(modelRepoOrPath, 'model_index.json', true, opts)
14+
15+
switch (index['_class_name']) {
16+
case 'StableDiffusionPipeline':
17+
case 'OnnxStableDiffusionPipeline':
18+
return StableDiffusionPipeline.fromPretrained(modelRepoOrPath, options)
19+
case 'StableDiffusionXLPipeline':
20+
return StableDiffusionXLPipeline.fromPretrained(modelRepoOrPath, options)
21+
default:
22+
throw new Error(`Unknown pipeline type ${index['_class_name']}`)
23+
}
24+
}
25+
}

0 commit comments

Comments
 (0)