Skip to content

Commit

Permalink
Init ML Kit Selfie Segmentation config
Browse files Browse the repository at this point in the history
Experiments for #2
  • Loading branch information
Volcomix committed Apr 24, 2021
1 parent e9c4065 commit ca1a6ca
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 33 deletions.
Binary file not shown.
2 changes: 1 addition & 1 deletion src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function App() {
] = useState<SegmentationConfig>({
model: 'meet',
backend: 'wasm',
inputResolution: '96p',
inputResolution: '160x96',
pipeline: 'webgl2',
})
const [
Expand Down
63 changes: 45 additions & 18 deletions src/core/components/SegmentationConfigCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,35 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) {
const model = event.target.value as SegmentationModel
let backend = props.config.backend
let inputResolution = props.config.inputResolution
if (model === 'meet') {
backend = 'wasm'
if (inputResolution === '360p') {
inputResolution = '144p'
}
} else if (model === 'bodyPix') {
backend = 'webgl'
inputResolution = '360p'
}
let pipeline = props.config.pipeline
if (model === 'bodyPix' && pipeline === 'webgl2') {
pipeline = 'canvas2dCpu'
switch (model) {
case 'bodyPix':
backend = 'webgl'
inputResolution = '640x360'
pipeline = 'canvas2dCpu'
break

case 'meet':
if (
(backend !== 'wasm' && backend !== 'wasmSimd') ||
(inputResolution !== '256x144' && inputResolution !== '160x96')
) {
backend = props.isSIMDSupported ? 'wasmSimd' : 'wasm'
inputResolution = '160x96'
pipeline = 'webgl2'
}
break

case 'mlkit':
if (
(backend !== 'wasm' && backend !== 'wasmSimd') ||
inputResolution !== '256x256'
) {
backend = props.isSIMDSupported ? 'wasmSimd' : 'wasm'
inputResolution = '256x256'
pipeline = 'webgl2'
}
break
}
props.onChange({
...props.config,
Expand Down Expand Up @@ -86,6 +103,7 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) {
onChange={handleModelChange}
>
<MenuItem value="meet">Meet</MenuItem>
<MenuItem value="mlkit">ML Kit</MenuItem>
<MenuItem value="bodyPix">BodyPix</MenuItem>
</Select>
</FormControl>
Expand Down Expand Up @@ -122,17 +140,26 @@ function SegmentationConfigCard(props: SegmentationConfigCardProps) {
value={props.config.inputResolution}
onChange={handleInputResolutionChange}
>
<MenuItem value="360p" disabled={props.config.model === 'meet'}>
360p
<MenuItem
value="640x360"
disabled={props.config.model !== 'bodyPix'}
>
640x360
</MenuItem>
<MenuItem
value="144p"
disabled={props.config.model === 'bodyPix'}
value="256x256"
disabled={props.config.model !== 'mlkit'}
>
256x256
</MenuItem>
<MenuItem
value="256x144"
disabled={props.config.model !== 'meet'}
>
144p
256x144
</MenuItem>
<MenuItem value="96p" disabled={props.config.model === 'bodyPix'}>
96p
<MenuItem value="160x96" disabled={props.config.model !== 'meet'}>
160x96
</MenuItem>
</Select>
</FormControl>
Expand Down
27 changes: 22 additions & 5 deletions src/core/helpers/segmentationHelper.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
export type SegmentationModel = 'bodyPix' | 'meet'
export type SegmentationModel = 'bodyPix' | 'meet' | 'mlkit'
export type SegmentationBackend = 'webgl' | 'wasm' | 'wasmSimd'
export type InputResolution = '360p' | '144p' | '96p'
export type InputResolution = '640x360' | '256x256' | '256x144' | '160x96'

export const inputResolutions: {
[resolution in InputResolution]: [number, number]
} = {
'360p': [640, 360],
'144p': [256, 144],
'96p': [160, 96],
'640x360': [640, 360],
'256x256': [256, 256],
'256x144': [256, 144],
'160x96': [160, 96],
}

export type PipelineName = 'canvas2dCpu' | 'webgl2'
Expand All @@ -18,3 +19,19 @@ export type SegmentationConfig = {
inputResolution: InputResolution
pipeline: PipelineName
}

export function getTFLiteModelFileName(
model: SegmentationModel,
inputResolution: InputResolution
) {
switch (model) {
case 'meet':
return inputResolution === '256x144' ? 'segm_full_v679' : 'segm_lite_v681'

case 'mlkit':
return 'selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16'

default:
throw new Error(`No TFLite file for this segmentation model: ${model}`)
}
}
22 changes: 13 additions & 9 deletions src/core/hooks/useTFLite.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { useEffect, useState } from 'react'
import { SegmentationConfig } from '../helpers/segmentationHelper'
import {
getTFLiteModelFileName,
SegmentationConfig,
} from '../helpers/segmentationHelper'

declare function createTFLiteModule(): Promise<TFLite>
declare function createTFLiteSIMDModule(): Promise<TFLite>
Expand Down Expand Up @@ -40,12 +43,13 @@ function useTFLite(segmentationConfig: SegmentationConfig) {
}, [])

useEffect(() => {
async function loadMeetModel() {
async function loadTFLiteModel() {
if (
!tflite ||
(isSIMDSupported && !tfliteSIMD) ||
(!isSIMDSupported && segmentationConfig.backend === 'wasmSimd') ||
segmentationConfig.model !== 'meet'
(segmentationConfig.model !== 'meet' &&
segmentationConfig.model !== 'mlkit')
) {
return
}
Expand All @@ -61,11 +65,11 @@ function useTFLite(segmentationConfig: SegmentationConfig) {
)
}

const modelFileName =
segmentationConfig.inputResolution === '144p'
? 'segm_full_v679'
: 'segm_lite_v681'
console.log('Loading meet model:', modelFileName)
const modelFileName = getTFLiteModelFileName(
segmentationConfig.model,
segmentationConfig.inputResolution
)
console.log('Loading tflite model:', modelFileName)

const modelResponse = await fetch(
`${process.env.PUBLIC_URL}/models/${modelFileName}.tflite`
Expand Down Expand Up @@ -104,7 +108,7 @@ function useTFLite(segmentationConfig: SegmentationConfig) {
setSelectedTFLite(newSelectedTFLite)
}

loadMeetModel()
loadTFLiteModel()
}, [
tflite,
tfliteSIMD,
Expand Down

0 comments on commit ca1a6ca

Please # to comment.