diff --git a/public/models/selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16.tflite b/public/models/selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16.tflite new file mode 100755 index 0000000..374c072 Binary files /dev/null and b/public/models/selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16.tflite differ diff --git a/src/App.tsx b/src/App.tsx index 2ef8f5c..0143b0b 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -31,7 +31,7 @@ function App() { ] = useState({ model: 'meet', backend: 'wasm', - inputResolution: '96p', + inputResolution: '160x96', pipeline: 'webgl2', }) const [ diff --git a/src/core/components/SegmentationConfigCard.tsx b/src/core/components/SegmentationConfigCard.tsx index b9e3a05..0c2cd5e 100644 --- a/src/core/components/SegmentationConfigCard.tsx +++ b/src/core/components/SegmentationConfigCard.tsx @@ -28,18 +28,35 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) { const model = event.target.value as SegmentationModel let backend = props.config.backend let inputResolution = props.config.inputResolution - if (model === 'meet') { - backend = 'wasm' - if (inputResolution === '360p') { - inputResolution = '144p' - } - } else if (model === 'bodyPix') { - backend = 'webgl' - inputResolution = '360p' - } let pipeline = props.config.pipeline - if (model === 'bodyPix' && pipeline === 'webgl2') { - pipeline = 'canvas2dCpu' + switch (model) { + case 'bodyPix': + backend = 'webgl' + inputResolution = '640x360' + pipeline = 'canvas2dCpu' + break + + case 'meet': + if ( + (backend !== 'wasm' && backend !== 'wasmSimd') || + (inputResolution !== '256x144' && inputResolution !== '160x96') + ) { + backend = props.isSIMDSupported ? 'wasmSimd' : 'wasm' + inputResolution = '160x96' + pipeline = 'webgl2' + } + break + + case 'mlkit': + if ( + (backend !== 'wasm' && backend !== 'wasmSimd') || + inputResolution !== '256x256' + ) { + backend = props.isSIMDSupported ? 'wasmSimd' : 'wasm' + inputResolution = '256x256' + pipeline = 'webgl2' + } + break } props.onChange({ ...props.config, @@ -86,6 +103,7 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) { onChange={handleModelChange} > Meet + ML Kit BodyPix @@ -122,17 +140,26 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) { value={props.config.inputResolution} onChange={handleInputResolutionChange} > - - 360p + + 640x360 + 256x256 + + - 144p + 256x144 - - 96p + + 160x96 diff --git a/src/core/helpers/segmentationHelper.ts b/src/core/helpers/segmentationHelper.ts index b2f1ae8..21a3ae7 100644 --- a/src/core/helpers/segmentationHelper.ts +++ b/src/core/helpers/segmentationHelper.ts @@ -1,13 +1,14 @@ -export type SegmentationModel = 'bodyPix' | 'meet' +export type SegmentationModel = 'bodyPix' | 'meet' | 'mlkit' export type SegmentationBackend = 'webgl' | 'wasm' | 'wasmSimd' -export type InputResolution = '360p' | '144p' | '96p' +export type InputResolution = '640x360' | '256x256' | '256x144' | '160x96' export const inputResolutions: { [resolution in InputResolution]: [number, number] } = { - '360p': [640, 360], - '144p': [256, 144], - '96p': [160, 96], + '640x360': [640, 360], + '256x256': [256, 256], + '256x144': [256, 144], + '160x96': [160, 96], } export type PipelineName = 'canvas2dCpu' | 'webgl2' @@ -18,3 +19,19 @@ export type SegmentationConfig = { inputResolution: InputResolution pipeline: PipelineName } + +export function getTFLiteModelFileName( + model: SegmentationModel, + inputResolution: InputResolution +) { + switch (model) { + case 'meet': + return inputResolution === '256x144' ? 'segm_full_v679' : 'segm_lite_v681' + + case 'mlkit': + return 'selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16' + + default: + throw new Error(`No TFLite file for this segmentation model: ${model}`) + } +} diff --git a/src/core/hooks/useTFLite.ts b/src/core/hooks/useTFLite.ts index 32869dc..03447a2 100644 --- a/src/core/hooks/useTFLite.ts +++ b/src/core/hooks/useTFLite.ts @@ -1,5 +1,8 @@ import { useEffect, useState } from 'react' -import { SegmentationConfig } from '../helpers/segmentationHelper' +import { + getTFLiteModelFileName, + SegmentationConfig, +} from '../helpers/segmentationHelper' declare function createTFLiteModule(): Promise declare function createTFLiteSIMDModule(): Promise @@ -40,12 +43,13 @@ function useTFLite(segmentationConfig: SegmentationConfig) { }, []) useEffect(() => { - async function loadMeetModel() { + async function loadTFLiteModel() { if ( !tflite || (isSIMDSupported && !tfliteSIMD) || (!isSIMDSupported && segmentationConfig.backend === 'wasmSimd') || - segmentationConfig.model !== 'meet' + (segmentationConfig.model !== 'meet' && + segmentationConfig.model !== 'mlkit') ) { return } @@ -61,11 +65,11 @@ function useTFLite(segmentationConfig: SegmentationConfig) { ) } - const modelFileName = - segmentationConfig.inputResolution === '144p' - ? 'segm_full_v679' - : 'segm_lite_v681' - console.log('Loading meet model:', modelFileName) + const modelFileName = getTFLiteModelFileName( + segmentationConfig.model, + segmentationConfig.inputResolution + ) + console.log('Loading tflite model:', modelFileName) const modelResponse = await fetch( `${process.env.PUBLIC_URL}/models/${modelFileName}.tflite` @@ -104,7 +108,7 @@ function useTFLite(segmentationConfig: SegmentationConfig) { setSelectedTFLite(newSelectedTFLite) } - loadMeetModel() + loadTFLiteModel() }, [ tflite, tfliteSIMD,