Skip to content

Files

Latest commit

 

History

History
309 lines (242 loc) · 12.3 KB

README.adoc

File metadata and controls

309 lines (242 loc) · 12.3 KB

Camel AI Example: TensorFlow Serving

Requires: Apache Camel 4.10.0 or later

Introduction

This directory contains a collection of small examples for showing what the Camel TensorFlow Serving component can do.

Preparation

Before you run this set of examples, you need a local TensorFlow Serving server running. The easiest way to do so is to use a Docker image. This example uses Bitnami’s TensorFlow Serving image, because it supports not only amd64 but also arm64 architectures so macOS users can try it easily.

From the root of this project, you can run the Docker image with the following command:

docker run --rm -t -p 8500:8500 -p 8501:8501 --name tf-serving \
    -v ./tensorflow-serving/models/:/models/ \
    -v ./tensorflow-serving/models/models.pbtxt:/bitnami/tensorflow-serving/conf/tensorflow-serving.conf \
    bitnami/tensorflow-serving &

Note that this command also mounts the models directory to the container, which includes the two pre-trained models: half_plus_two and mnist. The examples use the two models.

The examples

This set of examples shows you how you can interact with a TensorFlow Serving server using the Camel component.

Check model status

You can check the status of a model by specifying the model name (and optionally the version).

from("timer:model-status?repeatCount=1")
    .to("tensorflow-serving:model-status?modelName=mnist&modelVersion=1")
    .log("Status: ${body.getModelVersionStatus(0).state}");

To run this example, execute the following command from the project root:

camel run tensorflow-serving/model_status.java

You should get the following output in the console:

INFO 77956 --- [://model-status] model_status.java:16 : Status: AVAILABLE

Obtain model metadata

When working with models in TensorFlow Serving, it is always important to first understand the metadata of the model. This is because to call a model, you need to know the label names and data types of the inputs and outputs.

If you’ve enabled the RESTful API for the TensorFlow Serving server, you can just hit the REST endpoint to get the metadata of a model.

curl http://localhost:8501/v1/models/mnist/versions/1/metadata

But you can also get the metadata of a model within a Camel route.

from("timer:model-metadata?repeatCount=1")
    .to("tensorflow-serving:model-metadata?modelName=mnist&modelVersion=1")
    .log("Metadata: ${body.getMetadataOrThrow('signature_def')}");

To run this example, execute the following command from the project root:

camel run tensorflow-serving/model_metadata.java

You should get the following output in the console:

INFO 78150 --- [/model-metadata] model_metadata.java:16 : Metadata: type_url: "type.googleapis.com/tensorflow.serving.SignatureDefMap"
value: "\n\245\001\n\005serve\022\233\001\n?\n\fkeras_tensor\022/\n\024serve_keras_tensor:0\020\001\032\025\022\v\b\377\377\377\377\377\377\377\377\377\001\022\002\b\034\022\002\b\034\022<\n\boutput_0\0220\n\031StatefulPartitionedCall:0\020\001\032\021\022\v\b\377\377\377\377\377\377\377\377\377\001\022\002\b\n\032\032tensorflow/serving/predict\n>\n\025__saved_model_init_op\022%\022#\n\025__saved_model_init_op\022\n\n\004NoOp\032\002\030\001\n\273\001\n\017serving_default\022\247\001\nI\n\fkeras_tensor\0229\n\036serving_default_keras_tensor:0\020\001\032\025\022\v\b\377\377\377\377\377\377\377\377\377\001\022\002\b\034\022\002\b\034\022>\n\boutput_0\0222\n\033StatefulPartitionedCall_1:0\020\001\032\021\022\v\b\377\377\377\377\377\377\377\377\377\001\022\002\b\n\032\032tensorflow/serving/predict"

Run prediction with a model

Once you know the metadata of the model, you can run predictions on the model. Let’s try a MNIST saved model to recognise handwritten digits with a Camel route. The example route reads handwritten digit images from the data/mnist directory and sends them to the TensorFlow Serving server to recognise the numbers with the MNIST model.

Recognition of handwritten numbers using the MNIST model
Figure 1. Recognition of handwritten numbers using the MNIST model
public void configure() throws Exception {
    from("file:data/mnist?noop=true&recursive=true&include=.*\\.png")
        .process(this::toPredictRequest)
        .to("tensorflow-serving:predict?modelName=mnist&modelVersion=1")
        .process(this::argmax)
        .log("${headers.camelFileName} => ${body}");
}

void toPredictRequest(Exchange exchange) {
    byte[] body = exchange.getMessage().getBody(byte[].class);
    List<Float> data = preprocess(body);
    TensorProto inputs = TensorProto.newBuilder()
            .setDtype(DataType.DT_FLOAT)
            .setTensorShape(TensorShapeProto.newBuilder()
                    .addDim(Dim.newBuilder().setSize(28))
                    .addDim(Dim.newBuilder().setSize(28)))
            .addAllFloatVal(data)
            .build();
    PredictRequest request = PredictRequest.newBuilder()
            .putInputs("keras_tensor", inputs)
            .build();
    exchange.getMessage().setBody(request);
}

List<Float> preprocess(byte[] data) {
    try {
        BufferedImage image = ImageIO.read(new ByteArrayInputStream(data));
        int width = image.getWidth();
        int height = image.getHeight();
        if (width != 28 || height != 28) {
            throw new RuntimeCamelException("Image size must be 28x28");
        }
        List<Float> normalised = new ArrayList<>(width * height);
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = image.getRGB(x, y);
                normalised.add((rgb & 0xFF) / 255.0f);
            }
        }
        return normalised;
    } catch (IOException e) {
        throw new RuntimeCamelException("Error reading image", e);
    }
}

