Model output not looking as expected (tfjs/tflite) styletransfer model

Hey all, I’m trying to implement the style transfer from the TFHub: TensorFlow Hub using tf.js but I’m running into a little bit of an issue. I think I’ve got it all hooked up together but the resulting image is not looking very good. Any help would be truly appreciated!

Current result in the browser


import '@tensorflow/tfjs-backend-cpu';
import{ browser, image, tidy, Tensor, reshape} from '@tensorflow/tfjs-core';
import * as tflite from '@tensorflow/tfjs-tflite'

const predictionUrl = '';
const transferUrl = '';

export async function run() {
  const predictionModel = await tflite.loadTFLiteModel(predictionUrl);
  const transferMode = await tflite.loadTFLiteModel(transferUrl);
  const styleImg = await browser.fromPixelsAsync(document.querySelector('.styleImage'));
  const img = await browser.fromPixelsAsync(document.querySelector('.originalImage'));
  const resizedStyleImage = reshape(image.resizeBilinear(styleImg, [256, 256]), [1, 256, 256, 3]);
  const styleBottleNeck = await tidy(() => predictionModel.predict([resizedStyleImage])) as Tensor;
  const resizedImageTensor =  reshape(image.resizeBilinear(img, [384, 384], true), [1, 384, 384, 3]);
  const output = await tidy(() => transferMode.predict([resizedImageTensor, styleBottleNeck])) as Tensor;
  return browser.toPixels(reshape(image.resizeBilinear(output, [540, 960], true),  [540, 960, 3]), document.querySelector('.canvas'))

Hi, can you provide the result from TF using their colab?

This is what I get using the CoLab:

Is there maybe some preprocessing step I’m missing in my code?

In the hub example I see:

Convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]
1 Like

Yes, you need to normalize the input image tensor to [0, 1].

1 Like

That seems to have vastly improved the result. I still need to validate that the results match with the CoLab given the exact same images as I may be missing additional processing steps.

For folks who may run into similar issues, to normalize, I divided the image tensors by 255 like so:

const normalizedStyleImg= div(await browser.fromPixelsAsync(document.querySelector('.styleImage')), 255);
const normalizedContentImg = div(await browser.fromPixelsAsync(document.querySelector('.originalImage')), 255);

Do let us know if you end up putting the demo somewhere! Would love to share with the community if it is public :slight_smile: Just tag on social with #MadeWithTFJS so we can find it if it is!