Skip to content

Commit

Permalink
Merge pull request #22 from see--/master
Browse files Browse the repository at this point in the history
Add DeepLab
  • Loading branch information
shaqian authored Apr 7, 2019
2 parents 2a77224 + fcdf90f commit 9cee30d
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 51 deletions.
91 changes: 91 additions & 0 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
Expand All @@ -43,6 +44,7 @@
import java.util.PriorityQueue;
import java.util.Vector;


public class TflitePlugin implements MethodCallHandler {
private final Registrar mRegistrar;
private Interpreter tfLite;
Expand Down Expand Up @@ -144,6 +146,16 @@ public void onMethodCall(MethodCall call, Result result) {
catch (Exception e) {
result.error("Failed to run model" , e.getMessage(), e);
}
} else if (call.method.equals("runSegmentationOnImage")) {
try {
byte[] res = runSegmentationOnImage((HashMap) call.arguments);
result.success(res);
}
catch (Exception e) {
result.error("Failed to run model" , e.getMessage(), e);
}
} else {
result.error("Invalid method", call.method.toString(), "");
}
}

Expand Down Expand Up @@ -706,6 +718,85 @@ public int compare(Map<String, Object> lhs, Map<String, Object> rhs) {
return results;
}

private byte[] runSegmentationOnImage(HashMap args) throws IOException {
String path = args.get("path").toString();
double mean = (double)(args.get("imageMean"));
float IMAGE_MEAN = (float)mean;
double std = (double)(args.get("imageStd"));
float IMAGE_STD = (float)std;
List<Long> labelColors = (ArrayList)args.get("labelColors");

long startTime = SystemClock.uptimeMillis();
ByteBuffer input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD);
ByteBuffer output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes());
output.order(ByteOrder.nativeOrder());
tfLite.run(input, output);
Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime));

if (input.limit() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (output.position() != output.limit()) throw new RuntimeException("Unexpected output position");

output.flip();
Bitmap outputArgmax = fetchArgmax(output, labelColors);
return compressPNG(outputArgmax);
}


Bitmap fetchArgmax(ByteBuffer output, List<Long> labelColors) {
Tensor outputTensor = tfLite.getOutputTensor(0);
int outputBatchSize = outputTensor.shape()[0];
assert outputBatchSize == 1;
int outputHeight = outputTensor.shape()[1];
int outputWidth = outputTensor.shape()[2];
int outputChannels = outputTensor.shape()[3];

Bitmap outputArgmax = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);

if (outputTensor.dataType() == DataType.FLOAT32) {
for (int i = 0; i < outputHeight; ++i) {
for (int j = 0; j < outputWidth; ++j) {
int maxIndex = 0;
float maxValue = 0.0f;
for (int c = 0; c < outputChannels; ++c) {
float outputValue = output.getFloat();
if (outputValue > maxValue) {
maxIndex = c;
maxValue = outputValue;
}
}
int labelColor = labelColors.get(maxIndex).intValue();
outputArgmax.setPixel(j, i, labelColor);
}
}
} else {
for (int i = 0; i < outputHeight; ++i) {
for (int j = 0; j < outputWidth; ++j) {
int maxIndex = 0;
int maxValue = 0;
for (int c = 0; c < outputChannels; ++c) {
int outputValue = output.get();
if (outputValue > maxValue) {
maxIndex = c;
maxValue = outputValue;
}
}
int labelColor = labelColors.get(maxIndex).intValue();
outputArgmax.setPixel(j, i, labelColor);
}
}
}
return outputArgmax;
}

byte[] compressPNG(Bitmap bitmap) {
// https://stackoverflow.com/questions/4989182/converting-java-bitmap-to-byte-array#4989543
ByteArrayOutputStream stream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.PNG, 100, stream);
byte[] byteArray = stream.toByteArray();
// bitmap.recycle();
return byteArray;
}

private float expit(final float x) {
return (float) (1. / (1. + Math.exp(-x)));
}
Expand Down
Binary file added example/assets/deeplabv3_257_mv_gpu.tflite
Binary file not shown.
21 changes: 21 additions & 0 deletions example/assets/deeplabv3_257_mv_gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
background
aeroplane
biyclce
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
potted plant
sheep
sofa
train
tv-monitor
166 changes: 115 additions & 51 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void main() => runApp(new App());
const String mobile = "MobileNet";
const String ssd = "SSD MobileNet";
const String yolo = "Tiny YOLOv2";
const String deeplab = "DeepLab";