void argmax(Exchange exchange) {
    PredictResponse response = exchange.getMessage().getBody(PredictResponse.class);
    TensorProto tensor = response.getOutputsOrThrow("output_0");
    int result = IntStream.range(0, tensor.getFloatValCount())
            .reduce((max, i) -> tensor.getFloatVal(max) > tensor.getFloatVal(i) ? max : i)
            .orElseThrow();
    exchange.getMessage().setBody(result);
}
Tip
How to know the inputs and outputs of a model

As you can see from the example code, the most difficult part of invoking a TensorFlow Serving model is correctly constructing the input TensorProto object (the toPredictRequest(Exchange) method in the example). The key keras_tensor and the data type/shape passed to the inputs in the PredictRequest can be obtained by referring to the model metadata: signature_defserving_defaultinputs. Similarly, the key output_0 and the data type/shape of the outputs from the response (the argmax(Exchange) method in the example) can be obtained from the metadata: signature_defserving_defaultoutputs.

To run this example, execute the following command from the project root:

camel run tensorflow-serving/predict.java

You should get the following output in the console:

INFO 50429 --- [le://data/mnist] predict.java:39 : 9/62.png => 9
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 0/71.png => 0
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 7/60.png => 7
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 6/88.png => 6
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 1/14.png => 1
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 8/177.png => 8
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 4/48.png => 4
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 3/63.png => 3
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 2/77.png => 2
...
INFO 50429 --- [le://data/mnist] predict.java:39 : 5/59.png => 5

Classification

In addition to the generic Predict API, TensorFlow Serving provides two specialised inference APIs. One of them is the Classify API, which is dedicated to the classification problems. This API sends examples as the input data to a classification model and returns the labels and scores of the inferred classes.

The MNIST model used in the previous example does not provide a signature for the classification problem, so for demonstration purposes here we will instead use a test model included in the TensorFlow Serving repository: half_plus_two. This is a minimal model that simply divides the input value by two and adds two.

Half plus two
Figure 2. Half plus two
public void configure() throws Exception {
    from("timer:classify?repeatCount=1")
        .setBody(constant(createInput("x", 1.0f)))
        .to("tensorflow-serving:classify?modelName=half_plus_two&modelVersion=123&signatureName=classify_x_to_y")
        .log("Result: ${body.result}");
}

Input createInput(String key, float f) {
    Feature feature = Feature.newBuilder()
            .setFloatList(FloatList.newBuilder().addValue(f))
            .build();
    Features features = Features.newBuilder()
            .putFeature(key, feature)
            .build();
    Example example = Example.newBuilder()
            .setFeatures(features)
            .build();
    ExampleList exampleList = ExampleList.newBuilder()
            .addExamples(example)
            .build();
    return Input.newBuilder()
            .setExampleList(exampleList)
            .build();
}
Tip
You can get the signature name classify_x_to_y from the model metadata.

To run this example, execute the following command from the project root:

camel run tensorflow-serving/classify.java

You should get the following output in the console:

INFO 94792 --- [imer://classify] classify.java:31 : Result: classifications {
  classes {
    score: 2.5
  }
}

Regression

The other specialised inference API that TensorFlow Serving provides is the Regress API, which is dedicated to the regression problems. This API sends examples as the input data to a regression model and returns a regressed value per example.

We will use the half_plus_two model again for demonstration purposes.

Half plus two
Figure 3. Half plus two
from("timer:regress?repeatCount=1")
    .setBody(constant(Input.newBuilder()
        .setExampleList(ExampleList.newBuilder()
            .addExamples(Example.newBuilder()
                .setFeatures(Features.newBuilder()
                    .putFeature("x", Feature.newBuilder()
                        .setFloatList(FloatList.newBuilder().addValue(1.0f))
                        .build()))))
        .build()))
    .to("tensorflow-serving:regress?modelName=half_plus_two&modelVersion=123&signatureName=regress_x_to_y")
    .log("Result: ${body.result}");
Tip
You can get the signature name regress_x_to_y from the model metadata.

To run this example, execute the following command from the project root:

camel run tensorflow-serving/regress.java

You should get the following output in the console:

INFO 96520 --- [timer://regress] regress.java:31 : Result: regressions {
  value: 2.5
}

Export to a project

You can export these examples to a project (for example Quarkus) using:

cd tensorflow-serving
camel export --runtime quarkus --gav=org.apache.camel.example:tensorflow-serving:1.0-SNAPSHOT predict.java

Help and contributions

If you hit any problem using Camel or have some feedback, then please let us know.

We also love contributors, so get involved :-)

The Camel riders!