Background removal with TensorFlow.js*
Background removal is a technique used with video compositing to remove the background from the subject of a video. This article is part of a series of articles [1] on removing background from a webcam live-stream, making it suitable for use in web conferencing, for example.
In the previous article, I used a Intel® RealSense™ Depth Camera with the Chrome* Web browser to implement a working prototype. The article also explained what the depth frame (produced by the depth camera) looks like and how to decide if pixels along the edge belong to foreground or background, based on color similarity. Recall that the approach works only with a relatively uniform background. Another requirement is that the code must run locally on the client-side in a web browser.
This article takes a different approach: instead of using a depth camera, I will use a standard web camera and then use convolutional neural networks (CNN) to detect pixels belonging to persons and segment them out of the rest. Note how the approach is different – we were using distance, provided by the depth camera, and compared it to a defined threshold value to make a decision if color camera pixels were in foreground or background. Here, we will only show persons in foreground, regardless how far they are from the camera.
This scenario has the same requirement, that the code does not require server aid for computation, but must run locally on client-side in Web browser.
For simplicity, this blog post contains minimal working source code, with explanations on how to get it running. It does not describe handling of edge artefacts or performance optimization. Those topics are mentioned in the previous article. This blog post contains a concise introduction to CNN and segmentation is kept concise but with the quotations and external references so you can do additional research and reuse the approach.
Semantic Segmentation using CNN
I got the idea to implement using CNN from Google AI blog post on Mobile Real-time Video Segmentation. Their idea to speed up the prediction is excellent: to use the current frame’s computed mask along with color frame, as input when predicting the next frame’s mask. I wanted to try the same on a web page, using JavaScript, but I didn’t want to invest time in preparing a dataset and training the model on my own, at least not before verifying that TensorFlow.js segmentation performance, using existing pretrained models, is acceptable.
After some searching, I decided to use pretrained model MobileNetV2 + DeepLabV3 developed by Google Research for on-device semantic segmentation. It can segment twenty types of objects from an image but here we limit its use on person segmentation.
Semantic segmentation means that for every pixel, CNN predicts the type of object it belongs to. It is computationally more complex than object detection (classification) that predicts what is, and with what probability, in some area of the image. This also becomes evident later, when we visualize neural network model.
For the reference on DeepLabV3+, check the Google AI blog (and the references at the bottom of the page) about Semantic Image Segmentation with DeepLab in TensorFlow. The model is built on top of MobileNetV2 neural network infrastructure, which is a lightweight network structure designed to run on mobile clients.
Before starting with scripts and code, let’s download the frozen inference graph [2] mobilenetv2_coco_voc_trainaug from the set of pretrained models on TensorFlow DeepLab Model Zoo. The frozen graph file name is deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz
- save it and unpack it. The unpacked folder contains the frozen_inference_graph.pb
file - that’s the model we will visualize and convert to a form suitable for use on the web with TensorFlow.js.
tensorflowjs_converter
To use the downloaded model, we must convert it to TensorFlow.js web friendly format. The documentation for conversion script includes the following example:
tensorflowjs_converter \ --input_format=tf_frozen_model \ --output_node_names='MobilenetV1/Predictions/Reshape_1' \ /mobilenet/frozen_model.pb \ /mobilenet/web_model
While one could deduce the values of the required parameters, --output_node_names
doesn’t look obvious. The included example is related to MobilenetV1 and it shows one output node, with path MobilenetV1/Predictions/Reshape_1
.
Finding out DeepLab + MobileNetV2 input and output node names
To get the output node name(s) for MobileNetV2 + DeepLabV3 model, it is required to visualize the network structure, explore network nodes, and identify input and output node names. This is where we use the TensorBoard tool.
To install it to your development machine, follow the Install TensorFlow instructions. Note that installing TensorFlow via pip also installs TensorBoard.
To import the frozen model to TensorBoard, we use the script import_pb_to_tensorboard.py
. It is part of TensorFlow Github repository - clone the project or download it from Github. I ran the following command, in a previously unpacked folder with the frozen_inference_graph.pb
file:
python3 ~/tensorflow/tensorflow/python/tools/import_pb_to_tensorboard.py --model_dir frozen_inference_graph.pb --log_dir /tmp/tensorflow_logdir
The command printed this line:
Model Imported. Visualize by running: tensorboard --logdir=/tmp/tensorflow_logdir
Leading to the next command to run, with the output:
$ tensorboard --logdir=/tmp/tensorflow_logdir TensorBoard 1.11.0 at http://aleksandar-laptop:6006 (Press CTRL+C to quit)
Opening web browser and navigating to http://localhost:6006 displays the graph. Double-click the nodes to expand them. The expanded graph looks like this:
The input node, called ImageTensor, is at the bottom left area and the output node is at the top right. Zoomed and selected, the output node looks like this:
From this image, we can see that the path to the output node, which is the path required for the conversion script, is SemanticPredictions/(SemanticPredictions).
Running tensorflowjs_converter and quantization
Running the command with that value for the output node returns this error:
KeyError: "The name 'SemanticPredictions(SemanticPredictions)' refers to an Operation not in the graph."
After some trial and error with different combinations, I learned that the output node name is “SemanticPredictions”. So, the command I used to convert frozen model to to TensorFlow.js web friendly format is:
tensorflowjs_converter --input_format=tf_frozen_model --output_node_names="SemanticPredictions" --saved_model_tags=serve frozen_inference_graph.pb deeplabv3_mnv2_pascal_train_aug_web_model
The folder size is around 8.7MB when produced using the deeplabv3_mnv2_pascal_train_aug_web_model. Note how the original frozen_inference_graph.pb file was about the same size. The size can be further reduced by adding the --quantization_bytes 1
flag to the command as shown below:
tensorflowjs_converter --input_format=tf_frozen_model --output_node_names="SemanticPredictions" --saved_model_tags=serve --quantization_bytes 1 frozen_inference_graph.pb deeplabv3_mnv2_pascal_train_aug_web_model
The size of the folder after using quantization is 2.3MB. I have tested application behavior using both TensorFlow.js model folders, original and quantized, and didn’t notice a difference in behavior. However, with the quantized model, there is about four times less data to fetch by web page.
Application code
Now that we have a model file suitable for use on the web, let’s explore web application code that is using the model and identifying background in webcam capture.
As a disclaimer, I am presenting minimal working code here. In a follow up implementation, I will address performance optimization and will move rendering from 2D canvas to WebGL, because WebGL can efficiently handle border area artefacts. However, displaying the original form of mask, allows us to observe how precise the model segmentation is. The application, run on my laptop [3], looks like this:
All application code is in one file, called index.html
, which resides in the same folder as deeplabv3_mnv2_pascal_train_aug_web_model
. If you used the --quantization_bytes option
, the folder should look like this:
├── deeplabv3_mnv2_pascal_train_aug_web_model │ ├── [2.1M] group1-shard1of1 │ ├── [ 83k] tensorflowjs_model.pb │ └── [ 57k] weights_manifest.json └── [2.9k] index.html
index.html
starts with:
<html> <head> <body onload="onLoad()"> <div id="container"> <div id="show-background">Show background as magenta <input id="show-background-toggle" type="checkbox" checked> </div> <canvas id="canvas" width=640px height=480px></canvas> </div> </body> </head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.2"></script>
This is our application page - we render video to canvas sized 640x480 and there is a simple checkbox (see the screenshot below) overlayed to the canvas to toggle detected background being rendered as solid color magenta. This is in line with approach in the previous article and demo.
The last line specifies use of TensorFlow.js library script, version 0.13.2 from npm project mirror on cdn.jsdelivr.net.
When the page gets loaded, <body onload="onLoad()">
specifies the launch of the onLoad handler. All the application script resides within this method. It starts by loading the model, prepared in the previous section, using tf.loadFrozenModel, and prompts to start camera video capture using MediaDevices.getUserMedia.
<script> async function onLoad() { const MODEL_URL = 'deeplabv3_mnv2_pascal_train_aug_web_model/tensorflowjs_model.pb'; const WEIGHTS_URL = 'deeplabv3_mnv2_pascal_train_aug_web_model/weights_manifest.json'; // Model's input and output have width and height of 513. const TENSOR_EDGE = 513; const [model, stream] = await Promise.all([ tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL), navigator.mediaDevices.getUserMedia({video: {facingMode: 'user', frameRate: 30, width : 640, height:480}})]);
We use await Promise.all
to wait for both asynchronous operations, tf.loadFrozenModel and getUserMedia, to complete. The model takes RGBA images of size 513 x 513 as input, packed as tensor of 4 dimensions. Written with size of each dimension, the input tensor size is 1 x 513 x 513 x 4:
1 (batch number, 1 in our case) x 513 (video width) x 513 (video height) x 4 (4 color components in RGBA).
The output is a batch with one matrix of size 513 x 513, where each value in the matrix is an id of identified object, or 0, for pixels that do not belong to detected objects. From this output, we make a background mask of size 513 x 513.
It is very likely that your webcam doesn’t natively support capture of frame size 513 x 513, so you must do some scaling. In our case, we scaled from the chosen camera frame size (640 x 480) to model input (513 x 513) and then, from model output (513 x 513) to canvas size (640 x 480). This is why we need to prepare the following canvases, video, and buffers:
const video = document.createElement('video'); video.autoplay = true; video.width = video.height = TENSOR_EDGE; const ctx = document.getElementById("canvas").getContext("2d"); const videoCopy = ctx.canvas.cloneNode(false).getContext("2d"); const maskContext = document.createElement('canvas').getContext("2d"); maskContext.canvas.width = maskContext.canvas.height = TENSOR_EDGE; const img = maskContext.createImageData(TENSOR_EDGE, TENSOR_EDGE); let imgd = img.data; new Uint32Array(imgd.buffer).fill(0x00FFFF00);
Later in code, we set (video.srcObject = stream
) source of video for HTMLVideoElement to camera capture stream. The video element’s width and height are set to 513 - this is later fed to the model, using tf.fromPixels(video). Once the model computes the mask, we render both video frame and computed mask. Since the model execution takes some time, we need to save the video frame state just before the computation starts, to render it later, with the computed background mask. This is the purpose of videoCopy.
A word about this graphics setup, canvases, and buffers - WebGL 2.0 should perform better, handling the copies and canvases, all in one render pass, and we could also handle border artefacts, but I’ll leave that for a future article.
The computed mask is of size 513 x 513 and we must render it to the displayed canvas of size 640 x 480. To do that, mask canvas (maskContext.canvas) of size 513 x 513 receives the pixels of computed mask and later is drawn to displayed canvas using ctx.drawImage(maskContext.canvas,...). The last line initializes all the values in the mask buffer (imgd) to magenta color and, with alpha channel set to 0, transparent. Later in the code, we set the alpha to 255, if the pixel is part of person’s segment (data[i] == 15), to make the magenta color visible in the canvas:
const render = () => { videoCopy.drawImage(video, 0, 0, ctx.canvas.width, ctx.canvas.height); const out = tf.tidy(() => { return model.execute({'ImageTensor': tf.fromPixels(video).expandDims(0)}); }); const data = out.dataSync(); for (let i = 0; i < data.length; i++) { imgd[i * 4 + 3] = data[i] == 15 ? 0 : 255; } maskContext.putImageData(img, 0, 0); ctx.drawImage(videoCopy.canvas, 0, 0); if (document.getElementById("show-background-toggle").checked) ctx.drawImage(maskContext.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height); window.requestAnimationFrame(render); } video.oncanplay = render; video.srcObject = stream; } </script>
The performance sensitive part is model.execute - on my laptop, with powerful discrete GPU [4], it takes 70-90 ms.
Performance issues
Measured time is similar, both with Chrome* or Firefox* browser, on Ubuntu* 18.04 or Microsoft Windows* 10. When executing on a fully integrated but less powerful GPU [3], the required time increases to around 800ms. Since TensorFlow.js is designed as a WebGL accelerated library, I expected the best performance on a powerful GPU [4]. Additionally, even with powerful GPU, I noticed significant discrepancies in measured performance, depending on whether the browser runs on my main laptop screen or on an external screen attached to the HDMI port. I need to study this issue further.
During this time, the TensorFlow.js core performs the following tasks. After the video frame is uploaded to texture, processing is done on GPU under fenceSync. Once the sync is signaled as ready, data can be read back to CPU without blocking. A future performance optimization study could attempt to [5] pipeline the processing of next frame, as soon after issuing commands for the current frame, instead of waiting on the sync to be signaled as ready.
For local rendering, if we don’t need to read back the data from GPU but to also consume the result on GPU, we could use the same WebGL context used for computation as for rendering, and issue render call following the prediction execution. I plan to check if TensorFlow.js custom WebGL operations could be used for this purpose.
Conclusion and follow up work
Compared to the approach taken in my previous article, using TensorFlow.js we don’t need a uniform background and special depth sensing camera, but it works (works as in single frame processing takes less than 100ms) only on machines with powerful GPU [4].
The approach looks very promising and I’m eager to spend some time trying to optimize performances using the following approaches:
- Pipeline next frame before current frame result is computed and available for read back from GPU.
- Use the TensorFlow.js canvas also for rendering and avoid the need to read back data to CPU.
- Tweak hyperparameters, including depth multiplier of 0.5 and input resolution, to trade off precision vs. performance.
Footnotes
- The first prototype code and article uses Intel® RealSense™ Depth Cameras on Chrome* Web browser to compute distance of pixels from camera.
- Frozen graph means that it is a single, standalone GraphDef file, containing all of the content from trained GraphDef file and checkpoint files, with checkpoint files’ variables converted into const ops.
- Intel® Iris™ Pro 5200 on MacBook Pro (Retina, 15-inch, Mid 2014).
- NVIDIA® GeForce® GTX1070 on Asus ROG-GL702VS.
- There are four synchronous WebGL readPixel operations happening in every frame, stalling the pipeline inference, which prevents pipelining now. This requires further investigation.
Performance Disclaimer
Software and workloads used in performance tests may have been optimized for performance only on Intel microprocessors. Performance tests, such as SYSmark and MobileMark, are measured using specific computer systems, components, software, operations and functions. Any change to any of those factors may cause the results to vary. You should consult other information and performance tests to assist you in fully evaluating your contemplated purchases, including the performance of that product when combined with other products.
Configurations: Intel® Iris™ Pro 5200 on MacBook Pro (Retina, 15-inch, Mid 2014) and NVIDIA® GeForce® GTX1070 on Asus ROG-GL702VS.
For more information go to www.intel.com/benchmarks