class App extends StatelessWidget {
@override
Expand All @@ -31,12 +32,18 @@ class MyApp extends StatefulWidget {
class _MyAppState extends State<MyApp> {
File _image;
List _recognitions;
String _model = "";
String _model = mobile;
double _imageHeight;
double _imageWidth;

Future getImage() async {
Future predictImagePicker() async {
var image = await ImagePicker.pickImage(source: ImageSource.gallery);
if (image == null) return;
predictImage(image);
}

Future predictImage(File image) async {
if (image == null) return;

switch (_model) {
case yolo:
Expand All @@ -45,6 +52,9 @@ class _MyAppState extends State<MyApp> {
case ssd:
ssdMobileNet(image);
break;
case deeplab:
segmentMobileNet(image);
break;
default:
recognizeImage(image);
// recognizeImageBinary(image);
Expand All @@ -67,6 +77,7 @@ class _MyAppState extends State<MyApp> {
@override
void initState() {
super.initState();
loadModel();
}

Future loadModel() async {
Expand All @@ -84,6 +95,11 @@ class _MyAppState extends State<MyApp> {
model: "assets/ssd_mobilenet.tflite",
labels: "assets/ssd_mobilenet.txt");
break;
case deeplab:
res = await Tflite.loadModel(
model: "assets/deeplabv3_257_mv_gpu.tflite",
labels: "assets/deeplabv3_257_mv_gpu.txt");
break;
default:
res = await Tflite.loadModel(
model: "assets/mobilenet_v1_1.0_224.tflite",
Expand Down Expand Up @@ -194,11 +210,25 @@ class _MyAppState extends State<MyApp> {
});
}

onSelect(model) {
Future segmentMobileNet(File image) async {
var recognitions = await Tflite.runSegmentationOnImage(
path: image.path,
imageMean: 127.5,
imageStd: 127.5,
);

setState(() {
_recognitions = recognitions;
});
}

onSelect(model) async {
setState(() {
_model = model;
_recognitions = null;
});
loadModel();
await loadModel();
predictImage(_image);
}

List<Widget> renderBoxes(Size screen) {
Expand All @@ -214,6 +244,7 @@ class _MyAppState extends State<MyApp> {
height: re["rect"]["h"] * factorY,
child: Container(
decoration: BoxDecoration(
borderRadius: BorderRadius.all(Radius.circular(8.0)),
border: Border.all(
color: blue,
width: 2,
Expand All @@ -235,58 +266,91 @@ class _MyAppState extends State<MyApp> {
@override
Widget build(BuildContext context) {
Size size = MediaQuery.of(context).size;
List<Widget> stackChildren = [];

if (_model == deeplab && _recognitions != null) {
stackChildren.add(Positioned(
top: 0.0,
left: 0.0,
width: size.width,
child: _image == null
? Text('No image selected.')
: Container(
decoration: BoxDecoration(
image: DecorationImage(
alignment: Alignment.topCenter,
image: MemoryImage(_recognitions),
fit: BoxFit.fill)),
child: Opacity(opacity: 0.3, child: Image.file(_image))),
));
} else {
stackChildren.add(Positioned(
top: 0.0,
left: 0.0,
width: size.width,
child: _image == null ? Text('No image selected.') : Image.file(_image),
));
}

if (_model == mobile) {
stackChildren.add(Center(
child: Column(
children: _recognitions != null
? _recognitions.map((res) {
return Text(
"${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
style: TextStyle(
color: Colors.black,
fontSize: 20.0,
background: Paint()..color = Colors.white,
),
);
}).toList()
: [],
),
));
} else if (_model == ssd || _model == yolo) {
stackChildren.addAll(renderBoxes(size));
}

return Scaffold(
appBar: AppBar(
title: const Text('tflite example app'),
),
body: _model == ""
? Center(
child: Column(
children: <Widget>[
RaisedButton(
child: const Text(mobile),
onPressed: () => onSelect(mobile),
),
RaisedButton(
child: const Text(ssd),
onPressed: () => onSelect(ssd),
),
RaisedButton(
child: const Text(yolo),
onPressed: () => onSelect(yolo),
),
],
),
)
: Stack(
children: <Widget>[
Container(
child: _image == null
? Text('No image selected.')
: Image.file(_image),
actions: <Widget>[
PopupMenuButton<String>(
onSelected: onSelect,
itemBuilder: (context) {
List<PopupMenuEntry<String>> menuEntries = [
const PopupMenuItem<String>(
child: Text(mobile),
value: mobile,
),
_model == mobile
? Center(
child: Column(
children: _recognitions != null
? _recognitions.map((res) {
return Text(
"${res["index"]} - ${res["label"]}: ${res["confidence"].toString()}",
style: TextStyle(
color: Colors.black,
fontSize: 20.0,
background: Paint()..color = Colors.white,
),
);
}).toList()
: [],
),
)
: Stack(children: renderBoxes(size)),
],
),
const PopupMenuItem<String>(
child: Text(ssd),
value: ssd,
),
const PopupMenuItem<String>(
child: Text(yolo),
value: yolo,
),
];

if (Platform.isAndroid) {
menuEntries.add(const PopupMenuItem<String>(
child: Text(deeplab),
value: deeplab,
));
}
return menuEntries;
},
)
],
),
body: Stack(
children: stackChildren,
),
floatingActionButton: FloatingActionButton(
onPressed: getImage,
onPressed: predictImagePicker,
tooltip: 'Pick Image',
child: Icon(Icons.image),
),
Expand Down
3 changes: 3 additions & 0 deletions example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ flutter:
- assets/yolov2_tiny.txt
- assets/ssd_mobilenet.tflite
- assets/ssd_mobilenet.txt
- assets/deeplabv3_257_mv_gpu.tflite
- assets/deeplabv3_257_mv_gpu.txt


# An image asset can refer to one or more resolution-specific "variants", see
# https://flutter.io/assets-and-images/#resolution-aware.
Expand Down
Loading

0 comments on commit 9cee30d

Please # to comment.